diff --git a/official/cv/resnet/infer.py b/official/cv/resnet/infer.py index ac81d18d98da4d1d0815ca6fc73ebd4d6ab6424f..6095ebed26343d29951ecb35a72dbc51a40016dd 100644 --- a/official/cv/resnet/infer.py +++ b/official/cv/resnet/infer.py @@ -55,7 +55,7 @@ def infer_net(): target = config.device_target # init context - ms.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False) if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) ms.set_context(device_id=device_id) diff --git a/official/cv/resnet/src/resnet_gpu_benchmark.py b/official/cv/resnet/src/resnet_gpu_benchmark.py index 9eeeb9a74b72e694683f04a2de0d373bb90b5390..67ec6ffa67598dee896fbf1d59a204d97ab85f6f 100644 --- a/official/cv/resnet/src/resnet_gpu_benchmark.py +++ b/official/cv/resnet/src/resnet_gpu_benchmark.py @@ -15,6 +15,7 @@ """ResNet.""" import numpy as np from scipy.stats import truncnorm +import mindspore as ms import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.common.tensor import Tensor diff --git a/official/cv/resnet/train.py b/official/cv/resnet/train.py index 71e1184f224b2b97e4e8cf2eb45f01fa4b89d0c8..3f2790077d3dfcf6d5f9eec0b6c5afcd608a37f1 100644 --- a/official/cv/resnet/train.py +++ b/official/cv/resnet/train.py @@ -144,7 +144,7 @@ def set_parameter(): ms.set_context(mode=ms.PYNATIVE_MODE, device_target=target, save_graphs=False) if config.parameter_server: - context.set_ps_context(enable_ps=True) + ms.set_ps_context(enable_ps=True) if config.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID'))