Skip to content
Snippets Groups Projects
Unverified Commit faf749b2 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3567 ONNX Infer:RepVGG Support

Merge pull request !3567 from wuqingdian/repvgg
parents 57ca2855 d33538f0
No related branches found
No related tags found
No related merge requests found
...@@ -76,6 +76,7 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式 ...@@ -76,6 +76,7 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式
├──run_distribute_train_ascend.sh // 多卡Ascend910训练脚本 ├──run_distribute_train_ascend.sh // 多卡Ascend910训练脚本
├──run_infer_310.sh // Ascend310推理脚本 ├──run_infer_310.sh // Ascend310推理脚本
├──run_eval_ascend.sh // 测试脚本 ├──run_eval_ascend.sh // 测试脚本
├──run_infer_onnx.sh // ONNX推理脚本
├── src ├── src
├──configs // RepVGG的配置文件 ├──configs // RepVGG的配置文件
├──data // 数据集配置文件 ├──data // 数据集配置文件
...@@ -99,6 +100,7 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式 ...@@ -99,6 +100,7 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式
├── export.py // 模型导出文件 ├── export.py // 模型导出文件
├── preprocess.py // 推理数据集与处理文件 ├── preprocess.py // 推理数据集与处理文件
├── postprocess.py // 推理精度处理文件 ├── postprocess.py // 推理精度处理文件
├── infer_onnx.py // 推理onnx文件
``` ```
## 脚本参数 ## 脚本参数
...@@ -111,7 +113,7 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式 ...@@ -111,7 +113,7 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式
# Architecture # Architecture
arch: RepVGG-A0-A0tiny # RepVGG结构选择 arch: RepVGG-A0-A0tiny # RepVGG结构选择
# ===== Dataset ===== # # ===== Dataset ===== #
data_url: ./data/imagenet # 数据集地址 data_url: ./dataset # 数据集地址
set: ImageNet # 数据集名字 set: ImageNet # 数据集名字
num_classes: 1000 # 数据集分类数目 num_classes: 1000 # 数据集分类数目
mix_up: 0.0 # MixUp数据增强参数 mix_up: 0.0 # MixUp数据增强参数
...@@ -173,6 +175,9 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式 ...@@ -173,6 +175,9 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式
# 使用脚本启动单卡运行评估示例 # 使用脚本启动单卡运行评估示例
bash ./scripts/run_eval_ascend.sh [DEVICE_ID] [CONFIG_PATH] [CHECKPOINT_PATH] bash ./scripts/run_eval_ascend.sh [DEVICE_ID] [CONFIG_PATH] [CHECKPOINT_PATH]
# 使用脚本启动onnx单卡运行评估示例
bash ./scripts/run_infer_onnx.sh [ONNX_PATH] [DATASET_PATH] [DEVICE_TARGET] [DEVICE_ID]
``` ```
对于分布式训练,需要提前创建JSON格式的hccl配置文件。 对于分布式训练,需要提前创建JSON格式的hccl配置文件。
...@@ -189,13 +194,13 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式 ...@@ -189,13 +194,13 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式
python export.py --pretrained [CKPT_FILE] --config [CONFIG_PATH] --device_target [DEVICE_TARGET] --file_format [FILE_FORMAT] python export.py --pretrained [CKPT_FILE] --config [CONFIG_PATH] --device_target [DEVICE_TARGET] --file_format [FILE_FORMAT]
``` ```
导出的模型会以模型的结构名字命名并且保存在当前目录下, 注意: FILE_FORMAT 必须在 ["AIR", "MINDIR"]中选择。 导出的模型会以模型的结构名字命名并且保存在当前目录下, 注意: FILE_FORMAT 必须在 ["AIR", "MINDIR", "ONNX"]中选择。
## 推理过程 ## 推理过程
### 推理 ### 推理
在进行推理之前我们需要先导出模型。mindir可以在任意环境上导出,air模型只能在昇腾910环境上导出。以下展示了使用mindir模型执行推理的示例。 在进行推理之前我们需要先导出模型。mindir可以在任意环境上导出,air模型只能在昇腾910环境上导出。onnx可以在CPU/GPU/Ascend环境下导出。以下展示了使用mindir模型执行推理的示例。
- 在昇腾310上使用ImageNet-1k数据集进行推理 - 在昇腾310上使用ImageNet-1k数据集进行推理
...@@ -208,6 +213,17 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式 ...@@ -208,6 +213,17 @@ RepVGG是由清华大学&旷世科技等提出的一种新颖的CNN设计范式
Top5 acc: 0.90734 Top5 acc: 0.90734
``` ```
- 在GPU/CPU上使用ImageNet-1k数据集进行ONNX推理
推理的结果保存在主目录下,在infer_onnx.log日志文件中可以找到推理结果。
```bash
# onnx inference
bash run_infer_onnx.sh [ONNX_PATH] [DATASET_PATH] [DEVICE_TARGET] [DEVICE_ID]
top-1 accuracy: 0.72024
top-5 accuracy: 0.90394
```
# [模型描述](#目录) # [模型描述](#目录)
## 性能 ## 性能
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval"""
import onnxruntime as ort
from mindspore import nn
from src.args import args
from src.data.imagenet import create_dataset_imagenet
def create_session(onnx_path, target_device):
if target_device == 'GPU':
providers = ['CUDAExecutionProvider']
elif target_device == 'CPU':
providers = ['CPUExecutionProvider']
else:
raise ValueError(
f'Unsupported target device {target_device}, '
f'Expected one of: "CPU", "GPU"'
)
session = ort.InferenceSession(onnx_path, providers=providers)
input_name = session.get_inputs()[0].name
return session, input_name
def run_eval(onnx_path, data_dir, target_device):
session, input_name = create_session(onnx_path, target_device)
args.batch_size = 1
dataset = create_dataset_imagenet(data_dir, args, training=False)
metrics = {
'top-1 accuracy': nn.Top1CategoricalAccuracy(),
'top-5 accuracy': nn.Top5CategoricalAccuracy(),
}
for batch in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
y_pred = session.run(None, {input_name: batch['image']})[0]
for metric in metrics.values():
metric.update(y_pred, batch['label'])
return {name: metric.eval() for name, metric in metrics.items()}
if __name__ == '__main__':
results = run_eval(args.onnx_path, args.dataset_path, args.device_target)
for name, value in results.items():
print(f'{name}: {value:.5f}')
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 4 ]]; then
echo "Usage: bash run_infer_onnx.sh [ONNX_PATH] [DATASET_PATH] [DEVICE_TARGET(optional)] [DEVICE_ID(optional)]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
onnx_path=$(get_real_path $1)
dataset_path=$(get_real_path $2)
if [ $# -eq 3 ]; then
device_target=$3
fi
if [ $# -eq 4 ]; then
device_id=$4
fi
echo "onnx_path: "$onnx_path
echo "dataset_path: "$dataset_path
echo "device_target: "$device_target
echo "device_id: "$device_id
function infer()
{
python ./infer_onnx.py --onnx_path=$onnx_path \
--dataset_path=$dataset_path \
--device_target=$device_target \
--device_id=$device_id &> infer_onnx.log
}
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd # Copyright 2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""global args for Transformer in Transformer(TNT)""" """global args for Transformer in Transformer(TNT)"""
import argparse import argparse
import ast import ast
import os import os
import sys import sys
import yaml import yaml
from src.configs import parser as _parser from src.configs import parser as _parser
args = None args = None
def parse_arguments(): def parse_arguments():
"""parse_arguments""" """parse_arguments"""
global args global args
parser = argparse.ArgumentParser(description="MindSpore TNT Training") parser = argparse.ArgumentParser(description="MindSpore TNT Training")
parser.add_argument("-a", "--arch", metavar="ARCH", default="ResNet18", help="model architecture") parser.add_argument("-a", "--arch", metavar="ARCH", default="ResNet18", help="model architecture")
parser.add_argument("--accumulation_step", default=1, type=int, help="accumulation step") parser.add_argument("--accumulation_step", default=1, type=int, help="accumulation step")
parser.add_argument("--amp_level", default="O2", choices=["O0", "O1", "O2", "O3"], help="AMP Level") parser.add_argument("--amp_level", default="O0", choices=["O0", "O1", "O2", "O3"], help="AMP Level")
parser.add_argument("--batch_size", default=256, type=int, metavar="N", parser.add_argument("--batch_size", default=256, type=int, metavar="N",
help="mini-batch size (default: 256), this is the total " help="mini-batch size (default: 256), this is the total "
"batch size of all Devices on the current node when " "batch size of all Devices on the current node when "
"using Data Parallel or Distributed Data Parallel") "using Data Parallel or Distributed Data Parallel")
parser.add_argument("--beta", default=[0.9, 0.999], type=lambda x: [float(a) for a in x.split(",")], parser.add_argument("--beta", default=[0.9, 0.999], type=lambda x: [float(a) for a in x.split(",")],
help="beta for optimizer") help="beta for optimizer")
parser.add_argument("--with_ema", default=False, type=ast.literal_eval, help="training with ema") parser.add_argument("--with_ema", default=False, type=ast.literal_eval, help="training with ema")
parser.add_argument("--ema_decay", default=0.9999, type=float, help="ema decay") parser.add_argument("--ema_decay", default=0.9999, type=float, help="ema decay")
parser.add_argument('--data_url', default="./data", help='location of data.') parser.add_argument('--data_url', default="./data", help='location of data.')
parser.add_argument("--device_id", default=0, type=int, help="device id") parser.add_argument("--device_id", default=0, type=int, help="device id")
parser.add_argument("--device_num", default=1, type=int, help="device num") parser.add_argument("--device_num", default=1, type=int, help="device num")
parser.add_argument("--device_target", default="Ascend", choices=["GPU", "Ascend"], type=str) parser.add_argument("--device_target", default="Ascend", choices=["GPU", "Ascend"], type=str)
parser.add_argument("--epochs", default=300, type=int, metavar="N", help="number of total epochs to run") parser.add_argument("--epochs", default=300, type=int, metavar="N", help="number of total epochs to run")
parser.add_argument("--eps", default=1e-8, type=float) parser.add_argument("--eps", default=1e-8, type=float)
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format") parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR", "ONNX"], default="MINDIR",
parser.add_argument("--in_chans", default=3, type=int) help="file format")
parser.add_argument("--is_dynamic_loss_scale", default=1, type=int, help="is_dynamic_loss_scale ") parser.add_argument("--in_chans", default=3, type=int)
parser.add_argument("--keep_checkpoint_max", default=20, type=int, help="keep checkpoint max num") parser.add_argument("--is_dynamic_loss_scale", default=1, type=int, help="is_dynamic_loss_scale ")
parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd") parser.add_argument("--keep_checkpoint_max", default=20, type=int, help="keep checkpoint max num")
parser.add_argument("--set", help="name of dataset", type=str, default="ImageNet") parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd")
parser.add_argument("--mix_up", default=0., type=float, help="mix up") parser.add_argument("--set", help="name of dataset", type=str, default="ImageNet")
parser.add_argument("--mlp_ratio", help="mlp ", default=4., type=float) parser.add_argument("--mix_up", default=0., type=float, help="mix up")
parser.add_argument("-j", "--num_parallel_workers", default=20, type=int, metavar="N", parser.add_argument("--mlp_ratio", help="mlp ", default=4., type=float)
help="number of data loading workers (default: 20)") parser.add_argument("-j", "--num_parallel_workers", default=20, type=int, metavar="N",
parser.add_argument("--start_epoch", default=0, type=int, metavar="N", help="number of data loading workers (default: 20)")
help="manual epoch number (useful on restarts)") parser.add_argument("--start_epoch", default=0, type=int, metavar="N",
parser.add_argument("--warmup_length", default=0, type=int, help="number of warmup iterations") help="manual epoch number (useful on restarts)")
parser.add_argument("--warmup_lr", default=5e-7, type=float, help="warm up learning rate") parser.add_argument("--warmup_length", default=0, type=int, help="number of warmup iterations")
parser.add_argument("--wd", "--weight_decay", default=0.05, type=float, metavar="W", parser.add_argument("--warmup_lr", default=5e-7, type=float, help="warm up learning rate")
help="weight decay (default: 0.05)", dest="weight_decay") parser.add_argument("--wd", "--weight_decay", default=0.05, type=float, metavar="W",
parser.add_argument("--loss_scale", default=1024, type=int, help="loss_scale") help="weight decay (default: 0.05)", dest="weight_decay")
parser.add_argument("--lr", "--learning_rate", default=5e-4, type=float, help="initial lr", dest="lr") parser.add_argument("--loss_scale", default=1024, type=int, help="loss_scale")
parser.add_argument("--lr_scheduler", default="cosine_annealing", help="schedule for the learning rate.") parser.add_argument("--lr", "--learning_rate", default=5e-4, type=float, help="initial lr", dest="lr")
parser.add_argument("--lr_adjust", default=30, type=float, help="interval to drop lr") parser.add_argument("--lr_scheduler", default="cosine_annealing", help="schedule for the learning rate.")
parser.add_argument("--lr_gamma", default=0.97, type=int, help="multistep multiplier") parser.add_argument("--lr_adjust", default=30, type=float, help="interval to drop lr")
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") parser.add_argument("--lr_gamma", default=0.97, type=int, help="multistep multiplier")
parser.add_argument("--num_classes", default=1000, type=int) parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument("--pretrained", dest="pretrained", default=None, type=str, help="use pre-trained model") parser.add_argument("--num_classes", default=1000, type=int)
parser.add_argument("--config", help="Config file to use (see configs dir)", default=None, required=True) parser.add_argument("--pretrained", dest="pretrained", default=None, type=str,
parser.add_argument("--seed", default=0, type=int, help="seed for initializing training. ") help="use pre-trained model")
parser.add_argument("--save_every", default=2, type=int, help="save every ___ epochs(default:2)") parser.add_argument("--config", help="Config file to use (see configs dir)",
parser.add_argument("--label_smoothing", type=float, help="label smoothing to use, default 0.1", default=0.1) default="./src/configs/RepVGG-A0.yaml", required=False)
parser.add_argument("--image_size", default=224, help="image Size.", type=int) parser.add_argument("--seed", default=0, type=int, help="seed for initializing training. ")
parser.add_argument('--train_url', default="./", help='location of training outputs.') parser.add_argument("--save_every", default=2, type=int, help="save every ___ epochs(default:2)")
parser.add_argument("--run_modelarts", type=ast.literal_eval, default=False, help="whether run on modelarts") parser.add_argument("--label_smoothing", type=float, help="label smoothing to use, default 0.1", default=0.1)
parser.add_argument("--deploy", type=ast.literal_eval, default=False, help="whether run deploy") parser.add_argument("--image_size", default=224, help="image Size.", type=int)
args = parser.parse_args() parser.add_argument('--train_url', default="./", help='location of training outputs.')
parser.add_argument("--run_modelarts", type=ast.literal_eval, default=False, help="whether run on modelarts")
get_config() parser.add_argument("--deploy", type=ast.literal_eval, default=False, help="whether run deploy")
parser.add_argument("--onnx_path", type=str, default=None, help="ONNX file path")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset path")
def get_config(): args = parser.parse_args()
"""get_config"""
global args get_config()
override_args = _parser.argv_to_vars(sys.argv)
# load yaml file
if args.run_modelarts: def get_config():
import moxing as mox """get_config"""
if not args.config.startswith("obs:/"): global args
args.config = "obs:/" + args.config override_args = _parser.argv_to_vars(sys.argv)
with mox.file.File(args.config, 'r') as f: # load yaml file
yaml_txt = f.read() if args.run_modelarts:
else: import moxing as mox
yaml_txt = open(args.config).read() if not args.config.startswith("obs:/"):
args.config = "obs:/" + args.config
# override args with mox.file.File(args.config, 'r') as f:
loaded_yaml = yaml.load(yaml_txt, Loader=yaml.FullLoader) yaml_txt = f.read()
else:
for v in override_args: yaml_txt = open(args.config).read()
loaded_yaml[v] = getattr(args, v)
# override args
print(f"=> Reading YAML config from {args.config}") loaded_yaml = yaml.load(yaml_txt, Loader=yaml.FullLoader)
args.__dict__.update(loaded_yaml) for v in override_args:
print(args) loaded_yaml[v] = getattr(args, v)
os.environ["DEVICE_TARGET"] = args.device_target
if "DEVICE_NUM" not in os.environ.keys(): print(f"=> Reading YAML config from {args.config}")
os.environ["DEVICE_NUM"] = str(args.device_num)
if "RANK_SIZE" not in os.environ.keys(): args.__dict__.update(loaded_yaml)
os.environ["RANK_SIZE"] = str(args.device_num) print(args)
os.environ["DEVICE_TARGET"] = args.device_target
def run_args(): if "DEVICE_NUM" not in os.environ.keys():
"""run and get args""" os.environ["DEVICE_NUM"] = str(args.device_num)
global args if "RANK_SIZE" not in os.environ.keys():
if args is None: os.environ["RANK_SIZE"] = str(args.device_num)
parse_arguments()
def run_args():
"""run and get args"""
run_args() global args
if args is None:
parse_arguments()
run_args()
...@@ -45,4 +45,4 @@ ema_decay: 0.9999 ...@@ -45,4 +45,4 @@ ema_decay: 0.9999
# ===== Hardware setup ===== # # ===== Hardware setup ===== #
num_parallel_workers: 16 num_parallel_workers: 16
device_target: Ascend device_target: Ascend
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment