diff --git a/official/cv/mobilenetv1/eval.py b/official/cv/mobilenetv1/eval.py
index c6a549bed5723c6db1991fef5e0359eece7fb748..4d7c90a976691ec2055a45323ef3cd033e44b5bb 100644
--- a/official/cv/mobilenetv1/eval.py
+++ b/official/cv/mobilenetv1/eval.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,20 +14,15 @@
 # ============================================================================
 """eval mobilenet_v1."""
 import os
-import time
-from mindspore import context
-from mindspore.common import set_seed
+import mindspore as ms
 from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
-from mindspore.train.model import Model
-from mindspore.train.serialization import load_checkpoint, load_param_into_net
 from src.CrossEntropySmooth import CrossEntropySmooth
 from src.mobilenet_v1 import mobilenet_v1 as mobilenet
 from src.model_utils.config import config
-from src.model_utils.moxing_adapter import moxing_wrapper
-from src.model_utils.device_adapter import get_device_id, get_device_num
+from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_process
 
 
-set_seed(1)
+ms.set_seed(1)
 
 if config.dataset == 'cifar10':
     from src.dataset import create_dataset1 as create_dataset
@@ -35,63 +30,6 @@ else:
     from src.dataset import create_dataset2 as create_dataset
 
 
-def modelarts_process():
-    """ modelarts process """
-    def unzip(zip_file, save_dir):
-        import zipfile
-        s_time = time.time()
-        if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
-            zip_isexist = zipfile.is_zipfile(zip_file)
-            if zip_isexist:
-                fz = zipfile.ZipFile(zip_file, 'r')
-                data_num = len(fz.namelist())
-                print("Extract Start...")
-                print("unzip file num: {}".format(data_num))
-                data_print = int(data_num / 100) if data_num > 100 else 1
-                i = 0
-                for file in fz.namelist():
-                    if i % data_print == 0:
-                        print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
-                    i += 1
-                    fz.extract(file, save_dir)
-                print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
-                    int(int(time.time() - s_time) % 60)))
-                print("Extract Done")
-            else:
-                print("This is not zip.")
-        else:
-            print("Zip has been extracted.")
-
-    if config.need_modelarts_dataset_unzip:
-        zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
-        save_dir_1 = os.path.join(config.data_path)
-
-        sync_lock = "/tmp/unzip_sync.lock"
-
-        # Each server contains 8 devices as most
-        if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
-            print("Zip file path: ", zip_file_1)
-            print("Unzip file save dir: ", save_dir_1)
-            unzip(zip_file_1, save_dir_1)
-            print("===Finish extract data synchronization===")
-            try:
-                os.mknod(sync_lock)
-            except IOError:
-                pass
-
-        while True:
-            if os.path.exists(sync_lock):
-                break
-            time.sleep(1)
-
-        print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
-        print("#" * 200, os.listdir(save_dir_1))
-        print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
-
-        config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
-    config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path)
-
-
 @moxing_wrapper(pre_process=modelarts_process)
 def eval_mobilenetv1():
     """ eval_mobilenetv1 """
@@ -101,10 +39,10 @@ def eval_mobilenetv1():
     target = config.device_target
 
     # init context
-    context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
+    ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False)
     if target == "Ascend":
         device_id = int(os.getenv('DEVICE_ID'))
-        context.set_context(device_id=device_id)
+        ms.set_context(device_id=device_id)
 
     # create dataset
     dataset = create_dataset(dataset_path=config.dataset_path, do_train=False, batch_size=config.batch_size,
@@ -115,8 +53,8 @@ def eval_mobilenetv1():
     net = mobilenet(class_num=config.class_num)
 
     # load checkpoint
-    param_dict = load_checkpoint(config.checkpoint_path)
-    load_param_into_net(net, param_dict)
+    param_dict = ms.load_checkpoint(config.checkpoint_path)
+    ms.load_param_into_net(net, param_dict)
     net.set_train(False)
 
     # define loss, model
@@ -129,11 +67,12 @@ def eval_mobilenetv1():
         loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
 
     # define model
-    model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
+    model = ms.Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
 
     # eval model
     res = model.eval(dataset)
     print("result:", res, "ckpt=", config.checkpoint_path)
 
+
 if __name__ == '__main__':
     eval_mobilenetv1()
diff --git a/official/cv/mobilenetv1/export.py b/official/cv/mobilenetv1/export.py
index 399c1a3373bbeaf5591ebd9e3a3fd591df63b778..679c4f5d1c160b43d4706be3f72dd049bdb5d862 100644
--- a/official/cv/mobilenetv1/export.py
+++ b/official/cv/mobilenetv1/export.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,32 +15,30 @@
 
 import numpy as np
 
-from mindspore import context, Tensor
-from mindspore.train.serialization import export, load_checkpoint
+import mindspore as ms
 
 from src.mobilenet_v1 import mobilenet_v1 as mobilenet
 from src.model_utils.config import config
 from src.model_utils.device_adapter import get_device_id
-from src.model_utils.moxing_adapter import moxing_wrapper
+from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_export_preprocess
 
 
-context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
+ms.set_context(mode=ms.GRAPH_MODE, device_target=config.device_target)
 
-def modelarts_process():
-    pass
 
-@moxing_wrapper(pre_process=modelarts_process)
+@moxing_wrapper(pre_process=modelarts_export_preprocess)
 def export_mobilenetv1():
     """ export_mobilenetv1 """
     target = config.device_target
     if target != "GPU":
-        context.set_context(device_id=get_device_id())
+        ms.set_context(device_id=get_device_id())
 
     network = mobilenet(class_num=config.class_num)
-    load_checkpoint(config.ckpt_file, net=network)
+    ms.load_checkpoint(config.ckpt_file, net=network)
     network.set_train(False)
-    input_data = Tensor(np.zeros([config.batch_size, 3, config.height, config.width]).astype(np.float32))
-    export(network, input_data, file_name=config.file_name, file_format=config.file_format)
+    input_data = ms.numpy.zeros([config.batch_size, 3, config.height, config.width]).astype(np.float32)
+    ms.export(network, input_data, file_name=config.file_name, file_format=config.file_format)
+
 
 if __name__ == '__main__':
     export_mobilenetv1()
diff --git a/official/cv/mobilenetv1/src/CrossEntropySmooth.py b/official/cv/mobilenetv1/src/CrossEntropySmooth.py
index a15fd2fbf74350bf579298e84d8c722ac3c46c53..f8283eb7609a5a1cfd8f39a792d24063dd744ad1 100644
--- a/official/cv/mobilenetv1/src/CrossEntropySmooth.py
+++ b/official/cv/mobilenetv1/src/CrossEntropySmooth.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,26 +13,23 @@
 # limitations under the License.
 # ============================================================================
 """define loss function for network"""
+import mindspore as ms
 import mindspore.nn as nn
-from mindspore import Tensor
-from mindspore.common import dtype as mstype
-from mindspore.nn.loss.loss import LossBase
-from mindspore.ops import functional as F
-from mindspore.ops import operations as P
+import mindspore.ops as ops
 
 
-class CrossEntropySmooth(LossBase):
+class CrossEntropySmooth(nn.LossBase):
     """CrossEntropy"""
     def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
         super(CrossEntropySmooth, self).__init__()
-        self.onehot = P.OneHot()
+        self.onehot = ops.OneHot()
         self.sparse = sparse
-        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
-        self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
+        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
+        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
         self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
 
     def construct(self, logit, label):
         if self.sparse:
-            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
+            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
         loss = self.ce(logit, label)
         return loss
diff --git a/official/cv/mobilenetv1/src/dataset.py b/official/cv/mobilenetv1/src/dataset.py
index 9f5b9328198eb3bdae3f3551806fb3fa8f6cf7ee..3d7c671508462dad5fed4f5541ae0076e123ccc1 100644
--- a/official/cv/mobilenetv1/src/dataset.py
+++ b/official/cv/mobilenetv1/src/dataset.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,22 +17,19 @@ create train or eval dataset.
 """
 import os
 from multiprocessing import cpu_count
-import mindspore.common.dtype as mstype
+import mindspore as ms
 import mindspore.dataset as ds
-import mindspore.dataset.vision.c_transforms as C
-import mindspore.dataset.transforms.c_transforms as C2
-from mindspore.communication.management import get_rank, get_group_size
+import mindspore.communication as comm
 
 THREAD_NUM = 12 if cpu_count() >= 12 else 8
 
 
-def create_dataset1(dataset_path, do_train, device_num=1, repeat_num=1, batch_size=32, target="Ascend"):
+def create_dataset1(dataset_path, do_train, device_num=1, batch_size=32, target="Ascend"):
     """
     create a train or evaluate cifar10 dataset for mobilenet
     Args:
         dataset_path(string): the path of dataset.
         do_train(bool): whether dataset is used for train or eval.
-        repeat_num(int): the repeat times of dataset. Default: 1
         batch_size(int): the batch size of dataset. Default: 32
         target(str): the device target. Default: Ascend
 
@@ -50,38 +47,35 @@ def create_dataset1(dataset_path, do_train, device_num=1, repeat_num=1, batch_si
     trans = []
     if do_train:
         trans += [
-            C.RandomCrop((32, 32), (4, 4, 4, 4)),
-            C.RandomHorizontalFlip(prob=0.5)
+            ds.vision.c_transforms.RandomCrop((32, 32), (4, 4, 4, 4)),
+            ds.vision.c_transforms.RandomHorizontalFlip(prob=0.5)
         ]
 
     trans += [
-        C.Resize((224, 224)),
-        C.Rescale(1.0 / 255.0, 0.0),
-        C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
-        C.HWC2CHW()
+        ds.vision.c_transforms.Resize((224, 224)),
+        ds.vision.c_transforms.Rescale(1.0 / 255.0, 0.0),
+        ds.vision.c_transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
+        ds.vision.c_transforms.HWC2CHW()
     ]
 
-    type_cast_op = C2.TypeCast(mstype.int32)
+    type_cast_op = ds.transforms.c_transforms.TypeCast(ms.int32)
 
     data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM)
     data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM)
 
     # apply batch operations
     data_set = data_set.batch(batch_size, drop_remainder=True)
-    # apply dataset repeat operation
-    data_set = data_set.repeat(repeat_num)
 
     return data_set
 
 
-def create_dataset2(dataset_path, do_train, device_num=1, repeat_num=1, batch_size=32, target="Ascend"):
+def create_dataset2(dataset_path, do_train, device_num=1, batch_size=32, target="Ascend"):
     """
     create a train or eval imagenet2012 dataset for mobilenet
 
     Args:
         dataset_path(string): the path of dataset.
         do_train(bool): whether dataset is used for train or eval.
-        repeat_num(int): the repeat times of dataset. Default: 1
         batch_size(int): the batch size of dataset. Default: 32
         target(str): the device target. Default: Ascend
 
@@ -103,21 +97,21 @@ def create_dataset2(dataset_path, do_train, device_num=1, repeat_num=1, batch_si
     # define map operations
     if do_train:
         trans = [
-            C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
-            C.RandomHorizontalFlip(prob=0.5),
-            C.Normalize(mean=mean, std=std),
-            C.HWC2CHW()
+            ds.vision.c_transforms.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
+            ds.vision.c_transforms.RandomHorizontalFlip(prob=0.5),
+            ds.vision.c_transforms.Normalize(mean=mean, std=std),
+            ds.vision.c_transforms.HWC2CHW()
         ]
     else:
         trans = [
-            C.Decode(),
-            C.Resize(256),
-            C.CenterCrop(image_size),
-            C.Normalize(mean=mean, std=std),
-            C.HWC2CHW()
+            ds.vision.c_transforms.Decode(),
+            ds.vision.c_transforms.Resize(256),
+            ds.vision.c_transforms.CenterCrop(image_size),
+            ds.vision.c_transforms.Normalize(mean=mean, std=std),
+            ds.vision.c_transforms.HWC2CHW()
         ]
 
-    type_cast_op = C2.TypeCast(mstype.int32)
+    type_cast_op = ds.transforms.c_transforms.TypeCast(ms.int32)
 
     data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM)
     data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM)
@@ -125,9 +119,6 @@ def create_dataset2(dataset_path, do_train, device_num=1, repeat_num=1, batch_si
     # apply batch operations
     data_set = data_set.batch(batch_size, drop_remainder=True)
 
-    # apply dataset repeat operation
-    data_set = data_set.repeat(repeat_num)
-
     return data_set
 
 
@@ -138,8 +129,8 @@ def _get_rank_info():
     rank_size = int(os.environ.get("RANK_SIZE", 1))
 
     if rank_size > 1:
-        rank_size = get_group_size()
-        rank_id = get_rank()
+        rank_size = comm.get_group_size()
+        rank_id = comm.get_rank()
     else:
         rank_size = 1
         rank_id = 0
diff --git a/official/cv/mobilenetv1/src/mobilenet_v1.py b/official/cv/mobilenetv1/src/mobilenet_v1.py
index 98b9dae418ec327ce4a4aeb4195b193eb09df6d0..29bfbabbae79e6ed885d2727c79c6ffb7db163d3 100644
--- a/official/cv/mobilenetv1/src/mobilenet_v1.py
+++ b/official/cv/mobilenetv1/src/mobilenet_v1.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,8 +13,9 @@
 # limitations under the License.
 # ============================================================================
 
+import mindspore.ops as ops
 import mindspore.nn as nn
-from mindspore.ops import operations as P
+
 
 def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'):
     output = []
@@ -83,7 +84,7 @@ class MobileNetV1(nn.Cell):
                 features = features + (output,)
             return features
         output = self.network(x)
-        output = P.ReduceMean()(output, (2, 3))
+        output = ops.ReduceMean()(output, (2, 3))
         output = self.fc(output)
         return output
 
diff --git a/official/cv/mobilenetv1/src/model_utils/moxing_adapter.py b/official/cv/mobilenetv1/src/model_utils/moxing_adapter.py
index 830d19a6fc99de8d602703971d5ac5b24e060d11..31233d6fce26aafb3b52fa910e03f63a6fd31db4 100644
--- a/official/cv/mobilenetv1/src/model_utils/moxing_adapter.py
+++ b/official/cv/mobilenetv1/src/model_utils/moxing_adapter.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2021-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,8 +17,10 @@
 
 import os
 import functools
-from mindspore import context
-from mindspore.profiler import Profiler
+import zipfile
+import time
+
+import mindspore as ms
 from .config import config
 
 _global_sync_count = 0
@@ -43,13 +45,13 @@ def get_job_id():
     job_id = job_id if job_id != "" else "default"
     return job_id
 
+
 def sync_data(from_path, to_path):
     """
     Download data from remote obs to local directory if the first url is remote url and the second one is local path
     Upload data from local directory to remote obs in contrast.
     """
     import moxing as mox
-    import time
     global _global_sync_count
     sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
     _global_sync_count += 1
@@ -93,7 +95,7 @@ def moxing_wrapper(pre_process=None, post_process=None):
                     sync_data(config.train_url, config.output_path)
                     print("Workspace downloaded: ", os.listdir(config.output_path))
 
-                context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
+                ms.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
                 config.device_num = get_device_num()
                 config.device_id = get_device_id()
                 if not os.path.exists(config.output_path):
@@ -103,7 +105,7 @@ def moxing_wrapper(pre_process=None, post_process=None):
                     pre_process()
 
             if config.enable_profiling:
-                profiler = Profiler()
+                profiler = ms.profiler.Profiler()
 
             run_func(*args, **kwargs)
 
@@ -120,3 +122,64 @@ def moxing_wrapper(pre_process=None, post_process=None):
                     sync_data(config.output_path, config.train_url)
         return wrapped_func
     return wrapper
+
+
+def modelarts_process():
+    """ modelarts process """
+    def unzip(zip_file, save_dir):
+        s_time = time.time()
+        if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
+            zip_isexist = zipfile.is_zipfile(zip_file)
+            if zip_isexist:
+                fz = zipfile.ZipFile(zip_file, 'r')
+                data_num = len(fz.namelist())
+                print("Extract Start...")
+                print("unzip file num: {}".format(data_num))
+                data_print = int(data_num / 100) if data_num > 100 else 1
+                i = 0
+                for file in fz.namelist():
+                    if i % data_print == 0:
+                        print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
+                    i += 1
+                    fz.extract(file, save_dir)
+                print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
+                                                     int(int(time.time() - s_time) % 60)))
+                print("Extract Done")
+            else:
+                print("This is not zip.")
+        else:
+            print("Zip has been extracted.")
+
+    if config.need_modelarts_dataset_unzip:
+        zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
+        save_dir_1 = os.path.join(config.data_path)
+
+        sync_lock = "/tmp/unzip_sync.lock"
+
+        # Each server contains 8 devices as most
+        if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
+            print("Zip file path: ", zip_file_1)
+            print("Unzip file save dir: ", save_dir_1)
+            unzip(zip_file_1, save_dir_1)
+            print("===Finish extract data synchronization===")
+            try:
+                os.mknod(sync_lock)
+            except IOError:
+                pass
+
+        while True:
+            if os.path.exists(sync_lock):
+                break
+            time.sleep(1)
+
+        print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
+        print("#" * 200, os.listdir(save_dir_1))
+        print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
+
+        config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
+    config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path)
+
+
+def modelarts_export_preprocess():
+    """ modelarts export process """
+    config.file_name = os.path.join(config.output_path, config.file_name)
diff --git a/official/cv/mobilenetv1/train.py b/official/cv/mobilenetv1/train.py
index 1324a7c07316f9271486ebafbb6863a980370527..8c13edb497871dd655651449c83a24d5375141bb 100644
--- a/official/cv/mobilenetv1/train.py
+++ b/official/cv/mobilenetv1/train.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,29 +14,22 @@
 # ============================================================================
 """train mobilenet_v1."""
 import os
-import time
-from mindspore import context
-from mindspore import Tensor
-from mindspore.nn.optim.momentum import Momentum
-from mindspore.train.model import Model
-from mindspore.context import ParallelMode
-from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
-from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
-from mindspore.train.loss_scale_manager import FixedLossScaleManager
-from mindspore.train.serialization import load_checkpoint, load_param_into_net
-from mindspore.communication.management import init, get_rank, get_group_size
-from mindspore.common import set_seed
+
+import mindspore as ms
 import mindspore.nn as nn
+import mindspore.communication as comm
 import mindspore.common.initializer as weight_init
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
+
 from src.lr_generator import get_lr
 from src.CrossEntropySmooth import CrossEntropySmooth
 from src.mobilenet_v1 import mobilenet_v1 as mobilenet
 from src.model_utils.config import config
-from src.model_utils.moxing_adapter import moxing_wrapper
-from src.model_utils.device_adapter import get_device_id, get_device_num
+from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_process
+from src.model_utils.device_adapter import get_device_num
 
 
-set_seed(1)
+ms.set_seed(1)
 
 if config.dataset == 'cifar10':
     from src.dataset import create_dataset1 as create_dataset
@@ -44,66 +37,11 @@ else:
     from src.dataset import create_dataset2 as create_dataset
 
 
-def modelarts_pre_process():
-    def unzip(zip_file, save_dir):
-        import zipfile
-        s_time = time.time()
-        if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
-            zip_isexist = zipfile.is_zipfile(zip_file)
-            if zip_isexist:
-                fz = zipfile.ZipFile(zip_file, 'r')
-                data_num = len(fz.namelist())
-                print("Extract Start...")
-                print("unzip file num: {}".format(data_num))
-                data_print = int(data_num / 100) if data_num > 100 else 1
-                i = 0
-                for file in fz.namelist():
-                    if i % data_print == 0:
-                        print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
-                    i += 1
-                    fz.extract(file, save_dir)
-                print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
-                    int(int(time.time() - s_time) % 60)))
-                print("Extract Done")
-            else:
-                print("This is not zip.")
-        else:
-            print("Zip has been extracted.")
-
-    if config.need_modelarts_dataset_unzip:
-        zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
-        save_dir_1 = os.path.join(config.data_path)
-
-        sync_lock = "/tmp/unzip_sync.lock"
-
-        # Each server contains 8 devices as most
-        if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
-            print("Zip file path: ", zip_file_1)
-            print("Unzip file save dir: ", save_dir_1)
-            unzip(zip_file_1, save_dir_1)
-            print("===Finish extract data synchronization===")
-            try:
-                os.mknod(sync_lock)
-            except IOError:
-                pass
-
-        while True:
-            if os.path.exists(sync_lock):
-                break
-            time.sleep(1)
-
-        print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
-        print("#" * 200, os.listdir(save_dir_1))
-        print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
-
-        config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
-    config.save_checkpoint_path = config.output_path
-
 def init_weigth(net):
     # init weight
     if config.pre_trained:
-        param_dict = load_checkpoint(config.pre_trained)
-        load_param_into_net(net, param_dict)
+        param_dict = ms.load_checkpoint(config.pre_trained)
+        ms.load_param_into_net(net, param_dict)
     else:
         for _, cell in net.cells_and_names():
             if isinstance(cell, nn.Conv2d):
@@ -115,7 +53,8 @@ def init_weigth(net):
                                                              cell.weight.shape,
                                                              cell.weight.dtype))
 
-@moxing_wrapper(pre_process=modelarts_pre_process)
+
+@moxing_wrapper(pre_process=modelarts_process)
 def train_mobilenetv1():
     """ train_mobilenetv1 """
     if config.dataset == 'imagenet2012':
@@ -124,31 +63,31 @@ def train_mobilenetv1():
     ckpt_save_dir = config.save_checkpoint_path
 
     # init context
-    context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
+    ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False)
 
     # Set mempool block size in PYNATIVE_MODE for improving memory utilization, which will not take effect in GRAPH_MODE
-    context.set_context(mempool_block_size="31GB")
+    ms.set_context(mempool_block_size="31GB")
 
     if config.parameter_server:
-        context.set_ps_context(enable_ps=True)
+        ms.set_ps_context(enable_ps=True)
     device_id = int(os.getenv('DEVICE_ID', '0'))
     if config.run_distribute:
         if target == "Ascend":
-            context.set_context(device_id=device_id)
-            context.set_auto_parallel_context(device_num=get_device_num(), parallel_mode=ParallelMode.DATA_PARALLEL,
-                                              gradients_mean=True)
-            init()
-            context.set_auto_parallel_context(all_reduce_fusion_config=[75])
+            ms.set_context(device_id=device_id)
+            ms.set_auto_parallel_context(device_num=get_device_num(), parallel_mode=ms.ParallelMode.DATA_PARALLEL,
+                                         gradients_mean=True)
+            comm.init()
+            ms.set_auto_parallel_context(all_reduce_fusion_config=[75])
         # GPU target
         else:
-            init()
-            context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
-                                              gradients_mean=True)
-        ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
+            comm.init()
+            ms.set_auto_parallel_context(device_num=ms.get_group_size(), parallel_mode=ms.ParallelMode.DATA_PARALLEL,
+                                         gradients_mean=True)
+        ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(comm.get_rank()) + "/"
 
     # create dataset
     dataset = create_dataset(dataset_path=config.dataset_path, do_train=True, device_num=config.device_num,
-                             repeat_num=1, batch_size=config.batch_size, target=target)
+                             batch_size=config.batch_size, target=target)
     step_size = dataset.get_dataset_size()
 
     # define net
@@ -162,7 +101,7 @@ def train_mobilenetv1():
     lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
                 warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
                 lr_decay_mode=config.lr_decay_mode)
-    lr = Tensor(lr)
+    lr = ms.Tensor(lr)
 
     # define opt
     decayed_params = []
@@ -177,10 +116,10 @@ def train_mobilenetv1():
         group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
                         {'params': no_decayed_params},
                         {'order_params': net.trainable_params()}]
-        opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
+        opt = nn.Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
     else:
-        opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
-                       config.weight_decay)
+        opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
+                          config.weight_decay)
     # define loss, model
     if config.dataset == "imagenet2012":
         if not config.use_label_smooth:
@@ -188,13 +127,13 @@ def train_mobilenetv1():
         loss = CrossEntropySmooth(sparse=True, reduction="mean",
                                   smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
     else:
-        loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
-    loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
+        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+    loss_scale = ms.FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
     if target == "Ascend":
-        model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
-                      amp_level="O2", keep_batchnorm_fp32=False)
+        model = ms.Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
+                         amp_level="O2", keep_batchnorm_fp32=False)
     else:
-        model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
+        model = ms.Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
 
     # define callbacks
     time_cb = TimeMonitor(data_size=step_size)
@@ -210,5 +149,6 @@ def train_mobilenetv1():
     model.train(config.epoch_size - config.pretrain_epoch_size, dataset, callbacks=cb,
                 sink_size=dataset.get_dataset_size(), dataset_sink_mode=(not config.parameter_server))
 
+
 if __name__ == '__main__':
     train_mobilenetv1()