Skip to content
Snippets Groups Projects
Unverified Commit 4a94e258 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3182 ONNX:Hourglass

Merge pull request !3182 from Dammond/onnx_hourglas
parents af81d353 1c10c391
No related branches found
No related tags found
No related merge requests found
......@@ -61,6 +61,7 @@ Stacked Hourglass 是一个用于人体姿态检测的模型,它采用堆叠
├──scripts
│ ├──run_distribute_train.sh # 分布式训练脚本
│ ├──run_eval.sh # 评估脚本
│ ├──run_eval_onnx.sh # 评估onnx推理模型精度脚本
│ └──run_standalone_train.sh # 单卡训练脚本
├──src
│ ├──dataset
......@@ -75,6 +76,7 @@ Stacked Hourglass 是一个用于人体姿态检测的模型,它采用堆叠
│ │ └──inference.py # 推理相关的函数,包含了推理的准确率计算等
│ └── config.py # 参数配置
├── eval.py # 评估脚本
├── eval_onnx.py # ONNX模型评估脚本
├── export.py # 导出脚本
├── README_CN.md # 项目相关描述
└── train.py # 训练脚本
......@@ -164,10 +166,12 @@ Tra PCK @, 0.5 , hip : 0.918 , count: 587
可以使用 `export.py` 脚本进行模型导出,使用方法为:
```sh
python export.py --ckpt_file [ckpt 文件路径]
python export.py --ckpt_file [ckpt 文件路径] --device_target [device 环境设备] --file_format [导出文件格式]
```
参数`ckpt_file` 是必需的
- `ckpt_file` 导出的ckpt模型文件,参数`ckpt_file` 是必需的
- `device_target`环境设备【Ascend】【GPU】【CPU】
- `file_format`导出文件格式【ONNX】【MINDIR】【AIR】
## 推理过程
......@@ -201,6 +205,46 @@ Tra PCK @, 0.5 , hip : 0.918 , count: 587
[...]
```
### 运行
在导出onnx模型后,进行onnx模型推理评估,使用方法为:
```bash
python eval_onnx.py --onnx_file [onnx onnx模型文件路径] --device_target [device 环境设备]
```
- `onnx_file` 导出的onnx模型文件
- `device_target`环境设备【Ascend】【GPU】【CPU】
或则可以运行onn推理脚本。
```shell
bash ./scripts/run_eval_onnx.sh [MINDIR_PATH] [ANNOT_PATH] [IMAGES_PATH] [DEVICE_TARGET]
```
- `ONNX_PATH` ONNX模型的路径
- `ANNOT_PATH` ANNO文件路径
- `IMAGES_PATH` 图像路径
- `DEVICE_TARGET` 环境设备【Ascend】【GPU】【CPU】
### 结果
运行完成,可以看到最终的精度结果。
```text
all :
Val PCK @, 0.5 , total : 0.877 , count: 44239
Tra PCK @, 0.5 , total : 0.943 , count: 4443
Val PCK @, 0.5 , ankle : 0.762 , count: 4234
Tra PCK @, 0.5 , ankle : 0.855 , count: 392
Val PCK @, 0.5 , knee : 0.808 , count: 4963
Tra PCK @, 0.5 , knee : 0.908 , count: 499
Val PCK @, 0.5 , hip : 0.863 , count: 5777
Tra PCK @, 0.5 , hip : 0.945 , count: 587
[...]
```
# 模型说明
## 训练性能(2HG)
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import onnxruntime as ort
import mindspore.context as context
from src.utils.inference import get_img, onnx_inference, MPIIEval, parse_args
args = parse_args()
def create_session(onnx_path, target_device):
"""
Create onnx session.
"""
if target_device == 'GPU':
providers = ['CUDAExecutionProvider']
elif target_device == 'CPU':
providers = ['CPUExecutionProvider']
else:
raise ValueError(
f'Unsupported target device {target_device}, '
f'Expected one of: "CPU", "GPU"'
)
session = ort.InferenceSession(onnx_path, providers=providers)
input_name = [x.name for x in session.get_inputs()]
return session, input_name
def hourglass_onnx_inference():
"""
Onnx inference
"""
session, input_name = create_session(args.onnx_file, args.device_target)
gts = []
preds = []
normalizing = []
num_eval = args.num_eval
num_train = args.train_num_eval
for anns, img, c, s, n in get_img(num_eval, num_train):
gts.append(anns)
ans = onnx_inference(img, session, input_name, c, s)
if ans.size > 0:
ans = ans[:, :, :3]
pred = []
for i in range(ans.shape[0]):
pred.append({"keypoints": ans[i, :, :]})
preds.append(pred)
normalizing.append(n)
mpii_eval = MPIIEval()
mpii_eval.eval(preds, gts, normalizing, num_train)
if __name__ == "__main__":
if not os.path.exists(args.onnx_file):
print("onnx file not valid")
exit()
if not os.path.exists(args.img_dir) or not os.path.exists(args.annot_dir):
print("Dataset not found.")
exit()
# Set context mode
if args.context_mode == "GRAPH":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
hourglass_onnx_inference()
......@@ -20,14 +20,12 @@ from src.models.StackedHourglassNet import StackedHourglassNet
import src.dataset.MPIIDataLoader as ds
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, context, export, load_checkpoint, load_param_into_net
from mindspore import Tensor, export, load_checkpoint, load_param_into_net
args = parse_args()
class MaxPool2dFilter(nn.Cell):
"""
maxpool 2d for filter
Maxpool 2d for filter
"""
def __init__(self):
......@@ -54,12 +52,13 @@ class Hourglass(nn.Cell):
self.pool = nn.MaxPool2d(3, 1, "same")
self.eq = ops.Equal()
def construct(self, input1, input2):
def construct(self, x):
"""
forward
"""
tmp1 = self.net(input1)
tmp2 = self.net(input2)
tmp1 = self.net(x)
tmp2 = self.net(x[:, ::-1])
tmp = ops.Concat(0)((tmp1, tmp2))
det = tmp[0, -1] + tmp[1, -1, :, :, ::-1][ds.flipped_parts["mpii"]]
......@@ -80,17 +79,9 @@ if __name__ == "__main__":
print("ckpt file not valid")
exit()
# Set context mode
if args.context_mode == "GRAPH":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
# Import net
net = StackedHourglassNet(args.nstack, args.inp_dim, args.oup_dim)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([1, args.input_res, args.input_res, 3], np.float32))
input_arr2 = Tensor(np.zeros([1, args.input_res, args.input_res, 3], np.float32))
net = Hourglass(net)
export(net, input_arr, input_arr2, file_name='Hourglass', file_format=args.file_format)
export(net, input_arr, file_name='Hourglass', file_format=args.file_format)
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]; then
echo "Usage: ./scripts/run_eval_onnx.sh [onnx path] [annot path] [image path] [device target]"
exit 1
fi
TARGET="./Eval_onnx"
#set -e
rm -rf $TARGET
mkdir $TARGET
ONNX_PATH=$1
ANNOT_PATH=$2
IMAGES_PATH=$3
DEVICE_TARGET=$4
python eval_onnx.py \
--onnx_file ${ONNX_PATH} \
--annot_dir ${ANNOT_PATH} \
--img_dir ${IMAGES_PATH} \
--device_target ${DEVICE_TARGET} > Eval_onnx/result.txt 2> Eval_onnx/err.txt
\ No newline at end of file
......@@ -33,9 +33,10 @@ def parse_args():
parser.add_argument("--output_res", type=int, default=64)
parser.add_argument("--annot_dir", type=str, default="./MPII/annot")
parser.add_argument("--img_dir", type=str, default="./MPII/images")
# Context
parser.add_argument("--context_mode", type=str, default="GRAPH", choices=["GRAPH", "PYNATIVE"])
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "CPU"])
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"])
# Train
parser.add_argument("--parallel", type=ast.literal_eval, default=False)
parser.add_argument("--amp_level", type=str, default="O2", choices=["O0", "O1", "O2", "O3"])
......@@ -53,7 +54,9 @@ def parse_args():
parser.add_argument("--ckpt_file", type=str, default="")
# Export
parser.add_argument("--file_name", type=str, default="Hourglass")
parser.add_argument("--file_format", type=str, default="MINDIR")
parser.add_argument("--file_format", type=str, default="ONNX")
# Onnx
parser.add_argument("--onnx_file", type=str, default="")
# infer_310
parser.add_argument("--result_path", type=str, default="../ascend310_infer/preprocess_Result")
parser.add_argument("--out_path", type=str, default="../ascend310_infer/result_Files")
......
......@@ -181,6 +181,30 @@ def inference(img, net, c, s):
return post_process(det, mat_, "valid", c, s, res)
def onnx_inference(img, session, input_name, c, s):
"""
Onnx inference.
"""
scale_ratio = 200
height, width = img.shape[0:2]
center = (width / 2, height / 2)
scale = max(height, width) / scale_ratio
res = (args.input_res, args.input_res)
mat_ = src.utils.img.get_transform(center, scale, res, scale_ratio)[:2]
inp = img / 255
tmp1 = session.run(None, dict(zip(input_name, [mindspore.Tensor([inp], dtype=mindspore.float32).asnumpy()])))[0]
tmp2 = session.run(None, dict(zip(input_name, [mindspore.Tensor([inp[:, ::-1]],
dtype=mindspore.float32).asnumpy()])))[0]
tmp = np.concatenate((tmp1, tmp2), axis=0)
det = tmp[0, -1] + tmp[1, -1, :, :, ::-1][ds.flipped_parts["mpii"]]
if det is None:
return [], []
det = det / 2
det = np.minimum(det, 1)
return post_process(det, mat_, "valid", c, s, res)
class MPIIEval:
"""
eval for MPII dataset
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment