diff --git a/official/cv/centerface/src/centerface.py b/official/cv/centerface/src/centerface.py
index b240b4abd9954cd56207b96e1c01ebcfa8c3376a..f51aaa2eb64edb67aa721fc0330963c8f8146bd7 100644
--- a/official/cv/centerface/src/centerface.py
+++ b/official/cv/centerface/src/centerface.py
@@ -21,16 +21,8 @@ from src.losses import FocalLoss, SmoothL1LossNew, SmoothL1LossNewCMask
 
 import mindspore as ms
 import mindspore.nn as nn
-from mindspore.common.tensor import Tensor
-from mindspore import context
-from mindspore.parallel._auto_parallel_context import auto_parallel_context
-from mindspore.communication.management import get_group_size
 from mindspore.ops import operations as P
-from mindspore.ops import functional as F
 from mindspore.ops import composite as C
-from mindspore.common import dtype as mstype
-from mindspore.ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual
-from mindspore.context import ParallelMode
 
 _grad_scale = C.MultitypeFuncGraph("grad_scale")
 reciprocal = P.Reciprocal()
@@ -192,9 +184,6 @@ class CenterFaceLoss(nn.Cell):
         loss = self.hm_weight * hm_loss + self.wh_weight * wh_loss + \
                self.off_weight * off_loss + self.lm_weight * lm_loss
 
-        # depend is needed when wight_mask and reg_mask is not been used
-        F.depend(loss, F.sqrt(F.cast(wight_mask, mstype.float32)))
-        F.depend(loss, F.sqrt(F.cast(reg_mask, mstype.float32)))
         # add print when you want to see loss detail and do debug
         return loss
 
@@ -217,68 +206,25 @@ class CenterFaceWithLossCell(nn.Cell):
                          hps_mask, landmarks)
         return loss
 
-class TrainingWrapper(nn.Cell):
+class TrainingWrapper(nn.TrainOneStepWithLossScaleCell):
     """
     Training wrapper
     """
     def __init__(self, network, optimizer, sens=1.0):
-        super(TrainingWrapper, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad() #False
-        self.network.add_flags(defer_inline=True)
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
+        scaling_sens = sens
+        if isinstance(scaling_sens, (int, float)):
+            scaling_sens = ms.Tensor(scaling_sens, ms.float32)
+        super(TrainingWrapper, self).__init__(network, optimizer, scaling_sens)
         self.sens = sens
-        self.reducer_flag = False
-        self.grad_reducer = None
-
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            if auto_parallel_context().get_device_num_is_set():
-                degree = context.get_auto_parallel_context("device_num")
-            else:
-                degree = get_group_size()
-            self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
-
-        self.hyper_map = C.HyperMap()
-        if context.get_context("device_target") == "GPU":
-            self.gpu_target = True
-            self.float_status = P.FloatStatus()
-            self.addn = P.AddN()
-            self.reshape = P.Reshape()
-        else:
-            self.gpu_target = False
-            self.alloc_status = NPUAllocFloatStatus()
-            self.get_status = NPUGetFloatStatus()
-            self.clear_status = NPUClearFloatStatus()
-        self.reduce_sum = ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = LessEqual()
-        self.allreduce = P.AllReduce()
-        self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
-
     # x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks
     def construct(self, x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks):
         """
         Construct method.
         """
         weights = self.weights
+        scaling_sens = self.scale_sense
         loss = self.network(x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks)
-
-        init = False
-
-        if not self.gpu_target:
-            # init overflow buffer
-            init = self.alloc_status()
-            init = F.depend(init, loss)
-            # clear overflow buffer
-            clear_status = self.clear_status(init)
-            loss = F.depend(loss, clear_status)
-
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         #sens = sens_input #P.Fill()(P.DType()(loss), P.Shape()(loss), sens_input) # user can contral loss scale by add a sens_input
         sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
         grads = self.grad(self.network, weights)(x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks,
@@ -286,30 +232,11 @@ class TrainingWrapper(nn.Cell):
         #grads = self.hyper_map(F.partial(_grad_scale, sens), grads) # if add this, the loss_scale optimizer is needed to set to 1
         if self.reducer_flag:
             grads = self.grad_reducer(grads)
-
-        if not self.gpu_target:
-            # get the overflow buffer
-            init = F.depend(init, grads)
-
-            get_status = self.get_status(init)
-            init = F.depend(init, get_status)
-            # sum overflow buffer elements, 0:not overflow , >0:overflow
-            flag_sum = self.reduce_sum(init, (0,))
-        else:
-            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
-            flag_sum = self.addn(flag_sum)
-            # convert flag_sum to scalar
-            flag_sum = self.reshape(flag_sum, (()))
-
-        if self.is_distributed:
-            # sum overflow flag over devices
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
         ret = (loss, cond, sens)
-        self.optimizer(grads)
+        if not overflow:
+            self.optimizer(grads)
         return ret
 
 
diff --git a/official/nlp/bert/src/bert_for_finetune.py b/official/nlp/bert/src/bert_for_finetune.py
index 2c1734c66b1752c197070f6b65901bd030a20b5c..12d95755de916c565df383034ecab3e6fe10b4e3 100644
--- a/official/nlp/bert/src/bert_for_finetune.py
+++ b/official/nlp/bert/src/bert_for_finetune.py
@@ -24,10 +24,6 @@ from mindspore.ops import composite as C
 from mindspore.common.tensor import Tensor
 from mindspore.common.parameter import Parameter
 from mindspore.common import dtype as mstype
-from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
-from mindspore.context import ParallelMode
-from mindspore.communication.management import get_group_size
-from mindspore import context
 from .bert_for_pre_training import clip_grad
 from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel
 from .utils import CrossEntropyCalculation
@@ -52,7 +48,7 @@ def _tensor_grad_overflow(grad):
     return grad_overflow(grad)
 
 
-class BertFinetuneCell(nn.Cell):
+class BertFinetuneCell(nn.TrainOneStepWithLossScaleCell):
     """
     Especially defined for finetuning where only four inputs tensor are needed.
 
@@ -68,44 +64,8 @@ class BertFinetuneCell(nn.Cell):
     """
 
     def __init__(self, network, optimizer, scale_update_cell=None):
-
-        super(BertFinetuneCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True,
-                                    sens_param=True)
-        self.reducer_flag = False
-        self.allreduce = P.AllReduce()
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(BertFinetuneCell, self).__init__(network, optimizer, scale_update_cell)
         self.cast = P.Cast()
-        self.gpu_target = False
-        if context.get_context("device_target") == "GPU":
-            self.gpu_target = True
-            self.float_status = P.FloatStatus()
-            self.addn = P.AddN()
-            self.reshape = P.Reshape()
-        else:
-            self.alloc_status = P.NPUAllocFloatStatus()
-            self.get_status = P.NPUGetFloatStatus()
-            self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
 
     def construct(self,
                   input_ids,
@@ -116,21 +76,16 @@ class BertFinetuneCell(nn.Cell):
         """Bert Finetune"""
 
         weights = self.weights
-        init = False
         loss = self.network(input_ids,
                             input_mask,
                             token_type_id,
                             label_ids)
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
 
-        if not self.gpu_target:
-            init = self.alloc_status()
-            init = F.depend(init, loss)
-            clear_status = self.clear_status(init)
-            scaling_sens = F.depend(scaling_sens, clear_status)
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(input_ids,
                                                  input_mask,
                                                  token_type_id,
@@ -141,70 +96,21 @@ class BertFinetuneCell(nn.Cell):
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
         if self.reducer_flag:
             grads = self.grad_reducer(grads)
-        if not self.gpu_target:
-            init = F.depend(init, grads)
-            get_status = self.get_status(init)
-            init = F.depend(init, get_status)
-            flag_sum = self.reduce_sum(init, (0,))
-        else:
-            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
-            flag_sum = self.addn(flag_sum)
-            flag_sum = self.reshape(flag_sum, (()))
-        if self.is_distributed:
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
         if not overflow:
             self.optimizer(grads)
         return (loss, cond)
 
 
-class BertSquadCell(nn.Cell):
+class BertSquadCell(nn.TrainOneStepWithLossScaleCell):
     """
     specifically defined for finetuning where only four inputs tensor are needed.
     """
 
     def __init__(self, network, optimizer, scale_update_cell=None):
-        super(BertSquadCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
-        self.reducer_flag = False
-        self.allreduce = P.AllReduce()
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(BertSquadCell, self).__init__(network, optimizer, scale_update_cell)
         self.cast = P.Cast()
-        self.gpu_target = False
-        if context.get_context("device_target") == "GPU":
-            self.gpu_target = True
-            self.float_status = P.FloatStatus()
-            self.addn = P.AddN()
-            self.reshape = P.Reshape()
-        else:
-            self.alloc_status = P.NPUAllocFloatStatus()
-            self.get_status = P.NPUGetFloatStatus()
-            self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
 
     def construct(self,
                   input_ids,
@@ -217,7 +123,6 @@ class BertSquadCell(nn.Cell):
                   sens=None):
         """BertSquad"""
         weights = self.weights
-        init = False
         loss = self.network(input_ids,
                             input_mask,
                             token_type_id,
@@ -226,14 +131,10 @@ class BertSquadCell(nn.Cell):
                             unique_id,
                             is_impossible)
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
-        if not self.gpu_target:
-            init = self.alloc_status()
-            init = F.depend(init, loss)
-            clear_status = self.clear_status(init)
-            scaling_sens = F.depend(scaling_sens, clear_status)
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(input_ids,
                                                  input_mask,
                                                  token_type_id,
@@ -247,23 +148,8 @@ class BertSquadCell(nn.Cell):
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
         if self.reducer_flag:
             grads = self.grad_reducer(grads)
-        if not self.gpu_target:
-            init = F.depend(init, grads)
-            get_status = self.get_status(init)
-            init = F.depend(init, get_status)
-            flag_sum = self.reduce_sum(init, (0,))
-        else:
-            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
-            flag_sum = self.addn(flag_sum)
-            flag_sum = self.reshape(flag_sum, (()))
-        if self.is_distributed:
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
         if not overflow:
             self.optimizer(grads)
         return (loss, cond)
diff --git a/official/nlp/dgu/src/bert_for_finetune.py b/official/nlp/dgu/src/bert_for_finetune.py
index 265a6bb758476f755659806014be7fc95a9ad47f..92402be8ce024561dcab8ecddb5af4026c8d3795 100644
--- a/official/nlp/dgu/src/bert_for_finetune.py
+++ b/official/nlp/dgu/src/bert_for_finetune.py
@@ -24,10 +24,6 @@ from mindspore.ops import composite as C
 from mindspore.common.tensor import Tensor
 from mindspore.common.parameter import Parameter
 from mindspore.common import dtype as mstype
-from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
-from mindspore.context import ParallelMode
-from mindspore.communication.management import get_group_size
-from mindspore import context
 from .bert_for_pre_training import clip_grad
 from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel
 from .utils import CrossEntropyCalculation
@@ -47,7 +43,7 @@ grad_overflow = P.FloatStatus()
 def _tensor_grad_overflow(grad):
     return grad_overflow(grad)
 
-class BertFinetuneCell(nn.Cell):
+class BertFinetuneCell(nn.TrainOneStepWithLossScaleCell):
     """
     Especially defined for finetuning where only four inputs tensor are needed.
 
@@ -62,44 +58,8 @@ class BertFinetuneCell(nn.Cell):
         scale_update_cell (Cell): Cell to do the loss scale. Default: None.
     """
     def __init__(self, network, optimizer, scale_update_cell=None):
-
-        super(BertFinetuneCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True,
-                                    sens_param=True)
-        self.reducer_flag = False
-        self.allreduce = P.AllReduce()
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(BertFinetuneCell, self).__init__(network, optimizer, scale_update_cell)
         self.cast = P.Cast()
-        self.gpu_target = False
-        if context.get_context("device_target") == "GPU":
-            self.gpu_target = True
-            self.float_status = P.FloatStatus()
-            self.addn = P.AddN()
-            self.reshape = P.Reshape()
-        else:
-            self.alloc_status = P.NPUAllocFloatStatus()
-            self.get_status = P.NPUGetFloatStatus()
-            self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
 
     def construct(self,
                   input_ids,
@@ -110,21 +70,16 @@ class BertFinetuneCell(nn.Cell):
         """Bert Finetune"""
 
         weights = self.weights
-        init = False
         loss = self.network(input_ids,
                             input_mask,
                             token_type_id,
                             label_ids)
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
 
-        if not self.gpu_target:
-            init = self.alloc_status()
-            init = F.depend(init, loss)
-            clear_status = self.clear_status(init)
-            scaling_sens = F.depend(scaling_sens, clear_status)
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(input_ids,
                                                  input_mask,
                                                  token_type_id,
@@ -135,61 +90,19 @@ class BertFinetuneCell(nn.Cell):
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
         if self.reducer_flag:
             grads = self.grad_reducer(grads)
-        if not self.gpu_target:
-            init = F.depend(init, grads)
-            get_status = self.get_status(init)
-            init = F.depend(init, get_status)
-            flag_sum = self.reduce_sum(init, (0,))
-        else:
-            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
-            flag_sum = self.addn(flag_sum)
-            flag_sum = self.reshape(flag_sum, (()))
-        if self.is_distributed:
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
         if not overflow:
             self.optimizer(grads)
         return (loss, cond)
 
-class BertSquadCell(nn.Cell):
+class BertSquadCell(nn.TrainOneStepWithLossScaleCell):
     """
     specifically defined for finetuning where only four inputs tensor are needed.
     """
     def __init__(self, network, optimizer, scale_update_cell=None):
-        super(BertSquadCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
-        self.reducer_flag = False
-        self.allreduce = P.AllReduce()
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(BertSquadCell, self).__init__(network, optimizer, scale_update_cell)
         self.cast = P.Cast()
-        self.alloc_status = P.NPUAllocFloatStatus()
-        self.get_status = P.NPUGetFloatStatus()
-        self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
 
     def construct(self,
                   input_ids,
@@ -202,7 +115,6 @@ class BertSquadCell(nn.Cell):
                   sens=None):
         """BertSquad"""
         weights = self.weights
-        init = self.alloc_status()
         loss = self.network(input_ids,
                             input_mask,
                             token_type_id,
@@ -211,12 +123,10 @@ class BertSquadCell(nn.Cell):
                             unique_id,
                             is_impossible)
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
-        init = F.depend(init, loss)
-        clear_status = self.clear_status(init)
-        scaling_sens = F.depend(scaling_sens, clear_status)
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(input_ids,
                                                  input_mask,
                                                  token_type_id,
@@ -230,18 +140,8 @@ class BertSquadCell(nn.Cell):
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
         if self.reducer_flag:
             grads = self.grad_reducer(grads)
-        init = F.depend(init, grads)
-        get_status = self.get_status(init)
-        init = F.depend(init, get_status)
-        flag_sum = self.reduce_sum(init, (0,))
-        if self.is_distributed:
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
         if not overflow:
             self.optimizer(grads)
         return (loss, cond)
diff --git a/official/nlp/emotect/src/ernie_for_finetune.py b/official/nlp/emotect/src/ernie_for_finetune.py
index 93b6010517ff384c99302485b31eefeb1ca63c51..3d59cb10a83f14284d543c0247d12d2de18e11a3 100644
--- a/official/nlp/emotect/src/ernie_for_finetune.py
+++ b/official/nlp/emotect/src/ernie_for_finetune.py
@@ -21,13 +21,7 @@ import mindspore.nn as nn
 from mindspore.ops import operations as P
 from mindspore.ops import functional as F
 from mindspore.ops import composite as C
-from mindspore.common.tensor import Tensor
-from mindspore.common.parameter import Parameter
 from mindspore.common import dtype as mstype
-from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
-from mindspore.context import ParallelMode
-from mindspore.communication.management import get_group_size
-from mindspore import context
 from .finetune_eval_model import ErnieCLSModel
 from .utils import CrossEntropyCalculation
 
@@ -70,7 +64,7 @@ grad_overflow = P.FloatStatus()
 def _tensor_grad_overflow(grad):
     return grad_overflow(grad)
 
-class ErnieFinetuneCell(nn.Cell):
+class ErnieFinetuneCell(nn.TrainOneStepWithLossScaleCell):
     """
     Especially defined for finetuning where only four inputs tensor are needed.
     Append an optimizer to the training network after that the construct
@@ -82,44 +76,8 @@ class ErnieFinetuneCell(nn.Cell):
         scale_update_cell (Cell): Cell to do the loss scale. Default: None.
     """
     def __init__(self, network, optimizer, scale_update_cell=None):
-
-        super(ErnieFinetuneCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True,
-                                    sens_param=True)
-        self.reducer_flag = False
-        self.allreduce = P.AllReduce()
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(ErnieFinetuneCell, self).__init__(network, optimizer, scale_update_cell)
         self.cast = P.Cast()
-        self.gpu_target = False
-        if context.get_context("device_target") == "GPU":
-            self.gpu_target = True
-            self.float_status = P.FloatStatus()
-            self.addn = P.AddN()
-            self.reshape = P.Reshape()
-        else:
-            self.alloc_status = P.NPUAllocFloatStatus()
-            self.get_status = P.NPUGetFloatStatus()
-            self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
 
     def construct(self,
                   input_ids,
@@ -130,21 +88,16 @@ class ErnieFinetuneCell(nn.Cell):
         """Ernie Finetune"""
 
         weights = self.weights
-        init = False
         loss = self.network(input_ids,
                             input_mask,
                             token_type_id,
                             label_ids)
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
 
-        if not self.gpu_target:
-            init = self.alloc_status()
-            init = F.depend(init, loss)
-            clear_status = self.clear_status(init)
-            scaling_sens = F.depend(scaling_sens, clear_status)
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(input_ids,
                                                  input_mask,
                                                  token_type_id,
@@ -155,23 +108,8 @@ class ErnieFinetuneCell(nn.Cell):
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
         if self.reducer_flag:
             grads = self.grad_reducer(grads)
-        if not self.gpu_target:
-            init = F.depend(init, grads)
-            get_status = self.get_status(init)
-            init = F.depend(init, get_status)
-            flag_sum = self.reduce_sum(init, (0,))
-        else:
-            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
-            flag_sum = self.addn(flag_sum)
-            flag_sum = self.reshape(flag_sum, (()))
-        if self.is_distributed:
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
         if not overflow:
             self.optimizer(grads)
         return (loss, cond)
diff --git a/official/nlp/ernie/src/ernie_for_finetune.py b/official/nlp/ernie/src/ernie_for_finetune.py
index d10a81a2e9e5372336dffeca9e03dd5bcaa9ecf8..5ea20fc61a2c26dab5a4ccee8764946816e0cfc9 100644
--- a/official/nlp/ernie/src/ernie_for_finetune.py
+++ b/official/nlp/ernie/src/ernie_for_finetune.py
@@ -24,10 +24,6 @@ from mindspore.ops import composite as C
 from mindspore.common.tensor import Tensor
 from mindspore.common.parameter import Parameter
 from mindspore.common import dtype as mstype
-from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
-from mindspore.context import ParallelMode
-from mindspore.communication.management import get_group_size
-from mindspore import context
 from src.finetune_eval_model import ErnieCLSModel, ErnieMRCModel, ErnieNERModel
 from src.utils import CrossEntropyCalculation
 
@@ -70,7 +66,7 @@ grad_overflow = P.FloatStatus()
 def _tensor_grad_overflow(grad):
     return grad_overflow(grad)
 
-class ErnieFinetuneCell(nn.Cell):
+class ErnieFinetuneCell(nn.TrainOneStepWithLossScaleCell):
     """
     Especially defined for finetuning where only four inputs tensor are needed.
     Append an optimizer to the training network after that the construct
@@ -82,44 +78,8 @@ class ErnieFinetuneCell(nn.Cell):
         scale_update_cell (Cell): Cell to do the loss scale. Default: None.
     """
     def __init__(self, network, optimizer, scale_update_cell=None):
-
-        super(ErnieFinetuneCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True,
-                                    sens_param=True)
-        self.reducer_flag = False
-        self.allreduce = P.AllReduce()
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(ErnieFinetuneCell, self).__init__(network, optimizer, scale_update_cell)
         self.cast = P.Cast()
-        self.gpu_target = False
-        if context.get_context("device_target") == "GPU":
-            self.gpu_target = True
-            self.float_status = P.FloatStatus()
-            self.addn = P.AddN()
-            self.reshape = P.Reshape()
-        else:
-            self.alloc_status = P.NPUAllocFloatStatus()
-            self.get_status = P.NPUGetFloatStatus()
-            self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
 
     def construct(self,
                   input_ids,
@@ -130,21 +90,16 @@ class ErnieFinetuneCell(nn.Cell):
         """Ernie Finetune"""
 
         weights = self.weights
-        init = False
         loss = self.network(input_ids,
                             input_mask,
                             token_type_id,
                             label_ids)
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
 
-        if not self.gpu_target:
-            init = self.alloc_status()
-            init = F.depend(init, loss)
-            clear_status = self.clear_status(init)
-            scaling_sens = F.depend(scaling_sens, clear_status)
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(input_ids,
                                                  input_mask,
                                                  token_type_id,
@@ -155,29 +110,11 @@ class ErnieFinetuneCell(nn.Cell):
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
         if self.reducer_flag:
             grads = self.grad_reducer(grads)
-        if not self.gpu_target:
-            init = F.depend(init, grads)
-            get_status = self.get_status(init)
-            init = F.depend(init, get_status)
-            flag_sum = self.reduce_sum(init, (0,))
-        else:
-            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
-            flag_sum = self.addn(flag_sum)
-            flag_sum = self.reshape(flag_sum, (()))
-        if self.is_distributed:
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
-        if overflow:
-            succ = False
-        else:
-            succ = self.optimizer(grads)
-        ret = (loss, cond)
-        return F.depend(ret, succ)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
+        if not overflow:
+            self.optimizer(grads)
+        return (loss, cond)
 
 class ErnieCLS(nn.Cell):
     """
@@ -212,40 +149,13 @@ class ErnieNER(nn.Cell):
         loss = self.loss(logits, label_ids, self.num_labels)
         return loss
 
-class ErnieMRCCell(nn.Cell):
+class ErnieMRCCell(nn.TrainOneStepWithLossScaleCell):
     """
     specifically defined for finetuning where only four inputs tensor are needed.
     """
     def __init__(self, network, optimizer, scale_update_cell=None):
-        super(ErnieMRCCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
-        self.reducer_flag = False
-        self.allreduce = P.AllReduce()
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(ErnieMRCCell, self).__init__(network, optimizer, scale_update_cell)
         self.cast = P.Cast()
-        self.alloc_status = P.NPUAllocFloatStatus()
-        self.get_status = P.NPUGetFloatStatus()
-        self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
 
     def construct(self,
                   input_ids,
@@ -257,7 +167,6 @@ class ErnieMRCCell(nn.Cell):
                   sens=None):
         """Ernie MRC"""
         weights = self.weights
-        init = self.alloc_status()
         loss = self.network(input_ids,
                             input_mask,
                             token_type_id,
@@ -265,12 +174,10 @@ class ErnieMRCCell(nn.Cell):
                             end_position,
                             unique_id)
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
-        init = F.depend(init, loss)
-        clear_status = self.clear_status(init)
-        scaling_sens = F.depend(scaling_sens, clear_status)
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(input_ids,
                                                  input_mask,
                                                  token_type_id,
@@ -283,24 +190,11 @@ class ErnieMRCCell(nn.Cell):
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
         if self.reducer_flag:
             grads = self.grad_reducer(grads)
-        init = F.depend(init, grads)
-        get_status = self.get_status(init)
-        init = F.depend(init, get_status)
-        flag_sum = self.reduce_sum(init, (0,))
-        if self.is_distributed:
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
-        if overflow:
-            succ = False
-        else:
-            succ = self.optimizer(grads)
-        ret = (loss, cond)
-        return F.depend(ret, succ)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
+        if not overflow:
+            self.optimizer(grads)
+        return (loss, cond)
 
 class ErnieMRC(nn.Cell):
     '''
diff --git a/official/nlp/mass/src/transformer/transformer_for_train.py b/official/nlp/mass/src/transformer/transformer_for_train.py
index 2164e17c1dc2af3b1dd64b3c6e41c76a3076729b..21a130240498cc1951ca7ebe255d18cf56694e38 100644
--- a/official/nlp/mass/src/transformer/transformer_for_train.py
+++ b/official/nlp/mass/src/transformer/transformer_for_train.py
@@ -14,15 +14,11 @@
 # ============================================================================
 """Transformer for training."""
 from mindspore import nn
-import mindspore.context as context
 from mindspore.ops import operations as P
 from mindspore.ops import functional as F
 from mindspore.ops import composite as C
 from mindspore.common.tensor import Tensor
-from mindspore.common.parameter import Parameter
 from mindspore.common import dtype as mstype
-from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
-from mindspore.context import ParallelMode
 
 from .transformer import Transformer
 from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients
@@ -215,7 +211,7 @@ grad_overflow = P.FloatStatus()
 def _tensor_grad_overflow(grad):
     return grad_overflow(grad)
 
-class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
+class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
     """
     Encapsulation class of Transformer network training.
 
@@ -232,50 +228,9 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
     """
 
     def __init__(self, network, optimizer, scale_update_cell=None):
-
-        super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.network.add_flags(defer_inline=True)
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True,
-                                    sens_param=True)
-        self.reducer_flag = False
-        self.all_reduce = P.AllReduce()
-
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode not in ParallelMode.MODE_LIST:
-            raise ValueError("Parallel mode does not support: ", self.parallel_mode)
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = None
-        if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
-            degree = context.get_auto_parallel_context("device_num")
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(TransformerTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
         self.clip_gradients = ClipGradients()
         self.cast = P.Cast()
-        if context.get_context("device_target") == "GPU":
-            self.gpu_target = True
-            self.float_status = P.FloatStatus()
-            self.addn = P.AddN()
-            self.reshape = P.Reshape()
-        else:
-            self.gpu_target = False
-            self.alloc_status = P.NPUAllocFloatStatus()
-            self.get_status = P.NPUGetFloatStatus()
-            self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
 
     def construct(self,
                   source_eos_ids,
@@ -316,19 +271,11 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
                             label_weights)
 
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
 
-        init = False
-        if not self.gpu_target:
-            # init overflow buffer
-            init = self.alloc_status()
-            init = F.depend(init, loss)
-            # clear overflow buffer
-            clear_status = self.clear_status(init)
-            scaling_sens = F.depend(scaling_sens, clear_status)
-
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(source_ids,
                                                  source_mask,
                                                  target_ids,
@@ -346,28 +293,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
             grads = self.grad_reducer(grads)
 
         # get the overflow buffer
-        if not self.gpu_target:
-            init = F.depend(init, grads)
-            get_status = self.get_status(init)
-            init = F.depend(init, get_status)
-            # sum overflow buffer elements, 0:not overflow , >0:overflow
-            flag_sum = self.reduce_sum(init, (0,))
-        else:
-            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
-            flag_sum = self.addn(flag_sum)
-            # convert flag_sum to scalar
-            flag_sum = self.reshape(flag_sum, (()))
-
-        if self.is_distributed:
-            # Sum overflow flag over devices.
-            flag_reduce = self.all_reduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
         if not overflow:
             self.optimizer(grads)
 
diff --git a/official/nlp/tinybert/src/tinybert_for_gd_td.py b/official/nlp/tinybert/src/tinybert_for_gd_td.py
index c2e8f9f91a32aaddf22cd424b193c50b55fefe3d..b1a70e92b2df3071e40829023d2483b141db73d5 100644
--- a/official/nlp/tinybert/src/tinybert_for_gd_td.py
+++ b/official/nlp/tinybert/src/tinybert_for_gd_td.py
@@ -21,9 +21,7 @@ from mindspore import context
 from mindspore.ops import operations as P
 from mindspore.ops import functional as F
 from mindspore.ops import composite as C
-from mindspore.common.tensor import Tensor
 from mindspore.common import dtype as mstype
-from mindspore.common.parameter import Parameter
 from mindspore.communication.management import get_group_size
 from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
 from mindspore.context import ParallelMode
@@ -200,7 +198,7 @@ class BertNetworkWithLoss_gd(nn.Cell):
             total_loss += rep_loss
         return self.cast(total_loss, mstype.float32)
 
-class BertTrainWithLossScaleCell(nn.Cell):
+class BertTrainWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
     """
     Encapsulation class of bert network training.
 
@@ -213,36 +211,8 @@ class BertTrainWithLossScaleCell(nn.Cell):
         scale_update_cell (Cell): Cell to do the loss scale. Default: None.
     """
     def __init__(self, network, optimizer, scale_update_cell=None):
-        super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True,
-                                    sens_param=True)
-        self.reducer_flag = False
-        self.allreduce = P.AllReduce()
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = F.identity
-        self.degree = 1
-        if self.reducer_flag:
-            self.degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(BertTrainWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
         self.cast = P.Cast()
-        self.alloc_status = P.NPUAllocFloatStatus()
-        self.get_status = P.NPUGetFloatStatus()
-        self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
 
     def construct(self,
                   input_ids,
@@ -255,14 +225,11 @@ class BertTrainWithLossScaleCell(nn.Cell):
                             input_mask,
                             token_type_id)
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
         # alloc status and clear should be right before gradoperation
-        init = self.alloc_status()
-        init = F.depend(init, loss)
-        clear_status = self.clear_status(init)
-        scaling_sens = F.depend(scaling_sens, clear_status)
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(input_ids,
                                                  input_mask,
                                                  token_type_id,
@@ -272,19 +239,8 @@ class BertTrainWithLossScaleCell(nn.Cell):
         grads = self.grad_reducer(grads)
         grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
-        init = F.depend(init, grads)
-        get_status = self.get_status(init)
-        init = F.depend(init, get_status)
-        flag_sum = self.reduce_sum(init, (0,))
-        if self.is_distributed:
-            # sum overflow flag over devices
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
         if not overflow:
             self.optimizer(grads)
         return (loss, cond, scaling_sens)
@@ -468,42 +424,13 @@ class BertNetworkWithLoss_td(nn.Cell):
             total_loss += cls_loss
         return self.cast(total_loss, mstype.float32)
 
-class BertEvaluationWithLossScaleCell(nn.Cell):
+class BertEvaluationWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
     """
     Especially defined for finetuning where only four inputs tensor are needed.
     """
     def __init__(self, network, optimizer, scale_update_cell=None):
-        super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False)
-        self.network = network
-        self.network.set_grad()
-        self.weights = optimizer.parameters
-        self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True,
-                                    sens_param=True)
-        self.reducer_flag = False
-        self.allreduce = P.AllReduce()
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
-            self.reducer_flag = True
-        self.grad_reducer = F.identity
-        self.degree = 1
-        if self.reducer_flag:
-            self.degree = get_group_size()
-            self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
-        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
+        super(BertEvaluationWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
         self.cast = P.Cast()
-        self.alloc_status = P.NPUAllocFloatStatus()
-        self.get_status = P.NPUGetFloatStatus()
-        self.clear_status = P.NPUClearFloatStatus()
-        self.reduce_sum = P.ReduceSum(keep_dims=False)
-        self.base = Tensor(1, mstype.float32)
-        self.less_equal = P.LessEqual()
-        self.hyper_map = C.HyperMap()
-        self.loss_scale = None
-        self.loss_scaling_manager = scale_update_cell
-        if scale_update_cell:
-            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
-
     def construct(self,
                   input_ids,
                   input_mask,
@@ -517,14 +444,11 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
                             token_type_id,
                             label_ids)
         if sens is None:
-            scaling_sens = self.loss_scale
+            scaling_sens = self.scale_sense
         else:
             scaling_sens = sens
         # alloc status and clear should be right before gradoperation
-        init = self.alloc_status()
-        init = F.depend(init, loss)
-        clear_status = self.clear_status(init)
-        scaling_sens = F.depend(scaling_sens, clear_status)
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
         grads = self.grad(self.network, weights)(input_ids,
                                                  input_mask,
                                                  token_type_id,
@@ -535,19 +459,8 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
         grads = self.grad_reducer(grads)
         grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
         grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
-        init = F.depend(init, grads)
-        get_status = self.get_status(init)
-        init = F.depend(init, get_status)
-        flag_sum = self.reduce_sum(init, (0,))
-        if self.is_distributed:
-            # sum overflow flag over devices
-            flag_reduce = self.allreduce(flag_sum)
-            cond = self.less_equal(self.base, flag_reduce)
-        else:
-            cond = self.less_equal(self.base, flag_sum)
-        overflow = cond
-        if sens is None:
-            overflow = self.loss_scaling_manager(self.loss_scale, cond)
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
         if not overflow:
             self.optimizer(grads)
         return (loss, cond, scaling_sens)