diff --git a/official/cv/retinaface_resnet50/eval.py b/official/cv/retinaface_resnet50/eval.py
index f9c487d1a5620707a37df33144fbebaf61f30415..6e79c5f29e3e6ce5c79e9996f35ce97a97a28905 100644
--- a/official/cv/retinaface_resnet50/eval.py
+++ b/official/cv/retinaface_resnet50/eval.py
@@ -20,9 +20,9 @@ import datetime
 import numpy as np
 import cv2
 
-from mindspore import Tensor, context
+import mindspore as ms
+from mindspore import Tensor
 from mindspore.common import set_seed
-from mindspore.train.serialization import load_checkpoint, load_param_into_net
 
 from src.config import cfg_res50
 from src.network import RetinaFace, resnet50
@@ -296,7 +296,7 @@ class DetectionEngine:
 
 
 def val():
-    context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
+    ms.set_context(mode=ms.GRAPH_MODE, device_target='GPU', save_graphs=False)
 
     cfg = cfg_res50
 
@@ -307,10 +307,10 @@ def val():
 
     # load checkpoint
     assert cfg['val_model'] is not None, 'val_model is None.'
-    param_dict = load_checkpoint(cfg['val_model'])
+    param_dict = ms.load_checkpoint(cfg['val_model'])
     print('Load trained model done. {}'.format(cfg['val_model']))
     network.init_parameters_data()
-    load_param_into_net(network, param_dict)
+    ms.load_param_into_net(network, param_dict)
 
     # testing dataset
     testset_folder = cfg['val_dataset_folder']
diff --git a/official/cv/retinaface_resnet50/src/loss.py b/official/cv/retinaface_resnet50/src/loss.py
index 6865a8a3186dfec8d2f5165ec34aea7961202504..01fc8a34c60bc1a66fa3b63793c9039d365f5095 100644
--- a/official/cv/retinaface_resnet50/src/loss.py
+++ b/official/cv/retinaface_resnet50/src/loss.py
@@ -14,27 +14,25 @@
 # ============================================================================
 """Loss."""
 import numpy as np
-import mindspore.common.dtype as mstype
 import mindspore as ms
 import mindspore.nn as nn
-from mindspore.ops import operations as P
-from mindspore.ops import functional as F
+import mindspore.ops as ops
 from mindspore import Tensor
 
 
 class SoftmaxCrossEntropyWithLogits(nn.Cell):
     def __init__(self):
         super(SoftmaxCrossEntropyWithLogits, self).__init__()
-        self.log_softmax = P.LogSoftmax()
-        self.neg = P.Neg()
-        self.one_hot = P.OneHot()
-        self.on_value = Tensor(1.0, mstype.float32)
-        self.off_value = Tensor(0.0, mstype.float32)
-        self.reduce_sum = P.ReduceSum()
+        self.log_softmax = ops.LogSoftmax()
+        self.neg = ops.Neg()
+        self.one_hot = ops.OneHot()
+        self.on_value = Tensor(1.0, ms.float32)
+        self.off_value = Tensor(0.0, ms.float32)
+        self.reduce_sum = ops.ReduceSum()
 
     def construct(self, logits, labels):
         prob = self.log_softmax(logits)
-        labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value)
+        labels = self.one_hot(labels, ops.shape(logits)[-1], self.on_value, self.off_value)
 
         return self.neg(self.reduce_sum(prob * labels, 1))
 
@@ -45,30 +43,30 @@ class MultiBoxLoss(nn.Cell):
         self.num_classes = num_classes
         self.num_boxes = num_boxes
         self.neg_pre_positive = neg_pre_positive
-        self.notequal = P.NotEqual()
-        self.less = P.Less()
-        self.tile = P.Tile()
-        self.reduce_sum = P.ReduceSum()
-        self.reduce_mean = P.ReduceMean()
-        self.expand_dims = P.ExpandDims()
-        self.smooth_l1_loss = P.SmoothL1Loss()
+        self.notequal = ops.NotEqual()
+        self.less = ops.Less()
+        self.tile = ops.Tile()
+        self.reduce_sum = ops.ReduceSum()
+        self.reduce_mean = ops.ReduceMean()
+        self.expand_dims = ops.ExpandDims()
+        self.smooth_l1_loss = ops.SmoothL1Loss()
         self.cross_entropy = SoftmaxCrossEntropyWithLogits()
-        self.maximum = P.Maximum()
-        self.minimum = P.Minimum()
-        self.sort_descend = P.TopK(True)
-        self.sort = P.TopK(True)
-        self.gather = P.GatherNd()
-        self.max = P.ReduceMax()
-        self.log = P.Log()
-        self.exp = P.Exp()
-        self.concat = P.Concat(axis=1)
-        self.reduce_sum2 = P.ReduceSum(keep_dims=True)
+        self.maximum = ops.Maximum()
+        self.minimum = ops.Minimum()
+        self.sort_descend = ops.TopK(True)
+        self.sort = ops.TopK(True)
+        self.gather = ops.GatherNd()
+        self.max = ops.ReduceMax()
+        self.log = ops.Log()
+        self.exp = ops.Exp()
+        self.concat = ops.Concat(axis=1)
+        self.reduce_sum2 = ops.ReduceSum(keep_dims=True)
         self.idx = Tensor(np.reshape(np.arange(batch_size * num_boxes), (-1, 1)), ms.int32)
 
     def construct(self, loc_data, loc_t, conf_data, conf_t, landm_data, landm_t):
 
         # landm loss
-        mask_pos1 = F.cast(self.less(0.0, F.cast(conf_t, mstype.float32)), mstype.float32)
+        mask_pos1 = ops.cast(self.less(0.0, ops.cast(conf_t, ms.float32)), ms.float32)
 
         N1 = self.maximum(self.reduce_sum(mask_pos1), 1)
         mask_pos_idx1 = self.tile(self.expand_dims(mask_pos1, -1), (1, 1, 10))
@@ -76,8 +74,8 @@ class MultiBoxLoss(nn.Cell):
         loss_landm = loss_landm / N1
 
         # Localization Loss
-        mask_pos = F.cast(self.notequal(0, conf_t), mstype.float32)
-        conf_t = F.cast(mask_pos, mstype.int32)
+        mask_pos = ops.cast(self.notequal(0, conf_t), ms.float32)
+        conf_t = ops.cast(mask_pos, ms.int32)
 
         N = self.maximum(self.reduce_sum(mask_pos), 1)
         mask_pos_idx = self.tile(self.expand_dims(mask_pos, -1), (1, 1, 4))
@@ -85,32 +83,32 @@ class MultiBoxLoss(nn.Cell):
         loss_l = loss_l / N
 
         # Conf Loss
-        conf_t_shape = F.shape(conf_t)
-        conf_t = F.reshape(conf_t, (-1,))
-        indices = self.concat((self.idx, F.reshape(conf_t, (-1, 1))))
+        conf_t_shape = ops.shape(conf_t)
+        conf_t = ops.reshape(conf_t, (-1,))
+        indices = self.concat((self.idx, ops.reshape(conf_t, (-1, 1))))
 
-        batch_conf = F.reshape(conf_data, (-1, self.num_classes))
+        batch_conf = ops.reshape(conf_data, (-1, self.num_classes))
         x_max = self.max(batch_conf)
         loss_c = self.log(self.reduce_sum2(self.exp(batch_conf - x_max), 1)) + x_max
-        loss_c = loss_c - F.reshape(self.gather(batch_conf, indices), (-1, 1))
-        loss_c = F.reshape(loss_c, conf_t_shape)
+        loss_c = loss_c - ops.reshape(self.gather(batch_conf, indices), (-1, 1))
+        loss_c = ops.reshape(loss_c, conf_t_shape)
 
         # hard example mining
-        num_matched_boxes = F.reshape(self.reduce_sum(mask_pos, 1), (-1,))
-        neg_masked_cross_entropy = F.cast(loss_c * (1 - mask_pos), mstype.float32)
+        num_matched_boxes = ops.reshape(self.reduce_sum(mask_pos, 1), (-1,))
+        neg_masked_cross_entropy = ops.cast(loss_c * (1 - mask_pos), ms.float32)
 
         _, loss_idx = self.sort_descend(neg_masked_cross_entropy, self.num_boxes)
-        _, relative_position = self.sort(F.cast(loss_idx, mstype.float32), self.num_boxes)
-        relative_position = F.cast(relative_position, mstype.float32)
+        _, relative_position = self.sort(ops.cast(loss_idx, ms.float32), self.num_boxes)
+        relative_position = ops.cast(relative_position, ms.float32)
         relative_position = relative_position[:, ::-1]
-        relative_position = F.cast(relative_position, mstype.int32)
+        relative_position = ops.cast(relative_position, ms.int32)
 
         num_neg_boxes = self.minimum(num_matched_boxes * self.neg_pre_positive, self.num_boxes - 1)
         tile_num_neg_boxes = self.tile(self.expand_dims(num_neg_boxes, -1), (1, self.num_boxes))
-        top_k_neg_mask = F.cast(self.less(relative_position, tile_num_neg_boxes), mstype.float32)
+        top_k_neg_mask = ops.cast(self.less(relative_position, tile_num_neg_boxes), ms.float32)
 
         cross_entropy = self.cross_entropy(batch_conf, conf_t)
-        cross_entropy = F.reshape(cross_entropy, conf_t_shape)
+        cross_entropy = ops.reshape(cross_entropy, conf_t_shape)
 
         loss_c = self.reduce_sum(cross_entropy * self.minimum(mask_pos + top_k_neg_mask, 1))
 
diff --git a/official/cv/retinaface_resnet50/src/network.py b/official/cv/retinaface_resnet50/src/network.py
index 568189dc4d64e056ffc48fc1e76678eeb61ac11a..b417e615d638b9e27abb00cd32063a27109cd17b 100644
--- a/official/cv/retinaface_resnet50/src/network.py
+++ b/official/cv/retinaface_resnet50/src/network.py
@@ -17,11 +17,10 @@ import math
 from functools import reduce
 import numpy as np
 
-import mindspore
+import mindspore as ms
 import mindspore.nn as nn
-from mindspore.ops import operations as P
-from mindspore.ops import composite as C
-from mindspore import context, Tensor
+import mindspore.ops as ops
+from mindspore import Tensor
 from mindspore.parallel._auto_parallel_context import auto_parallel_context
 from mindspore.communication.management import get_group_size
 
@@ -94,7 +93,7 @@ class ResidualBlock(nn.Cell):
         if self.down_sample:
             self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
                                                         _bn(out_channel)])
-        self.add = P.Add()
+        self.add = ops.Add()
 
     def construct(self, x):
         identity = x
@@ -133,10 +132,10 @@ class ResNet(nn.Cell):
 
         self.conv1 = _conv7x7(3, 64, stride=2)
         self.bn1 = _bn(64)
-        self.relu = P.ReLU()
+        self.relu = ops.ReLU()
 
 
-        self.pad = P.Pad(((0, 0), (0, 0), (1, 0), (1, 0)))
+        self.pad = ops.Pad(((0, 0), (0, 0), (1, 0), (1, 0)))
         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
 
 
@@ -161,7 +160,7 @@ class ResNet(nn.Cell):
                                        out_channel=out_channels[3],
                                        stride=strides[3])
 
-        self.mean = P.ReduceMean(keep_dims=True)
+        self.mean = ops.ReduceMean(keep_dims=True)
         self.flatten = nn.Flatten()
         self.end_point = _fc(out_channels[3], num_classes)
 
@@ -302,7 +301,7 @@ class SSH(nn.Cell):
         self.conv7X7_3 = ConvBN(out_channel // 4, out_channel // 4, kernel_size=3, stride=1, padding=1, groups=1,
                                 norm_layer=norm_layer)
 
-        self.cat = P.Concat(axis=1)
+        self.cat = ops.Concat(axis=1)
         self.relu = nn.ReLU()
 
     def construct(self, x):
@@ -344,11 +343,11 @@ class FPN(nn.Cell):
         output2 = self.output2(input2)
         output3 = self.output3(input3)
 
-        up3 = P.ResizeNearestNeighbor([P.Shape()(output2)[2], P.Shape()(output2)[3]])(output3)
+        up3 = ops.ResizeNearestNeighbor([ops.Shape()(output2)[2], ops.Shape()(output2)[3]])(output3)
         output2 = up3 + output2
         output2 = self.merge2(output2)
 
-        up2 = P.ResizeNearestNeighbor([P.Shape()(output1)[2], P.Shape()(output1)[3]])(output2)
+        up2 = ops.ResizeNearestNeighbor([ops.Shape()(output1)[2], ops.Shape()(output1)[3]])(output2)
         output1 = up2 + output1
         output1 = self.merge1(output1)
 
@@ -364,13 +363,13 @@ class ClassHead(nn.Cell):
         self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0,
                                  has_bias=True, weight_init=kaiming_weight, bias_init=kaiming_bias)
 
-        self.permute = P.Transpose()
-        self.reshape = P.Reshape()
+        self.permute = ops.Transpose()
+        self.reshape = ops.Reshape()
 
     def construct(self, x):
         out = self.conv1x1(x)
         out = self.permute(out, (0, 2, 3, 1))
-        return self.reshape(out, (P.Shape()(out)[0], -1, 2))
+        return self.reshape(out, (ops.Shape()(out)[0], -1, 2))
 
 class BboxHead(nn.Cell):
     def __init__(self, inchannels=512, num_anchors=3):
@@ -381,13 +380,13 @@ class BboxHead(nn.Cell):
         self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0, has_bias=True,
                                  weight_init=kaiming_weight, bias_init=kaiming_bias)
 
-        self.permute = P.Transpose()
-        self.reshape = P.Reshape()
+        self.permute = ops.Transpose()
+        self.reshape = ops.Reshape()
 
     def construct(self, x):
         out = self.conv1x1(x)
         out = self.permute(out, (0, 2, 3, 1))
-        return self.reshape(out, (P.Shape()(out)[0], -1, 4))
+        return self.reshape(out, (ops.Shape()(out)[0], -1, 4))
 
 class LandmarkHead(nn.Cell):
     def __init__(self, inchannels=512, num_anchors=3):
@@ -398,13 +397,13 @@ class LandmarkHead(nn.Cell):
         self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0, has_bias=True,
                                  weight_init=kaiming_weight, bias_init=kaiming_bias)
 
-        self.permute = P.Transpose()
-        self.reshape = P.Reshape()
+        self.permute = ops.Transpose()
+        self.reshape = ops.Reshape()
 
     def construct(self, x):
         out = self.conv1x1(x)
         out = self.permute(out, (0, 2, 3, 1))
-        return self.reshape(out, (P.Shape()(out)[0], -1, 10))
+        return self.reshape(out, (ops.Shape()(out)[0], -1, 10))
 
 class RetinaFace(nn.Cell):
     def __init__(self, phase='train', backbone=None):
@@ -424,7 +423,7 @@ class RetinaFace(nn.Cell):
         self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=[256, 256, 256], anchor_num=[2, 2, 2])
         self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=[256, 256, 256], anchor_num=[2, 2, 2])
 
-        self.cat = P.Concat(axis=1)
+        self.cat = ops.Concat(axis=1)
 
     def _make_class_head(self, fpn_num, inchannels, anchor_num):
         classhead = nn.CellList()
@@ -474,7 +473,7 @@ class RetinaFace(nn.Cell):
         if self.phase == 'train':
             output = (bbox_regressions, classifications, ldm_regressions)
         else:
-            output = (bbox_regressions, P.Softmax(-1)(classifications), ldm_regressions)
+            output = (bbox_regressions, ops.Softmax(-1)(classifications), ldm_regressions)
 
         return output
 
@@ -497,20 +496,20 @@ class TrainingWrapper(nn.Cell):
     def __init__(self, network, optimizer, sens=1.0):
         super(TrainingWrapper, self).__init__(auto_prefix=False)
         self.network = network
-        self.weights = mindspore.ParameterTuple(network.trainable_params())
+        self.weights = ms.ParameterTuple(network.trainable_params())
         self.optimizer = optimizer
-        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
+        self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
         self.sens = sens
         self.reducer_flag = False
         self.grad_reducer = None
-        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
-        class_list = [mindspore.context.ParallelMode.DATA_PARALLEL, mindspore.context.ParallelMode.HYBRID_PARALLEL]
+        self.parallel_mode = ms.get_auto_parallel_context("parallel_mode")
+        class_list = [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]
         if self.parallel_mode in class_list:
             self.reducer_flag = True
         if self.reducer_flag:
-            mean = context.get_auto_parallel_context("gradients_mean")
+            mean = ms.get_auto_parallel_context("gradients_mean")
             if auto_parallel_context().get_device_num_is_set():
-                degree = context.get_auto_parallel_context("device_num")
+                degree = ms.get_auto_parallel_context("device_num")
             else:
                 degree = get_group_size()
             self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
@@ -518,7 +517,7 @@ class TrainingWrapper(nn.Cell):
     def construct(self, *args):
         weights = self.weights
         loss = self.network(*args)
-        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
+        sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
         grads = self.grad(self.network, weights)(*args, sens)
         if self.reducer_flag:
             # apply grad reducer on grads
diff --git a/official/cv/retinaface_resnet50/train.py b/official/cv/retinaface_resnet50/train.py
index c3c04e1c4a4cbdce37f436eec8b655ffaa2d4690..e318bd377a14716be3ed2c6c685d8f2b0a37a18e 100644
--- a/official/cv/retinaface_resnet50/train.py
+++ b/official/cv/retinaface_resnet50/train.py
@@ -16,14 +16,12 @@
 from __future__ import print_function
 import math
 import argparse
-import mindspore
+import mindspore as ms
 
-from mindspore import context
 from mindspore.context import ParallelMode
 from mindspore.train import Model
 from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
 from mindspore.communication.management import init, get_rank, get_group_size
-from mindspore.train.serialization import load_checkpoint, load_param_into_net
 
 from src.config import cfg_res50
 from src.network import RetinaFace, RetinaFaceWithLossCell, TrainingWrapper, resnet50
@@ -33,14 +31,14 @@ from src.lr_schedule import adjust_learning_rate
 
 def train(cfg, args):
 
-    context.set_context(mode=context.GRAPH_MODE, device_target='GPU', save_graphs=False)
-    if context.get_context("device_target") == "GPU":
+    ms.set_context(mode=ms.GRAPH_MODE, device_target='GPU', save_graphs=False)
+    if ms.get_context("device_target") == "GPU":
         # Enable graph kernel
-        context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
+        ms.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion")
     if args.is_distributed:
         init("nccl")
-        context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
-                                          gradients_mean=True)
+        ms.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
+                                     gradients_mean=True)
         cfg['ckpt_path'] = cfg['ckpt_path'] + "ckpt_" + str(get_rank()) + "/"
 
     batch_size = cfg['batch_size']
@@ -67,8 +65,8 @@ def train(cfg, args):
 
     if cfg['pretrain'] and cfg['resume_net'] is None:
         pretrained_res50 = cfg['pretrain_path']
-        param_dict_res50 = load_checkpoint(pretrained_res50)
-        load_param_into_net(backbone, param_dict_res50)
+        param_dict_res50 = ms.load_checkpoint(pretrained_res50)
+        ms.load_param_into_net(backbone, param_dict_res50)
         print('Load resnet50 from [{}] done.'.format(pretrained_res50))
 
     net = RetinaFace(phase='train', backbone=backbone)
@@ -76,8 +74,8 @@ def train(cfg, args):
 
     if cfg['resume_net'] is not None:
         pretrain_model_path = cfg['resume_net']
-        param_dict_retinaface = load_checkpoint(pretrain_model_path)
-        load_param_into_net(net, param_dict_retinaface)
+        param_dict_retinaface = ms.load_checkpoint(pretrain_model_path)
+        ms.load_param_into_net(net, param_dict_retinaface)
         print('Resume Model from [{}] Done.'.format(cfg['resume_net']))
 
     net = RetinaFaceWithLossCell(net, multibox_loss, cfg)
@@ -86,10 +84,10 @@ def train(cfg, args):
                               warmup_epoch=cfg['warmup_epoch'])
 
     if cfg['optim'] == 'momentum':
-        opt = mindspore.nn.Momentum(net.trainable_params(), lr, momentum)
+        opt = ms.nn.Momentum(net.trainable_params(), lr, momentum)
     elif cfg['optim'] == 'sgd':
-        opt = mindspore.nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
-                               weight_decay=weight_decay, loss_scale=1)
+        opt = ms.nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=momentum,
+                        weight_decay=weight_decay, loss_scale=1)
     else:
         raise ValueError('optim is not define.')
 
@@ -115,7 +113,7 @@ if __name__ == '__main__':
     arg, _ = parser.parse_known_args()
 
     config = cfg_res50
-    mindspore.common.seed.set_seed(config['seed'])
+    ms.common.seed.set_seed(config['seed'])
     print('train config:\n', config)
 
     train(cfg=config, args=arg)