Skip to content
Snippets Groups Projects
Unverified Commit 46f077dc authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3092 add sync allreduce for notify wait

Merge pull request !3092 from bichaoyang/master
parents 3d1fc590 1ce29c1f
No related branches found
Tags 4.19.90-2202.2.0
No related merge requests found
......@@ -121,6 +121,8 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
else:
self.clip = ClipByGlobalNorm(self.weights, config)
self.cast = P.Cast()
self.sync_all_reduce = P.AllReduce()
self.sync_tensor = Tensor(0.0, dtype=mstype.float32)
def construct(self, input_ids, input_position, attention_mask, layer_past=None, sens=None):
"""Defines the computation performed."""
......@@ -156,7 +158,12 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
# if not, update weights
if not overflow:
if self.enable_offload:
self.optimizer(grads, clip_value)
res = self.optimizer(grads, clip_value)
# For moe and enable_offload, compile time of difference devices have great gap and cause
# notify wait, a sync allreduce at the end of last graph is need
sync_tensor = F.depend(self.sync_tensor, res)
sync_flag = self.sync_all_reduce(sync_tensor)
loss = F.depend(loss, sync_flag)
else:
self.optimizer(grads)
return loss, cond, scaling_sens
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment