Skip to content
Snippets Groups Projects
Unverified Commit c74846be authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3130 [武汉理工大学][ONNX][RAS]

Merge pull request !3130 from 杜闯/RAS-ONNX
parents 08de4a0f cc8bae24
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,9 @@
- [用法](#用法-1)
- [结果](#结果-1)
- [评估推理结果](#评估推理结果)
- [ONNX模型导出及评估](#onnx模型导出及评估)
- [ONNX模型导出](#onnx模型导出)
- [ONNX模型评估](#onnx模型评估)
- [模型描述](#模型描述)
- [评估精度](#评估性能)
- [随机情况说明](#随机情况说明)
......@@ -104,6 +107,7 @@ RAS总体网络架构如下:
├──run_distribute_train_gpu.sh # 使用GPU进行多卡训练的shell脚本
├──run_train_gpu.sh # 使用GPU进行单卡训练的shell脚本
├──run_eval_gpu.sh # 使用GPU进行评估的单卡shell脚本
├──run_eval_onnx_gpu.sh # 使用GPU对导出的onnx模型进行评估的单卡shell脚本
├──src
├──dataset_train.py #创建训练数据集
├──dataset_test.py # 创建推理数据集
......@@ -113,7 +117,9 @@ RAS总体网络架构如下:
├──TrainOneStepMyself.py #自定义训练,参数更新过程
├── train.py # 训练脚本
├── eval.py # 推理脚本
├── eval_onnx.py # onnx推理脚本
├── export.py
├── export_onnx.py # onnx导出脚本
```
### 脚本参数
......@@ -131,6 +137,7 @@ RAS总体网络架构如下:
'print_flag' : 20 //训练时每print_flag个step输出一次loss
'device_id' : 5 //训练时硬件的ID
'data_url' : xxx //数据路径
'onnx_file' : xxx //导出的onnx模型路径
'pretrained_model':xxx //resnet50预训练模型路径 在eval该参数为"pre_model"
```
......@@ -254,6 +261,27 @@ The Consumption of per step is 0.136 s
推理完成后,要对结果进行处理,为了方便,已经将评估部分加入到推理中,在推理完成后即可看到
该推理结果的Fmeasure,在推理的log中可以找到
## ONNX模型导出及评估
### ONNX模型导出
```bash
python export_onnx.py --device_target [DEVICE_TARGET] --pre_model [PRE_MODEL] --ckpt_file [CKPT_FILE]
ckpt_file 为训练保存的ckpt路径
pre_model 为网络resnet50预训练模型路径
# example: python export_onnx.py --device_target "GPU" --pre_model resnet50_gpu_v130_imagenet_official_cv_bs32_acc0.ckpt --ckpt_file ras_ascend_v170_dutstrain_research_cv_ECSSD91_DUTStest81_DUTOMRON75_HKUIS90.ckpt
```
### ONNX模型评估
```bash
bash script/run_eval_onnx_gpu.sh [data_url] [save_url] [onnx_file]
data_url 为推理数据路径
save_url 为生成结果图片的路径
onnx_file 为导出的onnx文件路径
# example: bash script/run_eval_onnx_gpu.sh dataset/HKU-IS/ ./output_hku_is ras_onnx.onnx
```
# 模型描述
## 评估精度
......
"""
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
import os
import sys
import argparse
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
import mindspore.ops as ops
from mindspore import Tensor
from src.dataset_test import TrainDataLoader
sys.path.append("../")
# data_url is the directory where the data set is located,
# and there must be two folders, images and gts, under data_url;
parser = argparse.ArgumentParser()
parser.add_argument('--device_target', type=str, default="Ascend", help="Ascend, GPU, CPU")
parser.add_argument('--data_url', type=str)
parser.add_argument('--save_url', type=str)
parser.add_argument('--onnx_file', type=str)
par = parser.parse_args()
def image_loader(imagename):
image = Image.open(imagename).convert("L")
return np.array(image)
def Fmeasure(predict_, groundtruth):
"""
Args:
predict: predict image
gt: ground truth
Returns:
Calculate F-measure
"""
sumLabel = 2 * np.mean(predict_)
if sumLabel > 1:
sumLabel = 1
Label3 = predict_ >= sumLabel
NumRec = np.sum(Label3)
#LabelAnd = (Label3 is True)
LabelAnd = Label3
#NumAnd = np.sum(np.logical_and(LabelAnd, groundtruth))
gt_t = gt > 0.5
NumAnd = np.sum(LabelAnd * gt_t)
num_obj = np.sum(groundtruth)
if NumAnd == 0:
p = 0
r = 0
FmeasureF = 0
else:
p = NumAnd / NumRec
r = NumAnd / num_obj
FmeasureF = (1.3 * p * r) / (0.3 * p + r)
return FmeasureF
def create_session(onnx_checkpoint_path, target_device='GPU'):
if target_device == 'GPU':
providers = ['CUDAExecutionProvider']
elif target_device == 'CPU':
providers = ['CPUExecutionProvider']
else:
raise ValueError(
f'Unsupported target device {target_device}, '
f'Expected one of: "CPU", "GPU"'
)
session = ort.InferenceSession(onnx_checkpoint_path, providers=providers)
input_name = session.get_inputs()[0].name
return session, input_name
if __name__ == "__main__":
filename = os.path.join(par.data_url, 'images/')
gtname = os.path.join(par.data_url, 'gts/')
save_path = par.save_url
if not os.path.exists(save_path):
os.makedirs(save_path)
testdataloader = TrainDataLoader(filename)
sess, input_sess = create_session(par.onnx_file, par.device_target)
Names = []
for data in os.listdir(filename):
name = data.split('.')[0]
Names.append(name)
Names = sorted(Names)
i = 0
sigmoid = ops.Sigmoid()
for data in testdataloader.dataset.create_dict_iterator(output_numpy=True):
data, data_org = data["data"], data["data_org"]
img = sess.run(None, {input_sess: data})[0]
img = Tensor(img)
upsample = ops.ResizeBilinear((data_org.shape[1], data_org.shape[2]), align_corners=False)
img = upsample(img)
img = sigmoid(img)
img = img.asnumpy().squeeze()
img = (img - img.min()) / (img.max() - img.min() + 1e-8)
img = img * 255
data_name = Names[i]
save_path_end = os.path.join(save_path, data_name + '.png')
cv2.imwrite(save_path_end, img)
print("--------------- %d OK ----------------" % i)
i += 1
print("-------------- EVALUATION END --------------------")
predictpath = par.save_url
# calculate F-measure
gtfiles = sorted([gtname + gt_file for gt_file in os.listdir(gtname)])
predictfiles = sorted([os.path.join(predictpath, predictfile) for predictfile in os.listdir(predictpath)])
Fs = []
for i in range(len(gtfiles)):
gt = image_loader(gtfiles[i]) / 255
predict = image_loader(predictfiles[i]) / 255
fmea = Fmeasure(predict, gt)
Fs.append(fmea)
print("Fmeasure is %.3f" % np.mean(Fs))
"""
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import load_checkpoint, load_param_into_net, export
from src.model import BoneModel
def run_export(device_target, device_id, pretrained_model, model_ckpt, batchsize):
ms.context.set_context(mode=ms.context.GRAPH_MODE, device_target=device_target, device_id=device_id)
net = BoneModel(device_target, pretrained_model)
param_dict = load_checkpoint(model_ckpt)
load_param_into_net(net, param_dict)
input_arr = ms.Tensor(np.ones((batchsize, 3, 352, 352)).astype(np.float32))
export(net, input_arr, file_name="ras_onnx", file_format='ONNX')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--device_target', type=str, default='GPU', help="device's name, Ascend,GPU,CPU")
parser.add_argument('--device_id', type=int, default=5, help="Number of device")
parser.add_argument('--batchsize', type=int, default=1, help="training batch size")
parser.add_argument('--pre_model', type=str)
parser.add_argument('--ckpt_file', type=str)
par = parser.parse_args()
run_export(par.device_target, int(par.device_id), par.pre_model, par.ckpt_file, par.batchsize)
numpy
PIL
argparse
\ No newline at end of file
argparse
onnxruntime-gpu
\ No newline at end of file
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "===================================================================================================="
echo "Please run the script as:"
echo "bash script/run_eval_onnx_gpu.sh [data_url] [save_url] [onnx_file]"
echo "for example: bash script/run_eval_onnx_gpu.sh /home/data/Test/ /home/data/results/ /home/data/models/RAS800.onnx"
echo "**********
data_url: The data_url directory is the directory where the dataset is located,and there must be two
folders, images and gts, under data_url;
save_url: This is a save path of evaluation results;
onnx_file: The save path of exported onnx model file.
**********"
echo "===================================================================================================="
exit 1
fi
set -e
rm -rf output_eval_onnx
mkdir output_eval_onnx
data_url=$1
save_url=$2
onnx_file=$3
python3 -u eval_onnx.py --data_url ${data_url} --save_url ${save_url} --onnx_file ${onnx_file} --device_target GPU > output_eval_onnx/eval_onnx_log.log 2>&1 &
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment