diff --git a/official/cv/pvnet/src/dataset.py b/official/cv/pvnet/src/dataset.py
index bea092778ef18bf209d779df91da353adfac7f69..31b7bcb783969673b646396650f36d6caf30544e 100644
--- a/official/cv/pvnet/src/dataset.py
+++ b/official/cv/pvnet/src/dataset.py
@@ -241,10 +241,10 @@ def create_dataset(cls_list, batch_size=16, workers=16, devices=1, rank=0, multi
         CV.RandomColorAdjust(
             cfg.brightness, cfg.contrast,
             cfg.saturation, cfg.hue),
-        C.ToTensor(),  # 0~255 HWC to 0~1 CHW
+        CV.ToTensor(),  # 0~255 HWC to 0~1 CHW
         C.TypeCast(mstype.float32),
         # Computed from random subset of ImageNet training images
-        C.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), is_hwc=False),
+        CV.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), is_hwc=False),
     ])
 
     mask_transforms = [
diff --git a/official/cv/pvnet/train.py b/official/cv/pvnet/train.py
index 5bc18d1b9bdd47508bfad74c66acf02089a0dc98..65d0d7c7e088431677aee30a7d7b762eb2b5b0b2 100644
--- a/official/cv/pvnet/train.py
+++ b/official/cv/pvnet/train.py
@@ -49,7 +49,7 @@ class Train:
             rank=cfg.rank
         )
         self.current_dir = os.path.dirname(os.path.abspath(__file__))
-        if cfg.pretrained_path is None:
+        if str(cfg.pretrained_path).lower() == 'none':
             self.pretrained_path = None
         else:
             self.pretrained_path = os.path.join(self.current_dir, cfg.pretrained_path)
diff --git a/official/nlp/transformer/src/transformer_for_train.py b/official/nlp/transformer/src/transformer_for_train.py
index 42c4f87489322f549a46bf76b58aca5625237304..cbfccc881cb9907811f314529e5f98fc47174005 100644
--- a/official/nlp/transformer/src/transformer_for_train.py
+++ b/official/nlp/transformer/src/transformer_for_train.py
@@ -14,6 +14,7 @@
 # ============================================================================
 """Transformer for training."""
 import numpy as np
+from mindspore import ms_function
 
 from mindspore.common.initializer import initializer
 import mindspore as ms
@@ -147,10 +148,16 @@ class TransformerTrainOneStepCell(nn.TrainOneStepCell):
 
         self.cast = ops.Cast()
         self.hyper_map = ops.HyperMap()
+        self.enable_tuple_broaden = True
 
     def set_sens(self, value):
         self.sens = value
 
+    @ms_function
+    def clip_grads(self, grads):
+        grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
+        return grads
+
     def construct(self,
                   source_eos_ids,
                   source_eos_mask,
@@ -181,7 +188,7 @@ class TransformerTrainOneStepCell(nn.TrainOneStepCell):
                                                  label_weights,
                                                  self.cast(ops.tuple_to_array((self.sens,)),
                                                            ms.float32))
-        grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
+        grads = self.clip_grads(grads)
         # apply grad reducer on grads
         grads = self.grad_reducer(grads)
         self.optimizer(grads)
@@ -227,6 +234,17 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
         self.loss_scaling_manager = scale_update_cell
         if scale_update_cell:
             self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=ms.float32))
+        self.enable_tuple_broaden = True
+
+    @ms_function
+    def clip_grads(self, grads):
+        grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
+        return grads
+
+    @ms_function
+    def clip_scale_grads(self, scale, grads):
+        grads = self.hyper_map(ops.partial(grad_scale, scale * self.degree), grads)
+        return grads
 
     def construct(self,
                   source_eos_ids,
@@ -267,8 +285,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell)
 
         # apply grad reducer on grads
         grads = self.grad_reducer(grads)
-        grads = self.hyper_map(ops.partial(grad_scale, scaling_sens * self.degree), grads)
-        grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
+        grads = self.clip_scale_grads(scaling_sens, grads)
+        grads = self.clip_grads(grads)
 
         cond = self.get_overflow_status(status, grads)
         overflow = cond
@@ -377,6 +395,26 @@ class TransformerTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
         self.loss_scaling_manager = scale_update_cell
         if scale_update_cell:
             self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=ms.float32))
+        self.enable_tuple_broaden = True
+
+    @ms_function
+    def clip_grads(self, grads):
+        grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
+        return grads
+
+    @ms_function
+    def clip_scale_grads(self, scale, grads):
+        grads = self.hyper_map(ops.partial(grad_scale, scale * self.degree), grads)
+        return grads
+
+    @ms_function
+    def clip_accumlate_hyper_map(self, grads):
+        return self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
+
+    @ms_function
+    def clip_reset_hyper_map(self):
+        return self.hyper_map(reset_accu_grads, self.accu_grads)
+
 
     def construct(self,
                   source_eos_ids,
@@ -426,7 +464,7 @@ class TransformerTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
                                                  self.cast(scaling_sens,
                                                            ms.float32))
 
-        accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
+        accu_succ = self.clip_accumlate_hyper_map(grads)
         mean_loss = ops.depend(mean_loss, accu_succ)
 
         init = ops.depend(init, mean_loss)
@@ -442,15 +480,15 @@ class TransformerTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
             # apply grad reducer on grads
             grads = self.grad_reducer(self.accu_grads)
             scaling = scaling_sens * self.degree * self.accumulation_steps
-            grads = self.hyper_map(ops.partial(grad_scale, scaling), grads)
+            grads = self.clip_scale_grads(scaling, grads)
             if self.enable_global_norm:
                 grads = ops.clip_by_global_norm(grads, 1.0, None)
             else:
-                grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
+                grads = self.clip_grads(grads)
             accu_overflow = ops.depend(accu_overflow, grads)
             accu_overflow = self.overflow_reducer(accu_overflow)
             overflow = self.less_equal(self.base, accu_overflow)
-            accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
+            accu_succ = self.clip_reset_hyper_map
             overflow = ops.depend(overflow, accu_succ)
             overflow = self.reshape(overflow, (()))
             if sens is None: