diff --git a/official/cv/resnet/infer.py b/official/cv/resnet/infer.py
index ac81d18d98da4d1d0815ca6fc73ebd4d6ab6424f..6095ebed26343d29951ecb35a72dbc51a40016dd 100644
--- a/official/cv/resnet/infer.py
+++ b/official/cv/resnet/infer.py
@@ -55,7 +55,7 @@ def infer_net():
     target = config.device_target
 
     # init context
-    ms.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
+    ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False)
     if target == "Ascend":
         device_id = int(os.getenv('DEVICE_ID'))
         ms.set_context(device_id=device_id)
diff --git a/official/cv/resnet/src/resnet_gpu_benchmark.py b/official/cv/resnet/src/resnet_gpu_benchmark.py
index 9eeeb9a74b72e694683f04a2de0d373bb90b5390..67ec6ffa67598dee896fbf1d59a204d97ab85f6f 100644
--- a/official/cv/resnet/src/resnet_gpu_benchmark.py
+++ b/official/cv/resnet/src/resnet_gpu_benchmark.py
@@ -15,6 +15,7 @@
 """ResNet."""
 import numpy as np
 from scipy.stats import truncnorm
+import mindspore as ms
 import mindspore.nn as nn
 from mindspore.ops import operations as P
 from mindspore.common.tensor import Tensor
diff --git a/official/cv/resnet/train.py b/official/cv/resnet/train.py
index 71e1184f224b2b97e4e8cf2eb45f01fa4b89d0c8..3f2790077d3dfcf6d5f9eec0b6c5afcd608a37f1 100644
--- a/official/cv/resnet/train.py
+++ b/official/cv/resnet/train.py
@@ -144,7 +144,7 @@ def set_parameter():
         ms.set_context(mode=ms.PYNATIVE_MODE, device_target=target, save_graphs=False)
 
     if config.parameter_server:
-        context.set_ps_context(enable_ps=True)
+        ms.set_ps_context(enable_ps=True)
     if config.run_distribute:
         if target == "Ascend":
             device_id = int(os.getenv('DEVICE_ID'))