diff --git a/research/cv/ICNet/src/models/icnet.py b/research/cv/ICNet/src/models/icnet.py index 333a54998acd0fe93eda76e3c566a682a3902c8a..9973816cac0784cf1d0712fa1805105d9f0c6839 100644 --- a/research/cv/ICNet/src/models/icnet.py +++ b/research/cv/ICNet/src/models/icnet.py @@ -16,15 +16,11 @@ import mindspore as ms import mindspore.nn as nn import mindspore.ops as ops -from mindspore import context from src.loss import ICNetLoss from src.models.resnet50_v1 import get_resnet50v1b __all__ = ['ICNet'] -context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') - - class ICNet(nn.Cell): """Image Cascade Network""" diff --git a/research/cv/ICNet/src/models/icnet_dc.py b/research/cv/ICNet/src/models/icnet_dc.py index 6476850fbce668371b9f88d0de8261d7deb119d5..0ecd90ec82f7b7fbd92191a742648fb368a22ade 100644 --- a/research/cv/ICNet/src/models/icnet_dc.py +++ b/research/cv/ICNet/src/models/icnet_dc.py @@ -16,15 +16,11 @@ import mindspore as ms import mindspore.nn as nn import mindspore.ops as ops -from mindspore import context from src.loss import ICNetLoss from src.models.resnet50_v1 import get_resnet50v1b __all__ = ['ICNetdc'] -context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') - - class ICNetdc(nn.Cell): """Image Cascade Network""" diff --git a/research/cv/ICNet/src/models/resnet50_v1.py b/research/cv/ICNet/src/models/resnet50_v1.py index 19d248be3ab3674f548ca1a703b6b16db960ad9d..b2a934fe75475f1b77cd1ad771777360b01b9153 100644 --- a/research/cv/ICNet/src/models/resnet50_v1.py +++ b/research/cv/ICNet/src/models/resnet50_v1.py @@ -279,7 +279,7 @@ def get_resnet50v1b(class_num=1001, ckpt_root='', pretrained=True): strides=[1, 2, 2, 2], num_classes=class_num) - if pretrained: + if pretrained and model.training: pretrained_ckpt = ckpt_root param_dict = load_checkpoint(pretrained_ckpt) load_param_into_net(model, param_dict)