Skip to content
Snippets Groups Projects
Commit 23da2be6 authored by iamhankai's avatar iamhankai
Browse files

add ViG

update research/cv/ViG/src/models/vig/vig.py.

update

update

update

update

update research/cv/ViG/README_CN.md.

update research/cv/ViG/eval.py.

update research/cv/ViG/src/tools/get_misc.py.

update research/cv/ViG/train.py.

update to 2022
parent f5d1b654
No related branches found
No related tags found
No related merge requests found
Showing
with 2286 additions and 0 deletions
# 目录
<!-- TOC -->
- [目录](#目录)
- [描述](#描述)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练和测试](#训练和测试)
- [训练结果](#训练结果)
- [结果](#结果)
- [导出过程](#导出过程)
- [导出](#导出)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# [描述](#目录)
图神经网络(GNN)是一种最初用于图数据任务的神经网络。本文首次提出了基于纯GNN的模型来解决通用性计算机视觉问题,如图像识别、目标检测等。视觉GNN(即ViG)将图像视为一个由patch作为节点的图结构,使用GNN来处理该图结构,进行节点之间的信息交互和特征变换。通过堆叠ViG模块,作者建立了用于图像识别的ViG模型。
论文:Kai Han, Yunhe Wang, Jianyuan Guo, Yehui Tang, Enhua Wu. Vision GNN: An Image is Worth Graph of Nodes. 2022. [paper link](https://arxiv.org/abs/2206.00272)
# [数据集](#目录)
使用的数据集:[ImageNet2012](http://www.image-net.org/)
- 数据集大小:共1000个类、224*224彩色图像
- 训练集:共1,281,167张图像
- 测试集:共50,000张图像
- 数据格式:JPEG
- 注:数据在dataset.py中处理。
- 下载数据集,目录结构如下:
```text
└─dataset
├─train # 训练数据集
└─val # 评估数据集
```
# [特性](#目录)
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorials/experts/zh-CN/master/others/mixed_precision.html)的训练方法,使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
# [环境要求](#目录)
- 硬件(Ascend)
- 使用Ascend来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/r1.3/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/r1.3/index.html)
# [脚本说明](#目录)
## 脚本及样例代码
```text
├── ViG
├── README_CN.md // ViG相关说明
├── scripts
├──run_standalone_train_ascend.sh // 单卡Ascend910训练脚本
├──run_distribute_train_ascend.sh // 多卡Ascend910训练脚本
├──run_eval_ascend.sh // 测试脚本
├── src
├──configs // ViG的配置文件
├──data // 数据集配置文件
├──imagenet.py // imagenet配置文件
├──augment // 数据增强函数文件
┕──data_utils // modelarts运行时数据集复制函数文件
│ ├──models // 模型定义文件夹
┕──ViG // ViG定义文件
│ ├──trainers // 自定义TrainOneStep文件
│ ├──tools // 工具文件夹
├──callback.py // 自定义回调函数,训练结束测试
├──cell.py // 一些关于cell的通用工具函数
├──criterion.py // 关于损失函数的工具函数
├──get_misc.py // 一些其他的工具函数
├──optimizer.py // 关于优化器和参数的函数
┕──schedulers.py // 学习率衰减的工具函数
├── train.py // 训练文件
├── eval.py // 评估文件
├── export.py // 导出模型文件
├── postprocess.py // 推理计算精度文件
├── preprocess.py // 推理预处理图片文件
```
## 脚本参数
在vig_s_patch16_224.yaml中可以同时配置训练参数和评估参数。
- 配置ViG和ImageNet-1k数据集。
```python
# Architecture
arch: vig_s_patch16_224 # ViG结构选择
# ===== Dataset ===== #
data_url: ./data/imagenet # 数据集地址
set: ImageNet # 数据集名字
num_classes: 1000 # 数据集分类数目
mix_up: 0.8 # MixUp数据增强参数
cutmix: 1.0 # CutMix数据增强参数
auto_augment: rand-m9-mstd0.5-inc1 # AutoAugment参数
interpolation: bicubic # 图像缩放插值方法
re_prob: 0.25 # 数据增强参数
re_mode: pixel # 数据增强参数
re_count: 1 # 数据增强参数
mixup_prob: 1. # 数据增强参数
switch_prob: 0.5 # 数据增强参数
mixup_mode: batch # 数据增强参数
# ===== Learning Rate Policy ======== #
optimizer: adamw # 优化器类别
base_lr: 0.0005 # 基础学习率
warmup_lr: 0.00000007 # 学习率热身初始学习率
min_lr: 0.000006 # 最小学习率
lr_scheduler: cosine_lr # 学习率衰减策略
warmup_length: 20 # 学习率热身轮数
image_size: 224 # 图像大小
# ===== Network training config ===== #
amp_level: O2 # 混合精度策略
beta: [ 0.9, 0.999 ] # adamw参数
clip_global_norm_value: 5. # 全局梯度范数裁剪阈值
is_dynamic_loss_scale: True # 是否使用动态缩放
epochs: 300 # 训练轮数
label_smoothing: 0.1 # 标签平滑参数
weight_decay: 0.05 # 权重衰减参数
momentum: 0.9 # 优化器动量
batch_size: 128 # 批大小
# ===== Hardware setup ===== #
num_parallel_workers: 16 # 数据预处理线程数
device_target: Ascend # GPU或者Ascend
```
更多配置细节请参考脚本`vig_s_patch16_224.yaml`。 通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
# [训练和测试](#目录)
- Ascend处理器环境运行
```bash
# 使用python启动单卡训练
python train.py --device_id 0 --device_target Ascend --vig_config ./src/configs/vig_s_patch16_224.yaml \
> train.log 2>&1 &
# 使用脚本启动单卡训练
bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH]
# 使用脚本启动多卡训练
bash ./scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [CONFIG_PATH]
# 使用python启动单卡运行评估示例
python eval.py --device_id 0 --device_target Ascend --vig_config ./src/configs/vig_s_patch16_224.yaml \
--pretrained ./ckpt_0/vig_s_patch16_224.ckpt > ./eval.log 2>&1 &
# 使用脚本启动单卡运行评估示例
bash ./scripts/run_eval_ascend.sh [DEVICE_ID] [CONFIG_PATH] [CHECKPOINT_PATH]
```
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
请遵循以下链接中的说明:
[hccl工具](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)
## 训练结果
### 结果
使用ImageNet-1k数据集进行训练和测试,可以找到类似以下的结果。
```shell
# result
Top1 acc: 0.804
Top5 acc: 0.952
```
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
|模型|ViG|
| 模型版本 | vig_s_patch16_224 |
| 资源 | Ascend 910 |
| 上传日期 | 2022-06-06 |
| MindSpore版本 | 1.7.0 |
| 数据集 | ImageNet-1k Train,共1,281,167张图像 |
| 训练参数 | epoch=300, batch_size=128 |
| 优化器 | AdamWeightDecay |
| 损失函数 | SoftTargetCrossEntropy |
| 损失| 0.9680 |
| 输出 | 概率 |
| 分类准确率 | 八卡:top1:80.4% top5:95.2% |
| 速度 | 八卡:1755毫秒/步 |
| 训练耗时 |212h40min(run on ModelArts)|
## 导出过程
### 导出
```shell
python export.py --pretrained [CKPT_FILE] --vig_config [CONFIG_PATH] --device_target [DEVICE_TARGET]
```
导出的模型会以模型的结构名字命名并且保存在当前目录下
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/models)
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval"""
from mindspore import Model
from mindspore import context
from mindspore import nn
from mindspore.common import set_seed
from src.args import args
from src.tools.cell import cast_amp
from src.tools.criterion import get_criterion, NetWithLoss
from src.tools.get_misc import get_dataset, set_device, get_model, pretrained, get_train_one_step
from src.tools.optimizer import get_optimizer
set_seed(args.seed)
def main():
mode = {
0: context.GRAPH_MODE,
1: context.PYNATIVE_MODE
}
context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
context.set_context(enable_graph_kernel=False)
set_device(args)
# get model
net = get_model(args)
cast_amp(net)
criterion = get_criterion(args)
net_with_loss = NetWithLoss(net, criterion)
if args.pretrained:
pretrained(args, net)
data = get_dataset(args, training=False)
batch_num = data.val_dataset.get_dataset_size()
optimizer = get_optimizer(args, net, batch_num)
# save a yaml file to read to record parameters
net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
eval_indexes = [0, 1, 2]
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net_with_loss, metrics=eval_metrics,
eval_network=eval_network,
eval_indexes=eval_indexes)
print(f"=> begin eval")
results = model.eval(data.val_dataset)
print(f"=> eval results:{results}")
print(f"=> eval success")
if __name__ == '__main__':
main()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx or mindir model#################
python export.py
"""
import numpy as np
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from mindspore import dtype as mstype
from src.args import args
from src.tools.cell import cast_amp
from src.tools.criterion import get_criterion, NetWithLoss
from src.tools.get_misc import get_model
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target in ["Ascend", "GPU"]:
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
net = get_model(args)
criterion = get_criterion(args)
cast_amp(net)
net_with_loss = NetWithLoss(net, criterion)
assert args.pretrained is not None, "checkpoint_path is None."
param_dict = load_checkpoint(args.pretrained)
load_param_into_net(net, param_dict)
net.set_train(False)
net.to_float(mstype.float32)
input_arr = Tensor(np.zeros([1, 3, args.image_size, args.image_size], np.float32))
export(net, input_arr, file_name=args.arch, file_format=args.file_format)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""postprocess for 310 inference"""
import argparse
import json
import os
import numpy as np
from mindspore.nn import Top1CategoricalAccuracy, Top5CategoricalAccuracy
parser = argparse.ArgumentParser(description="postprocess")
parser.add_argument("--result_dir", type=str, default="./result_Files", help="result files path.")
parser.add_argument('--dataset_name', type=str, choices=["imagenet2012"], default="imagenet2012")
args = parser.parse_args()
def calcul_acc(lab, preds):
return sum(1 for x, y in zip(lab, preds) if x == y) / len(lab)
if __name__ == '__main__':
batch_size = 1
top1_acc = Top1CategoricalAccuracy()
rst_path = args.result_dir
label_list = []
pred_list = []
file_list = os.listdir(rst_path)
top5_acc = Top5CategoricalAccuracy()
with open('./preprocess_Result/imagenet_label.json', "r") as label:
labels = json.load(label)
for f in file_list:
label = f.split("_0.bin")[0] + ".JPEG"
label_list.append(labels[label])
pred = np.fromfile(os.path.join(rst_path, f), np.float32)
pred = pred.reshape(batch_size, int(pred.shape[0] / batch_size))
top1_acc.update(pred, [labels[label],])
top5_acc.update(pred, [labels[label],])
print("Top1 acc: ", top1_acc.eval())
print("Top5 acc: ", top5_acc.eval())
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""preprocess"""
import argparse
import json
import os
parser = argparse.ArgumentParser('preprocess')
parser.add_argument('--dataset_name', type=str, choices=["imagenet2012"], default="imagenet2012")
parser.add_argument('--data_path', type=str, default='', help='eval data dir')
def create_label(result_path, dir_path):
"""
create_label
"""
dirs = os.listdir(dir_path)
file_list = []
for file in dirs:
file_list.append(file)
file_list = sorted(file_list)
total = 0
img_label = {}
for i, file_dir in enumerate(file_list):
files = os.listdir(os.path.join(dir_path, file_dir))
for f in files:
img_label[f] = i
total += len(files)
json_file = os.path.join(result_path, "imagenet_label.json")
with open(json_file, "w+") as label:
json.dump(img_label, label)
print("[INFO] Completed! Total {} data.".format(total))
args = parser.parse_args()
if __name__ == "__main__":
create_label('./preprocess_Result/', args.data_path)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: bash ./scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [CONFIG_PATH]"
exit 1
fi
export RANK_TABLE_FILE=$1
CONFIG_PATH=$2
export RANK_SIZE=8
export DEVICE_NUM=8
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf train_parallel$i
mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i
cp *.py ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $i, device $DEVICE_ID rank_id $RANK_ID"
env > env.log
taskset -c $cmdopt python -u ../train.py \
--device_target Ascend \
--device_id $i \
--vig_config=$CONFIG_PATH > log.txt 2>&1 &
cd ../
done
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 3 ]
then
echo "Usage: bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH] [CHECKPOINT_PATH]"
exit 1
fi
export DEVICE_ID=$1
CONFIG_PATH=$2
CHECKPOINT_PATH=$3
export RANK_SIZE=1
export DEVICE_NUM=1
rm -rf evaluation_ascend
mkdir ./evaluation_ascend
cd ./evaluation_ascend || exit
echo "start training for device id $DEVICE_ID"
env > env.log
python ../eval.py --device_target=Ascend --device_id=$DEVICE_ID --vig_config=$CONFIG_PATH --pretrained=$CHECKPOINT_PATH > eval.log 2>&1 &
cd ../
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH]"
exit 1
fi
export RANK_SIZE=1
export DEVICE_NUM=1
export DEVICE_ID=$1
CONFIG_PATH=$2
rm -rf train_standalone
mkdir ./train_standalone
cd ./train_standalone || exit
echo "start training for device id $DEVICE_ID"
env > env.log
python -u ../train.py \
--device_id=$DEVICE_ID \
--device_target="Ascend" \
--vig_config=$CONFIG_PATH > log.txt 2>&1 &
cd ../
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""global args for Vision GNN (ViG)"""
import argparse
import ast
import os
import sys
import yaml
from src.configs import parser as _parser
args = None
def parse_arguments():
"""parse_arguments"""
global args
parser = argparse.ArgumentParser(description="MindSpore ViG Training")
parser.add_argument("-a", "--arch", metavar="ARCH", default="ResNet18", help="model architecture")
parser.add_argument("--accumulation_step", default=1, type=int, help="accumulation step")
parser.add_argument("--amp_level", default="O2", choices=["O0", "O2", "O3"], help="AMP Level")
parser.add_argument("--batch_size", default=128, type=int, metavar="N",
help="mini-batch size (default: 256), this is the total "
"batch size of all GPUs on the current node when "
"using Data Parallel or Distributed Data Parallel")
parser.add_argument("--beta", default=[0.9, 0.999], type=lambda x: [float(a) for a in x.split(",")],
help="beta for optimizer")
parser.add_argument("--clip_global_norm_value", default=5., type=float, help="Clip grad value")
parser.add_argument('--data_url', default="./data", help='Location of data.')
parser.add_argument("--device_id", default=0, type=int, help="Device Id")
parser.add_argument("--device_num", default=1, type=int, help="device num")
parser.add_argument("--device_target", default="GPU", choices=["GPU", "Ascend", "CPU"], type=str)
parser.add_argument("--epochs", default=300, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument("--eps", default=1e-8, type=float)
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--in_channel", default=3, type=int)
parser.add_argument("--is_dynamic_loss_scale", default=1, type=int, help="is_dynamic_loss_scale ")
parser.add_argument("--keep_checkpoint_max", default=20, type=int, help="keep checkpoint max num")
parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd")
parser.add_argument("--set", help="name of dataset", type=str, default="ImageNet")
parser.add_argument("--graph_mode", default=1, type=int, help="graph mode with 0, python with 1")
parser.add_argument("--mix_up", default=0., type=float, help="mix up")
parser.add_argument("--mlp_ratio", help="mlp ", default=4., type=float)
parser.add_argument("-j", "--num_parallel_workers", default=20, type=int, metavar="N",
help="number of data loading workers (default: 20)")
parser.add_argument("--start_epoch", default=0, type=int, metavar="N",
help="manual epoch number (useful on restarts)")
parser.add_argument("--warmup_length", default=0, type=int, help="Number of warmup iterations")
parser.add_argument("--warmup_lr", default=5e-7, type=float, help="warm up learning rate")
parser.add_argument("--wd", "--weight_decay", default=0.05, type=float, metavar="W",
help="weight decay (default: 1e-4)", dest="weight_decay")
parser.add_argument("--loss_scale", default=1024, type=int, help="loss_scale")
parser.add_argument("--lr", "--learning_rate", default=2e-3, type=float, help="initial lr", dest="lr")
parser.add_argument("--lr_scheduler", default="cosine_annealing", help="Schedule for the learning rate.")
parser.add_argument("--lr_adjust", default=30, type=float, help="Interval to drop lr")
parser.add_argument("--lr_gamma", default=0.97, type=int, help="Multistep multiplier")
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument("--num_classes", default=1000, type=int)
parser.add_argument("--pretrained", dest="pretrained", default=None, type=str, help="use pre-trained model")
parser.add_argument("--vig_config", help="Config file to use (see configs dir)", default=None, required=True)
parser.add_argument("--seed", default=0, type=int, help="seed for initializing training. ")
parser.add_argument("--save_every", default=2, type=int, help="Save every ___ epochs(default:2)")
parser.add_argument("--label_smoothing", type=float, help="Label smoothing to use, default 0.0", default=0.1)
parser.add_argument("--image_size", default=224, help="Image Size.", type=int)
parser.add_argument('--train_url', default="./", help='Location of training outputs.')
parser.add_argument("--run_modelarts", type=ast.literal_eval, default=False, help="Whether run on modelarts")
args = parser.parse_args()
# Allow for use from notebook without config file
if len(sys.argv) > 1:
get_config()
def get_config():
"""get_config"""
global args
override_args = _parser.argv_to_vars(sys.argv)
# load yaml file
if args.run_modelarts:
import moxing as mox
if not args.vig_config.startswith("obs:/"):
args.vig_config = "obs:/" + args.vig_config
with mox.file.File(args.vig_config, 'r') as f:
yaml_txt = f.read()
else:
yaml_txt = open(args.vig_config).read()
# override args
loaded_yaml = yaml.load(yaml_txt, Loader=yaml.FullLoader)
for v in override_args:
loaded_yaml[v] = getattr(args, v)
print(f"=> Reading YAML config from {args.vig_config}")
args.__dict__.update(loaded_yaml)
print(args)
if "DEVICE_NUM" not in os.environ.keys():
os.environ["DEVICE_NUM"] = str(args.device_num)
os.environ["RANK_SIZE"] = str(args.device_num)
def run_args():
"""run and get args"""
global args
if args is None:
parse_arguments()
run_args()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""parser function"""
USABLE_TYPES = set([float, int])
def trim_preceding_hyphens(st):
i = 0
while st[i] == "-":
i += 1
return st[i:]
def arg_to_varname(st: str):
st = trim_preceding_hyphens(st)
st = st.replace("-", "_")
return st.split("=")[0]
def argv_to_vars(argv):
var_names = []
for arg in argv:
if arg.startswith("-") and arg_to_varname(arg) != "vig_config":
var_names.append(arg_to_varname(arg))
return var_names
# Architecture
arch: vig_s_patch16_224
# ===== Dataset ===== #
data_url: ../data/imagenet
set: ImageNet
num_classes: 1000
mix_up: 0.8
cutmix: 1.0
auto_augment: rand-m9-mstd0.5-inc1
interpolation: bicubic
re_prob: 0.25
re_mode: pixel
re_count: 1
mixup_prob: 1.
switch_prob: 0.5
mixup_mode: batch
image_size: 224
# ===== Learning Rate Policy ======== #
optimizer: adamw
base_lr: 0.002
drop_path_rate: 0.1
warmup_lr: 0.00000007
min_lr: 0.000006
lr_scheduler: cosine_lr
warmup_length: 20
# ===== Network training config ===== #
amp_level: O2
keep_bn_fp32: True
beta: [ 0.9, 0.999 ]
clip_global_norm_value: 5.
is_dynamic_loss_scale: True
epochs: 300
label_smoothing: 0.1
weight_decay: 0.05
momentum: 0.9
batch_size: 16
# ===== Hardware setup ===== #
num_parallel_workers: 8
device_target: Ascend
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init datasets"""
from .imagenet import ImageNet
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init augment"""
from .auto_augment import _pil_interp, rand_augment_transform
from .mixup import Mixup
from .random_erasing import RandomErasing
This diff is collapsed.
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Mixup and Cutmix
Papers:
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
Code Reference:
CutMix: https://github.com/clovaai/CutMix-PyTorch
Hacked together by / Copyright 2020 Ross Wightman
"""
import numpy as np
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore import ops as P
def one_hot(x, num_classes, on_value=1., off_value=0.):
"""one hot to label"""
x = x.reshape(-1)
x = np.eye(num_classes)[x]
x = np.clip(x, a_min=off_value, a_max=on_value, dtype=np.float32)
return x
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
"""mixup_target"""
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
y2 = one_hot(np.flip(target, axis=0), num_classes, on_value=on_value, off_value=off_value)
return y1 * lam + y2 * (1. - lam)
def rand_bbox(img_shape, lam, margin=0., count=None):
""" Standard CutMix bounding-box
Generates a random square bbox based on lambda value. This impl includes
support for enforcing a border margin as percent of bbox dimensions.
Args:
img_shape (tuple): Image shape as tuple
lam (float): Cutmix lambda value
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
count (int): Number of bbox to generate
"""
ratio = np.sqrt(1 - lam)
img_h, img_w = img_shape[-2:]
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
yl = np.clip(cy - cut_h // 2, 0, img_h)
yh = np.clip(cy + cut_h // 2, 0, img_h)
xl = np.clip(cx - cut_w // 2, 0, img_w)
xh = np.clip(cx + cut_w // 2, 0, img_w)
return yl, yh, xl, xh
def rand_bbox_minmax(img_shape, minmax, count=None):
""" Min-Max CutMix bounding-box
Inspired by Darknet cutmix impl, generates a random rectangular bbox
based on min/max percent values applied to each dimension of the input image.
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
Args:
img_shape (tuple): Image shape as tuple
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
count (int): Number of bbox to generate
"""
assert len(minmax) == 2
img_h, img_w = img_shape[-2:]
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
yl = np.random.randint(0, img_h - cut_h, size=count)
xl = np.random.randint(0, img_w - cut_w, size=count)
yu = yl + cut_h
xu = xl + cut_w
return yl, yu, xl, xu
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
""" Generate bbox and apply lambda correction.
"""
if ratio_minmax is not None:
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
else:
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
if correct_lam or ratio_minmax is not None:
bbox_area = (yu - yl) * (xu - xl)
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
return (yl, yu, xl, xu), lam
class Mixup:
""" Mixup/Cutmix that applies different params to each element or whole batch
Args:
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
prob (float): probability of applying mixup or cutmix per batch or element
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
label_smoothing (float): apply label smoothing to the mixed target tensor
num_classes (int): number of classes for target
"""
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
if self.cutmix_minmax is not None:
assert len(self.cutmix_minmax) == 2
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
self.cutmix_alpha = 1.0
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mode = mode
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
def _params_per_elem(self, batch_size):
"""_params_per_elem"""
lam = np.ones(batch_size, dtype=np.float32)
use_cutmix = np.zeros(batch_size, dtype=np.bool)
if self.mixup_enabled:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand(batch_size) < self.switch_prob
lam_mix = np.where(
use_cutmix,
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
elif self.cutmix_alpha > 0.:
use_cutmix = np.ones(batch_size, dtype=np.bool)
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
return lam, use_cutmix
def _params_per_batch(self):
"""_params_per_batch"""
lam = 1.
use_cutmix = False
if self.mixup_enabled and np.random.rand() < self.mix_prob:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand() < self.switch_prob
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.cutmix_alpha > 0.:
use_cutmix = True
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = float(lam_mix)
return lam, use_cutmix
def _mix_elem(self, x):
"""_mix_elem"""
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size)
x_orig = x.clone() # need to keep an unmodified original for mixing source
for i in range(batch_size):
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
return P.ExpandDims()(Tensor(lam_batch, dtype=mstype.float32), 1)
def _mix_pair(self, x):
"""_mix_pair"""
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
x_orig = x.clone() # need to keep an unmodified original for mixing source
for i in range(batch_size // 2):
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
return P.ExpandDims()(Tensor(lam_batch, dtype=mstype.float32), 1)
def _mix_batch(self, x):
"""_mix_batch"""
lam, use_cutmix = self._params_per_batch()
if lam == 1.:
return 1.
if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[:, :, yl:yh, xl:xh] = np.flip(x, axis=0)[:, :, yl:yh, xl:xh]
else:
x_flipped = np.flip(x, axis=0) * (1. - lam)
x *= lam
x += x_flipped
return lam
def __call__(self, x, target):
"""Mixup apply"""
# the same to image, label
assert len(x) % 2 == 0, 'Batch size should be even when using this'
if self.mode == 'elem':
lam = self._mix_elem(x)
elif self.mode == 'pair':
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
return x.astype(np.float32), target.astype(np.float32)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Random Erasing (Cutout)
Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
Copyright Zhun Zhong & Liang Zheng
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import random
import numpy as np
def _get_pixels(per_pixel, rand_color, patch_size, dtype=np.float32):
"""_get_pixels"""
if per_pixel:
func = np.random.normal(size=patch_size).astype(dtype)
elif rand_color:
func = np.random.normal(size=(patch_size[0], 1, 1)).astype(dtype)
else:
func = np.zeros((patch_size[0], 1, 1), dtype=dtype)
return func
class RandomErasing:
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
This variant of RandomErasing is intended to be applied to either a batch
or single image tensor after it has been normalized by dataset mean and std.
Args:
probability: Probability that the Random Erasing operation will be performed.
min_area: Minimum percentage of erased area wrt input image area.
max_area: Maximum percentage of erased area wrt input image area.
min_aspect: Minimum aspect ratio of erased area.
mode: pixel color mode, one of 'const', 'rand', or 'pixel'
'const' - erase block is constant color of 0 for all channels
'rand' - erase block is same per-channel random (normal) color
'pixel' - erase block is per-pixel random (normal) color
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is randomly chosen between 1 and this value.
"""
def __init__(self, probability=0.5, min_area=0.02, max_area=1 / 3, min_aspect=0.3,
max_aspect=None, mode='const', min_count=1, max_count=None, num_splits=0):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.min_count = min_count
self.max_count = max_count or min_count
self.num_splits = num_splits
mode = mode.lower()
self.rand_color = False
self.per_pixel = False
if mode == 'rand':
self.rand_color = True # per block random normal
elif mode == 'pixel':
self.per_pixel = True # per pixel random normal
else:
assert not mode or mode == 'const'
def _erase(self, img, chan, img_h, img_w, dtype):
"""_erase"""
if random.random() > self.probability:
pass
else:
area = img_h * img_w
count = self.min_count if self.min_count == self.max_count else \
random.randint(self.min_count, self.max_count)
for _ in range(count):
for _ in range(10):
target_area = random.uniform(self.min_area, self.max_area) * area / count
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img_w and h < img_h:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w),
dtype=dtype)
break
return img
def __call__(self, x):
"""RandomErasing apply"""
if len(x.shape) == 3:
output = self._erase(x, *x.shape, x.dtype)
else:
output = np.zeros_like(x)
batch_size, chan, img_h, img_w = x.shape
# skip first slice of batch if num_splits is set (for clean portion of samples)
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
for i in range(batch_start, batch_size):
output[i] = self._erase(x[i], chan, img_h, img_w, x.dtype)
return output
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path, threads=16):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path, threads=threads)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore.dataset.vision.utils import Inter
from src.data.augment.auto_augment import _pil_interp, rand_augment_transform
from src.data.augment.mixup import Mixup
from src.data.augment.random_erasing import RandomErasing
from .data_utils.moxing_adapter import sync_data
class ImageNet:
"""ImageNet Define"""
def __init__(self, args, training=True):
if args.run_modelarts:
print('Download data.')
local_data_path = '/cache/data'
sync_data(args.data_url, local_data_path, threads=128)
print('Create train and evaluate dataset.')
train_dir = os.path.join(local_data_path, "train")
val_ir = os.path.join(local_data_path, "val")
self.train_dataset = create_dataset_imagenet(train_dir, training=True, args=args)
self.val_dataset = create_dataset_imagenet(val_ir, training=False, args=args)
else:
train_dir = os.path.join(args.data_url, "train")
val_ir = os.path.join(args.data_url, "val")
if training:
self.train_dataset = create_dataset_imagenet(train_dir, training=True, args=args)
self.val_dataset = create_dataset_imagenet(val_ir, training=False, args=args)
def create_dataset_imagenet(dataset_dir, args, repeat_num=1, training=True):
"""
create a train or eval imagenet2012 dataset for TNT
Args:
dataset_dir(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
Returns:
dataset
"""
device_num, rank_id = _get_rank_info()
shuffle = bool(training)
if device_num == 1 or not training:
data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers,
shuffle=shuffle)
else:
data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers, shuffle=shuffle,
num_shards=device_num, shard_id=rank_id)
image_size = args.image_size
# define map operations
# BICUBIC: 3
if training:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
aa_params = dict(
translate_const=int(image_size * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
interpolation = args.interpolation
auto_augment = args.auto_augment
assert auto_augment.startswith('rand')
aa_params['interpolation'] = _pil_interp(interpolation)
transform_img = [
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(3 / 4, 4 / 3),
interpolation=Inter.BICUBIC),
vision.RandomHorizontalFlip(prob=0.5),
py_vision.ToPIL()
]
transform_img += [rand_augment_transform(auto_augment, aa_params)]
transform_img += [
py_vision.ToTensor(),
py_vision.Normalize(mean=mean, std=std),
RandomErasing(args.re_prob, mode=args.re_mode, max_count=args.re_count)
]
else:
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# test transform complete
transform_img = [
vision.Decode(),
vision.Resize(int(256 / 224 * image_size), interpolation=Inter.BICUBIC),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
transform_label = C.TypeCast(mstype.int32)
data_set = data_set.map(input_columns="image", num_parallel_workers=args.num_parallel_workers,
operations=transform_img)
data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
operations=transform_label)
if (args.mix_up > 0. or args.cutmix > 0.) and not training:
# if use mixup and not training(False), one hot val data label
one_hot = C.OneHot(num_classes=args.num_classes)
data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
operations=one_hot)
# apply batch operations
data_set = data_set.batch(args.batch_size, drop_remainder=True,
num_parallel_workers=args.num_parallel_workers)
if (args.mix_up > 0. or args.cutmix > 0.) and training:
mixup_fn = Mixup(
mixup_alpha=args.mix_up, cutmix_alpha=args.cutmix, cutmix_minmax=None,
prob=args.mixup_prob, switch_prob=args.switch_prob, mode=args.mixup_mode,
label_smoothing=args.label_smoothing, num_classes=args.num_classes)
data_set = data_set.map(operations=mixup_fn, input_columns=["image", "label"],
num_parallel_workers=args.num_parallel_workers)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
from mindspore.communication.management import get_rank, get_group_size
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment