diff --git a/official/cv/fastscnn/train.py b/official/cv/fastscnn/train.py index e2132e9e815c09369f97a00051e92ecc7637378c..8ba9af1a40af158f24d71e27bd63331bf17b97a9 100644 --- a/official/cv/fastscnn/train.py +++ b/official/cv/fastscnn/train.py @@ -106,12 +106,11 @@ def train(): if args.is_distributed: assert args.device_target == "Ascend" context.set_context(device_id=device_id) - init("hccl") + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=args.group_size, parallel_mode=ParallelMode.DATA_PARALLEL) + init() args.rank = get_rank() args.group_size = get_group_size() - device_num = args.group_size - context.reset_auto_parallel_context() - context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL) else: if args.device_target in ["Ascend", "GPU"]: context.set_context(device_id=device_id) @@ -182,34 +181,33 @@ def train(): time_cb = TimeMonitor(data_size=args.steps_per_epoch) loss_cb = LossMonitor() callbacks = [time_cb, loss_cb] + + if args.rank_save_ckpt_flag: + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch*args.save_every, + keep_checkpoint_max=args.ckpt_save_max) + save_ckpt_path = os.path.join(save_dir, 'ckpt_' + str(args.rank) + '/') + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=save_ckpt_path, + prefix='rank_'+str(args.rank)) + callbacks.append(ckpt_cb) + + if args.eval_while_train and args.rank == 0: + + val_dataset, _ = create_CitySegmentation(args, data_path=args.dataset, \ + split='val', mode='val', transform=input_transform, \ + base_size=args.base_size, crop_size=args.crop_size, \ + batch_size=1, shuffle=False) + loss_f = TempLoss() + network_eval = Model(f_model, loss_fn=loss_f, metrics={"SegmentationMetric": SegmentationMetric(19)}) + + eval_cb = EvalCallBack(network_eval, val_dataset, interval=args.eval_steps, + eval_start_epoch=args.eval_start_epoch, save_best_ckpt=True, + ckpt_directory=save_dir, besk_ckpt_name="best_map.ckpt", + metrics_name=("pixAcc", "mIou")) + callbacks.append(eval_cb) else: callbacks = None - if args.rank_save_ckpt_flag: - ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch*args.save_every, - keep_checkpoint_max=args.ckpt_save_max) - save_ckpt_path = os.path.join(save_dir, 'ckpt_' + str(args.rank) + '/') - ckpt_cb = ModelCheckpoint(config=ckpt_config, - directory=save_ckpt_path, - prefix='rank_'+str(args.rank)) - callbacks.append(ckpt_cb) - - if args.eval_while_train and args.rank == 0: - - val_dataset, _ = create_CitySegmentation(args, data_path=args.dataset, \ - split='val', mode='val', transform=input_transform, \ - base_size=args.base_size, crop_size=args.crop_size, \ - batch_size=1, device_num=1, \ - rank=args.rank, shuffle=False) - loss_f = TempLoss() - network_eval = Model(f_model, loss_fn=loss_f, metrics={"SegmentationMetric": SegmentationMetric(19)}) - - eval_cb = EvalCallBack(network_eval, val_dataset, interval=args.eval_steps, - eval_start_epoch=args.eval_start_epoch, save_best_ckpt=True, - ckpt_directory=save_dir, besk_ckpt_name="best_map.ckpt", - metrics_name=("pixAcc", "mIou")) - callbacks.append(eval_cb) - model.train(args.epochs, train_dataset, callbacks=callbacks, dataset_sink_mode=True) args.logger.info("training finished....")