Skip to content
Snippets Groups Projects
Commit 89feb45e authored by anzhengqi's avatar anzhengqi
Browse files

modify networks

parent 263022a5
No related branches found
No related tags found
No related merge requests found
...@@ -28,7 +28,7 @@ if config.device_target == "Ascend": ...@@ -28,7 +28,7 @@ if config.device_target == "Ascend":
def run_export(): def run_export():
"""run export.""" """run export."""
if config.network_dataset == 'se-resnet50_imagenet2012': if config.network_dataset == 'se-resnet50_imagenet2012':
from src.resnet import resnet50 as resnet from src.resnet import se_resnet50 as resnet
elif config.network_dataset == 'se-resnet101_imagenet2012': elif config.network_dataset == 'se-resnet101_imagenet2012':
from src.resnet import resnet101 as resnet from src.resnet import resnet101 as resnet
else: else:
......
...@@ -31,7 +31,6 @@ args = parser.parse_args() ...@@ -31,7 +31,6 @@ args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if __name__ == '__main__': if __name__ == '__main__':
context.set_context(device_id="Ascend")
net = U2NET() net = U2NET()
param_dict = load_checkpoint(args.ckpt_file) param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
......
...@@ -27,9 +27,6 @@ def swish(x): ...@@ -27,9 +27,6 @@ def swish(x):
return x * P.Sigmoid()(x) return x * P.Sigmoid()(x)
ACT2FN = {"gelu": nn.GELU(), "relu": P.ReLU(), "swish": swish}
class Attention(nn.Cell): class Attention(nn.Cell):
"""Attention""" """Attention"""
def __init__(self, config): def __init__(self, config):
...@@ -87,7 +84,7 @@ class Mlp(nn.Cell): ...@@ -87,7 +84,7 @@ class Mlp(nn.Cell):
weight_init='XavierUniform', bias_init='Normal') weight_init='XavierUniform', bias_init='Normal')
self.fc2 = nn.Dense(config.transformer_mlp_dim, config.hidden_size, self.fc2 = nn.Dense(config.transformer_mlp_dim, config.hidden_size,
weight_init='XavierUniform', bias_init='Normal') weight_init='XavierUniform', bias_init='Normal')
self.act_fn = ACT2FN["gelu"] self.act_fn = nn.GELU()
self.dropout = nn.Dropout(config.transformer_dropout_rate) self.dropout = nn.Dropout(config.transformer_dropout_rate)
def construct(self, x): def construct(self, x):
......
...@@ -50,7 +50,9 @@ then ...@@ -50,7 +50,9 @@ then
fi fi
mkdir ./train mkdir ./train
cp ../*.py ./train cp ../*.py ./train
cp ../*.yaml ./train
cp -r ../src ./train cp -r ../src ./train
cp -r ../model_utils ./train
cd ./train || exit cd ./train || exit
echo "start training for device $DEVICE_ID" echo "start training for device $DEVICE_ID"
env > env.log env > env.log
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment