Skip to content
Snippets Groups Projects
Unverified Commit c0e10c8c authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3055 昇腾众智-武汉理工大学-Mindspore ONNX-efficientnetb0

Merge pull request !3055 from EdisonLee/Efficientnet
parents 091aecb0 6b8c1d60
No related branches found
No related tags found
No related merge requests found
...@@ -69,6 +69,7 @@ EfficientNet总体网络架构如下: ...@@ -69,6 +69,7 @@ EfficientNet总体网络架构如下:
├──run_standalone_train.sh # 用于单卡训练的shell脚本 ├──run_standalone_train.sh # 用于单卡训练的shell脚本
├──run_distribute_train.sh # 用于八卡训练的shell脚本 ├──run_distribute_train.sh # 用于八卡训练的shell脚本
├──run_infer_310.sh # Ascend推理shell脚本 ├──run_infer_310.sh # Ascend推理shell脚本
├──run_infer_onnx.sh # ONNX推理shell脚本
└──run_eval.sh # 用于评估的shell脚本 └──run_eval.sh # 用于评估的shell脚本
├── src ├── src
├──models # 模型架构 ├──models # 模型架构
...@@ -81,6 +82,7 @@ EfficientNet总体网络架构如下: ...@@ -81,6 +82,7 @@ EfficientNet总体网络架构如下:
└──Monitor.py # 监控网络损失和其他数据 └──Monitor.py # 监控网络损失和其他数据
├── create_imagenet2012_label.py # 创建数据标签 ├── create_imagenet2012_label.py # 创建数据标签
├── eval.py # 评估脚本 ├── eval.py # 评估脚本
├── infer_onnx.py # ONNX评估
├── export.py # 模型格式转换脚本 ├── export.py # 模型格式转换脚本
├── postprogress.py # 310推理后处理脚本 ├── postprogress.py # 310推理后处理脚本
└── train.py # 训练脚本 └── train.py # 训练脚本
...@@ -173,7 +175,7 @@ result: {'Loss': 1.8745046273255959, 'Top_1_Acc': 0.7668870192307692, 'Top_5_Acc ...@@ -173,7 +175,7 @@ result: {'Loss': 1.8745046273255959, 'Top_1_Acc': 0.7668870192307692, 'Top_5_Acc
python export.py --checkpoint_path [CKPT_PATH] --file_name [OUT_FILE] --file_format[EXPORT_FORMAT] python export.py --checkpoint_path [CKPT_PATH] --file_name [OUT_FILE] --file_format[EXPORT_FORMAT]
``` ```
`EXPORT_FORMAT` 可选 ["AIR", "MINDIR"] `EXPORT_FORMAT` 可选 ["AIR", "MINDIR", "ONNX"]
## 推理过程 ## 推理过程
...@@ -196,6 +198,15 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] ...@@ -196,6 +198,15 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
accuracy:0.767 accuracy:0.767
``` ```
### ONNX推理
在推理之前需要在GPU环境上完成模型的导出。
```shell
# ONNX inference
bash scripts/run_infer_onnx.sh [ONNX_PATH] [DATASET_PATH] [DEVICE_TARGET]
```
# 模型说明 # 模型说明
## 训练性能 ## 训练性能
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Run evaluation for a model exported to ONNX"""
import argparse
import mindspore.nn as nn
import onnxruntime as ort
from src.dataset import create_dataset
def create_session(checkpoint_path, target_device):
if target_device == 'GPU':
providers = ['CUDAExecutionProvider']
elif target_device == 'CPU':
providers = ['CPUExecutionProvider']
else:
raise ValueError(
f'Unsupported target device {target_device}, '
f'Expected one of: "CPU", "GPU"'
)
session = ort.InferenceSession(checkpoint_path, providers=providers)
input_name = session.get_inputs()[0].name
return session, input_name
def run_eval(checkpoint_path, data_dir, target_device):
session, input_name = create_session(checkpoint_path, target_device)
dataset = create_dataset(data_dir, False, 1)
metrics = {
'top-1 accuracy': nn.Top1CategoricalAccuracy(),
'top-5 accuracy': nn.Top5CategoricalAccuracy(),
}
for batch in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
y_pred = session.run(None, {input_name: batch['image']})[0]
for metric in metrics.values():
metric.update(y_pred, batch['label'])
# print(batch['label'])
return {name: metric.eval() for name, metric in metrics.items()}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Image classification')
# onnx parameter
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--onnx_path', type=str, default=None, help='ONNX file path')
parser.add_argument('--device_target', type=str, default='GPU', help='Device target')
parser.add_argument('--device_id', type=int, default=0, help='Device id')
args = parser.parse_args()
results = run_eval(args.onnx_path, args.dataset_path, args.device_target)
for name, value in results.items():
print(f'{name}: {value:.5f}')
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# != 3 ]]; then
echo "Usage: bash run_infer_onnx.sh [ONNX_PATH] [DATASET_PATH] [DEVICE_TARGET]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
onnx_path=$(get_real_path $1)
dataset_path=$(get_real_path $2)
device_target=$3
echo "onnx_path: "$onnx_path
echo "dataset_path: "$dataset_path
echo "device_target: "$device_target
python ./infer_onnx.py --onnx_path=$onnx_path \
--dataset_path=$dataset_path \
--device_target=$device_target &> infer_onnx.log
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
...@@ -36,9 +36,9 @@ def create_dataset(dataset_path, do_train, batch_size=16, device_num=1, rank=0): ...@@ -36,9 +36,9 @@ def create_dataset(dataset_path, do_train, batch_size=16, device_num=1, rank=0):
dataset dataset
""" """
if device_num == 1: if device_num == 1:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=64, shuffle=True) ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=40, shuffle=True)
else: else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=64, shuffle=True, ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=40, shuffle=True,
num_shards=device_num, shard_id=rank) num_shards=device_num, shard_id=rank)
# define map operations # define map operations
if do_train: if do_train:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment