From cc8bae24c97817c84d2bfcee31041b34b097ec52 Mon Sep 17 00:00:00 2001 From: chauneahhin <2645168370@qq.com> Date: Fri, 1 Jul 2022 21:36:29 +0800 Subject: [PATCH] =?UTF-8?q?[=E6=AD=A6=E6=B1=89=E7=90=86=E5=B7=A5=E5=A4=A7?= =?UTF-8?q?=E5=AD=A6][ONNX][RAS]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- research/cv/ras/README.md | 28 ++++ research/cv/ras/eval_onnx.py | 146 ++++++++++++++++++++ research/cv/ras/export_onnx.py | 46 ++++++ research/cv/ras/requirements.txt | 3 +- research/cv/ras/script/run_eval_onnx_gpu.sh | 43 ++++++ 5 files changed, 265 insertions(+), 1 deletion(-) create mode 100644 research/cv/ras/eval_onnx.py create mode 100644 research/cv/ras/export_onnx.py create mode 100644 research/cv/ras/script/run_eval_onnx_gpu.sh diff --git a/research/cv/ras/README.md b/research/cv/ras/README.md index 1dc300f6e..a85df4dff 100644 --- a/research/cv/ras/README.md +++ b/research/cv/ras/README.md @@ -19,6 +19,9 @@ - [鐢ㄦ硶](#鐢ㄦ硶-1) - [缁撴灉](#缁撴灉-1) - [璇勪及鎺ㄧ悊缁撴灉](#璇勪及鎺ㄧ悊缁撴灉) + - [ONNX妯″瀷瀵煎嚭鍙婅瘎浼癩(#onnx妯″瀷瀵煎嚭鍙婅瘎浼�) + - [ONNX妯″瀷瀵煎嚭](#onnx妯″瀷瀵煎嚭) + - [ONNX妯″瀷璇勪及](#onnx妯″瀷璇勪及) - [妯″瀷鎻忚堪](#妯″瀷鎻忚堪) - [璇勪及绮惧害](#璇勪及鎬ц兘) - [闅忔満鎯呭喌璇存槑](#闅忔満鎯呭喌璇存槑) @@ -104,6 +107,7 @@ RAS鎬讳綋缃戠粶鏋舵瀯濡備笅: 鈹� 鈹溾攢鈹€run_distribute_train_gpu.sh # 浣跨敤GPU杩涜澶氬崱璁粌鐨剆hell鑴氭湰 鈹� 鈹溾攢鈹€run_train_gpu.sh # 浣跨敤GPU杩涜鍗曞崱璁粌鐨剆hell鑴氭湰 鈹� 鈹溾攢鈹€run_eval_gpu.sh # 浣跨敤GPU杩涜璇勪及鐨勫崟鍗hell鑴氭湰 + 鈹� 鈹溾攢鈹€run_eval_onnx_gpu.sh # 浣跨敤GPU瀵瑰鍑虹殑onnx妯″瀷杩涜璇勪及鐨勫崟鍗hell鑴氭湰 鈹溾攢鈹€src 鈹� 鈹溾攢鈹€dataset_train.py #鍒涘缓璁粌鏁版嵁闆� 鈹� 鈹溾攢鈹€dataset_test.py # 鍒涘缓鎺ㄧ悊鏁版嵁闆� @@ -113,7 +117,9 @@ RAS鎬讳綋缃戠粶鏋舵瀯濡備笅: 鈹� 鈹溾攢鈹€TrainOneStepMyself.py #鑷畾涔夎缁冿紝鍙傛暟鏇存柊杩囩▼ 鈹溾攢鈹€ train.py # 璁粌鑴氭湰 鈹溾攢鈹€ eval.py # 鎺ㄧ悊鑴氭湰 + 鈹溾攢鈹€ eval_onnx.py # onnx鎺ㄧ悊鑴氭湰 鈹溾攢鈹€ export.py + 鈹溾攢鈹€ export_onnx.py # onnx瀵煎嚭鑴氭湰 ``` ### 鑴氭湰鍙傛暟 @@ -131,6 +137,7 @@ RAS鎬讳綋缃戠粶鏋舵瀯濡備笅: 'print_flag' : 20 //璁粌鏃舵瘡print_flag涓猻tep杈撳嚭涓€娆oss 'device_id' : 5 //璁粌鏃剁‖浠剁殑ID 'data_url' : xxx //鏁版嵁璺緞 + 'onnx_file' : xxx //瀵煎嚭鐨刼nnx妯″瀷璺緞 'pretrained_model':xxx //resnet50棰勮缁冩ā鍨嬭矾寰� 鍦╡val璇ュ弬鏁颁负"pre_model" ``` @@ -254,6 +261,27 @@ The Consumption of per step is 0.136 s 鎺ㄧ悊瀹屾垚鍚庯紝瑕佸缁撴灉杩涜澶勭悊锛屼负浜嗘柟渚匡紝宸茬粡灏嗚瘎浼伴儴鍒嗗姞鍏ュ埌鎺ㄧ悊涓紝鍦ㄦ帹鐞嗗畬鎴愬悗鍗冲彲鐪嬪埌 璇ユ帹鐞嗙粨鏋滅殑Fmeasure锛屽湪鎺ㄧ悊鐨刲og涓彲浠ユ壘鍒� +## ONNX妯″瀷瀵煎嚭鍙婅瘎浼� + +### ONNX妯″瀷瀵煎嚭 + +```bash + python export_onnx.py --device_target [DEVICE_TARGET] --pre_model [PRE_MODEL] --ckpt_file [CKPT_FILE] + ckpt_file 涓鸿缁冧繚瀛樼殑ckpt璺緞 + pre_model 涓虹綉缁渞esnet50棰勮缁冩ā鍨嬭矾寰� + # example: python export_onnx.py --device_target "GPU" --pre_model resnet50_gpu_v130_imagenet_official_cv_bs32_acc0.ckpt --ckpt_file ras_ascend_v170_dutstrain_research_cv_ECSSD91_DUTStest81_DUTOMRON75_HKUIS90.ckpt +``` + +### ONNX妯″瀷璇勪及 + +```bash + bash script/run_eval_onnx_gpu.sh [data_url] [save_url] [onnx_file] + data_url 涓烘帹鐞嗘暟鎹矾寰� + save_url 涓虹敓鎴愮粨鏋滃浘鐗囩殑璺緞 + onnx_file 涓哄鍑虹殑onnx鏂囦欢璺緞 + # example: bash script/run_eval_onnx_gpu.sh dataset/HKU-IS/ ./output_hku_is ras_onnx.onnx +``` + # 妯″瀷鎻忚堪 ## 璇勪及绮惧害 diff --git a/research/cv/ras/eval_onnx.py b/research/cv/ras/eval_onnx.py new file mode 100644 index 000000000..3fcd2bbd7 --- /dev/null +++ b/research/cv/ras/eval_onnx.py @@ -0,0 +1,146 @@ +""" +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" + +import os +import sys +import argparse +import cv2 +import numpy as np +from PIL import Image +import onnxruntime as ort +import mindspore.ops as ops +from mindspore import Tensor + +from src.dataset_test import TrainDataLoader + +sys.path.append("../") + + +# data_url is the directory where the data set is located, +# and there must be two folders, images and gts, under data_url; + + +parser = argparse.ArgumentParser() +parser.add_argument('--device_target', type=str, default="Ascend", help="Ascend, GPU, CPU") +parser.add_argument('--data_url', type=str) +parser.add_argument('--save_url', type=str) +parser.add_argument('--onnx_file', type=str) + +par = parser.parse_args() + + +def image_loader(imagename): + image = Image.open(imagename).convert("L") + return np.array(image) + + +def Fmeasure(predict_, groundtruth): + """ + + Args: + predict: predict image + gt: ground truth + + Returns: + Calculate F-measure + """ + sumLabel = 2 * np.mean(predict_) + if sumLabel > 1: + sumLabel = 1 + Label3 = predict_ >= sumLabel + NumRec = np.sum(Label3) + #LabelAnd = (Label3 is True) + LabelAnd = Label3 + #NumAnd = np.sum(np.logical_and(LabelAnd, groundtruth)) + gt_t = gt > 0.5 + NumAnd = np.sum(LabelAnd * gt_t) + num_obj = np.sum(groundtruth) + if NumAnd == 0: + p = 0 + r = 0 + FmeasureF = 0 + else: + p = NumAnd / NumRec + r = NumAnd / num_obj + FmeasureF = (1.3 * p * r) / (0.3 * p + r) + return FmeasureF + + +def create_session(onnx_checkpoint_path, target_device='GPU'): + if target_device == 'GPU': + providers = ['CUDAExecutionProvider'] + elif target_device == 'CPU': + providers = ['CPUExecutionProvider'] + else: + raise ValueError( + f'Unsupported target device {target_device}, ' + f'Expected one of: "CPU", "GPU"' + ) + session = ort.InferenceSession(onnx_checkpoint_path, providers=providers) + + input_name = session.get_inputs()[0].name + return session, input_name + + +if __name__ == "__main__": + filename = os.path.join(par.data_url, 'images/') + gtname = os.path.join(par.data_url, 'gts/') + save_path = par.save_url + if not os.path.exists(save_path): + os.makedirs(save_path) + + testdataloader = TrainDataLoader(filename) + + sess, input_sess = create_session(par.onnx_file, par.device_target) + + Names = [] + for data in os.listdir(filename): + name = data.split('.')[0] + Names.append(name) + Names = sorted(Names) + i = 0 + sigmoid = ops.Sigmoid() + for data in testdataloader.dataset.create_dict_iterator(output_numpy=True): + data, data_org = data["data"], data["data_org"] + img = sess.run(None, {input_sess: data})[0] + img = Tensor(img) + upsample = ops.ResizeBilinear((data_org.shape[1], data_org.shape[2]), align_corners=False) + img = upsample(img) + img = sigmoid(img) + img = img.asnumpy().squeeze() + img = (img - img.min()) / (img.max() - img.min() + 1e-8) + img = img * 255 + data_name = Names[i] + save_path_end = os.path.join(save_path, data_name + '.png') + cv2.imwrite(save_path_end, img) + print("--------------- %d OK ----------------" % i) + i += 1 + print("-------------- EVALUATION END --------------------") + predictpath = par.save_url + + # calculate F-measure + gtfiles = sorted([gtname + gt_file for gt_file in os.listdir(gtname)]) + predictfiles = sorted([os.path.join(predictpath, predictfile) for predictfile in os.listdir(predictpath)]) + + Fs = [] + for i in range(len(gtfiles)): + gt = image_loader(gtfiles[i]) / 255 + predict = image_loader(predictfiles[i]) / 255 + fmea = Fmeasure(predict, gt) + Fs.append(fmea) + + print("Fmeasure is %.3f" % np.mean(Fs)) diff --git a/research/cv/ras/export_onnx.py b/research/cv/ras/export_onnx.py new file mode 100644 index 000000000..70197f75c --- /dev/null +++ b/research/cv/ras/export_onnx.py @@ -0,0 +1,46 @@ +""" +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" + +import argparse +import numpy as np +import mindspore as ms +from mindspore import load_checkpoint, load_param_into_net, export +from src.model import BoneModel + + +def run_export(device_target, device_id, pretrained_model, model_ckpt, batchsize): + ms.context.set_context(mode=ms.context.GRAPH_MODE, device_target=device_target, device_id=device_id) + net = BoneModel(device_target, pretrained_model) + param_dict = load_checkpoint(model_ckpt) + load_param_into_net(net, param_dict) + input_arr = ms.Tensor(np.ones((batchsize, 3, 352, 352)).astype(np.float32)) + + export(net, input_arr, file_name="ras_onnx", file_format='ONNX') + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--device_target', type=str, default='GPU', help="device's name, Ascend,GPU,CPU") + parser.add_argument('--device_id', type=int, default=5, help="Number of device") + parser.add_argument('--batchsize', type=int, default=1, help="training batch size") + parser.add_argument('--pre_model', type=str) + parser.add_argument('--ckpt_file', type=str) + par = parser.parse_args() + + + run_export(par.device_target, int(par.device_id), par.pre_model, par.ckpt_file, par.batchsize) diff --git a/research/cv/ras/requirements.txt b/research/cv/ras/requirements.txt index 608b88228..7373800ac 100644 --- a/research/cv/ras/requirements.txt +++ b/research/cv/ras/requirements.txt @@ -1,3 +1,4 @@ numpy PIL -argparse \ No newline at end of file +argparse +onnxruntime-gpu \ No newline at end of file diff --git a/research/cv/ras/script/run_eval_onnx_gpu.sh b/research/cv/ras/script/run_eval_onnx_gpu.sh new file mode 100644 index 000000000..8d4e8b32c --- /dev/null +++ b/research/cv/ras/script/run_eval_onnx_gpu.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] +then + echo "====================================================================================================" + echo "Please run the script as:" + echo "bash script/run_eval_onnx_gpu.sh [data_url] [save_url] [onnx_file]" + echo "for example: bash script/run_eval_onnx_gpu.sh /home/data/Test/ /home/data/results/ /home/data/models/RAS800.onnx" + echo "********** + data_url: The data_url directory is the directory where the dataset is located,and there must be two + folders, images and gts, under data_url; + save_url: This is a save path of evaluation results; + onnx_file: The save path of exported onnx model file. +**********" + echo "====================================================================================================" +exit 1 +fi + +set -e +rm -rf output_eval_onnx +mkdir output_eval_onnx + +data_url=$1 +save_url=$2 +onnx_file=$3 + +python3 -u eval_onnx.py --data_url ${data_url} --save_url ${save_url} --onnx_file ${onnx_file} --device_target GPU > output_eval_onnx/eval_onnx_log.log 2>&1 & + -- GitLab