diff --git a/official/nlp/pangu_alpha/src/utils.py b/official/nlp/pangu_alpha/src/utils.py
index 57caa2b2b5c4c818fbf40d14eb8db7d32669ffd9..a07ed7d2f87707ed00583a265e06d6681dd849c0 100644
--- a/official/nlp/pangu_alpha/src/utils.py
+++ b/official/nlp/pangu_alpha/src/utils.py
@@ -394,9 +394,6 @@ def add_training_params(opt):
     opt.add_argument("--data_column_name",
                      type=str, default="input_ids",
                      help="Column name of datasets")
-    opt.add_argument("--micro_batch_interleaved",
-                     type=int, default=2,
-                     help="Parallel split num of batch size. default 2")
 
 
 
diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py
index 7ada03c2db5f0674090d52a5cfe6b0c5afdd3d75..a70ad082a1bedc673a61273ccd67ce5ebb85ea8d 100644
--- a/official/nlp/pangu_alpha/train.py
+++ b/official/nlp/pangu_alpha/train.py
@@ -30,7 +30,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
 import mindspore.common.dtype as mstype
 from mindspore.parallel import set_algo_parameters
 from mindspore.parallel._cost_model_context import _set_multi_subgraphs
-from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, MicroBatchInterleaved
+from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
 from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss
 from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
 from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net
@@ -103,6 +103,7 @@ def set_parallel_context(args_opt):
         enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt')
     set_algo_parameters(elementwise_op_strategy_follow=True)
     _set_multi_subgraphs()
+    return rank, device_num
 
 
 def run_train(args_opt):
@@ -115,7 +116,7 @@ def run_train(args_opt):
     rank = 0
     device_num = 1
     if args_opt.distribute == "true":
-        set_parallel_context(args_opt)
+        rank, device_num = set_parallel_context(args_opt)
     context.set_context(save_graphs=False, save_graphs_path="./graphs_of_device_id_" + str(rank))
     # copy data from the cloud to the /cache/Data
     cache_url = '/cache/Data/'
@@ -130,14 +131,13 @@ def run_train(args_opt):
     model_parallel_num = args_opt.op_level_model_parallel_num
     data_parallel_num = int(device_num / model_parallel_num)
     batch_size = args_opt.per_batch_size * data_parallel_num
-    micro_batch_interleaved = args_opt.micro_batch_interleaved
     parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num,
                                                   pipeline_stage=args_opt.stage_num,
                                                   micro_batch_num=args_opt.micro_size,
                                                   optimizer_shard=bool(args_opt.optimizer_shard),
                                                   vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True,
                                                   gradient_aggregation_group=args_opt.gradient_aggregation_group)
-    config = PanguAlphaConfig(batch_size=batch_size // micro_batch_interleaved, num_heads=args_opt.num_heads,
+    config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads,
                               hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length,
                               vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers,
                               ffn_hidden_size=args_opt.embedding_size * 4, eod_token=bool(args_opt.eod_reset),
@@ -151,8 +151,7 @@ def run_train(args_opt):
     # Define network
     pangu_alpha = PanguAlphaModel(config=config)
     loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
-    pangu_alpha_with_loss_net = MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss),
-                                                      micro_batch_interleaved)
+    pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss)
     pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
     print("=====args_opt is: ", args_opt, flush=True)
     # Warm-up and cosine decay learning rate
@@ -171,7 +170,7 @@ def run_train(args_opt):
     loss_scale_value = math.pow(2, 32)
     epoch_num = args_opt.epoch_size
     # Dataset loading mindrecord files
-    ds = create_dataset(config.batch_size * micro_batch_interleaved, data_path=cache_url, data_start_index=0,
+    ds = create_dataset(config.batch_size, data_path=cache_url, data_start_index=0,
                         eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id,
                         device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num)
     actual_epoch_num = int(epoch_num * ds.get_dataset_size() / args_opt.sink_size)
@@ -181,7 +180,7 @@ def run_train(args_opt):
         pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True,
         config=config)
     if args_opt.train_and_eval_mode:
-        ds_eval = create_dataset(config.batch_size * micro_batch_interleaved, data_path=eval_cache_url,
+        ds_eval = create_dataset(config.batch_size, data_path=eval_cache_url,
                                  data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch),
                                  eod_id=args_opt.eod_id, device_num=device_num, rank=rank,
                                  column_name=args_opt.data_column_name, epoch=epoch_num,