diff --git a/research/cv/ReIDStrongBaseline/README.md b/research/cv/ReIDStrongBaseline/README.md
index f13e1e3e96f5c83abaa8ed05e48d65b1cf61eef7..363c9c8b48a1f20c6f652de99658aefce4af1b85 100644
--- a/research/cv/ReIDStrongBaseline/README.md
+++ b/research/cv/ReIDStrongBaseline/README.md
@@ -84,8 +84,8 @@ DukeMTMC-reID
 
 ## [Environment Requirements](#contents)
 
-- Hardware(GPU)
-    - Prepare hardware environment with GPU processor.
+- Hardware(Ascend/GPU)
+    - Prepare hardware environment with Ascend or GPU processor.
 - Framework
     - [MindSpore](https://gitee.com/mindspore/mindspore)
 - For more information, please check the resources below:
@@ -94,14 +94,24 @@ DukeMTMC-reID
 
 ## [Quick Start](#contents)
 
-### [Running scripts](#contents)
-
 Model uses pre-trained backbone ResNet50 trained on ImageNet2012. [Link](https://download.mindspore.cn/model_zoo/r1.3/resnet50_ascend_v130_imagenet2012_official_cv_bs256_top1acc76.97__top5acc_93.44/)
 
 After dataset preparation, you can start training and evaluation as follows:
 
 (Note that you must specify dataset path in `configs/market1501_config.yml`)
 
+### [Running on Ascend](#contents)
+
+```bash
+# run distributed training example
+bash scripts/run_distribute_train_ascend.sh ./configs/market1501_config.yml /path/to/dataset/ /path/to/output/ /path/to/pretrained_resnet50.ckpt rank_table_8pcs.json 8
+
+# run evaluation example
+bash scripts/run_eval_ascend.sh ./configs/market1501_config.yml /your/path/checkpoint_file /path/to/dataset/
+```
+
+### [Running on GPU](#contents)
+
 ```bash
 # run training example
 bash scripts/run_standalone_train_gpu.sh  ./configs/market1501_config.yml 0 /path/to/dataset/ /path/to/output/ /path/to/pretrained_resnet50.pth
@@ -119,6 +129,7 @@ bash scripts/run_eval_gpu.sh ./configs/market1501_config.yml /your/path/checkpoi
 
 ```text
 ReIDStrongBaseline
+├── ascend310_infer  # application for 310 inference
 ├── configs
 │   ├── dukemtmc_config.yml  # Training/evaluation config on DukeMTMC dataset
 │   └── market1501_config.yml  # Training/evaluation config on Market1501 dataset
@@ -129,8 +140,11 @@ ReIDStrongBaseline
 │   ├── local_adapter.py # Environment variables parser
 │   └── moxing_adapter.py # Moxing adapter for ModelArts
 ├── scripts
+│   ├── run_distribute_train_Ascend.sh  # Start multi Ascend training
 │   ├── run_distribute_train_gpu.sh  # Start multi GPU training
+│   ├── run_eval_Ascend.sh  # Start single Ascend evaluation
 │   ├── run_eval_gpu.sh # Start single GPU evaluation
+│   ├── run_infer_310.sh  # Start 310 inference
 │   └── run_standalone_train_gpu.sh  # Start single GPU training
 ├── src
 │   ├── callbacks.py # Logging to file callbacks
@@ -154,6 +168,7 @@ ReIDStrongBaseline
 │   └── triplet_loss.py  # Triplet  Loss definition
 ├── eval.py # Evaluate the network
 ├── export.py # Export the network
+├── postprogress.py # post process for 310 inference
 ├── train.py # Train the network
 ├── requirements.txt # Required libraries
 └── README.md
@@ -176,7 +191,7 @@ usage: train.py  --config_path CONFIG_PATH [--distribute DISTRIBUTE] [--device_t
 options:
     --config_path              path to .yml config file
     --distribute               pre_training by several devices: "true"(training by more than 1 device) | "false", default is "false"
-    --device_target            target device ("GPU" | "CPU")
+    --device_target            target device ("Ascend" | "GPU" | "CPU")
     --max_epoch                epoch size: N, default is 120
     --start_decay_epoch        epoch to decay, default is '40,70'
     --ids_per_batch            number of person in batch, default is 16 (8 for distributed)
@@ -225,6 +240,19 @@ Parameters for learning rate:
 - Set options in `configs/market1501_config.yaml` or `configs/dukemtmc_config.yaml`,
   including paths, learning rate and network hyperparameters.
 
+### Usage
+
+#### on Ascend
+
+- Run `run_distribute_train_Ascend.sh` for distributed training of ReID Strong Baseline model.
+- The `RANK_TABLE_FILE` is placed under `scripts/`
+
+    ```bash
+    bash scripts/run_distribute_train_Ascend.sh CONFIG_PATH DATA_DIR OUTPUT_PATH PRETRAINED_RESNET50 RANK_TABLE_FILE RANK_SIZE
+    ```
+
+#### on GPU
+
 - Run `run_standalone_train_gpu.sh` for non-distributed training of  model.
 
     ```bash
@@ -241,6 +269,18 @@ Parameters for learning rate:
 
 - Set options in `market1501_config.yaml`.
 
+### Usage
+
+#### on Ascend
+
+- Run `bash scripts/run_eval_Ascend.sh` for evaluation of ReID Strong Baseline model.
+
+    ```bash
+    bash scripts/run_eval_Ascend.sh CONFIG_PATH CKPT_PATH DATA_DIR
+    ```
+
+#### on GPU
+
 - Run `bash scripts/run_eval_gpu.sh` for evaluation of ReID Strong Baseline model.
 
     ```bash
@@ -274,61 +314,61 @@ Inference result will be shown in the terminal
 
 #### Market1501 Training Performance
 
-| Parameters                 | GPU                                                            |
-| -------------------------- | -------------------------------------------------------------- |
-| Resource                   | 8x Tesla V100-PCIE 32G                                         |
-| uploaded Date              | 03/11/2022 (month/day/year)                                    |
-| MindSpore Version          | 1.5.0                                                          |
-| Dataset                    | Market1501                                                     |
-| Training Parameters        | max_epoch=120, ids_per_batch=8, start_decay_epoch=151, lr_init=0.0014, lr_cri=1.0, decay_epochs='40,70' |
-| Optimizer                  | Adam, SGD                                                      |
-| Loss Function              | Triplet, Smooth Identity, Center                               |
-| Speed                      | 182ms/step (8pcs)                                              |
-| Loss                       | 0.24                                                           |
-| Params (M)                 | 25.1                                                           |
-| Checkpoint for inference   | 319Mb (.ckpt file)                                             |
-| Scripts                    | [ReID Strong Baseline scripts](scripts)                        |
+| Parameters                 | Ascend                      | GPU                                                       |
+| -------------------------- | --------------------------- |---------------------------------------------------------- |
+| Resource                   | 8x Ascend 910 32G      |8x Tesla V100-PCIE 32G                                          |
+| uploaded Date              | 04/21/2022 (month/day/year) |03/11/2022 (month/day/year)                                |
+| MindSpore Version          | 1.3.0                       |1.5.0                                                      |
+| Dataset                    | Market1501                  |Market1501                                                 |
+| Training Parameters        | max_epoch=120, ids_per_batch=8, start_decay_epoch=151, lr_init=0.0014, lr_cri=1.0, decay_epochs='40,70' |max_epoch=120, ids_per_batch=8, start_decay_epoch=151, lr_init=0.0014, lr_cri=1.0, decay_epochs='40,70' |
+| Optimizer                  | Adam, SGD                   |Adam, SGD                                                  |
+| Loss Function              | Triplet, Smooth Identity, Center |Triplet, Smooth Identity, Center                      |
+| Speed                      | 536ms/step (8pcs) |182ms/step (8pcs)                                                    |
+| Loss                       | 0.28                        |0.24                                                       |
+| Params (M)                 | 24.8                        |25.1                                                      |
+| Checkpoint for inference   | 305Mb (.ckpt file)          |319Mb (.ckpt file)                                        |
+| Scripts                    | [ReID Strong Baseline scripts](scripts) |[ReID Strong Baseline scripts](scripts)        |
 
 #### Market1501 Evaluation Performance
 
-| Parameters          | GPU                         |
-| ------------------- | --------------------------- |
-| Resource            | 1x Tesla V100-PCIE 32G      |
-| Uploaded Date       | 03/11/2022 (month/day/year) |
-| MindSpore Version   | 1.5.0                       |
-| Dataset             | Market1501                  |
-| batch_size          | 32                          |
-| outputs             | mAP, Rank-1                 |
-| Accuracy            | mAP: 86.99%, rank-1: 94.48% |
+| Parameters          | Ascend                        | GPU                         |
+| ------------------- | ----------------------------- | --------------------------- |
+| Resource            | 1x Ascend 910 32G             | 1x Tesla V100-PCIE 32G      |
+| Uploaded Date       | 04/21/2022 (month/day/year)   | 03/11/2022 (month/day/year) |
+| MindSpore Version   | 1.3.0                         | 1.5.0                       |
+| Dataset             | Market1501                    | Market1501                  |
+| batch_size          | 32                            | 32                          |
+| outputs             | mAP, Rank-1                   | mAP, Rank-1                 |
+| Accuracy            | mAP: 86.85%, rank-1: 94.36%   | mAP: 86.99%, rank-1: 94.48% |
 
 #### DukeMTMC-reID Training Performance
 
-| Parameters                 | GPU                                                            |
-| -------------------------- | -------------------------------------------------------------- |
-| Resource                   | 8x Tesla V100-PCIE 32G                                         |
-| uploaded Date              | 03/11/2022 (month/day/year)                                    |
-| MindSpore Version          | 1.5.0                                                          |
-| Dataset                    | DukeMTMC-reID                                                  |
-| Training Parameters        | max_epoch=120, ids_per_batch=8, start_decay_epoch=151, lr_init=0.0014, lr_cri=1.0, decay_epochs='40,70' |
-| Optimizer                  | Adam, SGD                                                      |
-| Loss Function              | Triplet, Smooth Identity, Center                               |
-| Speed                      | 180ms/step (8pcs)                                              |
-| Loss                       | 0.24                                                           |
-| Params (M)                 | 25.1                                                           |
-| Checkpoint for inference   | 319Mb (.ckpt file)                                             |
-| Scripts                    | [ReID Strong Baseline scripts](scripts)                        |
+| Parameters                 |  Ascend                     | GPU                                         |
+| -------------------------- |---------------------------- | ------------------------------------------- |
+| Resource                   | 8x Ascend 910 32G           | 8x Tesla V100-PCIE 32G                      |
+| uploaded Date              | 04/21/2022 (month/day/year) | 03/11/2022 (month/day/year)                 |
+| MindSpore Version          | 1.3.0                       | 1.5.0                                       |
+| Dataset                    | DukeMTMC-reID               | DukeMTMC-reID                               |
+| Training Parameters        | max_epoch=120, ids_per_batch=8, start_decay_epoch=151, lr_init=0.0014, lr_cri=1.0, decay_epochs='40,70'| max_epoch=120, ids_per_batch=8, start_decay_epoch=151, lr_init=0.0014, lr_cri=1.0, decay_epochs='40,70' |
+| Optimizer                  | Adam, SGD                   | Adam, SGD                                   |
+| Loss Function              | Triplet, Smooth Identity, Center | Triplet, Smooth Identity, Center       |
+| Speed                      | 524ms/step (8pcs)           | 180ms/step (8pcs)                           |
+| Loss                       | 0.27                        | 0.24                                        |
+| Params (M)                 | 24.8                        |25.1                                                      |
+| Checkpoint for inference   | 302Mb (.ckpt file)          | 319Mb (.ckpt file)                          |
+| Scripts                    | [ReID Strong Baseline scripts](scripts)| [ReID Strong Baseline scripts](scripts) |
 
 #### DukeMTMC-reID Evaluation Performance
 
-| Parameters          | GPU                         |
-| ------------------- | --------------------------- |
-| Resource            | 1x Tesla V100-PCIE 32G      |
-| Uploaded Date       | 03/11/2022 (month/day/year) |
-| MindSpore Version   | 1.5.0                       |
-| Dataset             | DukeMTMC-reID               |
-| batch_size          | 32                          |
-| outputs             | mAP, Rank-1                 |
-| Accuracy            | mAP: 76.68%, rank-1: 87.34% |
+| Parameters          | Ascend                      | GPU                         |
+| ------------------- | --------------------------- | --------------------------- |
+| Resource            | 1x Ascend 910 32G           | 1x Tesla V100-PCIE 32G      |
+| Uploaded Date       | 04/21/2022 (month/day/year) | 03/11/2022 (month/day/year) |
+| MindSpore Version   | 1.3.0                       | 1.5.0                       |
+| Dataset             | DukeMTMC-reID               | DukeMTMC-reID               |
+| batch_size          | 32                          | 32                          |
+| outputs             | mAP, Rank-1                 | mAP, Rank-1                 |
+| Accuracy            | mAP: 76.58%, rank-1: 87.43% | mAP: 76.68%, rank-1: 87.34% |
 
 ## [Description of Random Situation](#contents)
 
diff --git a/research/cv/ReIDStrongBaseline/ascend310_infer/CMakeLists.txt b/research/cv/ReIDStrongBaseline/ascend310_infer/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ee3c85447340e0449ff2b70ed24f60a17e07b2b6
--- /dev/null
+++ b/research/cv/ReIDStrongBaseline/ascend310_infer/CMakeLists.txt
@@ -0,0 +1,14 @@
+cmake_minimum_required(VERSION 3.14.1)
+project(Ascend310Infer)
+add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
+set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
+option(MINDSPORE_PATH "mindspore install path" "")
+include_directories(${MINDSPORE_PATH})
+include_directories(${MINDSPORE_PATH}/include)
+include_directories(${PROJECT_SRC_ROOT})
+find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
+file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
+
+add_executable(main src/main.cc src/utils.cc)
+target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
diff --git a/research/cv/ReIDStrongBaseline/ascend310_infer/build.sh b/research/cv/ReIDStrongBaseline/ascend310_infer/build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a6a9ed6f01913ebb27018273a2412fece3dc3d3c
--- /dev/null
+++ b/research/cv/ReIDStrongBaseline/ascend310_infer/build.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ -d out ]; then
+    rm -rf out
+fi
+
+mkdir out
+cd out || exit
+
+cmake .. \
+    -DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
+make
diff --git a/research/cv/ReIDStrongBaseline/ascend310_infer/inc/utils.h b/research/cv/ReIDStrongBaseline/ascend310_infer/inc/utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..a128cdfec8f6438f3b00a5be0a826cbd86f1c365
--- /dev/null
+++ b/research/cv/ReIDStrongBaseline/ascend310_infer/inc/utils.h
@@ -0,0 +1,32 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_INFERENCE_UTILS_H_
+#define MINDSPORE_INFERENCE_UTILS_H_
+
+#include <sys/stat.h>
+#include <dirent.h>
+#include <vector>
+#include <string>
+#include <memory>
+#include "include/api/types.h"
+
+std::vector<std::string> GetAllFiles(std::string_view dirName, const std::string& inputtype);
+DIR *OpenDir(std::string_view dirName);
+std::string RealPath(std::string_view path);
+mindspore::MSTensor ReadFileToTensor(const std::string &file);
+int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
+#endif
diff --git a/research/cv/ReIDStrongBaseline/ascend310_infer/src/main.cc b/research/cv/ReIDStrongBaseline/ascend310_infer/src/main.cc
new file mode 100644
index 0000000000000000000000000000000000000000..b6e42d60cb815267d8383900f372e08ea232e494
--- /dev/null
+++ b/research/cv/ReIDStrongBaseline/ascend310_infer/src/main.cc
@@ -0,0 +1,183 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <sys/time.h>
+#include <gflags/gflags.h>
+#include <dirent.h>
+#include <iostream>
+#include <string>
+#include <algorithm>
+#include <iosfwd>
+#include <vector>
+#include <fstream>
+#include <sstream>
+
+#include "../inc/utils.h"
+#include "include/dataset/execute.h"
+#include "include/dataset/transforms.h"
+#include "include/dataset/vision.h"
+#include "include/dataset/vision_ascend.h"
+#include "include/api/types.h"
+#include "include/api/model.h"
+#include "include/api/serialization.h"
+#include "include/api/context.h"
+
+using mindspore::Serialization;
+using mindspore::Model;
+using mindspore::Context;
+using mindspore::Status;
+using mindspore::ModelType;
+using mindspore::Graph;
+using mindspore::GraphCell;
+using mindspore::kSuccess;
+using mindspore::MSTensor;
+using mindspore::DataType;
+using mindspore::dataset::Execute;
+using mindspore::dataset::TensorTransform;
+using mindspore::dataset::vision::Decode;
+using mindspore::dataset::vision::Resize;
+using mindspore::dataset::vision::Normalize;
+using mindspore::dataset::vision::HWC2CHW;
+using mindspore::dataset::transforms::TypeCast;
+
+DEFINE_string(model_path, "", "model path");
+DEFINE_string(dataset_path, "", "dataset path");
+DEFINE_string(input_type, "", "gallery or query");
+
+DEFINE_int32(input_width, 128, "input width");
+DEFINE_int32(input_height, 256, "input height");
+DEFINE_int32(device_id, 0, "device id");
+DEFINE_string(precision_mode, "allow_fp32_to_fp16", "precision mode");
+DEFINE_string(op_select_impl_mode, "", "op select impl mode");
+DEFINE_string(aipp_path, "./aipp.cfg", "aipp path");
+DEFINE_string(device_target, "Ascend310", "device target");
+
+size_t GetFeature(MSTensor data);
+
+size_t GetFeature(MSTensor data) {
+    std::string homePath = "./" + FLAGS_input_type + "_result_Files";
+    std::string outFileName = homePath + "/feature_data.txt";
+    float max_value = -1;
+    size_t max_idx = 0;
+    const float *p = reinterpret_cast<const float *>(data.MutableData());
+    std::ofstream outfile;
+    outfile.open(outFileName, std::ios::app);
+
+    for (size_t i = 0; i < data.DataSize() / sizeof(float); ++i) {
+        outfile << p[i] << std::endl;
+        if (p[i] > max_value) {
+            max_value = p[i];
+            max_idx = i;
+        }
+    }
+    return max_idx;
+}
+
+int main(int argc, char **argv) {
+    gflags::ParseCommandLineFlags(&argc, &argv, true);
+    if (RealPath(FLAGS_model_path).empty()) {
+      std::cout << "Invalid model" << std::endl;
+      return 1;
+    }
+
+    auto context = std::make_shared<Context>();
+    auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
+    ascend310_info->SetDeviceID(FLAGS_device_id);
+    context->MutableDeviceInfo().push_back(ascend310_info);
+
+    Graph graph;
+    Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph);
+    if (ret != kSuccess) {
+        std::cout << "Load model failed." << std::endl;
+        return 1;
+    }
+
+    Model model;
+    ret = model.Build(GraphCell(graph), context);
+    if (ret != kSuccess) {
+        std::cout << "ERROR: Build failed." << std::endl;
+        return 1;
+    }
+
+    std::vector<MSTensor> modelInputs = model.GetInputs();
+
+    auto all_files = GetAllFiles(FLAGS_dataset_path, FLAGS_input_type);
+    if (all_files.empty()) {
+        std::cout << "ERROR: no input data." << std::endl;
+        return 1;
+    }
+
+    auto decode = Decode();
+    auto resize = Resize({256, 128});
+    auto normalize = Normalize({123.675, 116.28, 103.53}, {58.395, 57.12, 57.375});
+    auto hwc2chw = HWC2CHW();
+    auto typeCast = TypeCast(DataType::kNumberTypeFloat16);
+
+    mindspore::dataset::Execute transformDecode(decode);
+    mindspore::dataset::Execute transform({resize, normalize, hwc2chw});
+    mindspore::dataset::Execute transformCast(typeCast);
+
+    std::map<double, double> costTime_map;
+
+    size_t size = all_files.size();
+    for (size_t i = 0; i < size; ++i) {
+        struct timeval start;
+        struct timeval end;
+        double startTime_ms;
+        double endTime_ms;
+        std::vector<MSTensor> inputs;
+        std::vector<MSTensor> outputs;
+
+        std::cout << "Start predict input files:" << all_files[i] << std::endl;
+        mindspore::MSTensor image =  ReadFileToTensor(all_files[i]);
+
+        transformDecode(image, &image);
+        std::vector<int64_t> shape = image.Shape();
+        transform(image, &image);
+
+        inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(),
+                            image.Data().get(), image.DataSize());
+        gettimeofday(&start, NULL);
+        model.Predict(inputs, &outputs);
+        gettimeofday(&end, NULL);
+
+        GetFeature(outputs[0]);
+
+        startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
+        endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
+        costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms));
+    }
+    double average = 0.0;
+    int infer_cnt = 0;
+
+    for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
+        double diff = 0.0;
+        diff = iter->second - iter->first;
+        average += diff;
+        infer_cnt++;
+    }
+
+    average = average / infer_cnt;
+
+    std::stringstream timeCost;
+    timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << infer_cnt << std::endl;
+    std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl;
+    std::string file_name = "./time_Result" + std::string("/test_perform_static.txt");
+    std::ofstream file_stream(file_name.c_str(), std::ios::trunc);
+    file_stream << timeCost.str();
+    file_stream.close();
+    costTime_map.clear();
+  return 0;
+}
diff --git a/research/cv/ReIDStrongBaseline/ascend310_infer/src/utils.cc b/research/cv/ReIDStrongBaseline/ascend310_infer/src/utils.cc
new file mode 100644
index 0000000000000000000000000000000000000000..60154d7e34ff2aee5d6edde4e428889cfab74868
--- /dev/null
+++ b/research/cv/ReIDStrongBaseline/ascend310_infer/src/utils.cc
@@ -0,0 +1,199 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "inc/utils.h"
+
+#include <fstream>
+#include <algorithm>
+#include <iostream>
+#include <string>
+#include <vector>
+using mindspore::MSTensor;
+using mindspore::DataType;
+
+std::vector<std::string> split(const std::string &s, const std::string &separator) {
+    std::vector<std::string> result;
+    std::size_t i = 0;
+
+    while (i != s.size()) {
+        int flag = 0;
+        while (i != s.size() && flag == 0) {
+            flag = 1;
+            for (std::size_t x = 0; x < separator.size(); ++x) {
+                if (s[i] == separator[x]) {
+                    ++i;
+                    flag = 0;
+                    break;
+                }
+            }
+        }
+        flag = 0;
+        std::size_t j = i;
+        while (j != s.size() && flag == 0) {
+            for (std::size_t x = 0; x < separator.size(); ++x) {
+                if (s[j] == separator[x]) {
+                    flag = 1;
+                    break;
+                }
+            }
+            if (flag == 0) {
+                ++j;
+            }
+        }
+        if (i != j) {
+            result.push_back(s.substr(i, j-i));
+            i = j;
+        }
+    }
+    return result;
+}
+
+std::vector<std::string> GetAllFiles(std::string_view dirName, const std::string& inputtype) {
+    struct dirent *filename;
+    DIR *dir = OpenDir(dirName);
+    if (dir == nullptr) {
+        return {};
+    }
+    std::vector<std::string> res;
+    std::string homePath = "./" + inputtype + "_result_Files";
+    std::string pidfilename = homePath + "/savepid.txt";
+    std::string camidfilename = homePath + "/savecamid.txt";
+    std::ofstream outfilepid;
+    std::ofstream outfilecmd;
+    outfilepid.open(pidfilename, std::ios::app);
+    outfilecmd.open(camidfilename, std::ios::app);
+    std::string pid;
+    std::string camid;
+
+    while ((filename = readdir(dir)) != nullptr) {
+        std::string dName = std::string(filename->d_name);
+        if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
+            continue;
+        }
+        if (dName == "Thumbs.db") {
+            continue;
+        }
+        std::vector<std::string> v = split(dName, "_");
+        pid = v[0];
+        camid = v[1].substr(1, 1);
+        if (pid == "-1") {
+            continue;
+        }
+        outfilepid << pid << std::endl;
+        outfilecmd << camid << std::endl;
+        res.emplace_back(std::string(dirName) + "/" + filename->d_name);
+    }
+    for (auto &f : res) {
+        std::cout << "image file: " << f << std::endl;
+    }
+    return res;
+}
+
+int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
+    std::string homePath = "./result_Files";
+    const int INVALID_POINTER = -1;
+    const int ERROR = -2;
+    for (size_t i = 0; i < outputs.size(); ++i) {
+        size_t outputSize;
+        std::shared_ptr<const void> netOutput = outputs[i].Data();
+        outputSize = outputs[i].DataSize();
+
+        int pos = imageFile.rfind('/');
+        std::string fileName(imageFile, pos + 1);
+        fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".txt");
+        std::string outFileName = homePath + "/" + fileName;
+        FILE *outputFile = fopen(outFileName.c_str(), "wb");
+        if (outputFile == nullptr) {
+            std::cout << "open result file " << outFileName << " failed" << std::endl;
+            return INVALID_POINTER;
+        }
+        size_t size = fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
+
+        if (size != outputSize) {
+            fclose(outputFile);
+            outputFile = nullptr;
+            std::cout << "write result file " << outFileName << " failed, write size[" << size <<
+                "] is smaller than output size[" << outputSize << "], maybe the disk is full." << std::endl;
+            return ERROR;
+        }
+        fclose(outputFile);
+        outputFile = nullptr;
+    }
+    return 0;
+}
+
+mindspore::MSTensor ReadFileToTensor(const std::string &file) {
+  if (file.empty()) {
+    std::cout << "Pointer file is nullptr" << std::endl;
+    return mindspore::MSTensor();
+  }
+
+  std::ifstream ifs(file);
+  if (!ifs.good()) {
+    std::cout << "File: " << file << " is not exist" << std::endl;
+    return mindspore::MSTensor();
+  }
+
+  if (!ifs.is_open()) {
+    std::cout << "File: " << file << "open failed" << std::endl;
+    return mindspore::MSTensor();
+  }
+
+  ifs.seekg(0, std::ios::end);
+  size_t size = ifs.tellg();
+  mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
+
+  ifs.seekg(0, std::ios::beg);
+  ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
+  ifs.close();
+
+  return buffer;
+}
+
+DIR *OpenDir(std::string_view dirName) {
+    if (dirName.empty()) {
+        std::cout << " dirName is null ! " << std::endl;
+        return nullptr;
+    }
+    std::string realPath = RealPath(dirName);
+    struct stat s;
+    lstat(realPath.c_str(), &s);
+    if (!S_ISDIR(s.st_mode)) {
+        std::cout << "dirName is not a valid directory !" << std::endl;
+        return nullptr;
+    }
+    DIR *dir = opendir(realPath.c_str());
+    if (dir == nullptr) {
+        std::cout << "Can not open dir " << dirName << std::endl;
+        return nullptr;
+    }
+    std::cout << "Successfully opened the dir " << dirName << std::endl;
+    return dir;
+}
+
+std::string RealPath(std::string_view path) {
+    char realPathMem[PATH_MAX] = {0};
+    char *realPathRet = nullptr;
+    realPathRet = realpath(path.data(), realPathMem);
+    if (realPathRet == nullptr) {
+        std::cout << "File: " << path << " is not exist.";
+        return "";
+    }
+
+    std::string realPath(realPathMem);
+    std::cout << path << " realpath is: " << realPath << std::endl;
+    return realPath;
+}
diff --git a/research/cv/ReIDStrongBaseline/configs/dukemtmc_config.yml b/research/cv/ReIDStrongBaseline/configs/dukemtmc_config.yml
index 6c417f66d4862d9ddc6222ff8b112949266c3538..aac10450eaaf8c5093c28299dcda8b55ff246514 100644
--- a/research/cv/ReIDStrongBaseline/configs/dukemtmc_config.yml
+++ b/research/cv/ReIDStrongBaseline/configs/dukemtmc_config.yml
@@ -8,9 +8,9 @@ checkpoint_url: ""
 data_path: "/cache/data"
 output_path: "/cache/train"
 load_path: "/cache/checkpoint_path"
-device_target: "GPU"
+device_target: "Ascend" # ['Ascend', 'GPU']
 need_modelarts_dataset_unzip: False
-modelarts_dataset_unzip_name: "market1501"
+modelarts_dataset_unzip_name: "DukeMTMC-reID"
 
 # ==============================================================================
 # options
diff --git a/research/cv/ReIDStrongBaseline/configs/market1501_config.yml b/research/cv/ReIDStrongBaseline/configs/market1501_config.yml
index cb61f20a7b531efee8ac4ee0bbd8c5c94d2c4f7e..dc98208d45fa67777d8265f4d806b655fd32490a 100644
--- a/research/cv/ReIDStrongBaseline/configs/market1501_config.yml
+++ b/research/cv/ReIDStrongBaseline/configs/market1501_config.yml
@@ -8,7 +8,7 @@ checkpoint_url: ""
 data_path: "/cache/data"
 output_path: "/cache/train"
 load_path: "/cache/checkpoint_path"
-device_target: "GPU"
+device_target: "Ascend" # ['Ascend', 'GPU']
 need_modelarts_dataset_unzip: False
 modelarts_dataset_unzip_name: "market1501"
 
diff --git a/research/cv/ReIDStrongBaseline/postprocess.py b/research/cv/ReIDStrongBaseline/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec79f032cd33b51e6705d28a5b6b5eb4c6db2fe2
--- /dev/null
+++ b/research/cv/ReIDStrongBaseline/postprocess.py
@@ -0,0 +1,157 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""post process for 310 inference"""
+
+import argparse
+import numpy as np
+
+parser = argparse.ArgumentParser(description='Train StrongBaseline')
+parser.add_argument('--q_feature', type=str, default='')
+parser.add_argument('--q_pid', type=str, default='')
+parser.add_argument('--q_camid', type=str, default='')
+parser.add_argument('--g_feature', type=str, default='')
+parser.add_argument('--g_pid', type=str, default='')
+parser.add_argument('--g_camid', type=str, default='')
+
+parser.add_argument('--train_url', type=str, default='log')
+parser.add_argument('--reranking', type=lambda x: x.lower() == 'true', default=True, help='re_rank')
+parser.add_argument('--test_distance', type=str, default='global_local', help='test distance type')
+parser.add_argument('--unaligned', action='store_true')
+args = parser.parse_args()
+
+def get_query(feature_file, pid_file, camid_file):
+    """get query data"""
+    qf, q_pids, q_camids = [], [], []
+    openfilef = open(feature_file, 'r')
+    for line in openfilef.readlines():
+        temp = float(line)
+        qf.append(temp)
+    openfilep = open(pid_file, 'r')
+    for line in openfilep.readlines():
+        temp = int(line)
+        q_pids.append(temp)
+    openfilec = open(camid_file, 'r')
+    for line in openfilec.readlines():
+        temp = int(line)
+        q_camids.append(temp)
+    qf = np.array(qf)
+    qf = qf.reshape(-1, 2048)
+    #add norm
+    c = np.linalg.norm(qf, ord=2, axis=1, keepdims=True)
+    qf = qf/c
+    q_pids = np.asarray(q_pids)
+    q_camids = np.asarray(q_camids)
+    print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.shape[0], qf.shape[1]))
+    return qf, q_pids, q_camids
+
+def get_gallery(feature_file, pid_file, camid_file):
+    """get gallery data"""
+    gf, g_pids, g_camids = [], [], []
+    openfilef = open(feature_file, 'r')
+    for line in openfilef.readlines():
+        temp = float(line)
+        gf.append(temp)
+    openfilep = open(pid_file, 'r')
+    for line in openfilep.readlines():
+        temp = int(line)
+        g_pids.append(temp)
+    openfilec = open(camid_file, 'r')
+    for line in openfilec.readlines():
+        temp = int(line)
+        g_camids.append(temp)
+    gf = np.array(gf)
+    gf = gf.reshape(-1, 2048)
+    #add norm
+    c = np.linalg.norm(gf, ord=2, axis=1, keepdims=True)
+    gf = gf/c
+    g_pids = np.asarray(g_pids)
+    g_camids = np.asarray(g_camids)
+    print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.shape[0], gf.shape[1]))
+    return gf, g_pids, g_camids
+
+
+def eval_func(dist_mat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
+    """Evaluation with market1501 metric
+        Key: for each query identity, its gallery images from the same camera view are discarded.
+        """
+    num_q, num_g = dist_mat.shape
+    if num_g < max_rank:
+        max_rank = num_g
+        print("Note: number of gallery samples is quite small, got {}".format(num_g))
+    indices = np.argsort(dist_mat, axis=1)
+    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
+
+    # compute cmc curve for each query
+    print("==============================compute cmc curve for each query")
+    all_cmc = []
+    all_AP = []
+    num_valid_q = 0.  # number of valid query
+    for q_idx in range(num_q):
+        # get query pid and camid
+        q_pid = q_pids[q_idx]
+        q_camid = q_camids[q_idx]
+
+        # remove gallery samples that have the same pid and camid with query
+        order = indices[q_idx]
+        remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
+        keep = np.invert(remove)
+
+        # compute cmc curve
+        # binary vector, positions with value 1 are correct matches
+        orig_cmc = matches[q_idx][keep]
+        if not np.any(orig_cmc):
+            # this condition is true when query identity does not appear in gallery
+            continue
+
+        cmc = orig_cmc.cumsum()
+        cmc[cmc > 1] = 1
+
+        all_cmc.append(cmc[:max_rank])
+        num_valid_q += 1.
+
+        # compute average precision
+        # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
+        num_rel = orig_cmc.sum()
+        tmp_cmc = orig_cmc.cumsum()
+        tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
+        tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
+        AP = tmp_cmc.sum() / num_rel
+        all_AP.append(AP)
+
+    assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
+
+    all_cmc = np.asarray(all_cmc).astype(np.float32)
+    all_cmc = all_cmc.sum(0) / num_valid_q
+    mAP = np.mean(all_AP)
+
+    return all_cmc, mAP
+
+
+if __name__ == '__main__':
+    sqf, sq_pids, sq_camids = get_query(args.q_feature, args.q_pid, args.q_camid)
+    sgf, sg_pids, sg_camids = get_gallery(args.g_feature, args.g_pid, args.g_camid)
+    m, n = sqf.shape[0], sgf.shape[0]
+    distmat = np.power(sqf, 2).sum(axis=1, keepdims=True).repeat(n, axis=1) + \
+        np.power(sgf, 2).sum(axis=1, keepdims=True).repeat(m, axis=1).T
+    distmat = 1 * distmat - 2 * np.dot(sqf, sgf.transpose())
+    r, m_ap = eval_func(distmat, sq_pids, sg_pids, sq_camids, sg_camids)
+    s = 'After BNNeck'
+    print(f'[INFO] {s}')
+    print(
+        '[INFO] mAP: {:.4f} rank1: {:.4f} rank3: {:.4f} rank5: {:.4f} rank10: {:.4f}'.format(
+            m_ap,
+            r[0], r[2], r[4], r[9],
+        )
+    )
diff --git a/research/cv/ReIDStrongBaseline/scripts/run_distribute_train_ascend.sh b/research/cv/ReIDStrongBaseline/scripts/run_distribute_train_ascend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f701d720f4cf897ca0c605105dc206054443fed8
--- /dev/null
+++ b/research/cv/ReIDStrongBaseline/scripts/run_distribute_train_ascend.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "=============================================================================================================="
+echo "Please run the script at the diractory same with train.py: "
+echo "bash scripts/run_distribute_train_ascend.sh CONFIG_PATH DATA_DIR OUTPUT_PATH PRETRAINED_RESNET50 RANK_TABLE_FILE RANK_SIZE"
+echo "for example: bash scripts/run_distribute_train_ascend.sh ./configs/market1501_config.yml /path/to/dataset/ /path/to/output/ /path/to/resnet50_ascend_v130_imagenet2012_official_cv_bs256_top1acc76.97__top5acc_93.44.ckpt rank_table_8pcs.json 8"
+echo "It is better to use the absolute path."
+echo "=============================================================================================================="
+set -e
+
+config_path=$1
+DATA_DIR=$2
+OUTPUT_PATH=$3
+PRETRAINED_RESNET50=$4
+rank_table_8pcs_file=$5
+
+EXEC_PATH=$(pwd)
+
+export RANK_TABLE_FILE=${EXEC_PATH}/scripts/$rank_table_8pcs_file
+export RANK_SIZE=$6
+
+for((i=0;i<${RANK_SIZE};i++))
+do
+    rm -rf device$i
+    mkdir device$i
+    cp ./train.py ./device$i
+    cd ./device$i
+    export DEVICE_ID=$i
+    export RANK_ID=$i
+    echo "start training for device $i"
+    env > env$i.log
+    python ${EXEC_PATH}/train.py  \
+      --config_path="$config_path" \
+      --device_target="Ascend" \
+      --data_dir="$DATA_DIR" \
+      --ckpt_path="$OUTPUT_PATH" \
+      --train_log_path="$OUTPUT_PATH" \
+      --pre_trained_backbone="$PRETRAINED_RESNET50" \
+      --lr_init=0.00140 \
+      --lr_cri=1.0 \
+      --ids_per_batch=8 \
+      --is_distributed=1 > output.train.log 2>&1 &
+    cd ../
+done
diff --git a/research/cv/ReIDStrongBaseline/scripts/run_eval_ascend.sh b/research/cv/ReIDStrongBaseline/scripts/run_eval_ascend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f04532d23bf53072bcf46ac4d91b43c7a2c5796b
--- /dev/null
+++ b/research/cv/ReIDStrongBaseline/scripts/run_eval_ascend.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 3 ] ; then
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "bash scripts/run_eval_gpu.sh CONFIG_PATH CKPT_PATH DATA_DIR"
+echo "for example: bash scripts/run_eval_gpu.sh ./configs/market1501_config.yml /your/path/checkpoint_file /path/to/dataset/"
+echo "It is better to use absolute path."
+echo "=============================================================================================================="
+exit 1;
+fi
+
+get_real_path(){
+    if [ "${1:0:1}" == "/" ]; then
+        echo "$1"
+    else
+        echo "$(realpath -m $PWD/$1)"
+    fi
+}
+config_path=$(get_real_path "$1")
+
+PATH1=$(get_real_path "$2")
+echo "$PATH1"
+DATA_DIR=$(get_real_path "$3")
+
+python eval.py  \
+    --config_path="$config_path" \
+    --device_target="Ascend" \
+    --data_dir="$DATA_DIR" \
+    --eval_model="$PATH1" > output.eval.log 2>&1 &
diff --git a/research/cv/ReIDStrongBaseline/scripts/run_infer_310.sh b/research/cv/ReIDStrongBaseline/scripts/run_infer_310.sh
new file mode 100644
index 0000000000000000000000000000000000000000..85ae7b405548c09162d00ab06e25fa6d852a100c
--- /dev/null
+++ b/research/cv/ReIDStrongBaseline/scripts/run_infer_310.sh
@@ -0,0 +1,107 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+model_path=$1
+query_datapath=$2
+gallery_datapath=$3
+echo "-----------------------------------"
+echo "mindir name: "$model_path
+echo "query dataset path: "$query_datapath
+echo "gallery dataset path: "$gallery_datapath
+echo "-----------------------------------"
+
+export ASCEND_HOME=/usr/local/Ascend
+if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
+    export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
+    export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
+    export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
+    export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
+    export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
+else
+    export ASCEND_HOME=/usr/local/Ascend/latest/
+    export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
+    export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
+    export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
+    export ASCEND_OPP_PATH=$ASCEND_HOME/opp
+fi
+
+function compile_app()
+{
+    cd ascend310_infer || exit
+    bash build.sh &> build.log
+}
+
+function query_infer()
+{
+    if [ -d query_result_Files ]; then
+        rm -rf ./query_result_Files
+    fi
+    mkdir query_result_Files
+    input_type="query"
+    ../ascend310_infer/out/main --model_path=$model_path --dataset_path=$query_datapath --input_type=$input_type &> query_infer.log
+}
+
+function gallery_infer()
+{
+    cd ../ascend310_infer || exit
+    if [ -d gallery_result_Files ]; then
+        rm -rf ./gallery_result_Files
+    fi
+    mkdir gallery_result_Files
+    input_type="gallery"
+    ../ascend310_infer/out/main --model_path=$model_path --dataset_path=$gallery_datapath --input_type=$input_type &> gallery_infer.log
+}
+
+function cal_acc()
+{
+    cd ..
+    qf=ascend310_infer/query_result_Files/feature_data.txt
+    qp=ascend310_infer/query_result_Files/savepid.txt
+    qc=ascend310_infer/query_result_Files/savecamid.txt
+    gf=ascend310_infer/gallery_result_Files/feature_data.txt
+    gp=ascend310_infer/gallery_result_Files/savepid.txt
+    gc=ascend310_infer/gallery_result_Files/savecamid.txt
+    python postprocess.py --q_feature=$qf --q_pid=$qp --q_camid=$qc --g_feature=$gf --g_pid=$gp --g_camid=$gc &> acc.log
+}
+
+compile_app
+if [ $? -ne 0 ]; then
+    echo "compile app code failed"
+    exit 1
+fi
+echo "compile app code success"
+
+query_infer
+if [ $? -ne 0 ]; then
+    echo "execute query inference failed"
+    exit 1
+fi
+echo "execute query inference success"
+
+gallery_infer
+if [ $? -ne 0 ]; then
+    echo "execute gallery inference failed"
+    exit 1
+fi
+echo "execute gallery inference success"
+
+cal_acc
+if [ $? -ne 0 ]; then
+    echo "calculate accuracy failed"
+    exit 1
+fi
+echo "calculate accuracy success"
+echo "ascend 310 infer success"
diff --git a/research/cv/ReIDStrongBaseline/src/dataset.py b/research/cv/ReIDStrongBaseline/src/dataset.py
index 6955853c0844210dcbda48e1f118db222d907b79..88fd13504e7a7720e124b2348719cda2dadbf7f4 100644
--- a/research/cv/ReIDStrongBaseline/src/dataset.py
+++ b/research/cv/ReIDStrongBaseline/src/dataset.py
@@ -186,6 +186,7 @@ def create_dataset(
         column_names=['image', 'label'],
         sampler=sampler,
         shuffle=shuffle,
+        num_parallel_workers=num_parallel_workers,
     )
 
     dataset = dataset.map(
diff --git a/research/cv/ReIDStrongBaseline/train.py b/research/cv/ReIDStrongBaseline/train.py
index f7bd3d93faa98fc02b2ca0a62bc1f8942a15288a..39c385d7398d64a4459fa89aaebd94a119bc8582 100644
--- a/research/cv/ReIDStrongBaseline/train.py
+++ b/research/cv/ReIDStrongBaseline/train.py
@@ -128,16 +128,23 @@ def _prepare_configuration():
                 if not config.need_modelarts_dataset_unzip:
                     init()
 
-        config.group_size = get_group_size()
-        config.rank = get_rank()
-
-        device_num = config.group_size
-        context.reset_auto_parallel_context()
-        context.set_auto_parallel_context(
-            device_num=device_num,
-            parallel_mode=ParallelMode.DATA_PARALLEL,
-            gradients_mean=True,
-        )
+            config.group_size = get_group_size()
+            config.rank = get_rank()
+
+            device_num = config.group_size
+            context.reset_auto_parallel_context()
+            context.set_auto_parallel_context(
+                device_num=device_num,
+                parallel_mode=ParallelMode.DATA_PARALLEL,
+                gradients_mean=True,
+            )
+        if config.device_target == "Ascend":
+            device_id = int(os.getenv('DEVICE_ID'))
+            context.set_context(device_id=device_id)
+            init()
+            context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
+            config.rank = get_rank()
+            config.group_size = get_group_size()
     else:
         config.group_size = 1
         config.rank = 0