Skip to content
Snippets Groups Projects
Unverified Commit 272db35c authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2667 Models: Inception-V2 GPU

Merge pull request !2667 from adenisov/models-pr-inception-v2
parents 0df24a66 9312951d
No related branches found
No related tags found
No related merge requests found
Showing
with 1419 additions and 0 deletions
# Contents
- [Inception-v2 Description](#inception-v2-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Feature](#feature)
- [Mixed precision(Ascend](#mixed-precision-ascend)
- [Environmental Requirements](#environmental-requirements)
- [Script Description](#script-description)
- [Script and sample code](#script-and-sample-code)
- [Script parameters](#script-paramenters)
- [Training process](#training-process)
- [Evaluation](#evaluation)
- [Export process](#export-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training performance](#training-performance)
- [Inception-v2 on ImageNet-1k](#inception-v2-on-imagenet-1k)
- [Evaluation performance](#evaluation-performance)
- [Inception-v2 on ImageNet-1k](#inception-v2-on-imagenet-1k)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Inception-v2 Description](#contents)
Google's Inception-v2 is the second release in a series of deep learning convolutional architectures. Inception-v2 mainly adds BatchNorm to Inception-v1 and modifies
the previous Inception architecture to reduce the consumption of computing resources. This idea was proposed in the article ['Rethinking the Inception Architecture for
Computer Vision'](https://arxiv.org/pdf/1512.00567.pdf) published in 2015.
# [Model Architecture](#contents)
The overall architecture of the Inception-v2 is described in the article:
[Paper](https://arxiv.org/pdf/1512.00567.pdf)
# [Dataset](#contents)
Dataset used: [ImageNet2012](http://www.image-net.org/)
- Dataset size: a total of 1000 classes, 224*224 color images
- Training set: 1,281,167 images in total
- Test set: 50,000 images in total
- Data format: JPEG
- Note: Data is processed in dataset.py.
- Download the dataset, the directory structure is as follows:
```text
└─dataset
├─train # training dataset
└─val # validation dataset
```
# [Feature](#contents)
## Mixed precision
The training method using [mixed precision](https://www.mindspore.cn/tutorials/experts/en/master/others/mixed_precision.html) uses support for single-precision and
half-precision data to improve the training speed of deep learning neural networks, while maintaining the network accuracy that single-precision training can achieve.
Mixed-precision training increases computational speed and reduces memory usage while enabling training of larger models on specific hardware or enabling larger batches of
training.
Taking the FP16 operator as an example, if the input data type is FP32, the MindSpore background will automatically reduce the precision to process the data. You can open
the INFO log and search for "reduce precision" to view operators with reduced precision.
# [Environmental Requirements](#contents)
- Hardware
- Use GPU to build the hardware environment
- Frame
- [MindSpore](https://www.mindspore.cn/install/en)
- For details,see the following resources:
- [MindSpore tutorial](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# [Script Description](#contents)
## Script and sample code
```text
Inception-v2
├── scripts
│ ├── run_distributed_train_gpu.sh # GPU 8 card training
│ ├── run_eval_gpu.sh # Evaluation on GPU
│ └── run_standalone_train_gpu.sh # GPU single-card training
├── src
│ ├── config.py # Get GPU, Ascend, and CPU configuration parameters
│ ├── dataset.py # Dataset
│ ├── inception_v2.py # Network definition
│ ├── loss.py # Custom cross entropy loss function
│ └── lr_generator.py # Learning rate generator
├── eval.py # Network evaluation script
├── export.py # Script for export into AIR and MINDIR formats
├── README.md # Readme (in English)
├── requirements.txt # The list of required packages
└── train.py # Network training script
```
## Script parameters
The main parameters in config.py are as follows:
```text
'decay_method': 'cosine' # Learning rate scheduler mode
"loss_scale": 1024 # Loss scale
'batch_size': 128 # Batch size
'epoch_size': 250 # The number of epochs
'num_classes': 1000 # The number of classes
'smooth_factor': 0.1 # Label smoothing factor
'lr_init': 4e-5 # Initial learning rate
'lr_max': 4e-1 # Maximum learning rate
'lr_end': 4e-6 # Minimum learning rate
'warmup_epochs': 1 # The number of warmup epochs
'weight_decay': 4e-5 # Weight decay
'momentum': 0.9 # Momentum
'opt_eps': 1.0 # Epsilon
'dropout_keep_prob': 0.8 # The probability to keep the input data for a dropout layer
'amp_level': O3 # The option of the parameter `level` in `mindspore.amp.build_train_network`, choose from [O0, O2, O3]
```
Refer to the script `config.py` for more configuration details.
## Training process
After installing MindSpore through the official website, run the following command from directory `research/cv/Inception-v2`:
```bash
pip install -r requirements.txt
```
Then you can follow the steps below for training and evaluation:
- GPU:
Training on GPU:
```bash
# standalone training
bash scripts/run_standalone_train_gpu.sh [DEVICE_ID] [DATASET_PATH] [<PRE_TRAINED_PATH>]
# multi-gpu training
bash scripts/run_distributed_train_gpu.sh [RANK_SIZE] [DATASET_PATH] [<PRE_TRAINED_PATH>]
```
Example:
```bash
# standalone training
bash scripts/run_standalone_train_gpu.sh 0 /path/to/imagenet
# multi-gpu training
bash scripts/run_distributed_train_gpu.sh 8 /path/to/imagenet
```
## Evaluation
- GPU:
Evaluation on GPU:
```bash
bash scripts/run_eval_gpu.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]
```
Example:
```bash
bash scripts/run_eval_gpu.sh 0 /path/to/imagenet /path/to/trained.ckpt
```
## Export process
```bash
python export.py --ckpt_file [CKPT_FILE] --platform [PLATFORM] --file_format [FILE FORMAT]
```
For FILE_FORMAT choose MINDIR or AIR.
Example:
```bash
python export.py --ckpt_file /path/to/trained.ckpt --platform GPU --file_format MINDIR
```
The exported model will be named after the structure of the model and saved in the current directory.
# [Model Description](#contents)
## Performance
### Training performance
#### Inception-v2 on ImageNet-1k
| Parameters | GPU |
|---------------------|--------------------------------------------------------|
| Model | Inception-v2 |
| Resources | GPU: GeForce RTX 3090 CPU 2.90GHz, 64 cores RAM:252G |
| Upload date | 05 / 13 / 2022 (mm / dd / yyyy) |
| MindSpore version | 1.6.1 |
| Dataset | ImageNet-1k Train,1,281,167 images |
| Training parameters | epoch=250, batch_size=128 |
| Optimizer | Momentum |
| Loss function | CrossEntropy |
| Loss | 1.8897 |
| Output | Probability |
| Accuracy | 8P: top1: 76.25% top5: 92.92% |
| Speed | 8P: 295 ms/step |
| Training time | 25h 37m 41s |
### Evaluation performance
#### Inception-v2 on ImageNet-1k
| Parameters | GPU |
|-------------------|--------------------------------------------------------|
| Model | Inception-v2 |
| Resources | GPU: GeForce RTX 3090 CPU 2.90GHz, 64 cores RAM:252G |
| Upload date | 05 / 13 / 2022 (mm / dd / yyyy) |
| MindSpore version | 1.6.1 |
| Dataset | ImageNet-1k Val,50,000 images |
| Accuracy | top1: 76.25% top5: 92.92% |
# [ModelZoo Homepage](#contents)
Please visit the official [website](https://gitee.com/mindspore/models)
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluate_imagenet"""
import argparse
import ast
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_gpu, config_ascend, config_cpu
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
from src.inception_v2 import inception_v2_base
from src.loss import CrossEntropy_Val
CFG_DICT = {
"Ascend": config_ascend,
"GPU": config_gpu,
"CPU": config_cpu,
}
DS_DICT = {
"imagenet": create_dataset_imagenet,
"cifar10": create_dataset_cifar10,
}
def run_eval():
"""run evaluation"""
parser = argparse.ArgumentParser(description='image classification evaluation')
parser.add_argument("--data_url", type=str, help='data path for eval')
parser.add_argument("--train_url", type=str, help='log')
parser.add_argument("--run_online", type=ast.literal_eval)
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v2 (Default: None)')
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
args_opt = parser.parse_args()
cfg = CFG_DICT[args_opt.platform]
cfg.work_nums = 1
if args_opt.run_online:
import moxing as mox
cfg.ckpt_path = "/cache/checkpoint_inceptionv2/checkpoint.ckpt"
Imagenet_root = "/cache/data_eval_url"
mox.file.copy_parallel(args_opt.data_url, Imagenet_root)
mox.file.copy_parallel(args_opt.checkpoint, cfg.ckpt_path)
else:
cfg.ckpt_path = args_opt.checkpoint
Imagenet_root = args_opt.data_url
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
create_dataset = DS_DICT[cfg.ds_type]
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
net = inception_v2_base(num_classes=cfg.num_classes)
ckpt = load_checkpoint(cfg.ckpt_path)
load_param_into_net(net, ckpt)
net.set_train(False)
cfg.rank = 0
cfg.group_size = 1
# eval dataset
root = os.path.join(Imagenet_root, 'val')
dataset = create_dataset(root, cfg, False)
loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes)
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
metrics = model.eval(dataset, dataset_sink_mode=cfg.ds_sink_mode)
print("metric: ", metrics)
if __name__ == '__main__':
run_eval()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export"""
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.common import dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config_gpu, config_ascend, config_cpu
from src.inception_v2 import inception_v2_base
CFG_DICT = {
"Ascend": config_ascend,
"GPU": config_gpu,
"CPU": config_cpu,
}
def run_export():
"""run export"""
parser = argparse.ArgumentParser(description='Inception-v2 export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="inceptionv2", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
parser.add_argument("--platform", type=str, choices=["Ascend", "GPU"], default="Ascend", help="platform")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
config = CFG_DICT[args.platform]
if args.platform == "Ascend":
context.set_context(device_id=args.device_id)
net = inception_v2_base(num_classes=config.num_classes)
param_dict = load_checkpoint(args.ckpt_file)
net.init_parameters_data()
load_param_into_net(net, param_dict)
net.set_train(False)
input_shp = [args.batch_size, 3, config.image_height, config.image_width]
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp), mstype.float32)
export(net, input_array, file_name=args.file_name, file_format=args.file_format)
if __name__ == '__main__':
run_export()
easydict
numpy
pyyaml
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: bash scripts/run_distributed_train_gpu.sh [RANK_SIZE] [DATASET_PATH]"
echo "or bash scripts/run_distributed_train_gpu.sh [RANK_SIZE] [DATASET_PATH] [PRE_TRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
export DEVICE_NUM=$1
export RANK_SIZE=$1
DATASET_PATH=$(get_real_path $2)
PRE_TRAINED_PATH=$(get_real_path $3)
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $PRE_TRAINED_PATH ] && [ $# == 4 ]
then
echo "error: PRE_TRAINED_PATH=$PRE_TRAINED_PATH is not a file"
exit 1
fi
if [ -d "./distributed_train" ]
then
rm -rf ./distributed_train
echo "Remove dir ./distributed_train"
fi
mkdir ./distributed_train
echo "Create a dir ./distributed_train"
cp ./train.py ./distributed_train
cp -r ./src ./distributed_train
cd ./distributed_train || exit
echo "Start training for $DEVICE_NUM devices"
env > env.log
mpirun -n $RANK_SIZE --allow-run-as-root \
--output-filename log_output \
--merge-stderr-to-stdout \
python train.py \
--is_distributed True \
--platform GPU \
--data_url $DATASET_PATH \
--train_url train_output \
--device_num $DEVICE_NUM > log 2>&1 &
cd ..
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: bash scripts/run_eval_gpu.sh [DEVICE_ID] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $2)
CHECKPOINT_PATH=$(get_real_path $3)
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $CHECKPOINT_PATH ]
then
echo "error: CHECKPOINT_PATH=$CHECKPOINT_PATH is not a file"
exit 1
fi
export DEVICE_ID=$1
export DEVICE_NUM=1
export RANK_SIZE=1
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ./eval.py ./eval
cp -r ./src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --data_url $DATASET_PATH \
--run_online False --platform GPU \
--checkpoint $CHECKPOINT_PATH > log 2>&1 &
cd ..
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: bash scripts/run_standalone_train_gpu.sh [DEVICE_ID] [DATASET_PATH]"
echo "or bash scripts/run_standalone_train_gpu.sh [DEVICE_ID] [DATASET_PATH] [PRE_TRAINED_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $2)
PRE_TRAINED_PATH=$(get_real_path $3)
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $PRE_TRAINED_PATH ] && [ $# == 3 ]
then
echo "error: PRE_TRAINED_PATH=$PRE_TRAINED_PATH is not a file"
exit 1
fi
export DEVICE_ID=$1
export DEVICE_NUM=1
export RANK_SIZE=1
export RANK_ID=0
if [ -d "./standalone_train" ]
then
rm -rf ./standalone_train
echo "Remove dir ./standalone_train"
fi
mkdir ./standalone_train
echo "Create a dir ./standalone_train."
cp ./train.py ./standalone_train
cp -r ./src ./standalone_train
cd ./standalone_train || exit
echo "Start training for device $DEVICE_ID"
env > env.log
if [ $# == 2 ]
then
python train.py \
--data_url $DATASET_PATH \
--train_url train_output \
--platform GPU \
--device_num 1 > log 2>&1 &
else
python train.py \
--data_url $DATASET_PATH \
--train_url train_output \
--platform GPU --device_num 1 \
--pre_trained $PRE_TRAINED_PATH > log 2>&1 &
fi
cd ..
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
config_gpu = edict({
'platform': 'GPU',
'image_height': 224,
'image_width': 224,
'random_seed': 1,
'work_nums': 8,
'decay_method': 'cosine',
"loss_scale": 1024,
'batch_size': 128,
'epoch_size': 250,
'num_classes': 1000,
'ds_type': 'imagenet',
'ds_sink_mode': False,
'smooth_factor': 0.1,
'aux_factor': 0.2,
'lr_init': 0.00004,
'lr_max': 0.4,
'lr_end': 0.000004,
'warmup_epochs': 1,
'weight_decay': 0.00004,
'momentum': 0.9,
'opt_eps': 1.0,
'keep_checkpoint_max': 10,
'ckpt_path': './',
'is_save_on_master': 1,
'dropout_keep_prob': 0.8,
'has_bias': False,
'amp_level': 'O3'
})
config_ascend = edict({
'platform': 'Ascend',
'image_height': 224,
'image_width': 224,
'random_seed': 1,
'work_nums': 8,
'decay_method': 'cosine',
"loss_scale": 1024,
'batch_size': 128,
'epoch_size': 250,
'num_classes': 1000,
'ds_type': 'imagenet',
'ds_sink_mode': True,
'smooth_factor': 0.1,
'aux_factor': 0.2,
'lr_init': 0.00004,
'lr_max': 0.4,
'lr_end': 0.000004,
'warmup_epochs': 1,
'weight_decay': 0.00004,
'momentum': 0.9,
'opt_eps': 1.0,
'keep_checkpoint_max': 10,
'ckpt_path': './',
'is_save_on_master': 0,
'dropout_keep_prob': 0.8,
'has_bias': False,
'amp_level': 'O3'
})
config_cpu = edict({
'random_seed': 1,
'work_nums': 8,
'decay_method': 'cosine',
"loss_scale": 1024,
'batch_size': 128,
'epoch_size': 120,
'num_classes': 10,
'ds_type': 'cifar10',
'ds_sink_mode': False,
'smooth_factor': 0.1,
'aux_factor': 0.2,
'lr_init': 0.00004,
'lr_max': 0.1,
'lr_end': 0.000004,
'warmup_epochs': 1,
'weight_decay': 0.00004,
'momentum': 0.9,
'opt_eps': 1.0,
'keep_checkpoint_max': 10,
'ckpt_path': './',
'is_save_on_master': 0,
'dropout_keep_prob': 0.8,
'has_bias': False,
'amp_level': 'O0',
})
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision
def create_dataset_cifar10(dataset_path, cfg, training, repeat_num=1):
"""Data operations."""
dataset_path = os.path.join(dataset_path, "cifar-10-batches-bin" if training else "cifar-10-verify-bin")
if cfg.group_size == 1:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True)
else:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True,
num_shards=cfg.group_size, shard_id=cfg.rank)
resize_height = cfg.image_height
resize_width = cfg.image_width
# define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = vision.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
# apply map operations on images
data_set = data_set.map(operations=type_cast_op, input_columns="label")
data_set = data_set.map(operations=c_trans, input_columns="image")
# apply batch operations
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
return data_set
def create_dataset_imagenet(dataset_path, cfg, training, repeat_num=1):
"""
create a train or eval imagenet2012 dataset for inceptionv2
Args:
dataset_path(string): the path of dataset.
cfg: config
training(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
Returns:
dataset
"""
if cfg.group_size == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True,
num_shards=cfg.group_size, shard_id=cfg.rank)
image_size = cfg.image_height
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if training:
transform_img = [
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision.RandomHorizontalFlip(prob=0.5),
vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
transform_img = [
vision.Decode(),
vision.Resize(256),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
transform_label = [C.TypeCast(mstype.int32)]
data_set = data_set.map(input_columns="image", num_parallel_workers=cfg.work_nums, operations=transform_img,
python_multiprocessing=True)
data_set = data_set.map(input_columns="label", num_parallel_workers=cfg.work_nums, operations=transform_label)
# apply batch operations
data_set = data_set.batch(cfg.batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""inceptionv2 net"""
import mindspore.nn as nn
import mindspore.ops as op
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
def weight_variable(stddev):
"""Weight variable."""
return TruncatedNormal(stddev)
class Conv2dBlock(nn.Cell):
"""Conv2dBlock"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0,
pad_mode="same", is_variable=True):
super(Conv2dBlock, self).__init__()
self.is_variable = is_variable
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init="XavierUniform")
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.9997)
self.relu = nn.ReLU()
def construct(self, x):
"""construct"""
x = self.conv_1(x)
x = self.bn(x)
x = self.relu(x)
return x
class depthwise_separable_conv(nn.Cell):
"Depthwise conv + Pointwise conv"
def __init__(self, in_channels, out_channels, kernel_size=1,
stride=1, padding=0, pad_mode="same"):
super(depthwise_separable_conv, self).__init__()
self.is_use_pointwise = out_channels
if out_channels is not None:
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
stride, padding=padding, pad_mode=pad_mode, group=in_channels,
weight_init="XavierUniform")
self.conv2 = nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=1, padding=0, weight_init="XavierUniform")
else:
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, group=in_channels,
padding=padding, pad_mode=pad_mode,
weight_init="XavierUniform")
self.Relu = nn.ReLU()
def construct(self, x):
"""construct"""
if self.is_use_pointwise is not None:
x = self.conv1(x)
x = self.conv2(x)
else:
x = self.conv1(x)
x = self.bn(x)
x = self.Relu(x)
return x
class Inception(nn.Cell):
"""Inception Block"""
def __init__(self, in_channels, n1x1, n3x3red_a, n3x3, n3x3red_b, n3x3red_b_2, pool_planes):
super(Inception, self).__init__()
self.b1 = Conv2dBlock(in_channels, n1x1, kernel_size=1, is_variable=False)
self.b2 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red_a, kernel_size=1),
Conv2dBlock(n3x3red_a, n3x3, kernel_size=3, padding=0, is_variable=False)])
self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red_b, kernel_size=1),
Conv2dBlock(n3x3red_b, n3x3red_b_2, kernel_size=3, padding=0, is_variable=False),
Conv2dBlock(n3x3red_b_2, n3x3red_b_2, kernel_size=3, padding=0,
is_variable=False)])
self.avgpool_op = op.AvgPool(pad_mode="SAME", kernel_size=3, strides=1)
self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1)
self.concat = op.Concat(axis=1)
def construct(self, x):
"""construct"""
branch1 = self.b1(x)
branch2 = self.b2(x)
branch3 = self.b3(x)
cell = self.avgpool_op(x)
branch4 = self.b4(cell)
return self.concat((branch1, branch2, branch3, branch4))
class Inception_2(nn.Cell):
"""Inception_2 Block"""
def __init__(self, in_channels, n3x3red_a, n3x3, n3x3red_b, n3x3red_b_2):
super(Inception_2, self).__init__()
self.b1 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red_a, kernel_size=1),
Conv2dBlock(n3x3red_a, n3x3, kernel_size=3,
stride=2, padding=0, is_variable=False)])
self.b2 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red_b, kernel_size=1),
Conv2dBlock(n3x3red_b, n3x3red_b_2, kernel_size=3, padding=0, is_variable=False),
Conv2dBlock(n3x3red_b_2, n3x3red_b_2, kernel_size=3,
padding=0, stride=2, is_variable=False)])
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="SAME")
self.concat = P.Concat(axis=1)
def construct(self, x):
"""construct"""
branch1 = self.b1(x)
branch2 = self.b2(x)
branch3 = self.maxpool(x)
return self.concat((branch1, branch2, branch3))
class Logits(nn.Cell):
"""Module for Loss"""
def __init__(self, num_classes=10, dropout_keep_prob=0.8):
super(Logits, self).__init__()
self.avg_pool = nn.AvgPool2d(7, pad_mode='valid')
self.dropout = nn.Dropout(keep_prob=dropout_keep_prob)
self.flatten = P.Flatten()
self.fc = nn.Dense(1024, num_classes)
def construct(self, x):
"""construct"""
x = self.avg_pool(x)
x = self.dropout(x)
x = self.flatten(x)
x = self.fc(x)
return x
class inception_v2_base(nn.Cell):
"""Detail for net"""
def __init__(self, num_classes=10, input_channels=3, use_separable_conv=False,
dropout_keep_prob=0.8, include_top=True):
super(inception_v2_base, self).__init__()
self.feature_map_channels = {'Conv2d_1a_7x7': 64, 'MaxPool_2a_3x3': 64,
'Conv2d_2b_1x1': 64, 'Conv2d_2c_3x3': 192,
'MaxPool_3a_3x3': 192, 'Mixed_3b': 256,
'Mixed_3c': 320, 'Mixed_4a': 576, 'Mixed_4b': 576,
'Mixed_4c': 576, 'Mixed_4d': 576, 'Mixed_4e': 576,
'Mixed_5a': 1024, 'Mixed_5b': 1024, 'Mixed_5c': 1024}
if use_separable_conv:
self.Conv2d_1a_7x7 = depthwise_separable_conv(input_channels, 64,
kernel_size=7, stride=2, padding=0)
else:
self.Conv2d_1a_7x7 = Conv2dBlock(input_channels, 64, kernel_size=7, stride=2)
self.MaxPool_2a_3x3 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.Conv2d_2b_1x1 = Conv2dBlock(64, 64, kernel_size=1)
self.Conv2d_2c_3x3 = Conv2dBlock(64, 192, kernel_size=3, padding=0)
self.MaxPool_3a_3x3 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.Mixed_3b = Inception(192, 64, 64, 64, 64, 96, 32)
self.Mixed_3c = Inception(256, 64, 64, 96, 64, 96, 64)
self.Mixed_4a = Inception_2(320, 128, 160, 64, 96)
self.Mixed_4b = Inception(576, 224, 64, 96, 96, 128, 128)
self.Mixed_4c = Inception(576, 192, 96, 128, 96, 128, 128)
self.Mixed_4d = Inception(576, 160, 128, 160, 128, 160, 96)
self.Mixed_4e = Inception(576, 96, 128, 192, 160, 192, 96)
self.Mixed_5a = Inception_2(576, 128, 192, 192, 256)
self.Mixed_5b = Inception(1024, 352, 192, 320, 160, 224, 128)
self.Mixed_5c = Inception(1024, 352, 192, 320, 192, 224, 128)
self.include_top = include_top
if self.include_top:
self.logits = Logits(num_classes, dropout_keep_prob)
def construct(self, inputs):
"""inceptionv2 construct"""
end_points = {}
temp_point = 'Conv2d_1a_7x7'
net = self.Conv2d_1a_7x7(inputs)
end_points[temp_point] = net
temp_point = 'MaxPool_2a_3x3'
net = self.MaxPool_2a_3x3(net)
end_points[temp_point] = net
temp_point = 'Conv2d_2b_1x1'
net = self.Conv2d_2b_1x1(net)
end_points[temp_point] = net
temp_point = 'Conv2d_2c_3x3'
net = self.Conv2d_2c_3x3(net)
end_points[temp_point] = net
temp_point = 'MaxPool_3a_3x3'
net = self.MaxPool_3a_3x3(net)
end_points[temp_point] = net
temp_point = 'Mixed_3b'
net = self.Mixed_3b(net)
end_points[temp_point] = net
temp_point = 'Mixed_3c'
net = self.Mixed_3c(net)
end_points[temp_point] = net
# 28 x 28 x 320
temp_point = 'Mixed_4a'
net = self.Mixed_4a(net)
end_points[temp_point] = net
# 14 x 14 x 576
temp_point = 'Mixed_4b'
net = self.Mixed_4b(net)
end_points[temp_point] = net
# 14 x 14 x 576
temp_point = 'Mixed_4c'
net = self.Mixed_4c(net)
end_points[temp_point] = net
# 14 x 14 x 576
temp_point = 'Mixed_4d'
net = self.Mixed_4d(net)
end_points[temp_point] = net
# 14 x 14 x 576
temp_point = 'Mixed_4e'
net = self.Mixed_4e(net)
end_points[temp_point] = net
# 14 x 14 x 576
temp_point = 'Mixed_5a'
net = self.Mixed_5a(net)
end_points[temp_point] = net
temp_point = 'Mixed_5b'
net = self.Mixed_5b(net)
end_points[temp_point] = net
temp_point = 'Mixed_5c'
net = self.Mixed_5c(net)
end_points[temp_point] = net
if not self.include_top:
return net
logits = self.logits(net)
return logits
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network."""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropy(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
"""construct"""
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss
class CrossEntropy_Val(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
def __init__(self, smooth_factor=0, num_classes=1000):
super(CrossEntropy_Val, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logits, label):
"""construct"""
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss_logit = self.ce(logits, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
return loss_logit
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps, global_step=0):
"""
Applies three steps decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)[global_step:]
return lr_each_step
def _generate_exponential_lr(lr_init, lr_max, total_steps, warmup_steps, steps_per_epoch):
"""
Applies exponential decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array.
"""
lr_each_step = []
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
decay_nums = math.floor((float(i - warmup_steps) / steps_per_epoch) / 2)
decay_rate = pow(0.94, decay_nums)
lr = float(lr_max) * decay_rate
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
return lr_each_step
def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps, global_step=0):
"""
Applies cosine decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))
lr = (lr_max - lr_end) * cosine_decay + lr_end
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)[global_step:]
return lr_each_step
def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies liner decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
return lr_each_step
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode, global_step=0):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, steps_decay, cosine or liner(default)
Returns:
np.array, learning rate array
"""
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps, global_step)
elif lr_decay_mode == 'steps_decay':
lr_each_step = _generate_exponential_lr(lr_init, lr_max, total_steps, warmup_steps, steps_per_epoch)
elif lr_decay_mode == 'cosine':
lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps, global_step)
else:
lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
learning_rate = np.array(lr_each_step).astype(np.float32)
return learning_rate
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import argparse
import ast
import os
import random
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore import nn
from mindspore.common import set_seed
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_gpu, config_ascend, config_cpu
from src.dataset import create_dataset_imagenet
from src.inception_v2 import inception_v2_base
from src.loss import CrossEntropy
from src.lr_generator import get_lr
CFG_DICT = {
"Ascend": config_ascend,
"GPU": config_gpu,
"CPU": config_cpu,
}
DS_DICT = {
"imagenet": create_dataset_imagenet
}
def set_random_seed(i):
"""sets random seed"""
set_seed(i)
np.random.seed(i)
random.seed(i)
def run_train():
"""run train"""
parser = argparse.ArgumentParser(description='image classification training')
parser.add_argument("--data_url", type=str, help="dataset path.")
parser.add_argument("--device_num", type=int, default=8, help="Use device nums, default is 8.")
parser.add_argument("--train_url", type=str, help="train_out path.")
parser.add_argument("--run_online", type=ast.literal_eval, default=False, help="whether run online.")
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
parser.add_argument("--is_distributed", type=ast.literal_eval, default=False,
help="Use one card or multiple cards training.")
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
args_opt = parser.parse_args()
cfg = CFG_DICT[args_opt.platform]
set_random_seed(cfg.random_seed)
create_dataset = DS_DICT[cfg.ds_type]
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
Imagenet_root = args_opt.data_url
if not os.path.exists(args_opt.train_url):
os.makedirs(args_opt.train_url, exist_ok=True)
local_train_url = args_opt.train_url
# create dataset on cache
if args_opt.run_online:
import moxing as mox
Imagenet_root = "/cache/data_train"
mox.file.copy_parallel(args_opt.data_url, Imagenet_root)
local_train_url = "/cache/train_out_si"
if args_opt.is_distributed:
init()
cfg.rank = get_rank()
cfg.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
gradients_mean=True)
else:
cfg.rank = 0
cfg.group_size = 1
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
# dataloader
root = os.path.join(Imagenet_root, 'train')
dataset = create_dataset(root, cfg, True)
batches_per_epoch = dataset.get_dataset_size()
net = inception_v2_base(num_classes=cfg.num_classes, dropout_keep_prob=cfg.dropout_keep_prob)
# loss
loss = CrossEntropy(smooth_factor=cfg.smooth_factor, num_classes=cfg.num_classes)
# learning rate schedule
lr = Tensor(get_lr(lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs,
total_epochs=cfg.epoch_size, steps_per_epoch=batches_per_epoch, lr_decay_mode=cfg.decay_method))
group_params = filter(lambda x: x.requires_grad, net.get_parameters())
opt = nn.Momentum(group_params, lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay, loss_scale=cfg.loss_scale)
if args_opt.resume != '':
ckpt = load_checkpoint(args_opt.resume)
load_param_into_net(net, ckpt)
if args_opt.platform in ("Ascend", "GPU"):
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, amp_level=cfg.amp_level,
loss_scale_manager=loss_scale_manager)
else:
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, amp_level=cfg.amp_level)
print("============== Starting Training ==============", flush=True)
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
time_cb = TimeMonitor(data_size=batches_per_epoch)
callbacks = [loss_cb, time_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
save_ckpt_path = os.path.join(local_train_url, 'ckpt_' + str(cfg.rank) + '/')
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionv2-rank{cfg.rank}", directory=save_ckpt_path, config=config_ck)
if args_opt.is_distributed and cfg.is_save_on_master:
if cfg.rank == 0:
callbacks.append(ckpoint_cb)
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=cfg.ds_sink_mode)
else:
callbacks.append(ckpoint_cb)
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=cfg.ds_sink_mode)
if args_opt.run_online:
mox.file.copy_parallel(local_train_url, args_opt.train_url)
print("train success", flush=True)
if __name__ == '__main__':
run_train()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment