Skip to content
Snippets Groups Projects
Commit 35c7438e authored by m_meng's avatar m_meng
Browse files

add onnx

parent e4502e6b
No related branches found
No related tags found
No related merge requests found
......@@ -115,11 +115,12 @@ GhostNet的总体网络架构如下:[链接](https://arxiv.org/pdf/1911.11907.
├── CMakeLists.txt # ascend310推理
├── main.cc # ascend310推理
└── utils.cc # ascend310推理
├── scripts
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
├── run_eval.sh # 启动Ascend评估
├── run_infer_310.sh # 启动Ascend310推理
└── run_standalone_train.sh # 启动Ascend单机训练(单卡)
├── scripts
├── run_distribute_train.sh # 启动Ascend分布式训练(8卡)
├── run_eval.sh # 启动Ascend评估
├── run_eval_onnx.sh # 启动ONNX评估
├── run_infer_310.sh # 启动Ascend310推理
└── run_standalone_train.sh # 启动Ascend单机训练(单卡)
├── src
├── config.py # 参数配置
├── dataset.py # 数据预处理
......@@ -131,6 +132,7 @@ GhostNet的总体网络架构如下:[链接](https://arxiv.org/pdf/1911.11907.
├── launch.py
└── ghostnet.py # ghostnet网络
├── eval.py # 评估网络
├── eval_onnx.py # ONNX评估
├── create_imagenet2012_label.py # 创建ImageNet2012标签
├── export.py # 导出MindIR模型
├── postprocess.py # 310推理的后期处理
......@@ -240,11 +242,11 @@ ckpt = /home/lzu/ghost_Mindspore/scripts/device0/ghostnet-500_1251.ckpt
## [导出MindIR](#contents)
```shell
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
python export.py --device_target [DEVICE_TARGET] --file_format [FILE_FORMAT] --checkpoint_path [CKPT_PATH]
```
参数ckpt_file为必填项,
`FILE_FORMAT` 必须在 ["AIR", "MINDIR"]中选择。
`FILE_FORMAT` 必须在 ["AIR", "ONNX", "MINDIR"]中选择。
## 在Ascend310执行推理
......@@ -258,9 +260,9 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
## 结果
推理结果保存在脚本执行的当前路径, 你可以在acc.log中看到以下精度计算结果。
推理结果保存在脚本执行的当前路径, 你可以在 acc.log 中看到以下精度计算结果。
- 使用ImageNet2012数据集评估ghostnet
- 使用 ImageNet2012 数据集评估 ghostnet
```shell
Total data: 50000, top1 accuracy: 0.73816, top5 accuracy: 0.9178.
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python eval.py
"""
import argparse
import time
import mindspore.nn as nn
import onnxruntime as ort
from src.dataset import create_dataset
def create_session(checkpoint_path, target_device):
if target_device == 'GPU':
providers = ['CUDAExecutionProvider']
elif target_device == 'CPU':
providers = ['CPUExecutionProvider']
else:
raise ValueError(
f'Unsupported target device {target_device}, '
f'Expected one of: "CPU", "GPU"'
)
session = ort.InferenceSession(checkpoint_path, providers=providers)
input_name = session.get_inputs()[0].name
return session, input_name
def run_eval(checkpoint_path, data_dir, target_device):
session, input_name = create_session(checkpoint_path, target_device)
dataset = create_dataset(dataset_path=data_dir, do_train=False, infer_910=False, batch_size=1)
metrics = {
'top-1 accuracy': nn.Top1CategoricalAccuracy(),
'top-5 accuracy': nn.Top5CategoricalAccuracy(),
}
for batch in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
y_pred = session.run(None, {input_name: batch['image']})[0]
for metric in metrics.values():
metric.update(y_pred, batch['label'])
return {name: metric.eval() for name, metric in metrics.items()}
if __name__ == '__main__':
start = time.time()
parser = argparse.ArgumentParser(description='Image classification')
# onnx parameter
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--onnx_path', type=str, default=None, help='ONNX file path')
parser.add_argument('--device_target', type=str, default='GPU', help='Device target')
parser.add_argument('--device_id', type=int, default=0, help='Device id')
args = parser.parse_args()
results = run_eval(args.onnx_path, args.dataset_path, args.device_target)
for name, value in results.items():
print(f'{name}: {value:.5f}')
end = time.time()
print(str(end))
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -26,7 +26,7 @@ if __name__ == '__main__':
parser.add_argument('--device_target', type=str, choices=['Ascend', 'GPU', 'CPU'], default='Ascend',
help='device where the code will be implemented')
parser.add_argument('--device_id', type=int, default=0, help='device id')
parser.add_argument('--file_format', type=str, choices=['AIR', 'MINDIR'], default='MINDIR',
parser.add_argument('--file_format', type=str, choices=['AIR', 'ONNX', 'MINDIR'], default='ONNX',
help='file format')
parser.add_argument('--checkpoint_path', required=True, default=None, help='ckpt file path')
args = parser.parse_args()
......@@ -41,4 +41,4 @@ if __name__ == '__main__':
input_data = Tensor(np.zeros([1, 3, 224, 224]), ms.float32)
print(input_data.shape)
export(net, input_data, file_name='ghost', file_format=args.file_format)
export(net, input_data, file_name='ghostnet', file_format=args.file_format)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 3 ]; then
echo "Usage: bash run_infer_onnx.sh [ONNX_PATH] [DATASET_PATH] [DEVICE_TARGET]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
onnx_path=$(get_real_path $1)
dataset_path=$(get_real_path $2)
device_target=$3
echo "onnx_path: "$onnx_path
echo "dataset_path: "$dataset_path
echo "device_target: "$device_target
function infer()
{
python ./eval_onnx.py --onnx_path=$onnx_path \
--dataset_path=$dataset_path \
--device_target=$device_target &> infer_onnx.log
}
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
......@@ -43,9 +43,9 @@ def create_dataset(dataset_path, do_train, infer_910=False, device_id=0, batch_s
device_num = int(os.getenv('RANK_SIZE'))
if not do_train:
dataset_path = os.path.join(dataset_path, 'test')
dataset_path = os.path.join(dataset_path)
else:
dataset_path = os.path.join(dataset_path, 'train')
dataset_path = os.path.join(dataset_path)
if device_num == 1:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=True)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment