Skip to content
Snippets Groups Projects
Unverified Commit a928bb4b authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2230 Mobilenet_v1 usability rectification

Merge pull request !2230 from 张毅辉/mobilenet_v1
parents 3c05fc54 d1a4fbdd
No related branches found
No related tags found
No related merge requests found
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,20 +14,15 @@
# ============================================================================
"""eval mobilenet_v1."""
import os
import time
from mindspore import context
from mindspore.common import set_seed
import mindspore as ms
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.CrossEntropySmooth import CrossEntropySmooth
from src.mobilenet_v1 import mobilenet_v1 as mobilenet
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num
from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_process
set_seed(1)
ms.set_seed(1)
if config.dataset == 'cifar10':
from src.dataset import create_dataset1 as create_dataset
......@@ -35,63 +30,6 @@ else:
from src.dataset import create_dataset2 as create_dataset
def modelarts_process():
""" modelarts process """
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
int(int(time.time() - s_time) % 60)))
print("Extract Done")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
print("#" * 200, os.listdir(save_dir_1))
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path)
@moxing_wrapper(pre_process=modelarts_process)
def eval_mobilenetv1():
""" eval_mobilenetv1 """
......@@ -101,10 +39,10 @@ def eval_mobilenetv1():
target = config.device_target
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False)
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
ms.set_context(device_id=device_id)
# create dataset
dataset = create_dataset(dataset_path=config.dataset_path, do_train=False, batch_size=config.batch_size,
......@@ -115,8 +53,8 @@ def eval_mobilenetv1():
net = mobilenet(class_num=config.class_num)
# load checkpoint
param_dict = load_checkpoint(config.checkpoint_path)
load_param_into_net(net, param_dict)
param_dict = ms.load_checkpoint(config.checkpoint_path)
ms.load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
......@@ -129,11 +67,12 @@ def eval_mobilenetv1():
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define model
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
model = ms.Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
# eval model
res = model.eval(dataset)
print("result:", res, "ckpt=", config.checkpoint_path)
if __name__ == '__main__':
eval_mobilenetv1()
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,32 +15,30 @@
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint
import mindspore as ms
from src.mobilenet_v1 import mobilenet_v1 as mobilenet
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_export_preprocess
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
ms.set_context(mode=ms.GRAPH_MODE, device_target=config.device_target)
def modelarts_process():
pass
@moxing_wrapper(pre_process=modelarts_process)
@moxing_wrapper(pre_process=modelarts_export_preprocess)
def export_mobilenetv1():
""" export_mobilenetv1 """
target = config.device_target
if target != "GPU":
context.set_context(device_id=get_device_id())
ms.set_context(device_id=get_device_id())
network = mobilenet(class_num=config.class_num)
load_checkpoint(config.ckpt_file, net=network)
ms.load_checkpoint(config.ckpt_file, net=network)
network.set_train(False)
input_data = Tensor(np.zeros([config.batch_size, 3, config.height, config.width]).astype(np.float32))
export(network, input_data, file_name=config.file_name, file_format=config.file_format)
input_data = ms.numpy.zeros([config.batch_size, 3, config.height, config.width]).astype(np.float32)
ms.export(network, input_data, file_name=config.file_name, file_format=config.file_format)
if __name__ == '__main__':
export_mobilenetv1()
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,26 +13,23 @@
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F
from mindspore.ops import operations as P
import mindspore.ops as ops
class CrossEntropySmooth(LossBase):
class CrossEntropySmooth(nn.LossBase):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.onehot = ops.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,22 +17,19 @@ create train or eval dataset.
"""
import os
from multiprocessing import cpu_count
import mindspore.common.dtype as mstype
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import get_rank, get_group_size
import mindspore.communication as comm
THREAD_NUM = 12 if cpu_count() >= 12 else 8
def create_dataset1(dataset_path, do_train, device_num=1, repeat_num=1, batch_size=32, target="Ascend"):
def create_dataset1(dataset_path, do_train, device_num=1, batch_size=32, target="Ascend"):
"""
create a train or evaluate cifar10 dataset for mobilenet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
......@@ -50,38 +47,35 @@ def create_dataset1(dataset_path, do_train, device_num=1, repeat_num=1, batch_si
trans = []
if do_train:
trans += [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip(prob=0.5)
ds.vision.c_transforms.RandomCrop((32, 32), (4, 4, 4, 4)),
ds.vision.c_transforms.RandomHorizontalFlip(prob=0.5)
]
trans += [
C.Resize((224, 224)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
ds.vision.c_transforms.Resize((224, 224)),
ds.vision.c_transforms.Rescale(1.0 / 255.0, 0.0),
ds.vision.c_transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
ds.vision.c_transforms.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
type_cast_op = ds.transforms.c_transforms.TypeCast(ms.int32)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def create_dataset2(dataset_path, do_train, device_num=1, repeat_num=1, batch_size=32, target="Ascend"):
def create_dataset2(dataset_path, do_train, device_num=1, batch_size=32, target="Ascend"):
"""
create a train or eval imagenet2012 dataset for mobilenet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
......@@ -103,21 +97,21 @@ def create_dataset2(dataset_path, do_train, device_num=1, repeat_num=1, batch_si
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
ds.vision.c_transforms.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
ds.vision.c_transforms.RandomHorizontalFlip(prob=0.5),
ds.vision.c_transforms.Normalize(mean=mean, std=std),
ds.vision.c_transforms.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
ds.vision.c_transforms.Decode(),
ds.vision.c_transforms.Resize(256),
ds.vision.c_transforms.CenterCrop(image_size),
ds.vision.c_transforms.Normalize(mean=mean, std=std),
ds.vision.c_transforms.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
type_cast_op = ds.transforms.c_transforms.TypeCast(ms.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=THREAD_NUM)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=THREAD_NUM)
......@@ -125,9 +119,6 @@ def create_dataset2(dataset_path, do_train, device_num=1, repeat_num=1, batch_si
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
......@@ -138,8 +129,8 @@ def _get_rank_info():
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
rank_size = comm.get_group_size()
rank_id = comm.get_rank()
else:
rank_size = 1
rank_id = 0
......
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,8 +13,9 @@
# limitations under the License.
# ============================================================================
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.ops import operations as P
def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'):
output = []
......@@ -83,7 +84,7 @@ class MobileNetV1(nn.Cell):
features = features + (output,)
return features
output = self.network(x)
output = P.ReduceMean()(output, (2, 3))
output = ops.ReduceMean()(output, (2, 3))
output = self.fc(output)
return output
......
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,8 +17,10 @@
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
import zipfile
import time
import mindspore as ms
from .config import config
_global_sync_count = 0
......@@ -43,13 +45,13 @@ def get_job_id():
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
......@@ -93,7 +95,7 @@ def moxing_wrapper(pre_process=None, post_process=None):
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
ms.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
......@@ -103,7 +105,7 @@ def moxing_wrapper(pre_process=None, post_process=None):
pre_process()
if config.enable_profiling:
profiler = Profiler()
profiler = ms.profiler.Profiler()
run_func(*args, **kwargs)
......@@ -120,3 +122,64 @@ def moxing_wrapper(pre_process=None, post_process=None):
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper
def modelarts_process():
""" modelarts process """
def unzip(zip_file, save_dir):
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
print("#" * 200, os.listdir(save_dir_1))
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path)
def modelarts_export_preprocess():
""" modelarts export process """
config.file_name = os.path.join(config.output_path, config.file_name)
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,29 +14,22 @@
# ============================================================================
"""train mobilenet_v1."""
import os
import time
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import set_seed
import mindspore as ms
import mindspore.nn as nn
import mindspore.communication as comm
import mindspore.common.initializer as weight_init
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from src.lr_generator import get_lr
from src.CrossEntropySmooth import CrossEntropySmooth
from src.mobilenet_v1 import mobilenet_v1 as mobilenet
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num
from src.model_utils.moxing_adapter import moxing_wrapper, modelarts_process
from src.model_utils.device_adapter import get_device_num
set_seed(1)
ms.set_seed(1)
if config.dataset == 'cifar10':
from src.dataset import create_dataset1 as create_dataset
......@@ -44,66 +37,11 @@ else:
from src.dataset import create_dataset2 as create_dataset
def modelarts_pre_process():
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
int(int(time.time() - s_time) % 60)))
print("Extract Done")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
print("#" * 200, os.listdir(save_dir_1))
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
config.save_checkpoint_path = config.output_path
def init_weigth(net):
# init weight
if config.pre_trained:
param_dict = load_checkpoint(config.pre_trained)
load_param_into_net(net, param_dict)
param_dict = ms.load_checkpoint(config.pre_trained)
ms.load_param_into_net(net, param_dict)
else:
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
......@@ -115,7 +53,8 @@ def init_weigth(net):
cell.weight.shape,
cell.weight.dtype))
@moxing_wrapper(pre_process=modelarts_pre_process)
@moxing_wrapper(pre_process=modelarts_process)
def train_mobilenetv1():
""" train_mobilenetv1 """
if config.dataset == 'imagenet2012':
......@@ -124,31 +63,31 @@ def train_mobilenetv1():
ckpt_save_dir = config.save_checkpoint_path
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
ms.set_context(mode=ms.GRAPH_MODE, device_target=target, save_graphs=False)
# Set mempool block size in PYNATIVE_MODE for improving memory utilization, which will not take effect in GRAPH_MODE
context.set_context(mempool_block_size="31GB")
ms.set_context(mempool_block_size="31GB")
if config.parameter_server:
context.set_ps_context(enable_ps=True)
ms.set_ps_context(enable_ps=True)
device_id = int(os.getenv('DEVICE_ID', '0'))
if config.run_distribute:
if target == "Ascend":
context.set_context(device_id=device_id)
context.set_auto_parallel_context(device_num=get_device_num(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
context.set_auto_parallel_context(all_reduce_fusion_config=[75])
ms.set_context(device_id=device_id)
ms.set_auto_parallel_context(device_num=get_device_num(), parallel_mode=ms.ParallelMode.DATA_PARALLEL,
gradients_mean=True)
comm.init()
ms.set_auto_parallel_context(all_reduce_fusion_config=[75])
# GPU target
else:
init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
comm.init()
ms.set_auto_parallel_context(device_num=ms.get_group_size(), parallel_mode=ms.ParallelMode.DATA_PARALLEL,
gradients_mean=True)
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(comm.get_rank()) + "/"
# create dataset
dataset = create_dataset(dataset_path=config.dataset_path, do_train=True, device_num=config.device_num,
repeat_num=1, batch_size=config.batch_size, target=target)
batch_size=config.batch_size, target=target)
step_size = dataset.get_dataset_size()
# define net
......@@ -162,7 +101,7 @@ def train_mobilenetv1():
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
lr = Tensor(lr)
lr = ms.Tensor(lr)
# define opt
decayed_params = []
......@@ -177,10 +116,10 @@ def train_mobilenetv1():
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
opt = nn.Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
else:
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay)
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay)
# define loss, model
if config.dataset == "imagenet2012":
if not config.use_label_smooth:
......@@ -188,13 +127,13 @@ def train_mobilenetv1():
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = ms.FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
if target == "Ascend":
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
model = ms.Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
else:
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
model = ms.Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
......@@ -210,5 +149,6 @@ def train_mobilenetv1():
model.train(config.epoch_size - config.pretrain_epoch_size, dataset, callbacks=cb,
sink_size=dataset.get_dataset_size(), dataset_sink_mode=(not config.parameter_server))
if __name__ == '__main__':
train_mobilenetv1()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment