Skip to content
Snippets Groups Projects
Unverified Commit 911d60fc authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2815 fix PDarts train script error

Merge pull request !2815 from xubangduo/fixscripts
parents c4fd226a d9f63971
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,7 @@ def run_train():
args_opt = parser.parse_args()
device_id = int(os.getenv("DEVICE_ID"))
device_id = int(os.getenv("DEVICE_ID", '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
if args_opt.distribute == "true":
D.init()
......
......@@ -83,7 +83,7 @@ class Val_Callback(Callback):
"""
def __init__(self, model, train_dataset, val_dataset, checkpoint_path, prefix,
network, img_size, device_id=0, is_eval_train_dataset='False'):
network, img_size, rank_id=0, is_eval_train_dataset='False'):
super(Val_Callback, self).__init__()
self.model = model
self.train_dataset = train_dataset
......@@ -93,7 +93,7 @@ class Val_Callback(Callback):
self.prefix = prefix
self.network = network
self.img_size = img_size
self.device_id = device_id
self.rank_id = rank_id
self.is_eval_train_dataset = is_eval_train_dataset
def epoch_end(self, run_context):
......@@ -118,7 +118,7 @@ class Val_Callback(Callback):
self.max_val_acc = val_acc
cb_params = run_context.original_args()
epoch = cb_params.cur_epoch_num
model_info = self.prefix + '_id' + str(self.device_id) + \
model_info = self.prefix + '_id' + str(self.rank_id) + \
'_epoch' + str(epoch) + '_valacc' + str(val_acc)
if self.checkpoint_path.startswith('s3://') or self.checkpoint_path.startswith('obs://'):
save_path = '/cache/save_model/'
......
......@@ -22,14 +22,14 @@ from mindspore.dataset.vision.utils import Inter
def create_cifar10_dataset(data_dir, training=True, repeat_num=1, num_parallel_workers=5,
resize_height=32, resize_width=32, batch_size=512,
num_samples=None, shuffle=None, cutout_length=0, device_id=0, device_num=1):
num_samples=None, shuffle=None, cutout_length=0, rank_id=0, rank_size=1):
"""Data operations."""
ds.config.set_seed(1)
ds.config.set_num_parallel_workers(num_parallel_workers)
if training:
data_set = ds.Cifar10Dataset(data_dir, num_samples=num_samples,
shuffle=shuffle, num_shards=device_num, shard_id=device_id)
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id)
else:
data_set = ds.Cifar10Dataset(data_dir, num_samples=num_samples,
shuffle=shuffle, num_shards=1, shard_id=0)
......
......@@ -99,14 +99,15 @@ def cosine_lr(base_lr, decay_steps, total_steps):
def main():
device_num = int(os.getenv('RANK_SIZE', '1'))
device_id = get_rank()
context.set_context(mode=context.GRAPH_MODE,
device_target=args.device_target)
context.set_context(enable_graph_kernel=True)
if device_num > 1:
rank_size = int(os.getenv('RANK_SIZE', '1'))
rank_id = 0
if rank_size > 1:
init()
context.set_auto_parallel_context(device_num=device_num,
rank_id = get_rank()
context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
......@@ -142,10 +143,10 @@ def main():
train_path = os.path.join(args.data_url, 'train')
train_dataset = create_cifar10_dataset(
train_path, True, batch_size=args.batch_size, shuffle=True, cutout_length=args.cutout_length,
device_id=device_id, device_num=device_num)
rank_id=rank_id, rank_size=rank_size)
val_path = os.path.join(args.data_url, 'val')
val_dataset = create_cifar10_dataset(
val_path, False, batch_size=128, shuffle=False, device_id=device_id, device_num=device_num)
val_path, False, batch_size=128, shuffle=False, rank_id=rank_id, rank_size=rank_size)
# learning rate setting
step_size = train_dataset.get_dataset_size()
......@@ -185,7 +186,7 @@ def main():
time_cb = TimeMonitor()
val_callback = Val_Callback(model, train_dataset, val_dataset, args.train_url,
prefix='PDarts', network=network, img_size=32,
device_id=device_id, is_eval_train_dataset=True)
rank_id=rank_id, is_eval_train_dataset=True)
callbacks = [loss_cb, time_cb, val_callback, set_attr_cb]
model.train(args.epochs, train_dataset,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment