Skip to content
Snippets Groups Projects
Commit 2279201d authored by bichaoyang's avatar bichaoyang
Browse files

fix issue of offload

parent 84501432
No related branches found
No related tags found
No related merge requests found
......@@ -40,6 +40,7 @@ LOCAL_DEVICE_NUM=${10}
EXPERT_NUM=${11}
ENABLE_ALLTOALL=${12}
EXPERT_PARALLEL_NUM=${13}
export HCCL_EXEC_TIMEOUT=1500
for((i=0;i<${LOCAL_DEVICE_NUM};i++));
do
......
......@@ -121,8 +121,6 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
else:
self.clip = ClipByGlobalNorm(self.weights, config)
self.cast = P.Cast()
self.sync_all_reduce = P.AllReduce()
self.sync_tensor = Tensor(0.0, dtype=mstype.float32)
def construct(self, input_ids, input_position, attention_mask, layer_past=None, sens=None):
"""Defines the computation performed."""
......@@ -158,12 +156,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
# if not, update weights
if not overflow:
if self.enable_offload:
res = self.optimizer(grads, clip_value)
# For moe and enable_offload, compile time of difference devices have great gap and cause
# notify wait, a sync allreduce at the end of last graph is need
sync_tensor = F.depend(self.sync_tensor, res)
sync_flag = self.sync_all_reduce(sync_tensor)
loss = F.depend(loss, sync_flag)
self.optimizer(grads, clip_value)
else:
self.optimizer(grads)
return loss, cond, scaling_sens
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment