diff --git a/official/nlp/pangu_alpha/scripts/run_distribute_train_moe_host_device.sh b/official/nlp/pangu_alpha/scripts/run_distribute_train_moe_host_device.sh
index ee09f62d299fcee07b924b1d4f397705cddc2015..c50344d6263013dedb3a0504b260d722360a74bb 100644
--- a/official/nlp/pangu_alpha/scripts/run_distribute_train_moe_host_device.sh
+++ b/official/nlp/pangu_alpha/scripts/run_distribute_train_moe_host_device.sh
@@ -40,6 +40,7 @@ LOCAL_DEVICE_NUM=${10}
 EXPERT_NUM=${11}
 ENABLE_ALLTOALL=${12}
 EXPERT_PARALLEL_NUM=${13}
+export HCCL_EXEC_TIMEOUT=1500
 
 for((i=0;i<${LOCAL_DEVICE_NUM};i++));
 do
diff --git a/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py b/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py
index b774ec5b35d7628f0deebc3b44b36595587833b6..b16ea67e8ad687d4e8bf38cd6c6c867a216fd277 100644
--- a/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py
+++ b/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py
@@ -121,8 +121,6 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
         else:
             self.clip = ClipByGlobalNorm(self.weights, config)
         self.cast = P.Cast()
-        self.sync_all_reduce = P.AllReduce()
-        self.sync_tensor = Tensor(0.0, dtype=mstype.float32)
 
     def construct(self, input_ids, input_position, attention_mask, layer_past=None, sens=None):
         """Defines the computation performed."""
@@ -158,12 +156,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
         # if not, update weights
         if not overflow:
             if self.enable_offload:
-                res = self.optimizer(grads, clip_value)
-                # For moe and enable_offload, compile time of difference devices have great gap and cause
-                # notify wait, a sync allreduce at the end of last graph is need
-                sync_tensor = F.depend(self.sync_tensor, res)
-                sync_flag = self.sync_all_reduce(sync_tensor)
-                loss = F.depend(loss, sync_flag)
+                self.optimizer(grads, clip_value)
             else:
                 self.optimizer(grads)
         return loss, cond, scaling_sens