diff --git a/official/nlp/pangu_alpha/src/adam.py b/official/nlp/pangu_alpha/src/adam.py index c24c9ba06d1b9c19b05afffc60d34bc9d39dadfe..49a08135d1a96fe48d8a393e6ec96580a5adc310 100644 --- a/official/nlp/pangu_alpha/src/adam.py +++ b/official/nlp/pangu_alpha/src/adam.py @@ -42,10 +42,10 @@ def _update_run_kernel(opt, clip_value, beta1, beta2, eps, lr, weight_decay, if optim_filter: if decay_flags: next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, - _cpu_div(P.Cast()(gradient, mstype.float16), clip_value)) + P.Cast()(gradient, mstype.float16), clip_value) else: next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0, - _cpu_div(P.Cast()(gradient, mstype.float16), clip_value)) + P.Cast()(gradient, mstype.float16), clip_value) return F.depend(success, next_param) return success