Skip to content
Snippets Groups Projects
Unverified Commit ad02bd41 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1651 PDarts-GPU模型提交

Merge pull request !1651 from wqx/PDarts-GPU
parents 097cd4a3 72aeaec9
No related branches found
No related tags found
No related merge requests found
Showing
with 182 additions and 65 deletions
......@@ -3,7 +3,7 @@
<!-- TOC -->
- [目录](#目录)
- [PDarts描述](#pdarts描述)
- [PDarts描述](#PDarts描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
......@@ -15,10 +15,12 @@
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [推理过程](#推理过程)
- [导出MindIR](#导出mindir)
- [在Ascend310执行推理](#在ascend310执行推理)
- [导出MindIR](#导出MindIR)
- [在Ascend310执行推理](#在Ascend310执行推理)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练准确率结果](#训练准确率结果)
......@@ -61,12 +63,13 @@
# 环境要求
- 硬件(Ascend910)
- 硬件(Ascend/GPU)
- 使用Ascend或GPU处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# 快速入门
......@@ -77,9 +80,11 @@
bash ./scripts/run_standalone_train_ascend /data/cifar-10-binary ./output
bash ./scripts/run_distribution_train_ascend ../rank_table.json /data/cifar-10-binary ../output
bash ./scripts/run_distribution_train_gpu.sh 8 0,1,2,3,4,5,6,7 ../cifar-10-binary/ ./output/
# 评估示例
bash ./scripts/run_standalone_eval_ascend.sh /data/cifar-10-binary/val ./output/model_checkpoint.ckpt
bash ./scripts/run_standalone_eval_gpu.sh ../cifar-10-binary/val/ model_checkpoint.ckpt
```
# 脚本说明
......@@ -93,9 +98,12 @@
├── README.md // PDarts相关说明
├── scripts
│ ├── run_standalone_eval_ascend.sh // Ascend评估shell脚本
│ ├── run_standalone_eval_gpu.sh // GPU评估shell脚本
│ ├── run_export.sh // 导出模型shell脚本
│ ├── run_standalone_train_ascend.sh // Ascend单卡训练shell脚本
│ ├── run_standalone_train_gpu.sh // GPU单卡训练shell脚本
│ ├── run_distribution_train_ascend.sh // Ascend 8卡训练shell脚本
│ ├── run_distribution_train_gpu.sh // GPU 8卡训练shell脚本
│ ├── run_infer_310.sh // Ascend310环境推理shell脚本
├── src
│ ├── call_backs // 训练过程中的回调方法
......@@ -118,7 +126,7 @@
可通过`train.py`脚本中的参数修改训练行为。`train.py`脚本中的参数如下:
```bash
--device_target 设备类型,支持Ascend
--device_target 设备类型,支持Ascend、GPU
--local_data_root 数据拷贝的缓存目录(主要针对在modelarts上运行时使用)
--data_url 数据路径
--train_url 训练结果输出路径
......@@ -135,7 +143,7 @@
--auxiliary_weight auxiliary loss的权重比例,当auxiliary为True时有效
--drop_path_prob dropout的比例
--arch 模型架构,默认值为'PDARTS'
--amp_level 混合精度级别,Ascend910环境建议使用O3
--amp_level 混合精度级别,Ascend910环境建议使用O3,GPU环境建议用O2
--optimizer 训练用的优化器,默认使用Momentum
--cutout_length 数据的裁剪长度,默认为16
```
......@@ -144,14 +152,20 @@
### 训练
- Ascend910处理器环境运行
- Ascend910处理器、GPU环境运行
```bash
单卡
单卡Ascend910
bash ./scripts/run_standalone_train_ascend /data/cifar-10-binary ./output
8卡
8卡Ascend910
bash ./scripts/run_distribution_train_ascend ../rank_table.json /data/cifar-10-binary ../output
注:单卡训练启动脚本一共有2个参数,8卡训练脚本有3个参数,分别为[rank_table配置文件(8卡训练脚本需要使用)] [cifar10数据集路径] [训练输出路径]
注:单卡Ascend910训练启动脚本一共有2个参数,8卡训练脚本有3个参数,分别为[rank_table配置文件(8卡训练脚本需要使用)] [cifar10数据集路径] [训练输出路径]
单卡GPU
bash ./scripts/run_standalone_train_gpu.sh ./cifar-10-binary/ ./output/
8卡GPU
bash ./scripts/run_distribution_train_gpu.sh 8 0,1,2,3,4,5,6,7 ../cifar-10-binary/ ./output/
注:单卡GPU训练启动脚本一共有2个参数,8卡训练脚本有4个参数,分别为[DEVICE_NUM(8卡环境需要设置为8)][VISIABLE_DEVICES(0,1,2,3,4,5,6,7,即每个GPU分配的id,中间用逗号隔开)] [cifar10数据集路径] [训练输出路径]
```
cifar10数据集的要求格式
......@@ -205,6 +219,18 @@
8卡训练最终精度acc top1为97.01%,acc top5为99.91%
- GPU环境
运行以下命令进行评估。
```bash
bash ./scripts/run_standalone_eval_gpu.sh ../cifar-10-binary/val/ model_checkpoint.ckpt
注:评估脚本参数一共为两个,分别是[验证集路径] [ckpt文件路径];
数据集格式与上面训练过程相同,并选择cifar-10-binary/val部分进行评估;
```
8卡训练最终精度acc top1为97.205%,acc top5为99.939%
## 推理过程
### 导出MindIR
......@@ -233,31 +259,31 @@ bash ./scripts/run_infer_310.sh [MINDIR_PATH] [DATASET_PATH]
### 训练准确率结果
| 参数 | PDarts |
| ------------------- | --------------------------- |
| 模型版本 | PDarts |
| 资源 | Ascend 910 |
| 上传日期 | 2021/6/9 |
| MindSpore版本 | 1.2.0 Ascend |
| 数据集 | cifar10 |
| 轮次 | 600 |
| 输出 | 概率 |
| 损失 | 0.1574 |
| 总时间 | 单卡:约15小时 8卡:约3.2小时 |
| 训练精度 | 单卡:Top1:97.1%; Top5:99.93% 8卡:Top1:97.01%; Top5:99.91% |
| 模型 | PDarts | PDarts |
| ------------------- | --------------------------- | --------------------------- |
| 模型版本 | PDarts-Ascend | PDarts-GPU |
| 资源 | Ascend 910 | V100 |
| 上传日期 | 2021/6/9 | 2021/12/28 |
| MindSpore版本 | 1.2.0 Ascend | 1.5.0 GPU |
| 数据集 | cifar10 | cifar10 |
| 轮次 | 600 | 600 |
| 输出 | 概率 | 概率 |
| 损失 | 0.1574 | 0.1241 |
| 总时间 | 单卡:约15小时 8卡:约3.2小时 | 8卡:约7小时 |
| 训练精度 | 单卡:Top1:97.1%; Top5:99.93% 8卡:Top1:97.01%; Top5:99.91% | 8卡:Top1:97.205%; Top5:99.939% |
### 训练性能结果
| 参数 | PDarts |
| ------------------- | --------------------------- |
| 模型版本 | PDarts |
| 资源 | Ascend 910 |
| 上传日期 | 2021/6/9 |
| MindSpore版本 | 1.2.0 Ascend |
| 数据集 | cifar10 |
| batch_size | 单卡:128 8卡:32 |
| 输出 | 概率 |
| 速度 | 单卡:189.4ms/step 8卡:65.5ms/step |
| 模型 | PDarts | PDarts |
| ------------------- | --------------------------- | --------------------------- |
| 模型版本 | PDarts-Ascend | PDarts-GPU |
| 资源 | Ascend 910 | V100 |
| 上传日期 | 2021/6/9 | 2021/12/28 |
| MindSpore版本 | 1.2.0 Ascend | 1.5.0 GPU |
| 数据集 | cifar10 |cifar10|
| batch_size | 单卡:128 8卡:32 | 8卡:32 |
| 输出 | 概率 |概率|
| 速度 | 单卡:189.4ms/step 8卡:65.5ms/step | 8卡:约180ms/step |
# 随机情况说明
......
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -56,7 +56,7 @@ parser.add_argument("--train_batch_size", type=int,
parser.add_argument("--ckpt_file", type=str, required=True,
help="Checkpoint file path.")
parser.add_argument("--device_target", type=str, choices=["Ascend"], default="Ascend",
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU"], default="Ascend",
help="device target")
parser.add_argument('--amp_level', type=str, default='O3', help='')
args = parser.parse_args()
......
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 4 ]
then
echo "Usage: bash run_distribution_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [CIFAR10_DATA_PATH] [OUTPUT_PATH]"
exit 1
fi
if [ $1 -lt 1 ] || [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi
export RANK_SIZE=$1
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export CUDA_VISIBLE_DEVICES="$2"
CIFAR10_DATA_PATH=$3
OUTPUT_PATH=$4
mpirun -n $1 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python ${BASEPATH}/../train.py --device_target=GPU --data_url $CIFAR10_DATA_PATH \
--train_url $OUTPUT_PATH --optimizer SGD --load_weight None --no_top False \
--learning_rate 0.075 --batch_size 32 --amp_level=O2 > log 2>&1 &
\ No newline at end of file
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 2 ]
then
echo "Usage: bash run_standalone_eval_gpu.sh [VAL_DATA_PATH] [CKPT_FILE_PATH]"
exit 1
fi
VAL_DATA_PATH=$1
CKPT_FILE_PATH=$2
python eval.py --val_path=$VAL_DATA_PATH --ckpt_file=$CKPT_FILE_PATH --device_target=GPU --amp_level=O0
\ No newline at end of file
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 2 ]
then
echo "Usage: bash run_standalone_train_gpu.sh [CIFAR10_DATA_PATH] [OUTPUT_PATH]"
exit 1
fi
export RANK_SIZE=1
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
CIFAR10_DATA_PATH=$1
OUTPUT_PATH=$2
python ${BASEPATH}/../train.py --data_url $CIFAR10_DATA_PATH --train_url $OUTPUT_PATH --optimizer Momentum \
--load_weight None --no_top False --device_target=GPU --amp_level=O2
\ No newline at end of file
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -83,7 +83,7 @@ class Val_Callback(Callback):
"""
def __init__(self, model, train_dataset, val_dataset, checkpoint_path, prefix,
network, img_size, is_eval_train_dataset='False'):
network, img_size, device_id=0, is_eval_train_dataset='False'):
super(Val_Callback, self).__init__()
self.model = model
self.train_dataset = train_dataset
......@@ -93,6 +93,7 @@ class Val_Callback(Callback):
self.prefix = prefix
self.network = network
self.img_size = img_size
self.device_id = device_id
self.is_eval_train_dataset = is_eval_train_dataset
def epoch_end(self, run_context):
......@@ -117,8 +118,8 @@ class Val_Callback(Callback):
self.max_val_acc = val_acc
cb_params = run_context.original_args()
epoch = cb_params.cur_epoch_num
model_info = self.prefix + '_valacc' + \
str(val_acc) + '_epoch' + str(epoch)
model_info = self.prefix + '_id' + str(self.device_id) + \
'_epoch' + str(epoch) + '_valacc' + str(val_acc)
if self.checkpoint_path.startswith('s3://') or self.checkpoint_path.startswith('obs://'):
save_path = '/cache/save_model/'
else:
......
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,8 +13,6 @@
# limitations under the License.
# ============================================================================
"""Read train and eval data"""
import os
import mindspore.dataset as ds
from mindspore.common import dtype as mstype
import mindspore.dataset.transforms.c_transforms as C
......@@ -24,12 +22,11 @@ from mindspore.dataset.vision.utils import Inter
def create_cifar10_dataset(data_dir, training=True, repeat_num=1, num_parallel_workers=5,
resize_height=32, resize_width=32, batch_size=512,
num_samples=None, shuffle=None, cutout_length=0):
num_samples=None, shuffle=None, cutout_length=0, device_id=0, device_num=1):
"""Data operations."""
ds.config.set_seed(1)
ds.config.set_num_parallel_workers(num_parallel_workers)
device_id, device_num = get_device_info()
if training:
data_set = ds.Cifar10Dataset(data_dir, num_samples=num_samples,
shuffle=shuffle, num_shards=device_num, shard_id=device_id)
......@@ -71,9 +68,3 @@ def create_cifar10_dataset(data_dir, training=True, repeat_num=1, num_parallel_w
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
return data_set
def get_device_info():
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
return device_id, device_num
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment