From dcf27a92ce689b577007890f8caa1d0b66cf2dd8 Mon Sep 17 00:00:00 2001
From: yang-how <985871988@qq.com>
Date: Wed, 6 Jul 2022 18:28:36 +0800
Subject: [PATCH] add glore_res onnx eval

---
 research/cv/glore_res/README_CN.md            | 24 +++++++
 .../config/config_resnet101_gpu.yaml          |  1 +
 .../config/config_resnet200_ascend.yaml       |  1 +
 .../config/config_resnet200_gpu.yaml          |  1 +
 .../config/config_resnet50_ascend.yaml        |  1 +
 .../glore_res/config/config_resnet50_gpu.yaml |  1 +
 research/cv/glore_res/eval_onnx.py            | 68 +++++++++++++++++++
 .../cv/glore_res/scripts/run_eval_onnx.sh     | 30 ++++++++
 research/cv/glore_res/src/dataset.py          |  2 +
 9 files changed, 129 insertions(+)
 create mode 100644 research/cv/glore_res/eval_onnx.py
 create mode 100644 research/cv/glore_res/scripts/run_eval_onnx.sh

diff --git a/research/cv/glore_res/README_CN.md b/research/cv/glore_res/README_CN.md
index 4a7afb8ba..3deb40d12 100644
--- a/research/cv/glore_res/README_CN.md
+++ b/research/cv/glore_res/README_CN.md
@@ -26,6 +26,7 @@
             - [Ascend处理器环境运行](#ascend处理器环境运行-1)
             - [GPU处理器环境运行](#gpu处理器环境运行-1)
     - [推理结果](#推理结果)
+    - [onnx模型导出与推理](#onnx模型导出与推理)
 - [模型描述](#模型描述)
     - [性能](#性能)
         - [训练性能](#训练性能)
@@ -394,6 +395,29 @@ bash run_eval.sh ~/Imagenet/val/  ~/glore_resnet200-150_2502.ckpt ../config/conf
 result:{'top_1 acc':0.802303685897436}
 ```
 
+## onnx模型导出与推理
+
+- 导出 ONNX:  
+
+  ```shell
+  python export.py --config_path /path/to/glore.yaml --ckpt_url /path/to/glore_res50.ckpt --file_name /path/to/glore_res50 --batch_size 1 --file_format ONNX --device_target CPU
+  ```
+
+- 运行推理-python方式:
+
+  ```shell
+  python eval_onnx.py --config_path /path/to/glore.yaml --data_path /path/to/image_val/ --onnx_path /path/to/.onnx --batch_size 1 --device_target GPU > output.eval.log 2>&1
+  ```
+
+- 运行推理-bash方式:
+
+  ```shell
+  # 需要修改对应yaml配置文件的配置项
+  bash scripts/run_eval_onnx.sh /path/to/glore.yaml
+  ```
+
+- 推理结果将存放在 output.eval.log 中.
+
 # 模型描述
 
 ## 性能
diff --git a/research/cv/glore_res/config/config_resnet101_gpu.yaml b/research/cv/glore_res/config/config_resnet101_gpu.yaml
index d72b2bd04..73d2935c4 100644
--- a/research/cv/glore_res/config/config_resnet101_gpu.yaml
+++ b/research/cv/glore_res/config/config_resnet101_gpu.yaml
@@ -12,6 +12,7 @@ output_path: "/cache/train"
 load_path: "/cache/checkpoint_path/"
 device_target: "GPU"
 checkpoint_path: "./checkpoint/"
+onnx_path: "resnet101.onnx"
 
 # ==============================================================================
 # Training options
diff --git a/research/cv/glore_res/config/config_resnet200_ascend.yaml b/research/cv/glore_res/config/config_resnet200_ascend.yaml
index e9c84ef63..b7eee2b93 100644
--- a/research/cv/glore_res/config/config_resnet200_ascend.yaml
+++ b/research/cv/glore_res/config/config_resnet200_ascend.yaml
@@ -12,6 +12,7 @@ output_path: "/cache/train"
 load_path: "/cache/checkpoint_path/"
 device_target: "Ascend"
 checkpoint_path: "./checkpoint/"
+onnx_path: "resnet200.onnx"
 
 # ==============================================================================
 # Training options
diff --git a/research/cv/glore_res/config/config_resnet200_gpu.yaml b/research/cv/glore_res/config/config_resnet200_gpu.yaml
index d4c78f463..2c66bba3b 100644
--- a/research/cv/glore_res/config/config_resnet200_gpu.yaml
+++ b/research/cv/glore_res/config/config_resnet200_gpu.yaml
@@ -12,6 +12,7 @@ output_path: "/cache/train"
 load_path: "/cache/checkpoint_path/"
 device_target: "GPU"
 checkpoint_path: "./checkpoint/"
+onnx_path: "resnet200.onnx"
 
 # ==============================================================================
 # Training options
diff --git a/research/cv/glore_res/config/config_resnet50_ascend.yaml b/research/cv/glore_res/config/config_resnet50_ascend.yaml
index 69b139364..a598f830b 100644
--- a/research/cv/glore_res/config/config_resnet50_ascend.yaml
+++ b/research/cv/glore_res/config/config_resnet50_ascend.yaml
@@ -12,6 +12,7 @@ output_path: "/cache/train"
 load_path: "/cache/checkpoint_path/"
 device_target: "Ascend"
 checkpoint_path: "./checkpoint/"
+onnx_path: "resnet50.onnx"
 
 # ==============================================================================
 # Training options
diff --git a/research/cv/glore_res/config/config_resnet50_gpu.yaml b/research/cv/glore_res/config/config_resnet50_gpu.yaml
index e2eede34e..51b504a57 100644
--- a/research/cv/glore_res/config/config_resnet50_gpu.yaml
+++ b/research/cv/glore_res/config/config_resnet50_gpu.yaml
@@ -12,6 +12,7 @@ output_path: "/cache/train"
 load_path: "/cache/checkpoint_path/"
 device_target: "GPU"
 checkpoint_path: "./checkpoint/"
+onnx_path: "resnet50.onnx"
 
 # ==============================================================================
 # Training options
diff --git a/research/cv/glore_res/eval_onnx.py b/research/cv/glore_res/eval_onnx.py
new file mode 100644
index 000000000..001f77f7b
--- /dev/null
+++ b/research/cv/glore_res/eval_onnx.py
@@ -0,0 +1,68 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Run evaluation for a model exported to ONNX"""
+
+import mindspore.nn as nn
+import onnxruntime as ort
+
+from src.config import config
+
+def create_session(checkpoint_path, target_device):
+    if target_device == 'GPU':
+        providers = ['CUDAExecutionProvider']
+    elif target_device == 'CPU':
+        providers = ['CPUExecutionProvider']
+    else:
+        raise ValueError(
+            f'Unsupported target device {target_device}, '
+            f'Expected one of: "CPU", "GPU"'
+        )
+    session = ort.InferenceSession(checkpoint_path, providers=providers)
+    input_name = session.get_inputs()[0].name
+    return session, input_name
+
+def eval_acc(eval_arg):
+    """get network and init"""
+    session, input_name = create_session(eval_arg.onnx_path, eval_arg.device_target)
+
+    if eval_arg.net == 'resnet50':
+        from src.dataset import create_eval_dataset
+        predict_data = create_eval_dataset(dataset_path=eval_arg.data_path,
+                                           repeat_num=1, batch_size=eval_arg.batch_size)
+    else:
+        from src.dataset import create_dataset_ImageNet as ImageNet
+        predict_data = ImageNet(dataset_path=eval_arg.data_path,
+                                do_train=False,
+                                repeat_num=1,
+                                batch_size=eval_arg.batch_size,
+                                target='CPU')
+
+    metrics = {
+        'top-1 accuracy': nn.Top1CategoricalAccuracy(),
+        'top-5 accuracy': nn.Top5CategoricalAccuracy(),
+    }
+
+    for batch in predict_data.create_dict_iterator(num_epochs=1, output_numpy=True):
+        y_pred = session.run(None, {input_name: batch['image']})[0]
+        for metric in metrics.values():
+            metric.update(y_pred, batch['label'])
+
+    return {name: metric.eval() for name, metric in metrics.items()}
+
+if __name__ == '__main__':
+    results = eval_acc(config)
+
+    for name, value in results.items():
+        print(f'{name}: {value:.4f}')
diff --git a/research/cv/glore_res/scripts/run_eval_onnx.sh b/research/cv/glore_res/scripts/run_eval_onnx.sh
new file mode 100644
index 000000000..78d4b643c
--- /dev/null
+++ b/research/cv/glore_res/scripts/run_eval_onnx.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 1 ]
+then
+    echo "=============================================================================================================="
+    echo "Please run the script as: "
+    echo "bash scripts/run_onnx_eval.sh CONFIG_PATH "
+    echo "for example: bash scripts/run_onnx_eval.sh /path/to/resnet50.yaml "
+    echo "=============================================================================================================="
+exit 1
+fi
+
+CONFIG_PATH=$1
+
+python eval_onnx.py \
+    --config_path=$CONFIG_PATH > output.eval_onnx.log 2>&1
diff --git a/research/cv/glore_res/src/dataset.py b/research/cv/glore_res/src/dataset.py
index 9a6ecc0d0..b283f15e1 100644
--- a/research/cv/glore_res/src/dataset.py
+++ b/research/cv/glore_res/src/dataset.py
@@ -195,6 +195,8 @@ def create_dataset_ImageNet(dataset_path, do_train, use_randaugment=False, repea
         init("nccl")
         rank_id = get_rank()
         device_num = get_group_size()
+    else:
+        device_num = 1
 
     if device_num == 1:
         da = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
-- 
GitLab