diff --git a/official/nlp/pangu_alpha/src/adam.py b/official/nlp/pangu_alpha/src/adam.py
index c24c9ba06d1b9c19b05afffc60d34bc9d39dadfe..49a08135d1a96fe48d8a393e6ec96580a5adc310 100644
--- a/official/nlp/pangu_alpha/src/adam.py
+++ b/official/nlp/pangu_alpha/src/adam.py
@@ -42,10 +42,10 @@ def _update_run_kernel(opt, clip_value, beta1, beta2, eps, lr, weight_decay,
     if optim_filter:
         if decay_flags:
             next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay,
-                             _cpu_div(P.Cast()(gradient, mstype.float16), clip_value))
+                             P.Cast()(gradient, mstype.float16), clip_value)
         else:
             next_param = opt(param, m, v, lr, beta1, beta2, eps, 0.0,
-                             _cpu_div(P.Cast()(gradient, mstype.float16), clip_value))
+                             P.Cast()(gradient, mstype.float16), clip_value)
         return F.depend(success, next_param)
     return success