diff --git a/research/cv/ICNet/src/models/icnet.py b/research/cv/ICNet/src/models/icnet.py
index 333a54998acd0fe93eda76e3c566a682a3902c8a..9973816cac0784cf1d0712fa1805105d9f0c6839 100644
--- a/research/cv/ICNet/src/models/icnet.py
+++ b/research/cv/ICNet/src/models/icnet.py
@@ -16,15 +16,11 @@
 import mindspore as ms
 import mindspore.nn as nn
 import mindspore.ops as ops
-from mindspore import context
 from src.loss import ICNetLoss
 from src.models.resnet50_v1 import get_resnet50v1b
 
 __all__ = ['ICNet']
 
-context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
-
-
 class ICNet(nn.Cell):
     """Image Cascade Network"""
 
diff --git a/research/cv/ICNet/src/models/icnet_dc.py b/research/cv/ICNet/src/models/icnet_dc.py
index 6476850fbce668371b9f88d0de8261d7deb119d5..0ecd90ec82f7b7fbd92191a742648fb368a22ade 100644
--- a/research/cv/ICNet/src/models/icnet_dc.py
+++ b/research/cv/ICNet/src/models/icnet_dc.py
@@ -16,15 +16,11 @@
 import mindspore as ms
 import mindspore.nn as nn
 import mindspore.ops as ops
-from mindspore import context
 from src.loss import ICNetLoss
 from src.models.resnet50_v1 import get_resnet50v1b
 
 __all__ = ['ICNetdc']
 
-context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
-
-
 class ICNetdc(nn.Cell):
     """Image Cascade Network"""
 
diff --git a/research/cv/ICNet/src/models/resnet50_v1.py b/research/cv/ICNet/src/models/resnet50_v1.py
index 19d248be3ab3674f548ca1a703b6b16db960ad9d..b2a934fe75475f1b77cd1ad771777360b01b9153 100644
--- a/research/cv/ICNet/src/models/resnet50_v1.py
+++ b/research/cv/ICNet/src/models/resnet50_v1.py
@@ -279,7 +279,7 @@ def get_resnet50v1b(class_num=1001, ckpt_root='', pretrained=True):
                         strides=[1, 2, 2, 2],
                         num_classes=class_num)
 
-    if pretrained:
+    if pretrained and model.training:
         pretrained_ckpt = ckpt_root
         param_dict = load_checkpoint(pretrained_ckpt)
         load_param_into_net(model, param_dict)