Skip to content
Snippets Groups Projects
Unverified Commit 322ef912 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2888 optimize bert for pynative

Merge pull request !2888 from chujinjin/optimize_bert
parents 7face724 d5e20ed3
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.api import ms_function
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
......@@ -277,10 +278,16 @@ class BertTrainOneStepCell(nn.TrainOneStepCell):
self.cast = P.Cast()
self.hyper_map = C.HyperMap()
self.enable_clip_grad = enable_clip_grad
self.enable_tuple_broaden = True
def set_sens(self, value):
self.sens = value
@ms_function
def clip_grads(self, grads):
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
return grads
def construct(self,
input_ids,
input_mask,
......@@ -309,7 +316,7 @@ class BertTrainOneStepCell(nn.TrainOneStepCell):
self.cast(F.tuple_to_array((self.sens,)),
mstype.float32))
if self.enable_clip_grad:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
grads = self.clip_grads(grads)
grads = self.grad_reducer(grads)
self.optimizer(grads)
return loss
......@@ -358,6 +365,12 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
self.enable_tuple_broaden = True
@ms_function
def clip_grads(self, grads):
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
return grads
def construct(self,
input_ids,
......@@ -395,7 +408,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
grads = self.grad_reducer(grads)
degree_sens = self.cast(scaling_sens * self.degree, mstype.float32)
grads = self.hyper_map(F.partial(grad_scale, degree_sens), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
grads = self.clip_grads(grads)
cond = self.get_overflow_status(status, grads)
overflow = cond
......@@ -431,6 +444,12 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell)
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
self.enable_tuple_broaden = True
@ms_function
def clip_grads(self, grads):
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
return grads
def construct(self,
input_ids,
......@@ -468,7 +487,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
grads = self.clip_grads(grads)
cond = self.get_overflow_status(status, grads)
overflow = cond
if self.loss_scaling_manager is not None:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment