Skip to content
Snippets Groups Projects
Commit c2a58d91 authored by zhaoting's avatar zhaoting
Browse files

move set_context to run.py

parent 9f0a3234
No related branches found
No related tags found
No related merge requests found
......@@ -20,7 +20,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
| Audio | Speech Synthesis | [lpcnet](https://gitee.com/mindspore/models/tree/master/official/audio/lpcnet) |✅| ✅ | |
| Audio | Speech Synthesis | [melgan](https://gitee.com/mindspore/models/tree/master/official/audio/melgan) |✅| ✅ | |
| Audio | Speech Synthesis | [tacotron2](https://gitee.com/mindspore/models/tree/master/research/audio/tacotron2) |✅| | |
| Graph Neural Network | Recommender System | [bgcf](https://gitee.com/mindspore/models/tree/master/official/gnn/bgcf) |✅| ✅ | |
| Graph Neural Network | Text Classification | [bgcf](https://gitee.com/mindspore/models/tree/master/official/gnn/bgcf) |✅| ✅ | |
| Graph Neural Network | Text Classification | [gat](https://gitee.com/mindspore/models/tree/master/official/gnn/gat) |✅| ✅ | |
| Graph Neural Network | Text Classification | [gcn](https://gitee.com/mindspore/models/tree/master/official/gnn/gcn) |✅| ✅ | |
| Recommendation | Recommender System | [naml](https://gitee.com/mindspore/models/tree/master/official/recommend/naml) |✅| ✅ | |
......
......@@ -20,7 +20,7 @@
| 语音 | 语音合成 | [lpcnet](https://gitee.com/mindspore/models/tree/master/official/audio/lpcnet) |✅| ✅ | |
| 语音 | 语音合成 | [melgan](https://gitee.com/mindspore/models/tree/master/official/audio/melgan) |✅| ✅ | |
| 语音 | 语音合成 | [tacotron2](https://gitee.com/mindspore/models/tree/master/research/audio/tacotron2) |✅| | |
| 推荐 | 推荐系统 | [bgcf](https://gitee.com/mindspore/models/tree/master/official/gnn/bgcf) |✅| ✅ | |
| 图神经网络 | 文本分类 | [bgcf](https://gitee.com/mindspore/models/tree/master/official/gnn/bgcf) |✅| ✅ | |
| 图神经网络 | 文本分类 | [gat](https://gitee.com/mindspore/models/tree/master/official/gnn/gat) |✅| ✅ | |
| 图神经网络 | 文本分类 | [gcn](https://gitee.com/mindspore/models/tree/master/official/gnn/gcn) |✅| ✅ | |
| 推荐 | 推荐系统 | [naml](https://gitee.com/mindspore/models/tree/master/official/recommend/naml) |✅| ✅ | |
......
......@@ -20,7 +20,6 @@ import mindspore as ms
import mindspore.nn as nn
from src.dataset import create_dataset
from src.models import define_net, load_ckpt
from src.utils import context_device_init
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_process
from src.model_utils.device_adapter import get_device_id
......@@ -30,11 +29,11 @@ config.is_training = config.is_training_eval
@moxing_wrapper(pre_process=modelarts_process)
def eval_mobilenetv2():
ms.set_context(mode=ms.GRAPH_MODE, device_target=config.platform, save_graphs=False)
config.dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
print('\nconfig: \n', config)
if not config.device_id:
config.device_id = get_device_id()
context_device_init(config)
_, _, net = define_net(config, config.is_training)
load_ckpt(net, config.pretrain_ckpt)
......
......@@ -18,9 +18,7 @@ mobilenetv2 export file.
import numpy as np
import mindspore as ms
from src.models import define_net, load_ckpt
from src.utils import context_device_init
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
from src.model_utils.moxing_adapter import moxing_wrapper
config.batch_size = config.batch_size_export
......@@ -31,9 +29,7 @@ config.is_training = config.is_training_export
def export_mobilenetv2():
""" export_mobilenetv2 """
print('\nconfig: \n', config)
if not config.device_id:
config.device_id = get_device_id()
context_device_init(config)
ms.set_context(mode=ms.GRAPH_MODE, device_target=config.platform, save_graphs=False)
_, _, net = define_net(config, config.is_training)
load_ckpt(net, config.ckpt_file)
......
......@@ -68,7 +68,7 @@ run_ascend_or_gpu()
fi
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ $1 = "Ascend" ]; then
if [ $1 = "Ascend" ] && [ $3 -eq 1 ]; then
export DEVICE_ID=${CANDIDATE_DEVICE[0]}
export RANK_ID=0
elif [ $1 = "GPU" ]; then
......
......@@ -13,34 +13,9 @@
# limitations under the License.
# ============================================================================
import mindspore as ms
import mindspore.communication as comm
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from src.models import Monitor
def context_device_init(config):
if config.platform == "GPU" and config.run_distribute:
config.device_id = 0
config.rank_id = 0
config.rank_size = 1
if config.platform == "CPU":
ms.set_context(mode=ms.GRAPH_MODE, device_target=config.platform, save_graphs=False)
elif config.platform in ["Ascend", "GPU"]:
ms.set_context(mode=ms.GRAPH_MODE, device_target=config.platform, device_id=config.device_id,
save_graphs=False)
if config.run_distribute:
comm.init()
config.rank_id = comm.get_rank()
config.rank_size = comm.get_group_size()
ms.set_auto_parallel_context(device_num=config.rank_size,
parallel_mode=ms.ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
raise ValueError("Only support CPU, GPU and Ascend.")
def config_ckpoint(config, lr, step_size, model=None, eval_dataset=None):
cb = [Monitor(lr_init=lr.asnumpy(), model=model, eval_dataset=eval_dataset)]
if config.save_checkpoint and config.rank_id == 0:
......
......@@ -20,11 +20,12 @@ import random
import numpy as np
import mindspore as ms
import mindspore.communication as comm
import mindspore.nn as nn
from src.dataset import create_dataset, extract_features
from src.lr_generator import get_lr
from src.utils import context_device_init, config_ckpoint
from src.utils import config_ckpoint
from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt, build_params_groups
from src.metric import DistAccuracy, ClassifyCorrectCell
from src.model_utils.config import config
......@@ -38,13 +39,21 @@ ms.set_seed(1)
@moxing_wrapper(pre_process=modelarts_process)
def train_mobilenetv2():
""" train_mobilenetv2 """
if config.platform == "CPU":
config.run_distribute = False
ms.set_context(mode=ms.GRAPH_MODE, device_target=config.platform, save_graphs=False)
if config.run_distribute:
comm.init()
config.rank_id = comm.get_rank()
config.rank_size = comm.get_group_size()
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL,
gradients_mean=True)
config.train_dataset_path = os.path.join(config.dataset_path, 'train')
config.eval_dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
if not config.device_id:
config.device_id = get_device_id()
start = time.time()
# set context and device init
context_device_init(config)
print('\nconfig: {} \n'.format(config))
# define network
backbone_net, head_net, net = define_net(config, config.is_training)
......
......@@ -16,7 +16,7 @@
"""Cycle GAN test."""
import os
from mindspore import Tensor
import mindspore as ms
from src.models.cycle_gan import get_generator
from src.utils.args import get_args
from src.dataset.cyclegan_dataset import create_dataset
......@@ -27,6 +27,12 @@ from src.utils.tools import save_image, load_ckpt
def predict():
"""Predict function."""
args = get_args("predict")
ms.set_context(mode=ms.GRAPH_MODE, device_target=args.platform,
save_graphs=args.save_graphs, device_id=args.device_id)
args.rank = 0
args.device_num = 1
if args.platform == "GPU":
ms.set_context(enable_graph_kernel=True)
G_A = get_generator(args)
G_B = get_generator(args)
G_A.set_train(True)
......@@ -44,7 +50,7 @@ def predict():
reporter = Reporter(args)
reporter.start_predict("A to B")
for data in ds.create_dict_iterator(output_numpy=True):
img_A = Tensor(data["image"])
img_A = ms.Tensor(data["image"])
path_A = data["image_name"][0]
path_B = path_A[0:-4] + "_fake_B.jpg"
fake_B = G_A(img_A)
......@@ -57,7 +63,7 @@ def predict():
reporter.dataset_size = args.dataset_size
reporter.start_predict("B to A")
for data in ds.create_dict_iterator(output_numpy=True):
img_B = Tensor(data["image"])
img_B = ms.Tensor(data["image"])
path_B = data["image_name"][0]
path_A = path_B[0:-4] + "_fake_A.jpg"
fake_A = G_B(img_B)
......
......@@ -16,16 +16,15 @@
"""export file."""
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export
import mindspore as ms
from src.models.cycle_gan import get_generator
from src.utils.args import get_args
from src.utils.tools import load_ckpt, enable_batch_statistics
args = get_args("export")
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
if __name__ == '__main__':
args = get_args("export")
ms.set_context(mode=ms.GRAPH_MODE, device_target=args.platform)
G_A = get_generator(args)
G_B = get_generator(args)
# Use BatchNorm2d with batchsize=1, affine=False, use_batch_statistics=True instead of InstanceNorm2d
......@@ -35,8 +34,8 @@ if __name__ == '__main__':
load_ckpt(args, G_A, G_B)
input_shp = [args.export_batch_size, 3, args.image_size, args.image_size]
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
input_array = ms.Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
G_A_file = f"{args.export_file_name}_AtoB"
export(G_A, input_array, file_name=G_A_file, file_format=args.export_file_format)
ms.export(G_A, input_array, file_name=G_A_file, file_format=args.export_file_format)
G_B_file = f"{args.export_file_name}_BtoA"
export(G_B, input_array, file_name=G_B_file, file_format=args.export_file_format)
ms.export(G_B, input_array, file_name=G_B_file, file_format=args.export_file_format)
......@@ -17,10 +17,6 @@
import argparse
import ast
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
parser = argparse.ArgumentParser(description='Cycle GAN')
# basic parameters
......@@ -118,24 +114,6 @@ args = parser.parse_args()
def get_args(phase):
"""Define the common options that are used in both training and test."""
if args.device_num > 1:
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=args.save_graphs)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=args.device_num)
init()
args.rank = get_rank()
args.group_size = get_group_size()
else:
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform,
save_graphs=args.save_graphs, device_id=args.device_id)
args.rank = 0
args.device_num = 1
if args.platform == "GPU":
context.set_context(enable_graph_kernel=True)
if args.platform == "Ascend":
args.pad_mode = "CONSTANT"
......
......@@ -22,6 +22,7 @@ Example:
import mindspore as ms
import mindspore.nn as nn
from mindspore.communication.management import init, get_rank, get_group_size
from src.utils.args import get_args
from src.utils.reporter import Reporter
from src.utils.tools import get_lr, ImagePool, load_ckpt
......@@ -34,6 +35,21 @@ ms.set_seed(1)
def train():
"""Train function."""
args = get_args("train")
if args.device_num > 1:
ms.set_context(mode=ms.GRAPH_MODE, device_target=args.platform, save_graphs=args.save_graphs)
init()
ms.reset_auto_parallel_context()
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
args.rank = get_rank()
args.group_size = get_group_size()
else:
ms.set_context(mode=ms.GRAPH_MODE, device_target=args.platform,
save_graphs=args.save_graphs, device_id=args.device_id)
args.rank = 0
args.device_num = 1
if args.platform == "GPU":
ms.set_context(enable_graph_kernel=True)
if args.need_profiler:
from mindspore.profiler.profiling import Profiler
profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment