diff --git a/official/cv/fastscnn/train.py b/official/cv/fastscnn/train.py
index e2132e9e815c09369f97a00051e92ecc7637378c..8ba9af1a40af158f24d71e27bd63331bf17b97a9 100644
--- a/official/cv/fastscnn/train.py
+++ b/official/cv/fastscnn/train.py
@@ -106,12 +106,11 @@ def train():
     if args.is_distributed:
         assert args.device_target == "Ascend"
         context.set_context(device_id=device_id)
-        init("hccl")
+        context.reset_auto_parallel_context()
+        context.set_auto_parallel_context(device_num=args.group_size, parallel_mode=ParallelMode.DATA_PARALLEL)
+        init()
         args.rank = get_rank()
         args.group_size = get_group_size()
-        device_num = args.group_size
-        context.reset_auto_parallel_context()
-        context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
     else:
         if args.device_target in ["Ascend", "GPU"]:
             context.set_context(device_id=device_id)
@@ -182,34 +181,33 @@ def train():
         time_cb = TimeMonitor(data_size=args.steps_per_epoch)
         loss_cb = LossMonitor()
         callbacks = [time_cb, loss_cb]
+
+        if args.rank_save_ckpt_flag:
+            ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch*args.save_every,
+                                           keep_checkpoint_max=args.ckpt_save_max)
+            save_ckpt_path = os.path.join(save_dir, 'ckpt_' + str(args.rank) + '/')
+            ckpt_cb = ModelCheckpoint(config=ckpt_config,
+                                      directory=save_ckpt_path,
+                                      prefix='rank_'+str(args.rank))
+            callbacks.append(ckpt_cb)
+
+        if args.eval_while_train and args.rank == 0:
+
+            val_dataset, _ = create_CitySegmentation(args, data_path=args.dataset, \
+                                           split='val', mode='val', transform=input_transform, \
+                                           base_size=args.base_size, crop_size=args.crop_size, \
+                                           batch_size=1, shuffle=False)
+            loss_f = TempLoss()
+            network_eval = Model(f_model, loss_fn=loss_f, metrics={"SegmentationMetric": SegmentationMetric(19)})
+
+            eval_cb = EvalCallBack(network_eval, val_dataset, interval=args.eval_steps,
+                                   eval_start_epoch=args.eval_start_epoch, save_best_ckpt=True,
+                                   ckpt_directory=save_dir, besk_ckpt_name="best_map.ckpt",
+                                   metrics_name=("pixAcc", "mIou"))
+            callbacks.append(eval_cb)
     else:
         callbacks = None
 
-    if args.rank_save_ckpt_flag:
-        ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch*args.save_every,
-                                       keep_checkpoint_max=args.ckpt_save_max)
-        save_ckpt_path = os.path.join(save_dir, 'ckpt_' + str(args.rank) + '/')
-        ckpt_cb = ModelCheckpoint(config=ckpt_config,
-                                  directory=save_ckpt_path,
-                                  prefix='rank_'+str(args.rank))
-        callbacks.append(ckpt_cb)
-
-    if args.eval_while_train and args.rank == 0:
-
-        val_dataset, _ = create_CitySegmentation(args, data_path=args.dataset, \
-                                       split='val', mode='val', transform=input_transform, \
-                                       base_size=args.base_size, crop_size=args.crop_size, \
-                                       batch_size=1, device_num=1, \
-                                       rank=args.rank, shuffle=False)
-        loss_f = TempLoss()
-        network_eval = Model(f_model, loss_fn=loss_f, metrics={"SegmentationMetric": SegmentationMetric(19)})
-
-        eval_cb = EvalCallBack(network_eval, val_dataset, interval=args.eval_steps,
-                               eval_start_epoch=args.eval_start_epoch, save_best_ckpt=True,
-                               ckpt_directory=save_dir, besk_ckpt_name="best_map.ckpt",
-                               metrics_name=("pixAcc", "mIou"))
-        callbacks.append(eval_cb)
-
     model.train(args.epochs, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
 
     args.logger.info("training finished....")