From 58307a2c6c2053c74fe9a46daa8d609a5b23c981 Mon Sep 17 00:00:00 2001 From: wangzeyangyi <tomzwang11@gmail.com> Date: Thu, 3 Mar 2022 11:36:19 +0800 Subject: [PATCH] updated deprecated API --- official/cv/resnet/infer.py | 2 +- official/cv/resnet/src/resnet_gpu_benchmark.py | 1 + official/cv/resnet/train.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/official/cv/resnet/infer.py b/official/cv/resnet/infer.py index ac81d18d9..6095ebed2 100644 --- a/official/cv/resnet/infer.py +++ b/official/cv/resnet/infer.py @@ -55,7 +55,7 @@ def infer_net(): target = config.device_target # init context - ms.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False) if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) ms.set_context(device_id=device_id) diff --git a/official/cv/resnet/src/resnet_gpu_benchmark.py b/official/cv/resnet/src/resnet_gpu_benchmark.py index 9eeeb9a74..67ec6ffa6 100644 --- a/official/cv/resnet/src/resnet_gpu_benchmark.py +++ b/official/cv/resnet/src/resnet_gpu_benchmark.py @@ -15,6 +15,7 @@ """ResNet.""" import numpy as np from scipy.stats import truncnorm +import mindspore as ms import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.common.tensor import Tensor diff --git a/official/cv/resnet/train.py b/official/cv/resnet/train.py index 71e1184f2..3f2790077 100644 --- a/official/cv/resnet/train.py +++ b/official/cv/resnet/train.py @@ -144,7 +144,7 @@ def set_parameter(): ms.set_context(mode=ms.PYNATIVE_MODE, device_target=target, save_graphs=False) if config.parameter_server: - context.set_ps_context(enable_ps=True) + ms.set_ps_context(enable_ps=True) if config.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) -- GitLab