Skip to content
Snippets Groups Projects
Commit 32c9cfce authored by kqzhang's avatar kqzhang
Browse files

add onnx

parent 19444b8a
No related branches found
No related tags found
No related merge requests found
......@@ -139,6 +139,7 @@ BraTS 2017原始数据集的文件目录结构如下所示:
│ ├── run_distribute_train_gpu.sh # 启动GPU分布式训练(8卡)
│ ├── run_standalone_train_gpu.sh # 启动GPU单机训练(单卡)
│ ├── run_eval_gpu.sh # 启动GPU评估
│ ├── run_onnx_eval.sh # ONNX推理shell脚本
├── src
│ ├── config.py # yaml文件解析
│ ├── dataset.py # 创建数据集
......@@ -152,6 +153,7 @@ BraTS 2017原始数据集的文件目录结构如下所示:
│ ├── train.txt # 训练数据集
├── train.py # 训练脚本
├── eval.py # 评估脚本
├── eval_onnx.py # ONNX评估脚本
├── export.py # 推理模型导出脚本
├── preprocess.py # 310推理数据预处理
├── postprocess.py # 310推理数据后处理
......@@ -303,7 +305,7 @@ mean dice enhance:
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
参数ckpt_file 是必需的,EXPORT_FORMAT 必须在 ["AIR", "MINDIR"]中进行选择。
参数ckpt_file 是必需的,EXPORT_FORMAT 必须在 ["AIR", "MINDIR", "ONNX"]中进行选择。
## 推理过程
......@@ -322,6 +324,21 @@ python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [
bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [TEST_PATH] [DEVICE_ID]
```
### ONNX推理
在执行推理之前,需要通过export.py导出onnx文件
- 在GPU上使用修正后的BraST 2017训练数据集进行推理
执行推理的命令如下所示,其中`DATA_PATH`是数据集路径;`TEST_PATH`是推理数据路径;`ONNX_PATH`是ONNX文件路径;`DEVICE_ID`可选,默认值为0。
```shell
# ONNX 推理
bash run_onnx_eval.sh [DATA_PATH] [TEST_PATH] [ONNX_PATH] [CONFIG_PATH] [DEVICE_ID]
```
上述python命令将在后台运行,您可以通过eval_onnx.log文件查看结果。
# 模型描述
## 性能
......
......@@ -51,6 +51,9 @@ image_width: 240
image_height: 240
image_channel: 155
# onnx eval
onnx_path: ""
---
# Help description for each configuration
data_path: "The directory of data."
......
......@@ -40,6 +40,9 @@ ckpt_file: "./dense24-5_4200.ckpt"
file_name: "dense24"
file_format: "MINDIR"
# onnx eval
onnx_path: ""
---
# Help description for each configuration
data_path: "The directory of data."
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python eval.py
"""
import numpy as np
import onnxruntime as ort
from src.config import config
from src.dataset import vox_generator_test
def one_hot(label, num_classes):
""" one-hot encode """
label_ = np.zeros([len(label), num_classes])
label_[np.arange(len(label)), label] = 1
return label_
def calculate_dice(true_label, pred_label, num_classes):
"""
calculate dice
Args:
true_label: true sparse labels
pred_label: predict sparse labels
num_classes: number of classes
Returns:
dice evaluation index
"""
true_label = true_label.astype(int)
pred_label = pred_label.astype(int)
true_label = true_label.flatten()
true_label = one_hot(true_label, num_classes)
pred_label = pred_label.flatten()
pred_label = one_hot(pred_label, num_classes)
intersection = np.sum(true_label * pred_label, axis=0)
return (2. * intersection) / (np.sum(true_label, axis=0) + np.sum(pred_label, axis=0))
def create_session(checkpoint_path, target_device):
'''
create onnx session
'''
if target_device == 'GPU':
providers = ['CUDAExecutionProvider']
elif target_device == 'CPU':
providers = ['CPUExecutionProvider']
else:
raise ValueError(f'Unsupported target device {target_device!r}. Expected one of: "CPU", "GPU"')
session = ort.InferenceSession(checkpoint_path, providers=providers)
input_name_1 = session.get_inputs()[0].name
input_name_2 = session.get_inputs()[1].name
return session, input_name_1, input_name_2
if __name__ == '__main__':
# test dataset
test_files = []
with open(config.test_path) as f:
for line in f:
test_files.append(line[:-1])
data_gen_test = vox_generator_test(config.data_path, test_files, config.correction)
# network
network, flair_t2_node_name, t1_t1ce_node_name = create_session(config.onnx_path, config.device_target)
OFFSET_H = config.offset_height
OFFSET_W = config.offset_width
OFFSET_C = config.offset_channel
HSIZE = config.height_size
WSIZE = config.width_size
CSIZE = config.channel_size
PSIZE = config.pred_size
OFFSET_PH = (HSIZE - PSIZE) // 2
OFFSET_PW = (WSIZE - PSIZE) // 2
OFFSET_PC = (CSIZE - PSIZE) // 2
batches_w = int(np.ceil((240 - WSIZE) / float(OFFSET_W))) + 1
batches_h = int(np.ceil((240 - HSIZE) / float(OFFSET_H))) + 1
batches_c = int(np.ceil((155 - CSIZE) / float(OFFSET_C))) + 1
dice_whole, dice_core, dice_et = [], [], []
for i in range(len(test_files)):
print('predicting %s' % test_files[i])
x, x_n, y = data_gen_test.__next__()
pred = np.zeros([240, 240, 155, 5])
for hi in range(batches_h):
offset_h = min(OFFSET_H * hi, 240 - HSIZE)
offset_ph = offset_h + OFFSET_PH
for wi in range(batches_w):
offset_w = min(OFFSET_W * wi, 240 - WSIZE)
offset_pw = offset_w + OFFSET_PW
for ci in range(batches_c):
offset_c = min(OFFSET_C * ci, 155 - CSIZE)
offset_pc = offset_c + OFFSET_PC
data = x[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :]
data_norm = x_n[offset_h:offset_h + HSIZE, offset_w:offset_w + WSIZE, offset_c:offset_c + CSIZE, :]
data_norm = np.expand_dims(data_norm, 0)
if not np.max(data) == 0 and np.min(data) == 0:
flair_t2_node = data_norm[:, :, :, :, :2]
t1_t1ce_node = data_norm[:, :, :, :, 2:]
flair_t2_node = np.transpose(flair_t2_node, axes=[0, 4, 1, 2, 3])
t1_t1ce_node = np.transpose(t1_t1ce_node, axes=[0, 4, 1, 2, 3])
flair_t2_score, t1_t1ce_score = network.run(None, {flair_t2_node_name: flair_t2_node,
t1_t1ce_node_name: t1_t1ce_node})
t1_t1ce_score = np.transpose(t1_t1ce_score, axes=[0, 2, 3, 4, 1])
pred[offset_ph:offset_ph + PSIZE, offset_pw:offset_pw + PSIZE, offset_pc:offset_pc + PSIZE, :] \
+= np.squeeze(t1_t1ce_score)
pred = np.argmax(pred, axis=-1)
pred = pred.astype(int)
print('calculating dice...')
whole_pred = (pred > 0).astype(int)
whole_gt = (y > 0).astype(int)
core_pred = (pred == 1).astype(int) + (pred == 4).astype(int)
core_gt = (y == 1).astype(int) + (y == 4).astype(int)
et_pred = (pred == 4).astype(int)
et_gt = (y == 4).astype(int)
dice_whole_batch = calculate_dice(whole_gt, whole_pred, 2)
dice_core_batch = calculate_dice(core_gt, core_pred, 2)
dice_et_batch = calculate_dice(et_gt, et_pred, 2)
dice_whole.append(dice_whole_batch)
dice_core.append(dice_core_batch)
dice_et.append(dice_et_batch)
print(dice_whole_batch)
print(dice_core_batch)
print(dice_et_batch)
dice_whole = np.array(dice_whole)
dice_core = np.array(dice_core)
dice_et = np.array(dice_et)
print('mean dice whole:')
print(np.mean(dice_whole, axis=0))
print('mean dice core:')
print(np.mean(dice_core, axis=0))
print('mean dice enhance:')
print(np.mean(dice_et, axis=0))
onnxruntime-gpu
numpy
pyyaml
nibabel
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
set -e
if [ $# -ne 5 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bbash run_onnx_eval.sh DATA_PATH TEST_PATH ONNX_PATH CONFIG_PATH DEVICE_ID"
echo "For example: bash run_onnx_eval.sh ./data ./test ./onnx default_config.ymal 0"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
exit 1
fi
DATA_PATH=$1
TEST_PATH=$2
ONNX_PATH=$3
CONFIG_PATH=$4
DEVICE_ID=$5
export DATA_PATH=${DATA_PATH}
export TEST_PATH=${TEST_PATH}
export ONNX_PATH=${ONNX_PATH}
export CONFIG_PATH=${CONFIG_PATH}
export CUDA_VISIBLE_DEVICES=${DEVICE_ID}
if [ ! -d "$DATA_PATH" ]; then
echo "dataset does not exit"
exit
fi
echo "eval_onnx begin."
cd ../
nohup python eval_onnx.py > eval_onnx.log 2>&1 \
--batch_size=1 \
--data_path=$DATA_PATH \
--test_path=$TEST_PATH \
--onnx_path=$ONNX_PATH \
--config_path=$CONFIG_PATH \
--device_id=$DEVICE_ID &
echo "eval_onnx background..."
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment