diff --git a/official/cv/centerface/src/centerface.py b/official/cv/centerface/src/centerface.py index b240b4abd9954cd56207b96e1c01ebcfa8c3376a..f51aaa2eb64edb67aa721fc0330963c8f8146bd7 100644 --- a/official/cv/centerface/src/centerface.py +++ b/official/cv/centerface/src/centerface.py @@ -21,16 +21,8 @@ from src.losses import FocalLoss, SmoothL1LossNew, SmoothL1LossNewCMask import mindspore as ms import mindspore.nn as nn -from mindspore.common.tensor import Tensor -from mindspore import context -from mindspore.parallel._auto_parallel_context import auto_parallel_context -from mindspore.communication.management import get_group_size from mindspore.ops import operations as P -from mindspore.ops import functional as F from mindspore.ops import composite as C -from mindspore.common import dtype as mstype -from mindspore.ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual -from mindspore.context import ParallelMode _grad_scale = C.MultitypeFuncGraph("grad_scale") reciprocal = P.Reciprocal() @@ -192,9 +184,6 @@ class CenterFaceLoss(nn.Cell): loss = self.hm_weight * hm_loss + self.wh_weight * wh_loss + \ self.off_weight * off_loss + self.lm_weight * lm_loss - # depend is needed when wight_mask and reg_mask is not been used - F.depend(loss, F.sqrt(F.cast(wight_mask, mstype.float32))) - F.depend(loss, F.sqrt(F.cast(reg_mask, mstype.float32))) # add print when you want to see loss detail and do debug return loss @@ -217,68 +206,25 @@ class CenterFaceWithLossCell(nn.Cell): hps_mask, landmarks) return loss -class TrainingWrapper(nn.Cell): +class TrainingWrapper(nn.TrainOneStepWithLossScaleCell): """ Training wrapper """ def __init__(self, network, optimizer, sens=1.0): - super(TrainingWrapper, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() #False - self.network.add_flags(defer_inline=True) - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, sens_param=True) + scaling_sens = sens + if isinstance(scaling_sens, (int, float)): + scaling_sens = ms.Tensor(scaling_sens, ms.float32) + super(TrainingWrapper, self).__init__(network, optimizer, scaling_sens) self.sens = sens - self.reducer_flag = False - self.grad_reducer = None - - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - if auto_parallel_context().get_device_num_is_set(): - degree = context.get_auto_parallel_context("device_num") - else: - degree = get_group_size() - self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) - - self.hyper_map = C.HyperMap() - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.gpu_target = False - self.alloc_status = NPUAllocFloatStatus() - self.get_status = NPUGetFloatStatus() - self.clear_status = NPUClearFloatStatus() - self.reduce_sum = ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = LessEqual() - self.allreduce = P.AllReduce() - self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE - # x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks def construct(self, x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks): """ Construct method. """ weights = self.weights + scaling_sens = self.scale_sense loss = self.network(x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks) - - init = False - - if not self.gpu_target: - # init overflow buffer - init = self.alloc_status() - init = F.depend(init, loss) - # clear overflow buffer - clear_status = self.clear_status(init) - loss = F.depend(loss, clear_status) - + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) #sens = sens_input #P.Fill()(P.DType()(loss), P.Shape()(loss), sens_input) # user can contral loss scale by add a sens_input sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks, @@ -286,30 +232,11 @@ class TrainingWrapper(nn.Cell): #grads = self.hyper_map(F.partial(_grad_scale, sens), grads) # if add this, the loss_scale optimizer is needed to set to 1 if self.reducer_flag: grads = self.grad_reducer(grads) - - if not self.gpu_target: - # get the overflow buffer - init = F.depend(init, grads) - - get_status = self.get_status(init) - init = F.depend(init, get_status) - # sum overflow buffer elements, 0:not overflow , >0:overflow - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - # convert flag_sum to scalar - flag_sum = self.reshape(flag_sum, (())) - - if self.is_distributed: - # sum overflow flag over devices - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) ret = (loss, cond, sens) - self.optimizer(grads) + if not overflow: + self.optimizer(grads) return ret diff --git a/official/nlp/bert/src/bert_for_finetune.py b/official/nlp/bert/src/bert_for_finetune.py index 2c1734c66b1752c197070f6b65901bd030a20b5c..12d95755de916c565df383034ecab3e6fe10b4e3 100644 --- a/official/nlp/bert/src/bert_for_finetune.py +++ b/official/nlp/bert/src/bert_for_finetune.py @@ -24,10 +24,6 @@ from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.context import ParallelMode -from mindspore.communication.management import get_group_size -from mindspore import context from .bert_for_pre_training import clip_grad from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel from .utils import CrossEntropyCalculation @@ -52,7 +48,7 @@ def _tensor_grad_overflow(grad): return grad_overflow(grad) -class BertFinetuneCell(nn.Cell): +class BertFinetuneCell(nn.TrainOneStepWithLossScaleCell): """ Especially defined for finetuning where only four inputs tensor are needed. @@ -68,44 +64,8 @@ class BertFinetuneCell(nn.Cell): """ def __init__(self, network, optimizer, scale_update_cell=None): - - super(BertFinetuneCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(BertFinetuneCell, self).__init__(network, optimizer, scale_update_cell) self.cast = P.Cast() - self.gpu_target = False - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) def construct(self, input_ids, @@ -116,21 +76,16 @@ class BertFinetuneCell(nn.Cell): """Bert Finetune""" weights = self.weights - init = False loss = self.network(input_ids, input_mask, token_type_id, label_ids) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens - if not self.gpu_target: - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -141,70 +96,21 @@ class BertFinetuneCell(nn.Cell): grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) - if not self.gpu_target: - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) return (loss, cond) -class BertSquadCell(nn.Cell): +class BertSquadCell(nn.TrainOneStepWithLossScaleCell): """ specifically defined for finetuning where only four inputs tensor are needed. """ def __init__(self, network, optimizer, scale_update_cell=None): - super(BertSquadCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(BertSquadCell, self).__init__(network, optimizer, scale_update_cell) self.cast = P.Cast() - self.gpu_target = False - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) def construct(self, input_ids, @@ -217,7 +123,6 @@ class BertSquadCell(nn.Cell): sens=None): """BertSquad""" weights = self.weights - init = False loss = self.network(input_ids, input_mask, token_type_id, @@ -226,14 +131,10 @@ class BertSquadCell(nn.Cell): unique_id, is_impossible) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens - if not self.gpu_target: - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -247,23 +148,8 @@ class BertSquadCell(nn.Cell): grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) - if not self.gpu_target: - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) return (loss, cond) diff --git a/official/nlp/dgu/src/bert_for_finetune.py b/official/nlp/dgu/src/bert_for_finetune.py index 265a6bb758476f755659806014be7fc95a9ad47f..92402be8ce024561dcab8ecddb5af4026c8d3795 100644 --- a/official/nlp/dgu/src/bert_for_finetune.py +++ b/official/nlp/dgu/src/bert_for_finetune.py @@ -24,10 +24,6 @@ from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.context import ParallelMode -from mindspore.communication.management import get_group_size -from mindspore import context from .bert_for_pre_training import clip_grad from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel from .utils import CrossEntropyCalculation @@ -47,7 +43,7 @@ grad_overflow = P.FloatStatus() def _tensor_grad_overflow(grad): return grad_overflow(grad) -class BertFinetuneCell(nn.Cell): +class BertFinetuneCell(nn.TrainOneStepWithLossScaleCell): """ Especially defined for finetuning where only four inputs tensor are needed. @@ -62,44 +58,8 @@ class BertFinetuneCell(nn.Cell): scale_update_cell (Cell): Cell to do the loss scale. Default: None. """ def __init__(self, network, optimizer, scale_update_cell=None): - - super(BertFinetuneCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(BertFinetuneCell, self).__init__(network, optimizer, scale_update_cell) self.cast = P.Cast() - self.gpu_target = False - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) def construct(self, input_ids, @@ -110,21 +70,16 @@ class BertFinetuneCell(nn.Cell): """Bert Finetune""" weights = self.weights - init = False loss = self.network(input_ids, input_mask, token_type_id, label_ids) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens - if not self.gpu_target: - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -135,61 +90,19 @@ class BertFinetuneCell(nn.Cell): grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) - if not self.gpu_target: - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) return (loss, cond) -class BertSquadCell(nn.Cell): +class BertSquadCell(nn.TrainOneStepWithLossScaleCell): """ specifically defined for finetuning where only four inputs tensor are needed. """ def __init__(self, network, optimizer, scale_update_cell=None): - super(BertSquadCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(BertSquadCell, self).__init__(network, optimizer, scale_update_cell) self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) def construct(self, input_ids, @@ -202,7 +115,6 @@ class BertSquadCell(nn.Cell): sens=None): """BertSquad""" weights = self.weights - init = self.alloc_status() loss = self.network(input_ids, input_mask, token_type_id, @@ -211,12 +123,10 @@ class BertSquadCell(nn.Cell): unique_id, is_impossible) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -230,18 +140,8 @@ class BertSquadCell(nn.Cell): grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) return (loss, cond) diff --git a/official/nlp/emotect/src/ernie_for_finetune.py b/official/nlp/emotect/src/ernie_for_finetune.py index 93b6010517ff384c99302485b31eefeb1ca63c51..3d59cb10a83f14284d543c0247d12d2de18e11a3 100644 --- a/official/nlp/emotect/src/ernie_for_finetune.py +++ b/official/nlp/emotect/src/ernie_for_finetune.py @@ -21,13 +21,7 @@ import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C -from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.context import ParallelMode -from mindspore.communication.management import get_group_size -from mindspore import context from .finetune_eval_model import ErnieCLSModel from .utils import CrossEntropyCalculation @@ -70,7 +64,7 @@ grad_overflow = P.FloatStatus() def _tensor_grad_overflow(grad): return grad_overflow(grad) -class ErnieFinetuneCell(nn.Cell): +class ErnieFinetuneCell(nn.TrainOneStepWithLossScaleCell): """ Especially defined for finetuning where only four inputs tensor are needed. Append an optimizer to the training network after that the construct @@ -82,44 +76,8 @@ class ErnieFinetuneCell(nn.Cell): scale_update_cell (Cell): Cell to do the loss scale. Default: None. """ def __init__(self, network, optimizer, scale_update_cell=None): - - super(ErnieFinetuneCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(ErnieFinetuneCell, self).__init__(network, optimizer, scale_update_cell) self.cast = P.Cast() - self.gpu_target = False - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) def construct(self, input_ids, @@ -130,21 +88,16 @@ class ErnieFinetuneCell(nn.Cell): """Ernie Finetune""" weights = self.weights - init = False loss = self.network(input_ids, input_mask, token_type_id, label_ids) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens - if not self.gpu_target: - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -155,23 +108,8 @@ class ErnieFinetuneCell(nn.Cell): grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) - if not self.gpu_target: - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) return (loss, cond) diff --git a/official/nlp/ernie/src/ernie_for_finetune.py b/official/nlp/ernie/src/ernie_for_finetune.py index d10a81a2e9e5372336dffeca9e03dd5bcaa9ecf8..5ea20fc61a2c26dab5a4ccee8764946816e0cfc9 100644 --- a/official/nlp/ernie/src/ernie_for_finetune.py +++ b/official/nlp/ernie/src/ernie_for_finetune.py @@ -24,10 +24,6 @@ from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.context import ParallelMode -from mindspore.communication.management import get_group_size -from mindspore import context from src.finetune_eval_model import ErnieCLSModel, ErnieMRCModel, ErnieNERModel from src.utils import CrossEntropyCalculation @@ -70,7 +66,7 @@ grad_overflow = P.FloatStatus() def _tensor_grad_overflow(grad): return grad_overflow(grad) -class ErnieFinetuneCell(nn.Cell): +class ErnieFinetuneCell(nn.TrainOneStepWithLossScaleCell): """ Especially defined for finetuning where only four inputs tensor are needed. Append an optimizer to the training network after that the construct @@ -82,44 +78,8 @@ class ErnieFinetuneCell(nn.Cell): scale_update_cell (Cell): Cell to do the loss scale. Default: None. """ def __init__(self, network, optimizer, scale_update_cell=None): - - super(ErnieFinetuneCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(ErnieFinetuneCell, self).__init__(network, optimizer, scale_update_cell) self.cast = P.Cast() - self.gpu_target = False - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) def construct(self, input_ids, @@ -130,21 +90,16 @@ class ErnieFinetuneCell(nn.Cell): """Ernie Finetune""" weights = self.weights - init = False loss = self.network(input_ids, input_mask, token_type_id, label_ids) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens - if not self.gpu_target: - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -155,29 +110,11 @@ class ErnieFinetuneCell(nn.Cell): grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) - if not self.gpu_target: - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - if overflow: - succ = False - else: - succ = self.optimizer(grads) - ret = (loss, cond) - return F.depend(ret, succ) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) + if not overflow: + self.optimizer(grads) + return (loss, cond) class ErnieCLS(nn.Cell): """ @@ -212,40 +149,13 @@ class ErnieNER(nn.Cell): loss = self.loss(logits, label_ids, self.num_labels) return loss -class ErnieMRCCell(nn.Cell): +class ErnieMRCCell(nn.TrainOneStepWithLossScaleCell): """ specifically defined for finetuning where only four inputs tensor are needed. """ def __init__(self, network, optimizer, scale_update_cell=None): - super(ErnieMRCCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(ErnieMRCCell, self).__init__(network, optimizer, scale_update_cell) self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) def construct(self, input_ids, @@ -257,7 +167,6 @@ class ErnieMRCCell(nn.Cell): sens=None): """Ernie MRC""" weights = self.weights - init = self.alloc_status() loss = self.network(input_ids, input_mask, token_type_id, @@ -265,12 +174,10 @@ class ErnieMRCCell(nn.Cell): end_position, unique_id) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -283,24 +190,11 @@ class ErnieMRCCell(nn.Cell): grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - if overflow: - succ = False - else: - succ = self.optimizer(grads) - ret = (loss, cond) - return F.depend(ret, succ) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) + if not overflow: + self.optimizer(grads) + return (loss, cond) class ErnieMRC(nn.Cell): ''' diff --git a/official/nlp/mass/src/transformer/transformer_for_train.py b/official/nlp/mass/src/transformer/transformer_for_train.py index 2164e17c1dc2af3b1dd64b3c6e41c76a3076729b..21a130240498cc1951ca7ebe255d18cf56694e38 100644 --- a/official/nlp/mass/src/transformer/transformer_for_train.py +++ b/official/nlp/mass/src/transformer/transformer_for_train.py @@ -14,15 +14,11 @@ # ============================================================================ """Transformer for training.""" from mindspore import nn -import mindspore.context as context from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter from mindspore.common import dtype as mstype -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.context import ParallelMode from .transformer import Transformer from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients @@ -215,7 +211,7 @@ grad_overflow = P.FloatStatus() def _tensor_grad_overflow(grad): return grad_overflow(grad) -class TransformerTrainOneStepWithLossScaleCell(nn.Cell): +class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): """ Encapsulation class of Transformer network training. @@ -232,50 +228,9 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): """ def __init__(self, network, optimizer, scale_update_cell=None): - - super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.network.add_flags(defer_inline=True) - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.all_reduce = P.AllReduce() - - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode not in ParallelMode.MODE_LIST: - raise ValueError("Parallel mode does not support: ", self.parallel_mode) - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("gradients_mean") - degree = context.get_auto_parallel_context("device_num") - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(TransformerTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) self.clip_gradients = ClipGradients() self.cast = P.Cast() - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.gpu_target = False - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) def construct(self, source_eos_ids, @@ -316,19 +271,11 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): label_weights) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens - init = False - if not self.gpu_target: - # init overflow buffer - init = self.alloc_status() - init = F.depend(init, loss) - # clear overflow buffer - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) - + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(source_ids, source_mask, target_ids, @@ -346,28 +293,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) # get the overflow buffer - if not self.gpu_target: - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - # sum overflow buffer elements, 0:not overflow , >0:overflow - flag_sum = self.reduce_sum(init, (0,)) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - # convert flag_sum to scalar - flag_sum = self.reshape(flag_sum, (())) - - if self.is_distributed: - # Sum overflow flag over devices. - flag_reduce = self.all_reduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) diff --git a/official/nlp/tinybert/src/tinybert_for_gd_td.py b/official/nlp/tinybert/src/tinybert_for_gd_td.py index c2e8f9f91a32aaddf22cd424b193c50b55fefe3d..b1a70e92b2df3071e40829023d2483b141db73d5 100644 --- a/official/nlp/tinybert/src/tinybert_for_gd_td.py +++ b/official/nlp/tinybert/src/tinybert_for_gd_td.py @@ -21,9 +21,7 @@ from mindspore import context from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C -from mindspore.common.tensor import Tensor from mindspore.common import dtype as mstype -from mindspore.common.parameter import Parameter from mindspore.communication.management import get_group_size from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.context import ParallelMode @@ -200,7 +198,7 @@ class BertNetworkWithLoss_gd(nn.Cell): total_loss += rep_loss return self.cast(total_loss, mstype.float32) -class BertTrainWithLossScaleCell(nn.Cell): +class BertTrainWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): """ Encapsulation class of bert network training. @@ -213,36 +211,8 @@ class BertTrainWithLossScaleCell(nn.Cell): scale_update_cell (Cell): Cell to do the loss scale. Default: None. """ def __init__(self, network, optimizer, scale_update_cell=None): - super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = F.identity - self.degree = 1 - if self.reducer_flag: - self.degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(BertTrainWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) def construct(self, input_ids, @@ -255,14 +225,11 @@ class BertTrainWithLossScaleCell(nn.Cell): input_mask, token_type_id) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens # alloc status and clear should be right before gradoperation - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -272,19 +239,8 @@ class BertTrainWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - if self.is_distributed: - # sum overflow flag over devices - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) return (loss, cond, scaling_sens) @@ -468,42 +424,13 @@ class BertNetworkWithLoss_td(nn.Cell): total_loss += cls_loss return self.cast(total_loss, mstype.float32) -class BertEvaluationWithLossScaleCell(nn.Cell): +class BertEvaluationWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): """ Especially defined for finetuning where only four inputs tensor are needed. """ def __init__(self, network, optimizer, scale_update_cell=None): - super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.grad = C.GradOperation(get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = F.identity - self.degree = 1 - if self.reducer_flag: - self.degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + super(BertEvaluationWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell) self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_status = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - def construct(self, input_ids, input_mask, @@ -517,14 +444,11 @@ class BertEvaluationWithLossScaleCell(nn.Cell): token_type_id, label_ids) if sens is None: - scaling_sens = self.loss_scale + scaling_sens = self.scale_sense else: scaling_sens = sens # alloc status and clear should be right before gradoperation - init = self.alloc_status() - init = F.depend(init, loss) - clear_status = self.clear_status(init) - scaling_sens = F.depend(scaling_sens, clear_status) + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -535,19 +459,8 @@ class BertEvaluationWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - init = F.depend(init, grads) - get_status = self.get_status(init) - init = F.depend(init, get_status) - flag_sum = self.reduce_sum(init, (0,)) - if self.is_distributed: - # sum overflow flag over devices - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) return (loss, cond, scaling_sens)