Skip to content
Snippets Groups Projects
Commit 89feb45e authored by anzhengqi's avatar anzhengqi
Browse files

modify networks

parent 263022a5
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ if config.device_target == "Ascend":
def run_export():
"""run export."""
if config.network_dataset == 'se-resnet50_imagenet2012':
from src.resnet import resnet50 as resnet
from src.resnet import se_resnet50 as resnet
elif config.network_dataset == 'se-resnet101_imagenet2012':
from src.resnet import resnet101 as resnet
else:
......
......@@ -31,7 +31,6 @@ args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if __name__ == '__main__':
context.set_context(device_id="Ascend")
net = U2NET()
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
......
......@@ -27,9 +27,6 @@ def swish(x):
return x * P.Sigmoid()(x)
ACT2FN = {"gelu": nn.GELU(), "relu": P.ReLU(), "swish": swish}
class Attention(nn.Cell):
"""Attention"""
def __init__(self, config):
......@@ -87,7 +84,7 @@ class Mlp(nn.Cell):
weight_init='XavierUniform', bias_init='Normal')
self.fc2 = nn.Dense(config.transformer_mlp_dim, config.hidden_size,
weight_init='XavierUniform', bias_init='Normal')
self.act_fn = ACT2FN["gelu"]
self.act_fn = nn.GELU()
self.dropout = nn.Dropout(config.transformer_dropout_rate)
def construct(self, x):
......
......@@ -50,7 +50,9 @@ then
fi
mkdir ./train
cp ../*.py ./train
cp ../*.yaml ./train
cp -r ../src ./train
cp -r ../model_utils ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
......@@ -66,4 +68,4 @@ python train.py \
--per_batch_size=32 \
--weight_decay=0.016 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..
\ No newline at end of file
cd ..
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment