diff --git a/official/nlp/pangu_alpha/src/utils.py b/official/nlp/pangu_alpha/src/utils.py
index a07ed7d2f87707ed00583a265e06d6681dd849c0..57caa2b2b5c4c818fbf40d14eb8db7d32669ffd9 100644
--- a/official/nlp/pangu_alpha/src/utils.py
+++ b/official/nlp/pangu_alpha/src/utils.py
@@ -394,6 +394,9 @@ def add_training_params(opt):
     opt.add_argument("--data_column_name",
                      type=str, default="input_ids",
                      help="Column name of datasets")
+    opt.add_argument("--micro_batch_interleaved",
+                     type=int, default=2,
+                     help="Parallel split num of batch size. default 2")
 
 
 
diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py
index 847ea9b657a501f0bfc5a340be51773eac49d248..db51564d4cfc247bcd18f2457c0778d99b00f77f 100644
--- a/official/nlp/pangu_alpha/train.py
+++ b/official/nlp/pangu_alpha/train.py
@@ -32,7 +32,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
 import mindspore.common.dtype as mstype
 from mindspore.parallel import set_algo_parameters
 from mindspore.parallel._cost_model_context import _set_multi_subgraphs
-from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
+from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, MicroBatchInterleaved
 from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss
 from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
 from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net
@@ -133,13 +133,14 @@ def run_train(args_opt):
     model_parallel_num = args_opt.op_level_model_parallel_num
     data_parallel_num = int(device_num / model_parallel_num)
     batch_size = args_opt.per_batch_size * data_parallel_num
+    micro_batch_interleaved = args_opt.micro_batch_interleaved
     parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num,
                                                   pipeline_stage=args_opt.stage_num,
                                                   micro_batch_num=args_opt.micro_size,
                                                   optimizer_shard=bool(args_opt.optimizer_shard),
                                                   vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True,
                                                   gradient_aggregation_group=args_opt.gradient_aggregation_group)
-    config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads,
+    config = PanguAlphaConfig(batch_size=batch_size // micro_batch_interleaved, num_heads=args_opt.num_heads,
                               hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length,
                               vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers,
                               ffn_hidden_size=args_opt.embedding_size * 4, eod_token=bool(args_opt.eod_reset),
@@ -153,7 +154,8 @@ def run_train(args_opt):
     # Define network
     pangu_alpha = PanguAlphaModel(config=config)
     loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
-    pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss)
+    pangu_alpha_with_loss_net = MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss),
+                                                      micro_batch_interleaved)
     pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
     print("=====args_opt is: ", args_opt, flush=True)
     # Warm-up and cosine decay learning rate
@@ -172,7 +174,7 @@ def run_train(args_opt):
     loss_scale_value = math.pow(2, 32)
     epoch_num = args_opt.epoch_size
     # Dataset loading mindrecord files
-    ds = create_dataset(config.batch_size, data_path=cache_url, data_start_index=0,
+    ds = create_dataset(config.batch_size * micro_batch_interleaved, data_path=cache_url, data_start_index=0,
                         eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id,
                         device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num)
     actual_epoch_num = int(epoch_num * ds.get_dataset_size() / args_opt.sink_size)
@@ -182,7 +184,7 @@ def run_train(args_opt):
         pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True,
         config=config)
     if args_opt.train_and_eval_mode:
-        ds_eval = create_dataset(config.batch_size, data_path=eval_cache_url,
+        ds_eval = create_dataset(config.batch_size * micro_batch_interleaved, data_path=eval_cache_url,
                                  data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch),
                                  eod_id=args_opt.eod_id, device_num=device_num, rank=rank,
                                  column_name=args_opt.data_column_name, epoch=epoch_num,
@@ -353,6 +355,22 @@ def restore_exception_checkpoint(args_param, sink_size, dataset, model, network,
         return True
 
 
+def set_pipeline_parallel_context(args_opt):
+    r"""Set pipeline parallel context."""
+    D.init()
+    device_num = D.get_group_size()
+    rank_id = D.get_rank()
+    print("rank_id is {}, device_num is {}".format(rank_id, device_num))
+    context.reset_auto_parallel_context()
+    context.set_auto_parallel_context(
+        parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
+        full_batch=bool(args_opt.full_batch), loss_repeated_mean=True,
+        device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard),
+        pipeline_stages=args_opt.stage_num)
+    set_algo_parameters(elementwise_op_strategy_follow=True)
+    _set_multi_subgraphs()
+    return rank_id, device_num
+
 def run_train_pipeline(args_opt):
     r"""The main training process in pipeline."""
     # Set hccl connect time
@@ -360,22 +378,10 @@ def run_train_pipeline(args_opt):
 
     context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target)
     context.set_context(variable_memory_max_size="30GB")
+    rank_id = int(os.getenv("RANK_ID"))
+    device_num = 1
     if args_opt.distribute == "true":
-        D.init()
-        device_num = D.get_group_size()
-        rank_id = D.get_rank()
-        print("rank_id is {}, device_num is {}".format(rank_id, device_num))
-        context.reset_auto_parallel_context()
-        context.set_auto_parallel_context(
-            parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
-            full_batch=bool(args_opt.full_batch), loss_repeated_mean=True,
-            device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard),
-            pipeline_stages=args_opt.stage_num)
-        set_algo_parameters(elementwise_op_strategy_follow=True)
-        _set_multi_subgraphs()
-    else:
-        rank_id = int(os.getenv("RANK_ID"))
-        device_num = 1
+        rank_id, device_num = set_pipeline_parallel_context(args_opt)
     # copy data from the cloud to the /cache/Data
     cache_url = '/cache/Data/'
     eval_cache_url = '/cache/EvalData/'
@@ -392,7 +398,7 @@ def run_train_pipeline(args_opt):
         raise ValueError("The dp must large than 1 when applying optimizer shard.")
     per_batch_size = args_opt.per_batch_size
     batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
-
+    micro_batch_interleaved = args_opt.micro_batch_interleaved
     parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
                                                   model_parallel=model_parallel_num,
                                                   pipeline_stage=args_opt.stage_num,
@@ -400,7 +406,7 @@ def run_train_pipeline(args_opt):
                                                   optimizer_shard=bool(args_opt.optimizer_shard),
                                                   vocab_emb_dp=bool(args_opt.word_emb_dp),
                                                   recompute=True)
-    config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num,
+    config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num // micro_batch_interleaved,
                               num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size,
                               seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size,
                               num_layers=args_opt.num_layers, ffn_hidden_size=args_opt.embedding_size * 4,
@@ -411,7 +417,8 @@ def run_train_pipeline(args_opt):
     print("===config is: ", config, flush=True)
     pangu_alpha = PanguAlphaModel(config=config)
     loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
-    pangu_alpha_with_loss_net = PipelineCell(PanGUAlphaWithLoss(config, pangu_alpha, loss),
+    pangu_alpha_with_loss_net = PipelineCell(MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss),
+                                                                   micro_batch_interleaved),
                                              config.parallel_config.micro_batch_num)
     pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
     print("=====args_opt is: ", args_opt, flush=True)
@@ -427,7 +434,8 @@ def run_train_pipeline(args_opt):
     else:
         optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
 
-    ds = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=cache_url,
+    ds = create_dataset(config.batch_size * parallel_config.micro_batch_num * micro_batch_interleaved,
+                        data_path=cache_url,
                         device_num=stage_device_num,
                         rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
                         full_batch=context.get_auto_parallel_context("full_batch"),
@@ -443,7 +451,8 @@ def run_train_pipeline(args_opt):
     pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
         pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
     if args_opt.train_and_eval_mode:
-        ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=eval_cache_url,
+        ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num * micro_batch_interleaved,
+                                 data_path=eval_cache_url,
                                  device_num=stage_device_num, rank=rank_id % stage_device_num, eod_reset=True,
                                  data_start_index=0, full_batch=bool(args_opt.full_batch),
                                  column_name=args_opt.data_column_name,