Skip to content
Snippets Groups Projects
Unverified Commit fe2bb1ac authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3515 support running on modelarts & fix fasttext performance

Merge pull request !3515 from JichenZhao/master
parents 2f5db948 502655c1
No related branches found
No related tags found
No related merge requests found
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -35,14 +35,14 @@ def load_dataset(dataset_path,
shuffle=shuffle,
num_shards=rank_size,
shard_id=rank_id,
num_parallel_workers=4)
num_parallel_workers=32)
ori_dataset_size = data_set.get_dataset_size()
print(f"Dataset size: {ori_dataset_size}")
repeat_count = epoch_count
data_set = data_set.rename(input_columns=['src_tokens', 'src_tokens_length', 'label_idx'],
output_columns=['src_token_text', 'src_tokens_text_length', 'label_idx_tag'])
data_set = data_set.batch(batch_size, drop_remainder=False)
data_set = data_set.batch(batch_size, drop_remainder=False, num_parallel_workers=32)
data_set = data_set.repeat(repeat_count)
return data_set
......
......@@ -128,6 +128,35 @@ After installing MindSpore via the official website, you can start training and
<https://gitee.com/mindspore/models/tree/master/utils/hccl_tools>.
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
```python
# run distributed training on modelarts example
# (1) First, Perform a or b.
# a. Set "enable_modelarts=True" on yaml file.
# Set other parameters on yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add other parameters on the website UI interface.
# (2) Set the code directory to "/path/resnet" on the website UI interface.
# (3) Set the startup file to "train.py" on the website UI interface.
# (4) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (5) Create your job.
# run evaluation on modelarts example
# (1) Copy or upload your trained model to S3 bucket.
# (2) Perform a or b.
# a. Set "enable_modelarts=True" on yaml file.
# Set "checkpoint_file_path='/cache/checkpoint_path/model.ckpt'" on yaml file.
# Set "checkpoint_url=/The path of checkpoint in S3/" on yaml file.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "ckpt_file='checkpoint_file_name'" on the website UI interface.
# Add "checkpoint_url=/The path of checkpoint in S3/" on the website UI interface.
# (3) Set the code directory to "/path/resnet" on the website UI interface.
# (4) Set the startup file to "eval.py" on the website UI interface.
# (5) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (6) Create your job.
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
......
......@@ -3,9 +3,9 @@ enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/data/pafnucy/tests/data/dataset"
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
load_path: "/cache/data"
mindrecord_path: ""
device_target: Ascend
enable_profiling: False
......
......@@ -38,6 +38,8 @@ def get_data_size(csv_path):
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
config.logger = get_logger('./', config.device_id)
config.logger.save_args(config)
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=config.device_id)
network = SBNetWork(in_chanel=[19, 64, 128],
out_chanle=config.conv_channels,
......@@ -45,7 +47,13 @@ def run_eval():
lmbda=config.lmbda,
isize=config.isize, keep_prob=1.0)
network.set_train(False)
param_dict = load_checkpoint(config.ckpt_file)
if config.enable_modelarts:
config.mindrecord_path = config.data_path
config.ckpt_path = config.load_path
print("----------&&&&------", config.ckpt_path, flush=True)
param_dict = load_checkpoint(os.path.join(config.ckpt_path, config.ckpt_file))
else:
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(network, param_dict)
val_size = get_data_size(config.mindrecord_path)
val_path = './val/validation_dataset.mindrecord'
......@@ -66,6 +74,4 @@ def run_eval():
if __name__ == '__main__':
config.logger = get_logger('./', config.device_id)
config.logger.save_args(config)
run_eval()
......@@ -26,6 +26,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.net import SBNetWork
from src.logger import get_logger
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
from src.model_utils.moxing_adapter import moxing_wrapper
from src.data import Featurizer, make_grid
......@@ -100,16 +101,26 @@ def load_evaldata(configs, data_path):
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=5)
config.logger = get_logger('./', config.device_id)
config.logger.save_args(config)
if config.enable_modelarts:
path = config.data_path
config.ckpt_path = config.load_path
param_dict = load_checkpoint(os.path.join(config.ckpt_path, config.ckpt_file))
hdf_file_path = os.path.join(path, config.hdf_file)
else:
path = config.predict_input
param_dict = load_checkpoint(config.ckpt_file)
hdf_file_path = path
device_id = get_device_id()
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id)
network = SBNetWork(in_chanel=[19, 64, 128],
out_chanle=config.conv_channels,
dense_size=config.dense_sizes,
lmbda=config.lmbda,
isize=config.isize, keep_prob=1.0, is_training=False)
network.set_train(False)
hdf_file_path = input_file(config.hdf_file)
data_loader, names = load_evaldata(config, hdf_file_path)
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(network, param_dict)
prediction = []
for data in data_loader.create_dict_iterator(output_numpy=True):
......@@ -129,9 +140,4 @@ def run_eval():
if __name__ == '__main__':
config.logger = get_logger('./', config.device_id)
config.logger.save_args(config)
path = config.predict_input
paths = input_file(path)
config.hdf_file = paths
run_eval()
......@@ -22,14 +22,14 @@ from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.common import set_seed
from mindspore.nn import TrainOneStepCell
from mindspore.common import dtype as mstype
from mindspore.communication.management import init
from mindspore.context import ParallelMode
from mindspore.train.serialization import save_checkpoint
from src.logger import get_logger
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
from src.net import SBNetWork
from src.dataloader import minddataset_loader, minddataset_loader_val
......@@ -68,13 +68,22 @@ def run_train():
dense_size=config.dense_sizes,
lmbda=config.lmbda,
isize=config.isize, keep_prob=config.keep_prob)
lr = Tensor(float(config.lr))
lr = Tensor(float(config.lr), mstype.float32)
if config.enable_modelarts:
config.mindrecord_path = config.data_path
config.ckpt_path = os.path.join(config.output_path, 'ckpt_'+str(get_rank_id()))
if not os.path.exists(config.ckpt_path):
os.mkdir(config.ckpt_path)
config.logger.info("mkdir %s", config.ckpt_path)
else:
config.ckpt_path = './ckpt/'
optimizer = nn.Adam(params=network.trainable_params(), learning_rate=lr, weight_decay=config.weight_decay)
train_wrapper = TrainOneStepCell(network, optimizer=optimizer)
train_size, val_size = get_data_size(config.mindrecord_path)
rot_path = './train_rotation/train_rotation_dataset.mindrecord'
nrot_path = './no_rotation/train_norotation_dataset.mindrecord'
val_path = './val/validation_dataset.mindrecord'
print(os.listdir(config.mindrecord_path))
rotation_data, _ = minddataset_loader(configs=config,
mindfile=os.path.join(config.mindrecord_path, rot_path),
no_batch_size=train_size)
......@@ -130,7 +139,7 @@ def run_train():
if final_mse_v <= stand_mse_v:
stand_mse_v = final_mse_v
config.logger.info("Saving checkpoint file")
save_checkpoint(train_wrapper, f'./ckpt/pafnucy_{final_mse_v}_{epoch}.ckpt')
save_checkpoint(train_wrapper, f'{config.ckpt_path}/pafnucy_{final_mse_v}_{epoch}.ckpt')
config.logger.info("Finish Training.....")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment