diff --git a/research/cv/PAMTRI/MultiTaskNet/README_CN.md b/research/cv/PAMTRI/MultiTaskNet/README_CN.md
index c5535a52ae7107966fc9754ab872315a3c0ccb94..8b4db28f179c6c499dabc4d68f77098259da9adc 100644
--- a/research/cv/PAMTRI/MultiTaskNet/README_CN.md
+++ b/research/cv/PAMTRI/MultiTaskNet/README_CN.md
@@ -181,6 +181,7 @@ bash scripts/run_eval_gpu.sh [DATASET_NAME] [CKPT_PATH] [DEVICE_ID] [HEATMAP_SEG
             |   鈹溾攢鈹€ run_single_train_gpu.sh             // 鍗曞崱鍒癎PU鐨剆hell鑴氭湰
             |   鈹溾攢鈹€ run_eval_ascend.sh                  // Ascend璇勪及鐨剆hell鑴氭湰
             |   鈹溾攢鈹€ run_eval_gpu.sh                     // GPU璇勪及鐨剆hell鑴氭湰
+            |   鈹溾€斺€� run_onnx_eval_gpu.sh                // ONNX鎺ㄧ悊shell鑴氭湰
             |   鈹溾攢鈹€ run_infer_310.sh                    // Ascend鎺ㄧ悊shell鑴氭湰
             鈹溾攢鈹€ src
             |   鈹溾攢鈹€ dataset
@@ -201,6 +202,7 @@ bash scripts/run_eval_gpu.sh [DATASET_NAME] [CKPT_PATH] [DEVICE_ID] [HEATMAP_SEG
             |   |   鈹溾攢鈹€ save_callback.py                // 杈硅缁冭竟鎺ㄧ悊鐨勫疄鐜�
             |   |   鈹溾攢鈹€ pthtockpt.py                    // pth鏍煎紡鐨勯璁粌妯″瀷杞崲涓篶kpt
             鈹溾攢鈹€ eval.py                             // 绮惧害楠岃瘉鑴氭湰
+            鈹溾€斺€� eval_onnx.py                        // ONNX绮惧害楠岃瘉鑴氭湰
             鈹溾攢鈹€ train.py                            // 璁粌鑴氭湰
             鈹溾攢鈹€ export.py                           // 鎺ㄧ悊妯″瀷瀵煎嚭鑴氭湰
             鈹溾攢鈹€ preprocess.py                       // 310鎺ㄧ悊鍓嶅鐞嗚剼鏈�
@@ -440,6 +442,18 @@ python export.py --root /path/dataset --ckpt_path /path/ckpt --segmentaware --he
     bash scripts/run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_TARGET] [DEVICE_ID] [NEED_HEATMAP] [NEED_SEGMENT]
     ```
 
+### ONNX鎺ㄧ悊
+
+鍦ㄨ繘琛屾帹鐞嗕箣鍓嶆垜浠渶瑕佸厛瀵煎嚭妯″瀷銆�
+
+- 鍦℅PU鐜涓婁娇鐢╒eRi鏁版嵁闆嗚繘琛屾帹鐞�
+
+    鎵ц鎺ㄧ悊鐨勫懡浠ゅ涓嬫墍绀猴紝鍏朵腑`ONNX_PATH`鏄痮nnx鏂囦欢璺緞锛沗DATASET_PATH`鏄帹鐞嗘暟鎹泦璺緞锛沗DEVICE_ID`鍙€夛紝榛樿鍊间负0锛沗HEATMAP_SEGMEN` 琛ㄧず閫夋嫨浣跨敤鐑浘杩樻槸鍒嗘銆�
+
+    ```bash
+    bash run_onnx_eval_gpu.sh DATASET_NAME ONNX_PATH DEVICE_ID HEATMAP_SEGMENT
+    ```
+
 # 妯″瀷鎻忚堪
 
 ## 鎬ц兘
diff --git a/research/cv/PAMTRI/MultiTaskNet/eval_onnx.py b/research/cv/PAMTRI/MultiTaskNet/eval_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fc5b959f452a1aea67bc41614d1821881ad7bfa
--- /dev/null
+++ b/research/cv/PAMTRI/MultiTaskNet/eval_onnx.py
@@ -0,0 +1,68 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""MultiTaskNet onnx_eval"""
+import ast
+import argparse
+from src.utils.evaluate import onnx_test
+from src.dataset.dataset import eval_create_dataset
+import onnxruntime as ort
+parser = argparse.ArgumentParser(description='eval MultiTaskNet')
+
+parser.add_argument('--device_target', type=str, default="GPU")
+parser.add_argument('--device_id', type=int, default=0)
+parser.add_argument('--root', type=str, default='./data', help="root path to data directory")
+parser.add_argument('-d', '--dataset', type=str, default='veri', help="name of the dataset")
+parser.add_argument('--height', type=int, default=256, help="height of an image (default: 256)")
+parser.add_argument('--width', type=int, default=256, help="width of an image (default: 256)")
+parser.add_argument('--test-batch', default=1, type=int, help="test batch size")
+parser.add_argument('--heatmapaware', type=ast.literal_eval, default=False, help="embed heatmaps to images")
+parser.add_argument('--segmentaware', type=ast.literal_eval, default=False, help="embed segments to images")
+parser.add_argument('--onnx_path', type=str, default='')
+args = parser.parse_args()
+
+def create_session(onnx_path, target_device):
+    if target_device == 'GPU':
+        providers = ['CUDAExecutionProvider']
+    elif target_device == 'CPU':
+        providers = ['CPUExecutionProvider']
+    else:
+        raise ValueError(f'Unsupported target device {target_device!r}. Expected one of: "CPU", "GPU"')
+
+    sessions = ort.InferenceSession(onnx_path, providers=providers)
+    input_names1 = sessions.get_inputs()[0].name
+    input_names2 = sessions.get_inputs()[1].name
+    input_names = [input_names1, input_names2]
+    return sessions, input_names
+
+if __name__ == '__main__':
+    target = args.device_target
+    train_dataset_path = args.root
+
+    query_dataloader, gallery_dataloader, num_train_vids, \
+        num_train_vcolors, num_train_vtypes, _vcolor2label, \
+            _vtype2label = eval_create_dataset(dataset_dir=args.dataset,
+                                               root=train_dataset_path,
+                                               width=args.width,
+                                               height=args.height,
+                                               keyptaware=True,
+                                               heatmapaware=args.heatmapaware,
+                                               segmentaware=args.segmentaware,
+                                               train_batch=args.test_batch)
+
+    session, input_name = create_session(args.onnx_path, 'GPU')
+
+    _distmat = onnx_test(session, input_name, True, True, query_dataloader, gallery_dataloader, _vcolor2label,
+                         _vtype2label, return_distmat=True)
diff --git a/research/cv/PAMTRI/MultiTaskNet/requirements.txt b/research/cv/PAMTRI/MultiTaskNet/requirements.txt
index 88ef814d79e6464f4dd1007e8d85083501e2b6dd..44584dbe8e914a74b928330af44883afe954b17f 100644
--- a/research/cv/PAMTRI/MultiTaskNet/requirements.txt
+++ b/research/cv/PAMTRI/MultiTaskNet/requirements.txt
@@ -4,4 +4,5 @@ numpy==1.14.5
 Pillow==5.1.0
 scipy==1.2.0
 opencv-python
-matplotlib
\ No newline at end of file
+matplotlib
+onnxruntime-gpu
\ No newline at end of file
diff --git a/research/cv/PAMTRI/MultiTaskNet/scripts/run_onnx_eval_gpu.sh b/research/cv/PAMTRI/MultiTaskNet/scripts/run_onnx_eval_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c1d6c14bdd778bfcfe399ee8fe0396c02977617a
--- /dev/null
+++ b/research/cv/PAMTRI/MultiTaskNet/scripts/run_onnx_eval_gpu.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "bash run_eval_gpu.sh DATASET_NAME ONNX_PATH DEVICE_ID HEATMAP_SEGMENT"
+echo "For example: bash run_eval_gpu.sh ../data/ ./*.onnx 0 h"
+echo "It is better to use the absolute path."
+echo "=============================================================================================================="
+
+if [ $# != 4 ]; then
+  echo "bash run_eval_gpu.sh DATASET_NAME ONNX_PATH DEVICE_ID HEATMAP_SEGMENT"
+  exit 1
+fi
+
+set -e
+get_real_path(){
+  if [ "${1:0:1}" == "/" ]; then
+    echo "$1"
+  else
+    echo "$(realpath -m $PWD/$1)"
+  fi
+}
+DATA_PATH=$(get_real_path $1)
+ONNX_PATH=$(get_real_path $2)
+PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
+if [ "$4" == "h" ] || [ "$4" == "s" ];then
+    if [ "$4" == "h" ];then
+      need_heatmap=True
+      need_segment=False
+    else
+      need_heatmap=False
+      need_segment=True
+    fi
+else
+    echo "heatmap_segment must be h or s"
+    exit 1
+fi
+
+EXEC_PATH=$(pwd)
+echo "$EXEC_PATH"
+
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+export RANK_SIZE=1
+
+python ${PROJECT_DIR}/../eval_onnx.py --root ${DATA_PATH} \
+  --onnx_path ${ONNX_PATH} \
+  --device_id $3 \
+  --device_target GPU \
+  --heatmapaware ${need_heatmap} \
+  --segmentaware ${need_segment} > ${PROJECT_DIR}/../eval_onnx_gpu.log 2>&1 &
diff --git a/research/cv/PAMTRI/MultiTaskNet/src/utils/evaluate.py b/research/cv/PAMTRI/MultiTaskNet/src/utils/evaluate.py
index c059ae6fdbbc59443506301a3d9c8e2dd1f44434..5e218d6eedd17cfceca36b23a705ce6dc5bf768e 100644
--- a/research/cv/PAMTRI/MultiTaskNet/src/utils/evaluate.py
+++ b/research/cv/PAMTRI/MultiTaskNet/src/utils/evaluate.py
@@ -188,3 +188,120 @@ def test(model, keyptaware, multitask, queryloader, galleryloader,
     if return_distmat:
         return distmat
     return cmc[0]
+
+def onnx_test(InferenceSession, input_name, keyptaware, multitask, queryloader, galleryloader,
+              vcolor2label, vtype2label, ranks=range(1, 51), return_distmat=False):
+    """function eval"""
+
+    qf = []
+    q_vids = []
+    q_camids = []
+    q_vcolors = []
+    q_vtypes = []
+    pred_q_vcolors = []
+    pred_q_vtypes = []
+    for _, data in enumerate(queryloader.create_dict_iterator()):
+        imgs = data["img"]
+        vids = data["vid"]
+        camids = data["camid"]
+        vcolors = data["vcolor"]
+        vtypes = data["vtype"]
+        vkeypts = data["vkeypt"]
+
+        _, output_vcolors, output_vtypes, features = InferenceSession.run(None, {input_name[0]: imgs.asnumpy(),
+                                                                                 input_name[1]: vkeypts.asnumpy()})
+
+        qf.append(features)
+        q_vids.extend(vids.asnumpy())
+        q_camids.extend(camids.asnumpy())
+        q_vcolors.extend(vcolors.asnumpy())
+        q_vtypes.extend(vtypes.asnumpy())
+        pred_q_vcolors.extend(output_vcolors)
+        pred_q_vtypes.extend(output_vtypes)
+
+    qf = np.concatenate(qf, axis=0)
+    q_vids = np.asarray(q_vids)
+    q_camids = np.asarray(q_camids)
+    q_vcolors = np.asarray(q_vcolors)
+    q_vtypes = np.asarray(q_vtypes)
+    pred_q_vcolors = np.asarray(pred_q_vcolors)
+    pred_q_vtypes = np.asarray(pred_q_vtypes)
+
+    print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.shape[0], qf.shape[1]))
+
+    gf = []
+    g_vids = []
+    g_camids = []
+    g_vcolors = []
+    g_vtypes = []
+    pred_g_vcolors = []
+    pred_g_vtypes = []
+    for _, data in enumerate(galleryloader.create_dict_iterator()):
+        imgs = data["img"]
+        vids = data["vid"]
+        camids = data["camid"]
+        vcolors = data["vcolor"]
+        vtypes = data["vtype"]
+        vkeypts = data["vkeypt"]
+
+        _, output_vcolors, output_vtypes, features = InferenceSession.run(None, {input_name[0]: imgs.asnumpy(),
+                                                                                 input_name[1]: vkeypts.asnumpy()})
+        gf.append(features)
+        g_vids.extend(vids.asnumpy())
+        g_camids.extend(camids.asnumpy())
+        g_vcolors.extend(vcolors.asnumpy())
+        g_vtypes.extend(vtypes.asnumpy())
+        pred_g_vcolors.extend(output_vcolors)
+        pred_g_vtypes.extend(output_vtypes)
+
+    gf = np.concatenate(gf, axis=0)
+    g_vids = np.asarray(g_vids)
+    g_camids = np.asarray(g_camids)
+    g_vcolors = np.asarray(g_vcolors)
+    g_vtypes = np.asarray(g_vtypes)
+    pred_g_vcolors = np.asarray(pred_g_vcolors)
+    pred_g_vtypes = np.asarray(pred_g_vtypes)
+
+    print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.shape[0], gf.shape[1]))
+
+    m, n = qf.shape[0], gf.shape[0]
+    distmat = np.broadcast_to(np.power(qf, 2).sum(axis=1, keepdims=True), (m, n)) + \
+              np.broadcast_to(np.power(gf, 2).sum(axis=1, keepdims=True), (n, m)).T
+
+    distmat = distmat * 1 + (-2) * (np.matmul(qf, gf.T))
+
+    print("Computing CMC and mAP")
+    cmc, mAP = evaluate(distmat, q_vids, g_vids, q_camids, g_camids)
+
+    print("Results ----------")
+    print("mAP: {:.2%}".format(mAP))
+    print("CMC curve")
+    for r in ranks:
+        print("Rank-{:<3}: {:.2%}".format(r, cmc[r-1]))
+    print("------------------")
+
+    print("Compute attribute classification accuracy")
+
+    for q in range(q_vcolors.size):
+        q_vcolors[q] = vcolor2label[q_vcolors[q]]
+    for g in range(g_vcolors.size):
+        g_vcolors[g] = vcolor2label[g_vcolors[g]]
+    q_vcolor_errors = np.argmax(pred_q_vcolors, axis=1) - q_vcolors
+    g_vcolor_errors = np.argmax(pred_g_vcolors, axis=1) - g_vcolors
+    vcolor_error_num = np.count_nonzero(q_vcolor_errors) + np.count_nonzero(g_vcolor_errors)
+    vcolor_accuracy = 1.0 - (float(vcolor_error_num) / float(distmat.shape[0] + distmat.shape[1]))
+    print("Color classification accuracy: {:.2%}".format(vcolor_accuracy))
+
+    for q in range(q_vtypes.size):
+        q_vtypes[q] = vcolor2label[q_vtypes[q]]
+    for g in range(g_vtypes.size):
+        g_vtypes[g] = vcolor2label[g_vtypes[g]]
+    q_vtype_errors = np.argmax(pred_q_vtypes, axis=1) - q_vtypes
+    g_vtype_errors = np.argmax(pred_g_vtypes, axis=1) - g_vtypes
+    vtype_error_num = np.count_nonzero(q_vtype_errors) + np.count_nonzero(g_vtype_errors)
+    vtype_accuracy = 1.0 - (float(vtype_error_num) / float(distmat.shape[0] + distmat.shape[1]))
+    print("Type classification accuracy: {:.2%}".format(vtype_accuracy))
+
+    if return_distmat:
+        return distmat
+    return cmc[0]
diff --git a/research/cv/PAMTRI/PoseEstNet/README_CN.md b/research/cv/PAMTRI/PoseEstNet/README_CN.md
index 236091997764f9946cfe6992b8fd0fd9e92b3dc7..754f286297e926cc008326c4f7c10bf180856cb2 100644
--- a/research/cv/PAMTRI/PoseEstNet/README_CN.md
+++ b/research/cv/PAMTRI/PoseEstNet/README_CN.md
@@ -153,6 +153,7 @@ bash run_trans_gpu.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID]
             |   鈹溾攢鈹€ run_distribute_train_gpu.sh     // 鍒嗗竷寮忓埌GPU鐨剆hell鑴氭湰
             |   鈹溾攢鈹€ run_eval_ascend.sh              // Ascend璇勪及鐨剆hell鑴氭湰
             |   鈹溾攢鈹€ run_eval_gpu.sh                 // GPU璇勪及鐨剆hell鑴氭湰
+            |   鈹溾€斺€� run_onnx_eval_gpu.sh            // ONNX鎺ㄧ悊shell鑴氭湰
             |   鈹溾攢鈹€ run_trans_ascend.sh             // Ascend鐜涓嬬敓鎴愭暟鎹泦鑴氭湰
             |   鈹溾攢鈹€ run_trans_gpu.sh                // GPU鐜涓嬬敓鎴愭暟鎹泦鑴氭湰
             |   鈹溾攢鈹€ run_infer_310.sh                // Ascend鎺ㄧ悊shell鑴氭湰
@@ -186,6 +187,7 @@ bash run_trans_gpu.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID]
             鈹溾攢鈹€ config.yaml                         // 鍥哄畾鍙傛暟
             鈹溾攢鈹€ config_gpu.yaml                     // 鍥哄畾鍙傛暟锛坓pu锛�
             鈹溾攢鈹€ eval.py                             // 绮惧害楠岃瘉鑴氭湰
+            鈹溾€斺€� eval_onnx.py                        // ONNX绮惧害楠岃瘉鑴氭湰
             鈹溾攢鈹€ train.py                            // 璁粌鑴氭湰
             鈹溾攢鈹€ trans.py                            // 杞崲鏁版嵁闆嗚剼鏈�
             鈹溾攢鈹€ export.py                           // 鎺ㄧ悊妯″瀷瀵煎嚭鑴氭湰
@@ -329,6 +331,25 @@ epcoh:2 step:, loss is 3.0897788
   | pose_hrnet | 85.968 | 81.682 | 70.630 | 76.568 | 86.492 | 89.771 | 82.577 | 16.681 |
   ```
 
+### ONNX鎺ㄧ悊
+
+鍦ㄨ繘琛屾帹鐞嗕箣鍓嶆垜浠渶瑕佸厛瀵煎嚭妯″瀷銆侽NNX鍙互鍦℅PU鐜涓嬪鍑恒€�
+
+- 鍦℅PU鐜涓婁娇鐢╒eRi鏁版嵁闆嗚繘琛屾帹鐞�
+    鎵ц鎺ㄧ悊鐨勫懡浠ゅ涓嬫墍绀猴紝鍏朵腑'ONNX_PATH'鏄痮nnx鏂囦欢璺緞锛�'DATA_PATH'鏄帹鐞嗘暟鎹泦璺緞锛�'DEVICE_TARGET'鏄澶囩被鍨嬶紝榛樿'GPU'銆�
+
+   ```bash
+   bash run_eval_gpu.sh DATA_PATH ONNX_PATH DEVICE_ID
+   ```
+
+    鎺ㄧ悊鐨勭簿搴︾粨鏋滀繚瀛樺湪scripts鐩綍涓嬶紝鍦╫nnx_eval.log鏃ュ織鏂囦欢涓彲浠ユ壘鍒扮被浼间互涓嬬殑鍒嗙被鍑嗙‘鐜囩粨鏋溿€�
+
+    ```bash
+   | Arch | Wheel | Fender | Back | Front | WindshieldBack | WindshieldFront | Mean | Mean@0.1 |
+   | --- | --- | --- | --- | --- | --- | --- | --- | --- |
+   | pose_hrnet | 85.968 | 81.669 | 70.630 | 77.568 | 86.492 | 89.771 | 82.574 | 16.681 |
+    ```
+
 ## 瀵煎嚭杩囩▼
 
 ### 瀵煎嚭
diff --git a/research/cv/PAMTRI/PoseEstNet/eval_onnx.py b/research/cv/PAMTRI/PoseEstNet/eval_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a681518d77aca042f5d3d9c046e638f6edc8970
--- /dev/null
+++ b/research/cv/PAMTRI/PoseEstNet/eval_onnx.py
@@ -0,0 +1,67 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+########################## eval_onnx PoseEstNet ##########################
+eval lenet according to model file:
+python eval_onnx.py --cfg config.yaml --data_dir datapath --onnx_path onnxpath
+"""
+import json
+import argparse
+from pathlib import Path
+import onnxruntime as ort
+from src.dataset import create_dataset, get_label
+from src.utils.function import onnx_validate
+from src.config import cfg, update_config
+
+parser = argparse.ArgumentParser(description='Eval PoseEstNet')
+
+parser.add_argument('--cfg', type=str, default='')
+parser.add_argument('--data_dir', type=str, default='')
+parser.add_argument('--onnx_path', type=str, default='')
+parser.add_argument('--device_target', type=str, default="GPU")
+
+args = parser.parse_args()
+
+def create_session(onnx_path, target_device, is_train):
+    if target_device == 'GPU':
+        providers = ['CUDAExecutionProvider']
+    elif target_device == 'CPU':
+        providers = ['CPUExecutionProvider']
+    else:
+        raise ValueError(f'Unsupported target device {target_device!r}. Expected one of: "CPU", "GPU"')
+    sessions = ort.InferenceSession(onnx_path, providers=providers)
+    input_names = sessions.get_inputs()[0].name
+    onnx_train = is_train
+    return sessions, input_names, onnx_train
+
+
+if __name__ == '__main__':
+    update_config(cfg, args)
+    target = args.device_target
+    data, dataset = create_dataset(cfg, args.data_dir, is_train=False)
+    json_path = get_label(cfg, args.data_dir)
+    dst_json_path = Path(json_path)
+
+    with dst_json_path.open('r') as dst_file:
+        allImage = json.load(dst_file)
+
+
+    print(target)
+    session, input_name, train = create_session(args.onnx_path, target, False)
+
+    print("============== Starting Testing ==============")
+
+    perf_indicator = onnx_validate(cfg, dataset, data, session, input_name, train, allImage)
diff --git a/research/cv/PAMTRI/PoseEstNet/requirements.txt b/research/cv/PAMTRI/PoseEstNet/requirements.txt
index 44b48adcd6b23e41f2e7e9fb81623e9571a1ca7e..54f47bf0435ad6aa3db28cfcbc083b76e7f2939e 100644
--- a/research/cv/PAMTRI/PoseEstNet/requirements.txt
+++ b/research/cv/PAMTRI/PoseEstNet/requirements.txt
@@ -8,3 +8,4 @@ pyyaml
 json_tricks
 scikit-image
 yacs
+onnxruntime-gpu
\ No newline at end of file
diff --git a/research/cv/PAMTRI/PoseEstNet/scripts/run_onnx_eval_gpu.sh b/research/cv/PAMTRI/PoseEstNet/scripts/run_onnx_eval_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..688582f246321eccd54ac78002392807ee1e8a9d
--- /dev/null
+++ b/research/cv/PAMTRI/PoseEstNet/scripts/run_onnx_eval_gpu.sh
@@ -0,0 +1,49 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "bash run_eval_gpu.sh DATA_PATH ONNX_PATH DEVICE_ID"
+echo "For example: bash run_eval_gpu.sh /path/dataset /path/onnx 0"
+echo "It is better to use the absolute path."
+echo "=============================================================================================================="
+
+if [ $# != 3 ]; then
+  echo "bash run_eval_gpu.sh DATA_PATH ONNX_PATH DEVICE_ID"
+  exit 1
+fi
+
+set -e
+get_real_path(){
+  if [ "${1:0:1}" == "/" ]; then
+    echo "$1"
+  else
+    echo "$(realpath -m $PWD/$1)"
+  fi
+}
+DATA_PATH=$(get_real_path $1)
+ONNX_PATH=$(get_real_path $2)
+PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
+
+
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+export DEVICE_ID=$3
+export RANK_SIZE=1
+
+python ${PROJECT_DIR}/../eval_onnx.py --cfg ${PROJECT_DIR}/../config_gpu.yaml \
+  --data_dir ${DATA_PATH} \
+  --device_target GPU \
+  --onnx_path ${ONNX_PATH} > ${PROJECT_DIR}/../eval_onnx_gpu.log 2>&1 &
diff --git a/research/cv/PAMTRI/PoseEstNet/src/dataset/dataset.py b/research/cv/PAMTRI/PoseEstNet/src/dataset/dataset.py
index b8cf26b86cbdaac24ba4f278fcd1841ffde77eba..2c9aa8d2caf3850dedfc2cb508093a649f41144d 100644
--- a/research/cv/PAMTRI/PoseEstNet/src/dataset/dataset.py
+++ b/research/cv/PAMTRI/PoseEstNet/src/dataset/dataset.py
@@ -18,9 +18,8 @@ import copy
 import json
 from pathlib import Path
 import mindspore.dataset as ds
-
-import mindspore.dataset.vision.py_transforms as py_vision
-from mindspore.dataset.transforms.py_transforms import Compose
+import mindspore.dataset.vision as vision
+from mindspore.dataset.transforms.transforms import Compose
 from mindspore.communication.management import get_rank
 
 from .veri import VeRiDataset
@@ -44,8 +43,8 @@ def create_dataset(cfg, data_dir, is_train=True):
                 "joints", "joints_vis"], num_parallel_workers=1, shuffle=False, num_shards=1, shard_id=0)
 
     trans = Compose([
-        py_vision.ToTensor(),
-        py_vision.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), is_hwc=False)
+        vision.ToTensor(),
+        vision.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), is_hwc=False)
     ])
 
     dataset = dataset.map(operations=trans, input_columns="input", num_parallel_workers=8)
diff --git a/research/cv/PAMTRI/PoseEstNet/src/dataset/veri.py b/research/cv/PAMTRI/PoseEstNet/src/dataset/veri.py
index 9b67412f9753629a4144a2aea8a24fb3e7585577..26570e7f43f35fbd1fe6c32e61399dc53ab1226d 100644
--- a/research/cv/PAMTRI/PoseEstNet/src/dataset/veri.py
+++ b/research/cv/PAMTRI/PoseEstNet/src/dataset/veri.py
@@ -117,9 +117,9 @@ class VeRiDataset(JointsDataset):
             for row in reader:
                 joints = []
                 vis = []
-                top_lft = btm_rgt = [int(row[3]), int(row[4])]
+                top_lft = btm_rgt = [int(float(row[3])), int(float(row[4]))]
                 for j in range(36):
-                    joint = [int(row[j*3+3]), int(row[j*3+4]), int(row[j*3+5])]
+                    joint = [int(float(row[j*3+3])), int(float(row[j*3+4])), int(float(row[j*3+5]))]
                     joints.append(joint)
                     vis.append(joint[2])
                     if joint[0] < top_lft[0]:
diff --git a/research/cv/PAMTRI/PoseEstNet/src/utils/function.py b/research/cv/PAMTRI/PoseEstNet/src/utils/function.py
index 0f8c81958b49628226692c939081338a2cf5d776..a443edd9c9713afb3177a7053079848dd0979d52 100644
--- a/research/cv/PAMTRI/PoseEstNet/src/utils/function.py
+++ b/research/cv/PAMTRI/PoseEstNet/src/utils/function.py
@@ -136,6 +136,97 @@ def validate(config, val_loader, val_dataset, model, allImage):
 
     return perf_indicator
 
+def onnx_validate(config, val_loader, val_dataset, InferenceSession, input_name, is_train, allImage):
+    """onnx_validate"""
+
+    num_samples = len(val_dataset)
+    all_preds = np.zeros(
+        (num_samples, config.MODEL.NUM_JOINTS, 3),
+        dtype=np.float32
+    )
+    all_boxes = np.zeros((num_samples, 6))
+    image_path = []
+    filenames = []
+    imgnums = []
+    idx = 0
+    for i, data in enumerate(val_loader.create_dict_iterator()):
+        _input = data["input"]
+        center = data["center"]
+        scale = data["scale"]
+        score = data["score"]
+        image_label = data["image"]
+        joints = data["joints"]
+        joints_vis = data["joints_vis"]
+
+        joints = joints.asnumpy()
+        joints_vis = joints_vis.asnumpy()
+
+        image = []
+        for j in range(config.TEST.BATCH_SIZE):
+            image.append(allImage['{}'.format(image_label[j])])
+
+        outputs = InferenceSession.run(None, {input_name: _input.asnumpy()})[0]
+
+        if isinstance(outputs, list):
+            print("output is tuple")
+            output = outputs[-1]
+        else:
+            output = outputs
+        if config.TEST.FLIP_TEST:
+            # this part is ugly, because pytorch has not supported negative index
+            input_flipped = np.flip(_input.asnumpy(), 3)
+            outputs_flipped = InferenceSession.run(None, {input_name: input_flipped})[0]
+
+            if isinstance(outputs_flipped, list):
+                output_flipped = outputs_flipped[-1]
+            else:
+                output_flipped = outputs_flipped
+
+            output_flipped = flip_back(output_flipped, val_dataset.flip_pairs)
+
+            # feature is not aligned, shift flipped heatmap for higher accuracy
+            if config.TEST.SHIFT_HEATMAP: # true
+                output_flipped_copy = output_flipped
+                output_flipped[:, :, :, 1:] = output_flipped_copy[:, :, :, 0:-1]
+
+            output = (output + output_flipped) * 0.5
+
+        # measure accuracy and record loss
+        num_images = _input.shape[0]
+
+        c = center.asnumpy()
+        s = scale.asnumpy()
+        score = score.asnumpy()
+
+        output_copy = output
+
+        preds, maxvals = get_final_preds(config, output_copy, c, s)
+
+        all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
+        all_preds[idx:idx + num_images, :, 2:3] = maxvals
+        # double check this all_boxes parts
+        all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
+        all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
+        all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
+        all_boxes[idx:idx + num_images, 5] = score
+        image_path.extend(image)
+
+        idx += num_images
+
+        print('-------- Test: [{0}/{1}] ---------'.format(i, val_loader.get_dataset_size()))
+        name_values, perf_indicator = val_dataset.evaluate(
+            all_preds, '', all_boxes, image_path,
+            filenames, imgnums
+        )
+        model_name = config.MODEL.NAME
+        if isinstance(name_values, list):
+            for name_value in name_values:
+                _print_name_value(name_value, model_name)
+        else:
+            _print_name_value(name_values, model_name)
+
+    return perf_indicator
+
 def output_preds(config, val_loader, val_dataset, model, root, test_set, output_dir):
     """output_preds"""
     gt_file = os.path.join(root, 'label_{}.csv'.format(test_set))