From d1a4fbdda2634edb06417b947ce5f0b59b6bbaf9 Mon Sep 17 00:00:00 2001
From: zyhStack <zhangyihui7@huawei.com>
Date: Thu, 10 Mar 2022 11:21:18 +0800
Subject: [PATCH] mobilenet_v1 usability rectification

---
 official/cv/mobilenetv1/eval.py               |  81 ++---------
 official/cv/mobilenetv1/export.py             |  22 ++-
 .../cv/mobilenetv1/src/CrossEntropySmooth.py  |  19 ++-
 official/cv/mobilenetv1/src/dataset.py        |  57 ++++----
 official/cv/mobilenetv1/src/mobilenet_v1.py   |   7 +-
 .../src/model_utils/moxing_adapter.py         |  75 +++++++++-
 official/cv/mobilenetv1/train.py              | 132 +++++-------------
 7 files changed, 161 insertions(+), 232 deletions(-)

diff --git a/official/cv/mobilenetv1/eval.py b/official/cv/mobilenetv1/eval.py
index c6a549bed..4d7c90a97 100644
--- a/official/cv/mobilenetv1/eval.py
+++ b/official/cv/mobilenetv1/eval.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,20 +14,15 @@
 # ============================================================================
 """eval mobilenet_v1."""
 import os
-import time
-from mindspore import context
-from mindspore.common import set_seed
+import mindspore as ms
 from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
-from mindspore.train.model import Model
-from mindspore.train.serialization import load_checkpoint, load_param_into_net
 from src.CrossEntropySmooth import CrossEntropySmooth
 from src.mobilenet_v1 import mobilenet_v1 as mobilenet
 from src.model_utils.config import config
-from src.model_utils.moxing_adapter import moxing_wrapper
-from src.model_utils.device_adapter import get_device_id, get_device_num
+from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_process
 
 
-set_seed(1)
+ms.set_seed(1)
 
 if config.dataset == 'cifar10':
     from src.dataset import create_dataset1 as create_dataset
@@ -35,63 +30,6 @@ else:
     from src.dataset import create_dataset2 as create_dataset
 
 
-def modelarts_process():
-    """ modelarts process """
-    def unzip(zip_file, save_dir):
-        import zipfile
-        s_time = time.time()
-        if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
-            zip_isexist = zipfile.is_zipfile(zip_file)
-            if zip_isexist:
-                fz = zipfile.ZipFile(zip_file, 'r')
-                data_num = len(fz.namelist())
-                print("Extract Start...")
-                print("unzip file num: {}".format(data_num))
-                data_print = int(data_num / 100) if data_num > 100 else 1
-                i = 0
-                for file in fz.namelist():
-                    if i % data_print == 0:
-                        print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
-                    i += 1
-                    fz.extract(file, save_dir)
-                print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
-                    int(int(time.time() - s_time) % 60)))
-                print("Extract Done")
-            else:
-                print("This is not zip.")
-        else:
-            print("Zip has been extracted.")
-
-    if config.need_modelarts_dataset_unzip:
-        zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
-        save_dir_1 = os.path.join(config.data_path)
-
-        sync_lock = "/tmp/unzip_sync.lock"
-
-        # Each server contains 8 devices as most
-        if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
-            print("Zip file path: ", zip_file_1)
-            print("Unzip file save dir: ", save_dir_1)
-            unzip(zip_file_1, save_dir_1)
-            print("===Finish extract data synchronization===")
-            try:
-                os.mknod(sync_lock)
-            except IOError:
-                pass
-
-        while True:
-            if os.path.exists(sync_lock):
-                break
-            time.sleep(1)
-
-        print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
-        print("#" * 200, os.listdir(save_dir_1))
-        print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
-
-        config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
-    config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path)
-
-
 @moxing_wrapper(pre_process=modelarts_process)
 def eval_mobilenetv1():
     """ eval_mobilenetv1 """
@@ -101,10 +39,10 @@ def eval_mobilenetv1():
     target = config.device_target
 
     # init context
-    context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
+    ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False)
     if target == "Ascend":
         device_id = int(os.getenv('DEVICE_ID'))
-        context.set_context(device_id=device_id)
+        ms.set_context(device_id=device_id)
 
     # create dataset
     dataset = create_dataset(dataset_path=config.dataset_path, do_train=False, batch_size=config.batch_size,
@@ -115,8 +53,8 @@ def eval_mobilenetv1():
     net = mobilenet(class_num=config.class_num)
 
     # load checkpoint
-    param_dict = load_checkpoint(config.checkpoint_path)
-    load_param_into_net(net, param_dict)
+    param_dict = ms.load_checkpoint(config.checkpoint_path)
+    ms.load_param_into_net(net, param_dict)
     net.set_train(False)
 
     # define loss, model
@@ -129,11 +67,12 @@ def eval_mobilenetv1():
         loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
 
     # define model
-    model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
+    model = ms.Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
 
     # eval model
     res = model.eval(dataset)
     print("result:", res, "ckpt=", config.checkpoint_path)
 
+
 if __name__ == '__main__':
     eval_mobilenetv1()
diff --git a/official/cv/mobilenetv1/export.py b/official/cv/mobilenetv1/export.py
index 399c1a337..679c4f5d1 100644
--- a/official/cv/mobilenetv1/export.py
+++ b/official/cv/mobilenetv1/export.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,32 +15,30 @@
 
 import numpy as np
 
-from mindspore import context, Tensor
-from mindspore.train.serialization import export, load_checkpoint
+import mindspore as ms
 
 from src.mobilenet_v1 import mobilenet_v1 as mobilenet
 from src.model_utils.config import config
 from src.model_utils.device_adapter import get_device_id
-from src.model_utils.moxing_adapter import moxing_wrapper
+from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_export_preprocess
 
 
-context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
+ms.set_context(mode=ms.GRAPH_MODE, device_target=config.device_target)
 
-def modelarts_process():
-    pass
 
-@moxing_wrapper(pre_process=modelarts_process)
+@moxing_wrapper(pre_process=modelarts_export_preprocess)
 def export_mobilenetv1():
     """ export_mobilenetv1 """
     target = config.device_target
     if target != "GPU":
-        context.set_context(device_id=get_device_id())
+        ms.set_context(device_id=get_device_id())
 
     network = mobilenet(class_num=config.class_num)
-    load_checkpoint(config.ckpt_file, net=network)
+    ms.load_checkpoint(config.ckpt_file, net=network)
     network.set_train(False)
-    input_data = Tensor(np.zeros([config.batch_size, 3, config.height, config.width]).astype(np.float32))
-    export(network, input_data, file_name=config.file_name, file_format=config.file_format)
+    input_data = ms.numpy.zeros([config.batch_size, 3, config.height, config.width]).astype(np.float32)
+    ms.export(network, input_data, file_name=config.file_name, file_format=config.file_format)
+
 
 if __name__ == '__main__':
     export_mobilenetv1()
diff --git a/official/cv/mobilenetv1/src/CrossEntropySmooth.py b/official/cv/mobilenetv1/src/CrossEntropySmooth.py
index a15fd2fbf..f8283eb76 100644
--- a/official/cv/mobilenetv1/src/CrossEntropySmooth.py
+++ b/official/cv/mobilenetv1/src/CrossEntropySmooth.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,26 +13,23 @@
 # limitations under the License.
 # ============================================================================
 """define loss function for network"""
+import mindspore as ms
 import mindspore.nn as nn
-from mindspore import Tensor
-from mindspore.common import dtype as mstype
-from mindspore.nn.loss.loss import LossBase
-from mindspore.ops import functional as F
-from mindspore.ops import operations as P
+import mindspore.ops as ops
 
 
-class CrossEntropySmooth(LossBase):
+class CrossEntropySmooth(nn.LossBase):
     """CrossEntropy"""
     def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
         super(CrossEntropySmooth, self).__init__()
-        self.onehot = P.OneHot()
+        self.onehot = ops.OneHot()
         self.sparse = sparse
-        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
-        self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
+        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
+        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
         self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
 
     def construct(self, logit, label):
         if self.sparse:
-            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
+            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
         loss = self.ce(logit, label)
         return loss
diff --git a/official/cv/mobilenetv1/src/dataset.py b/official/cv/mobilenetv1/src/dataset.py
index 9f5b93281..3d7c67150 100644
--- a/official/cv/mobilenetv1/src/dataset.py
+++ b/official/cv/mobilenetv1/src/dataset.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,22 +17,19 @@ create train or eval dataset.
 """
 import os
 from multiprocessing import cpu_count
-import mindspore.common.dtype as mstype
+import mindspore as ms
 import mindspore.dataset as ds
-import mindspore.dataset.vision.c_transforms as C
-import mindspore.dataset.transforms.c_transforms as C2
-from mindspore.communication.management import get_rank, get_group_size
+import mindspore.communication as comm
 
 THREAD_NUM = 12 if cpu_count() >= 12 else 8
 
 
-def create_dataset1(dataset_path, do_train, device_num=1, repeat_num=1, batch_size=32, target="Ascend"):
+def create_dataset1(dataset_path, do_train, device_num=1, batch_size=32, target="Ascend"):
     """
     create a train or evaluate cifar10 dataset for mobilenet
     Args:
         dataset_path(string): the path of dataset.
         do_train(bool): whether dataset is used for train or eval.
-        repeat_num(int): the repeat times of dataset. Default: 1
         batch_size(int): the batch size of dataset. Default: 32
         target(str): the device target. Default: Ascend
 
@@ -50,38 +47,35 @@ def create_dataset1(dataset_path, do_train, device_num=1, repeat_num=1, batch_si
     trans = []
     if do_train:
         trans += [
-            C.RandomCrop((32, 32), (4, 4, 4, 4)),
-            C.RandomHorizontalFlip(prob=0.5)
+            ds.vision.c_transforms.RandomCrop((32, 32), (4, 4, 4, 4)),
+            ds.vision.c_transforms.RandomHorizontalFlip(prob=0.5)
         ]
 
     trans += [
-        C.Resize((224, 224)),
-        C.Rescale(1.0 / 255.0, 0.0),
-        C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
-        C.HWC2CHW()
+        ds.vision.c_transforms.Resize((224, 224)),
+        ds.vision.c_transforms.Rescale(1.0 / 255.0, 0.0),
+        ds.vision.c_transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
+        ds.vision.c_transforms.HWC2CHW()
     ]
 
-    type_cast_op = C2.TypeCast(mstype.int32)
+    type_cast_op = ds.transforms.c_transforms.TypeCast(ms.int32)
 
     data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM)
     data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM)
 
     # apply batch operations
     data_set = data_set.batch(batch_size, drop_remainder=True)
-    # apply dataset repeat operation
-    data_set = data_set.repeat(repeat_num)
 
     return data_set
 
 
-def create_dataset2(dataset_path, do_train, device_num=1, repeat_num=1, batch_size=32, target="Ascend"):
+def create_dataset2(dataset_path, do_train, device_num=1, batch_size=32, target="Ascend"):
     """
     create a train or eval imagenet2012 dataset for mobilenet
 
     Args:
         dataset_path(string): the path of dataset.
         do_train(bool): whether dataset is used for train or eval.
-        repeat_num(int): the repeat times of dataset. Default: 1
         batch_size(int): the batch size of dataset. Default: 32
         target(str): the device target. Default: Ascend
 
@@ -103,21 +97,21 @@ def create_dataset2(dataset_path, do_train, device_num=1, repeat_num=1, batch_si
     # define map operations
     if do_train:
         trans = [
-            C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
-            C.RandomHorizontalFlip(prob=0.5),
-            C.Normalize(mean=mean, std=std),
-            C.HWC2CHW()
+            ds.vision.c_transforms.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
+            ds.vision.c_transforms.RandomHorizontalFlip(prob=0.5),
+            ds.vision.c_transforms.Normalize(mean=mean, std=std),
+            ds.vision.c_transforms.HWC2CHW()
         ]
     else:
         trans = [
-            C.Decode(),
-            C.Resize(256),
-            C.CenterCrop(image_size),
-            C.Normalize(mean=mean, std=std),
-            C.HWC2CHW()
+            ds.vision.c_transforms.Decode(),
+            ds.vision.c_transforms.Resize(256),
+            ds.vision.c_transforms.CenterCrop(image_size),
+            ds.vision.c_transforms.Normalize(mean=mean, std=std),
+            ds.vision.c_transforms.HWC2CHW()
         ]
 
-    type_cast_op = C2.TypeCast(mstype.int32)
+    type_cast_op = ds.transforms.c_transforms.TypeCast(ms.int32)
 
     data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM)
     data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM)
@@ -125,9 +119,6 @@ def create_dataset2(dataset_path, do_train, device_num=1, repeat_num=1, batch_si
     # apply batch operations
     data_set = data_set.batch(batch_size, drop_remainder=True)
 
-    # apply dataset repeat operation
-    data_set = data_set.repeat(repeat_num)
-
     return data_set
 
 
@@ -138,8 +129,8 @@ def _get_rank_info():
     rank_size = int(os.environ.get("RANK_SIZE", 1))
 
     if rank_size > 1:
-        rank_size = get_group_size()
-        rank_id = get_rank()
+        rank_size = comm.get_group_size()
+        rank_id = comm.get_rank()
     else:
         rank_size = 1
         rank_id = 0
diff --git a/official/cv/mobilenetv1/src/mobilenet_v1.py b/official/cv/mobilenetv1/src/mobilenet_v1.py
index 98b9dae41..29bfbabba 100644
--- a/official/cv/mobilenetv1/src/mobilenet_v1.py
+++ b/official/cv/mobilenetv1/src/mobilenet_v1.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,8 +13,9 @@
 # limitations under the License.
 # ============================================================================
 
+import mindspore.ops as ops
 import mindspore.nn as nn
-from mindspore.ops import operations as P
+
 
 def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'):
     output = []
@@ -83,7 +84,7 @@ class MobileNetV1(nn.Cell):
                 features = features + (output,)
             return features
         output = self.network(x)
-        output = P.ReduceMean()(output, (2, 3))
+        output = ops.ReduceMean()(output, (2, 3))
         output = self.fc(output)
         return output
 
diff --git a/official/cv/mobilenetv1/src/model_utils/moxing_adapter.py b/official/cv/mobilenetv1/src/model_utils/moxing_adapter.py
index 830d19a6f..31233d6fc 100644
--- a/official/cv/mobilenetv1/src/model_utils/moxing_adapter.py
+++ b/official/cv/mobilenetv1/src/model_utils/moxing_adapter.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2021-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,8 +17,10 @@
 
 import os
 import functools
-from mindspore import context
-from mindspore.profiler import Profiler
+import zipfile
+import time
+
+import mindspore as ms
 from .config import config
 
 _global_sync_count = 0
@@ -43,13 +45,13 @@ def get_job_id():
     job_id = job_id if job_id != "" else "default"
     return job_id
 
+
 def sync_data(from_path, to_path):
     """
     Download data from remote obs to local directory if the first url is remote url and the second one is local path
     Upload data from local directory to remote obs in contrast.
     """
     import moxing as mox
-    import time
     global _global_sync_count
     sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
     _global_sync_count += 1
@@ -93,7 +95,7 @@ def moxing_wrapper(pre_process=None, post_process=None):
                     sync_data(config.train_url, config.output_path)
                     print("Workspace downloaded: ", os.listdir(config.output_path))
 
-                context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
+                ms.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
                 config.device_num = get_device_num()
                 config.device_id = get_device_id()
                 if not os.path.exists(config.output_path):
@@ -103,7 +105,7 @@ def moxing_wrapper(pre_process=None, post_process=None):
                     pre_process()
 
             if config.enable_profiling:
-                profiler = Profiler()
+                profiler = ms.profiler.Profiler()
 
             run_func(*args, **kwargs)
 
@@ -120,3 +122,64 @@ def moxing_wrapper(pre_process=None, post_process=None):
                     sync_data(config.output_path, config.train_url)
         return wrapped_func
     return wrapper
+
+
+def modelarts_process():
+    """ modelarts process """
+    def unzip(zip_file, save_dir):
+        s_time = time.time()
+        if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
+            zip_isexist = zipfile.is_zipfile(zip_file)
+            if zip_isexist:
+                fz = zipfile.ZipFile(zip_file, 'r')
+                data_num = len(fz.namelist())
+                print("Extract Start...")
+                print("unzip file num: {}".format(data_num))
+                data_print = int(data_num / 100) if data_num > 100 else 1
+                i = 0
+                for file in fz.namelist():
+                    if i % data_print == 0:
+                        print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
+                    i += 1
+                    fz.extract(file, save_dir)
+                print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
+                                                     int(int(time.time() - s_time) % 60)))
+                print("Extract Done")
+            else:
+                print("This is not zip.")
+        else:
+            print("Zip has been extracted.")
+
+    if config.need_modelarts_dataset_unzip:
+        zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
+        save_dir_1 = os.path.join(config.data_path)
+
+        sync_lock = "/tmp/unzip_sync.lock"
+
+        # Each server contains 8 devices as most
+        if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
+            print("Zip file path: ", zip_file_1)
+            print("Unzip file save dir: ", save_dir_1)
+            unzip(zip_file_1, save_dir_1)
+            print("===Finish extract data synchronization===")
+            try:
+                os.mknod(sync_lock)
+            except IOError:
+                pass
+
+        while True:
+            if os.path.exists(sync_lock):
+                break
+            time.sleep(1)
+
+        print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
+        print("#" * 200, os.listdir(save_dir_1))
+        print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
+
+        config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
+    config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path)
+
+
+def modelarts_export_preprocess():
+    """ modelarts export process """
+    config.file_name = os.path.join(config.output_path, config.file_name)
diff --git a/official/cv/mobilenetv1/train.py b/official/cv/mobilenetv1/train.py
index 1324a7c07..8c13edb49 100644
--- a/official/cv/mobilenetv1/train.py
+++ b/official/cv/mobilenetv1/train.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,29 +14,22 @@
 # ============================================================================
 """train mobilenet_v1."""
 import os
-import time
-from mindspore import context
-from mindspore import Tensor
-from mindspore.nn.optim.momentum import Momentum
-from mindspore.train.model import Model
-from mindspore.context import ParallelMode
-from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
-from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
-from mindspore.train.loss_scale_manager import FixedLossScaleManager
-from mindspore.train.serialization import load_checkpoint, load_param_into_net
-from mindspore.communication.management import init, get_rank, get_group_size
-from mindspore.common import set_seed
+
+import mindspore as ms
 import mindspore.nn as nn
+import mindspore.communication as comm
 import mindspore.common.initializer as weight_init
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
+
 from src.lr_generator import get_lr
 from src.CrossEntropySmooth import CrossEntropySmooth
 from src.mobilenet_v1 import mobilenet_v1 as mobilenet
 from src.model_utils.config import config
-from src.model_utils.moxing_adapter import moxing_wrapper
-from src.model_utils.device_adapter import get_device_id, get_device_num
+from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_process
+from src.model_utils.device_adapter import get_device_num
 
 
-set_seed(1)
+ms.set_seed(1)
 
 if config.dataset == 'cifar10':
     from src.dataset import create_dataset1 as create_dataset
@@ -44,66 +37,11 @@ else:
     from src.dataset import create_dataset2 as create_dataset
 
 
-def modelarts_pre_process():
-    def unzip(zip_file, save_dir):
-        import zipfile
-        s_time = time.time()
-        if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
-            zip_isexist = zipfile.is_zipfile(zip_file)
-            if zip_isexist:
-                fz = zipfile.ZipFile(zip_file, 'r')
-                data_num = len(fz.namelist())
-                print("Extract Start...")
-                print("unzip file num: {}".format(data_num))
-                data_print = int(data_num / 100) if data_num > 100 else 1
-                i = 0
-                for file in fz.namelist():
-                    if i % data_print == 0:
-                        print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
-                    i += 1
-                    fz.extract(file, save_dir)
-                print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
-                    int(int(time.time() - s_time) % 60)))
-                print("Extract Done")
-            else:
-                print("This is not zip.")
-        else:
-            print("Zip has been extracted.")
-
-    if config.need_modelarts_dataset_unzip:
-        zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
-        save_dir_1 = os.path.join(config.data_path)
-
-        sync_lock = "/tmp/unzip_sync.lock"
-
-        # Each server contains 8 devices as most
-        if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
-            print("Zip file path: ", zip_file_1)
-            print("Unzip file save dir: ", save_dir_1)
-            unzip(zip_file_1, save_dir_1)
-            print("===Finish extract data synchronization===")
-            try:
-                os.mknod(sync_lock)
-            except IOError:
-                pass
-
-        while True:
-            if os.path.exists(sync_lock):
-                break
-            time.sleep(1)
-
-        print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
-        print("#" * 200, os.listdir(save_dir_1))
-        print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
-
-        config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
-    config.save_checkpoint_path = config.output_path
-
 def init_weigth(net):
     # init weight
     if config.pre_trained:
-        param_dict = load_checkpoint(config.pre_trained)
-        load_param_into_net(net, param_dict)
+        param_dict = ms.load_checkpoint(config.pre_trained)
+        ms.load_param_into_net(net, param_dict)
     else:
         for _, cell in net.cells_and_names():
             if isinstance(cell, nn.Conv2d):
@@ -115,7 +53,8 @@ def init_weigth(net):
                                                              cell.weight.shape,
                                                              cell.weight.dtype))
 
-@moxing_wrapper(pre_process=modelarts_pre_process)
+
+@moxing_wrapper(pre_process=modelarts_process)
 def train_mobilenetv1():
     """ train_mobilenetv1 """
     if config.dataset == 'imagenet2012':
@@ -124,31 +63,31 @@ def train_mobilenetv1():
     ckpt_save_dir = config.save_checkpoint_path
 
     # init context
-    context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
+    ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False)
 
     # Set mempool block size in PYNATIVE_MODE for improving memory utilization, which will not take effect in GRAPH_MODE
-    context.set_context(mempool_block_size="31GB")
+    ms.set_context(mempool_block_size="31GB")
 
     if config.parameter_server:
-        context.set_ps_context(enable_ps=True)
+        ms.set_ps_context(enable_ps=True)
     device_id = int(os.getenv('DEVICE_ID', '0'))
     if config.run_distribute:
         if target == "Ascend":
-            context.set_context(device_id=device_id)
-            context.set_auto_parallel_context(device_num=get_device_num(), parallel_mode=ParallelMode.DATA_PARALLEL,
-                                              gradients_mean=True)
-            init()
-            context.set_auto_parallel_context(all_reduce_fusion_config=[75])
+            ms.set_context(device_id=device_id)
+            ms.set_auto_parallel_context(device_num=get_device_num(), parallel_mode=ms.ParallelMode.DATA_PARALLEL,
+                                         gradients_mean=True)
+            comm.init()
+            ms.set_auto_parallel_context(all_reduce_fusion_config=[75])
         # GPU target
         else:
-            init()
-            context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
-                                              gradients_mean=True)
-        ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
+            comm.init()
+            ms.set_auto_parallel_context(device_num=ms.get_group_size(), parallel_mode=ms.ParallelMode.DATA_PARALLEL,
+                                         gradients_mean=True)
+        ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(comm.get_rank()) + "/"
 
     # create dataset
     dataset = create_dataset(dataset_path=config.dataset_path, do_train=True, device_num=config.device_num,
-                             repeat_num=1, batch_size=config.batch_size, target=target)
+                             batch_size=config.batch_size, target=target)
     step_size = dataset.get_dataset_size()
 
     # define net
@@ -162,7 +101,7 @@ def train_mobilenetv1():
     lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
                 warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
                 lr_decay_mode=config.lr_decay_mode)
-    lr = Tensor(lr)
+    lr = ms.Tensor(lr)
 
     # define opt
     decayed_params = []
@@ -177,10 +116,10 @@ def train_mobilenetv1():
         group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
                         {'params': no_decayed_params},
                         {'order_params': net.trainable_params()}]
-        opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
+        opt = nn.Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
     else:
-        opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
-                       config.weight_decay)
+        opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
+                          config.weight_decay)
     # define loss, model
     if config.dataset == "imagenet2012":
         if not config.use_label_smooth:
@@ -188,13 +127,13 @@ def train_mobilenetv1():
         loss = CrossEntropySmooth(sparse=True, reduction="mean",
                                   smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
     else:
-        loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
-    loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
+        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+    loss_scale = ms.FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
     if target == "Ascend":
-        model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
-                      amp_level="O2", keep_batchnorm_fp32=False)
+        model = ms.Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
+                         amp_level="O2", keep_batchnorm_fp32=False)
     else:
-        model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
+        model = ms.Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
 
     # define callbacks
     time_cb = TimeMonitor(data_size=step_size)
@@ -210,5 +149,6 @@ def train_mobilenetv1():
     model.train(config.epoch_size - config.pretrain_epoch_size, dataset, callbacks=cb,
                 sink_size=dataset.get_dataset_size(), dataset_sink_mode=(not config.parameter_server))
 
+
 if __name__ == '__main__':
     train_mobilenetv1()
-- 
GitLab