diff --git a/research/cv/metric_learn/infer/convert/aipp.cfg b/research/cv/metric_learn/infer/convert/aipp.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..34348d30852f5a23679fa5879aeadd54383232d4
--- /dev/null
+++ b/research/cv/metric_learn/infer/convert/aipp.cfg
@@ -0,0 +1,17 @@
+aipp_op {
+    aipp_mode : static
+    input_format : RGB888_U8
+    related_input_rank : 0
+    csc_switch : false
+    crop: false
+    rbuv_swap_switch : false
+    mean_chn_0 : 0
+    mean_chn_1 : 0
+    mean_chn_2 : 0
+    min_chn_0 : 123.675
+    min_chn_1 : 116.28
+    min_chn_2 : 103.53
+    var_reci_chn_0 : 0.01712475383166366983474612552445
+    var_reci_chn_1 : 0.01750700280112044817927170868347
+    var_reci_chn_2 : 0.01742919389978213507625272331155
+}
diff --git a/research/cv/metric_learn/infer/convert/convert.sh b/research/cv/metric_learn/infer/convert/convert.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d6bbf907f9c5504faabf01098283352ed47845ca
--- /dev/null
+++ b/research/cv/metric_learn/infer/convert/convert.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+# bash convert.sh /home/data/xd_mindx/gxl/metric_learn/resnet50.air resnet50
+input_air_path=$1
+output_om_path=$2
+aipp_cfg=$3
+
+
+
+echo "Input AIR file path: ${input_air_path}"
+echo "Output OM file path: ${output_om_path}"
+
+atc --input_format=NCHW --framework=1 \
+    --model=${input_air_path} \
+    --output=${output_om_path} \
+    --soc_version=Ascend310 \
+    --disable_reuse_memory=0 \
+    --insert_op_conf=${aipp_cfg} \
+    --precision_mode=allow_mix_precision  \
+    --op_select_implmode=high_precision
\ No newline at end of file
diff --git a/research/cv/metric_learn/infer/docker_start_infer.sh b/research/cv/metric_learn/infer/docker_start_infer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2678ff3f94b2b0be1bb20af554f3787f58b70aef
--- /dev/null
+++ b/research/cv/metric_learn/infer/docker_start_infer.sh
@@ -0,0 +1,49 @@
+#!/usr/bin/env bash
+
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+docker_image=$1
+model_dir=$2
+
+
+function show_help() {
+    echo "Usage: docker_start.sh docker_image model_dir data_dir"
+}
+
+function param_check() {
+    if [ -z "${docker_image}" ]; then
+        echo "please input docker_image"
+        show_help
+        exit 1
+    fi
+
+    if [ -z "${model_dir}" ]; then
+        echo "please input model_dir"
+        show_help
+        exit 1
+    fi
+}
+
+param_check
+
+docker run -it -u root \
+  --device=/dev/davinci0 \
+  --device=/dev/davinci_manager \
+  --device=/dev/devmm_svm \
+  --device=/dev/hisi_hdc \
+  -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+  -v ${model_dir}:${model_dir} \
+  ${docker_image} \
+  /bin/bash
diff --git a/research/cv/metric_learn/infer/mxbase/CMakeLists.txt b/research/cv/metric_learn/infer/mxbase/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..178df116e3ed30f4fbc2fad2286bab289b8d5cc5
--- /dev/null
+++ b/research/cv/metric_learn/infer/mxbase/CMakeLists.txt
@@ -0,0 +1,55 @@
+cmake_minimum_required(VERSION 3.14.0)
+project(metric_learn)
+set(TARGET metric_learn)
+
+add_definitions(-DENABLE_DVPP_INTERFACE)
+add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
+add_definitions(-Dgoogle=mindxsdk_private)
+add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
+add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -pie)
+
+# Check environment variable
+if(NOT DEFINED ENV{ASCEND_HOME})
+    message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
+endif()
+if(NOT DEFINED ENV{ASCEND_VERSION})
+    message(WARNING "please define environment variable:ASCEND_VERSION")
+endif()
+if(NOT DEFINED ENV{ARCH_PATTERN})
+    message(WARNING "please define environment variable:ARCH_PATTERN")
+endif()
+set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
+set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
+
+set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
+set(MXBASE_INC $ENV{MX_SDK_HOME}/include)
+set(MXBASE_LIB_DIR $ENV{MX_SDK_HOME}/lib)
+set(MXBASE_POST_LIB_DIR $ENV{MX_SDK_HOME}/lib/modelpostprocessors)
+set(MXBASE_POST_PROCESS_DIR $ENV{MX_SDK_HOME}/include/MxBase/postprocess/include/)
+
+if(DEFINED ENV{MXSDK_OPENSOURCE_DIR})
+    set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
+else()
+    set(OPENSOURCE_DIR $ENV{MX_SDK_HOME}/opensource)
+endif()
+
+include_directories(${ACL_INC_DIR})
+include_directories(${MXBASE_INC})
+include_directories(${MXBASE_POST_PROCESS_DIR})
+include_directories(${OPENSOURCE_DIR}/include)
+include_directories(${OPENSOURCE_DIR}/include/opencv4)
+
+link_directories(${ACL_LIB_DIR})
+link_directories(${MXBASE_LIB_DIR})
+link_directories(${MXBASE_POST_LIB_DIR})
+link_directories(${OPENSOURCE_DIR}/lib)
+
+include_directories($ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/x86_64-linux/runtime/include)
+link_directories($ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/x86_64-linux/x86_64-linux/lib64/)
+include_directories($ENV{MX_SDK_HOME}/opensource/lib/glib-2.0/include)
+include_directories($ENV{ASCEND_HOME}/ascend-toolkit/5.0.4/x86_64-linux/runtime/include)
+
+add_executable(${TARGET} src/main.cpp src/MetricLearn.cpp)
+target_link_libraries(${TARGET} glog cpprest mxbase opencv_world stdc++fs)
+
+install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
diff --git a/research/cv/metric_learn/infer/mxbase/build.sh b/research/cv/metric_learn/infer/mxbase/build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..646de0a9995dacdb84e97fb05a5844aaaac72ffb
--- /dev/null
+++ b/research/cv/metric_learn/infer/mxbase/build.sh
@@ -0,0 +1,55 @@
+#!/bin/bash
+
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+path_cur=$(dirname $0)
+
+function check_env()
+{
+    # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
+    if [ ! "${ASCEND_VERSION}" ]; then
+        export ASCEND_VERSION=nnrt/latest
+        echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
+    else
+        echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
+    fi
+
+    if [ ! "${ARCH_PATTERN}" ]; then
+        # set ARCH_PATTERN to ./ when it was not specified by user
+        export ARCH_PATTERN=./
+        echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
+    else
+        echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
+    fi
+}
+
+function build_metric_learn()
+{
+    cd $path_cur
+    rm -rf build
+    mkdir -p build
+    cd build
+    cmake ..
+    make
+    ret=$?
+    if [ ${ret} -ne 0 ]; then
+        echo "Failed to build metric_learn."
+        exit ${ret}
+    fi
+    make install
+}
+
+check_env
+build_metric_learn
diff --git a/research/cv/metric_learn/infer/mxbase/src/MetricLearn.cpp b/research/cv/metric_learn/infer/mxbase/src/MetricLearn.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..44c71a88e4d62ebf941cd275f7a8469c5534af04
--- /dev/null
+++ b/research/cv/metric_learn/infer/mxbase/src/MetricLearn.cpp
@@ -0,0 +1,241 @@
+/*
+ * Copyright (c) 2022 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <dirent.h>
+#include <unistd.h>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+#include <iostream>
+#include <fstream>
+#include <algorithm>
+
+#include "MxBase/Log/Log.h"
+#include "MxBase/DeviceManager/DeviceManager.h"
+#include "MetricLearn.h"
+
+void getfilename(std::string *filename, std::string *filedir, const std::string &imgpath);
+
+APP_ERROR MetricLearn::Init(const InitParam &initParam) {
+    deviceId_ = initParam.deviceId;
+    APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
+    if (ret != APP_ERR_OK) {
+        LogError << "Init devices failed, ret=" << ret << ".";
+        return ret;
+    }
+    ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
+    if (ret != APP_ERR_OK) {
+        LogError << "Set context failed, ret=" << ret << ".";
+        return ret;
+    }
+    dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>();
+    ret = dvppWrapper_->Init();
+    if (ret != APP_ERR_OK) {
+        LogError << "DvppWrapper init failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
+    ret = model_->Init(initParam.modelPath, modelDesc_);
+    if (ret != APP_ERR_OK) {
+        LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    return APP_ERR_OK;
+}
+
+APP_ERROR MetricLearn::DeInit() {
+    dvppWrapper_->DeInit();
+    model_->DeInit();
+    MxBase::DeviceManager::GetInstance()->DestroyDevices();
+    return APP_ERR_OK;
+}
+
+APP_ERROR MetricLearn::ReadImage(const std::string &imgPath, cv::Mat &imageMat) {
+    imageMat = cv::imread(imgPath, cv::IMREAD_COLOR);
+    return APP_ERR_OK;
+}
+
+APP_ERROR MetricLearn::ResizeShortImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat) {
+    int height = srcImageMat.rows;
+    int width = srcImageMat.cols;
+    float percent = static_cast<float>(224.0) / std::min(height, width);
+    int INTER_LANCZOS4 = 4;
+    cv::resize(srcImageMat, dstImageMat, cv::Size(round(width * percent),
+    round(height * percent)), 0, 0, INTER_LANCZOS4);
+    return APP_ERR_OK;
+}
+
+APP_ERROR MetricLearn::ResizeImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat) {
+    static constexpr uint32_t resizeHeight = 224;
+    static constexpr uint32_t resizeWidth = 224;
+    cv::resize(srcImageMat, dstImageMat, cv::Size(resizeWidth, resizeHeight));
+    return APP_ERR_OK;
+}
+
+APP_ERROR MetricLearn::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase) {
+    const uint32_t dataSize =  imageMat.cols *  imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU;
+    MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
+    MxBase::MemoryData memoryDataSrc(imageMat.data, dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
+
+    APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
+    if (ret != APP_ERR_OK) {
+        LogError << GetError(ret) << "Memory malloc failed.";
+        return ret;
+    }
+
+    std::vector<uint32_t> shape = {imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU, static_cast<uint32_t>(imageMat.cols)};
+    tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32);
+    return APP_ERR_OK;
+}
+
+APP_ERROR MetricLearn::Inference(const std::vector<MxBase::TensorBase> &inputs, \
+        std::vector<MxBase::TensorBase> &outputs) {
+    auto dtypes = model_->GetOutputDataType();
+    for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
+        std::vector<uint32_t> shape = {};
+        for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
+            shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
+        }
+        MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
+        APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
+        if (ret != APP_ERR_OK) {
+            LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
+            return ret;
+        }
+        outputs.push_back(tensor);
+    }
+
+    MxBase::DynamicInfo dynamicInfo = {};
+    dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
+    dynamicInfo.batchSize = 1;
+
+    auto startTime = std::chrono::high_resolution_clock::now();
+
+    APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo);
+    auto endTime = std::chrono::high_resolution_clock::now();
+    double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();  // save time
+    inferCostTimeMilliSec += costMs;
+    if (ret != APP_ERR_OK) {
+        LogError << "ModelInference failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    return APP_ERR_OK;
+}
+
+APP_ERROR MetricLearn::SaveResult(MxBase::TensorBase *tensor, const std::string &resultpath) {
+    std::ofstream outfile(resultpath, std::ios::binary);
+    APP_ERROR ret = (*tensor).ToHost();
+    if (ret != APP_ERR_OK) {
+        LogError << "ToHost failed";
+        return ret;
+    }
+    if (outfile.fail()) {
+        LogError << "Failed to open result file: ";
+        return APP_ERR_COMM_FAILURE;
+    }
+    outfile.write(reinterpret_cast<char *>((*tensor).GetBuffer()), sizeof(float) * FEATURE_NUM);
+    outfile.close();
+
+    return APP_ERR_OK;
+}
+
+
+APP_ERROR MetricLearn::Process(const std::string &imgPath, const std::string &resultPath) {
+    cv::Mat imageMat;
+    APP_ERROR ret = ReadImage(imgPath, imageMat);
+    if (ret != APP_ERR_OK) {
+        LogError << "ReadImage failed, ret=" << ret << ".";
+        return ret;
+    }
+    ret = ResizeShortImage(imageMat, imageMat);
+    if (ret != APP_ERR_OK) {
+        LogError << "ResizeShortImage failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    ret = ResizeImage(imageMat, imageMat);
+    if (ret != APP_ERR_OK) {
+        LogError << "ResizeImage failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    std::vector<MxBase::TensorBase> inputs = {};
+    std::vector<MxBase::TensorBase> outputs = {};
+
+    MxBase::TensorBase tensorBase;
+    ret = CVMatToTensorBase(imageMat, tensorBase);
+    if (ret != APP_ERR_OK) {
+        LogError << "CVMatToTensorBase failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    inputs.push_back(tensorBase);
+
+    ret = Inference(inputs, outputs);
+    if (ret != APP_ERR_OK) {
+        LogError << "Inference failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    std::string filename = "";
+    std::string filedir = "";
+    getfilename(&filename, &filedir, imgPath);
+    std::string resultpath = resultPath + "/" + filename + ".bin";
+    std::string resultdir = resultPath + "/" + filedir;
+
+    DIR *dirPtr = opendir(resultdir.c_str());
+    if (dirPtr == nullptr) {
+        std::string sys = "mkdir -p "+ resultdir;
+        system(sys.c_str());
+    }
+
+    ret = SaveResult(&outputs[0], resultpath);
+    if (ret != APP_ERR_OK) {
+        LogError << "SaveResult failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    return APP_ERR_OK;
+}
+
+
+void getfilename(std::string *filename, std::string *filedir, const std::string &imgpath) {
+    int i, j = 0, count = 0;
+    for (i = imgpath.length() - 1; i >= 0; i--) {
+        // '/' is the delimiter between the file name and the parent directory in imgpath
+        if (imgpath[i] == '/') {
+            count++;
+            if (count == 2) {
+                j = i;
+                break;
+            }
+        }
+    }
+    // '.' is the delimiter between the file name and the file suffix
+    while (imgpath[++j] != '.') {
+        *filename += imgpath[j];
+    }
+
+    //'/' is the delimiter between the file name and the file directory
+    j = i;
+    while (imgpath[++j] != '/') {
+        *filedir += imgpath[j];
+    }
+}
diff --git a/research/cv/metric_learn/infer/mxbase/src/MetricLearn.h b/research/cv/metric_learn/infer/mxbase/src/MetricLearn.h
new file mode 100644
index 0000000000000000000000000000000000000000..40d1c814b008ad0f5dfdf4f95754ab2ece453659
--- /dev/null
+++ b/research/cv/metric_learn/infer/mxbase/src/MetricLearn.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2022 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MXBASE_METRIC_LEARN_H
+#define MXBASE_METRIC_LEARN_H
+
+#include <string>
+#include <vector>
+#include <memory>
+#include <opencv2/opencv.hpp>
+
+#include "MxBase/DvppWrapper/DvppWrapper.h"
+#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
+#include "MxBase/Tensor/TensorContext/TensorContext.h"
+#include "MxBase/DeviceManager/DeviceManager.h"
+
+struct InitParam {
+    uint32_t deviceId;
+    std::string modelPath;
+};
+
+
+class MetricLearn {
+ public:
+    static const int IMG_C = 3;
+    static const int IMG_H = 224;
+    static const int IMG_W = 224;
+    static const int FEATURE_NUM = 2048;
+    APP_ERROR Init(const InitParam &initParam);
+    APP_ERROR DeInit();
+    APP_ERROR ReadImage(const std::string &imgPath, cv::Mat &imageMat);
+    APP_ERROR ResizeImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat);
+    APP_ERROR ResizeShortImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat);
+    APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase);
+    APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs);
+    APP_ERROR Process(const std::string &imgPath, const std::string &resultPath);
+    // get infer time
+    double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
+
+ private:
+    APP_ERROR SaveResult(MxBase::TensorBase *tensor,
+                       const std::string &resultpath);
+
+ private:
+    std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_;
+    std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
+    MxBase::ModelDesc modelDesc_;
+    uint32_t deviceId_ = 0;
+    // infer time
+    double inferCostTimeMilliSec = 0.0;
+};
+
+#endif  // MXBASE_METRIC_LEARN_H
diff --git a/research/cv/metric_learn/infer/mxbase/src/main.cpp b/research/cv/metric_learn/infer/mxbase/src/main.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4766927d4b485d0b566cf007b122d2bd85d8fa77
--- /dev/null
+++ b/research/cv/metric_learn/infer/mxbase/src/main.cpp
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2022 Huawei Technologies Co., Ltd. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <dirent.h>
+#include <unistd.h>
+#include<fstream>
+#include<string>
+#include "MetricLearn.h"
+#include "MxBase/Log/Log.h"
+
+namespace {
+    const uint32_t DEVICE_ID = 0;
+    const char RESULT_PATH[] = "../data/preds/mxbase";
+    const char MODEL_PATH[] = "../convert/resnet50_acc74_aippnorm.om";
+}  // namespace
+
+
+int main(int argc, char* argv[]) {
+    if (argc <= 1) {
+        LogWarn << "Please input image path, such as './metric_learn image_dir'.";
+        return APP_ERR_OK;
+    }
+
+    InitParam initParam = {};
+    initParam.deviceId = DEVICE_ID;
+    initParam.modelPath = MODEL_PATH;
+    auto metric_learn = std::make_shared<MetricLearn>();
+    APP_ERROR ret = metric_learn->Init(initParam);
+    if (ret != APP_ERR_OK) {
+        LogError << "MetricLearn init failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    std::string imgPath = argv[1];
+    std::vector<std::string> imgFilePaths;
+
+    // read test_half.txt
+    DIR *dirPtr = opendir(imgPath.c_str());
+    if (dirPtr == nullptr) {
+        LogError << "opendir failed. dir:" << imgPath;
+        return APP_ERR_INTERNAL_ERROR;
+    }
+
+    std::fstream f(imgPath + "/test_half.txt");
+    std::string line;
+    while (getline(f, line)) {
+        int count = 0;
+        std::string filePath;
+        for (std::size_t i = 0; i < line.size(); i++) {
+            count++;
+            if (line[i] == ' ') {
+                filePath = line.substr(i - count + 1, count - 1);
+            }
+        }
+        imgFilePaths.emplace_back(imgPath + "/" + filePath);
+    }
+    f.close();
+
+    auto startTime = std::chrono::high_resolution_clock::now();
+    for (auto &imgFile : imgFilePaths) {
+        ret = metric_learn->Process(imgFile, RESULT_PATH);
+        if (ret != APP_ERR_OK) {
+            LogError << "MetricLearn process failed, ret=" << ret << ".";
+            metric_learn->DeInit();
+            return ret;
+        }
+    }
+    auto endTime = std::chrono::high_resolution_clock::now();
+    metric_learn->DeInit();
+    double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
+    double fps = 1000.0 * imgFilePaths.size() / metric_learn->GetInferCostMilliSec();
+    LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec";
+    return APP_ERR_OK;
+}
diff --git a/research/cv/metric_learn/infer/sdk/api/infer.py b/research/cv/metric_learn/infer/sdk/api/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1100e541d20b4dc7e89b6112a1a5dfb53160e1d
--- /dev/null
+++ b/research/cv/metric_learn/infer/sdk/api/infer.py
@@ -0,0 +1,128 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""sdk infer"""
+import json
+import logging
+import MxpiDataType_pb2 as MxpiDataType
+from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, MxProtobufIn, StringVector
+
+from config import config as cfg
+
+import cv2
+
+class SdkApi:
+    """sdk api"""
+    INFER_TIMEOUT = cfg.INFER_TIMEOUT
+    STREAM_NAME = cfg.STREAM_NAME
+
+    def __init__(self, pipeline_cfg):
+        self.pipeline_cfg = pipeline_cfg
+        self._stream_api = None
+        self._data_input = None
+        self._device_id = None
+
+    def init(self):
+        """sdk init """
+        with open(self.pipeline_cfg, 'r') as fp:
+            self._device_id = int(
+                json.loads(fp.read())[self.STREAM_NAME]["stream_config"]
+                ["deviceId"])
+            print("The device id: {}.".format(self._device_id))
+
+        # create api
+        self._stream_api = StreamManagerApi()
+
+        # init stream mgr
+        ret = self._stream_api.InitManager()
+        if ret != 0:
+            return False
+
+        # create streams
+        with open(self.pipeline_cfg, 'rb') as fp:
+            pipe_line = fp.read()
+
+        ret = self._stream_api.CreateMultipleStreams(pipe_line)
+        if ret != 0:
+            return False
+
+        self._data_input = MxDataInput()
+
+        return True
+
+    def __del__(self):
+        """del sdk"""
+        if not self._stream_api:
+            return
+
+        self._stream_api.DestroyAllStreams()
+
+    def send_data_input(self, stream_name, plugin_id, input_data):
+        """input data use SendData"""
+        data_input = MxDataInput()
+        encoded_image = cv2.imencode(".jpg", input_data)[1]
+        img_bytes = encoded_image.tobytes()
+        data_input.data = img_bytes
+        unique_id = self._stream_api.SendData(stream_name, plugin_id,
+                                              data_input)
+        if unique_id < 0:
+            logging.error("Fail to send data to stream.")
+            return False
+        return True
+
+    def _send_protobuf(self, stream_name, plugin_id, element_name, buf_type,
+                       pkg_list):
+        """input data use SendProtobuf"""
+        protobuf = MxProtobufIn()
+        protobuf.key = element_name.encode("utf-8")
+        protobuf.type = buf_type
+        protobuf.protobuf = pkg_list.SerializeToString()
+        protobuf_vec = InProtobufVector()
+        protobuf_vec.push_back(protobuf)
+        err_code = self._stream_api.SendProtobuf(stream_name, plugin_id,
+                                                 protobuf_vec)
+        if err_code != 0:
+            logging.error(
+                "Failed to send data to stream, stream_name(%s), plugin_id(%s), element_name(%s), "
+                "buf_type(%s), err_code(%s).", stream_name, plugin_id,
+                element_name, buf_type, err_code)
+            return False
+        return True
+
+    def send_img_input(self, stream_name, plugin_id, element_name, input_data,
+                       img_size):
+
+        """use cv input to sdk"""
+        vision_list = MxpiDataType.MxpiVisionList()
+        vision_vec = vision_list.visionVec.add()
+        vision_vec.visionInfo.format = 1
+        vision_vec.visionInfo.width = img_size[1]
+        vision_vec.visionInfo.height = img_size[0]
+        vision_vec.visionInfo.widthAligned = img_size[1]
+        vision_vec.visionInfo.heightAligned = img_size[0]
+        vision_vec.visionData.memType = 0
+        vision_vec.visionData.dataStr = input_data
+        vision_vec.visionData.dataSize = len(input_data)
+        buf_type = b"MxTools.MxpiVisionList"
+        return self._send_protobuf(stream_name, plugin_id, element_name, buf_type, vision_list)
+
+    def get_result(self, stream_name, out_plugin_id=0):
+        """get_result"""
+        key_vec = StringVector()
+        key_vec.push_back(b'mxpi_tensorinfer0')
+        infer_result = self._stream_api.GetProtobuf(
+            stream_name, out_plugin_id, key_vec)
+        result = MxpiDataType.MxpiTensorPackageList()
+        result.ParseFromString(infer_result[0].messageBuf)
+        return result.tensorPackageVec[0].tensorVec[0].dataStr
diff --git a/research/cv/metric_learn/infer/sdk/config/config.py b/research/cv/metric_learn/infer/sdk/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..67b508659fd1e3a4a8f78c241a8b3733f7fac197
--- /dev/null
+++ b/research/cv/metric_learn/infer/sdk/config/config.py
@@ -0,0 +1,25 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""config"""
+
+STREAM_NAME = "im_metric_learn"
+MODEL_WIDTH = 224
+MODEL_HEIGHT = 224
+
+INFER_TIMEOUT = 100000
+
+TENSOR_DTYPE_FLOAT32 = 0
+TENSOR_DTYPE_FLOAT16 = 1
+TENSOR_DTYPE_INT8 = 2
diff --git a/research/cv/metric_learn/infer/sdk/config/metric_learn.pipeline b/research/cv/metric_learn/infer/sdk/config/metric_learn.pipeline
new file mode 100644
index 0000000000000000000000000000000000000000..918b17bff93880b1d84ec0a9346d5016733a9e2b
--- /dev/null
+++ b/research/cv/metric_learn/infer/sdk/config/metric_learn.pipeline
@@ -0,0 +1,33 @@
+{
+    "im_metric_learn": {
+        "stream_config": {
+            "deviceId": "0"
+        },
+        "appsrc0": {
+            "props": {
+                "blocksize": "409600"
+            },
+            "factory": "appsrc",
+            "next": "mxpi_tensorinfer0"
+        },
+        "mxpi_tensorinfer0": {
+            "props": {
+                "dataSource": "appsrc0",
+                "modelPath": "../convert/resnet50_acc74_aippnorm.om",
+                "tensorFormat": "1"
+            },
+            "factory": "mxpi_modelinfer",
+            "next": "mxpi_dataserialize0"
+        },
+        "mxpi_dataserialize0": {
+            "props": {
+                "outputDataKeys": "mxpi_tensorinfer0"
+            },
+            "factory": "mxpi_dataserialize",
+            "next": "appsink0"
+        },
+        "appsink0": {
+            "factory": "appsink"
+        }
+    }
+}
\ No newline at end of file
diff --git a/research/cv/metric_learn/infer/sdk/main.py b/research/cv/metric_learn/infer/sdk/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4fdde513b42ad4281049573892b6e464f3d98f4
--- /dev/null
+++ b/research/cv/metric_learn/infer/sdk/main.py
@@ -0,0 +1,135 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""main"""
+
+import argparse
+import os
+import time
+
+import cv2
+from api.infer import SdkApi
+from config import config as cfg
+from StreamManagerApi import StreamManagerApi
+
+
+
+
+def parser_args():
+    """parser_args"""
+    parser = argparse.ArgumentParser(description="metric_learn inference")
+
+    parser.add_argument("--img_path",
+                        type=str,
+                        required=False,
+                        default="../../data/Stanford_Online_Products",
+                        help="image directory.")
+    parser.add_argument(
+        "--pipeline_path",
+        type=str,
+        required=False,
+        default="./config/metric_learn.pipeline",
+        help="image file path. The default is '/metric_learn/infer/sdk/config/metric_learn.pipeline'. ")
+    parser.add_argument(
+        "--model_type",
+        type=str,
+        required=False,
+        default="dvpp",
+        help=
+        "rgb: high-precision, dvpp: high performance. The default is 'dvpp'.")
+    parser.add_argument(
+        "--infer_mode",
+        type=str,
+        required=False,
+        default="infer",
+        help=
+        "infer:only infer, eval: accuracy evaluation. The default is 'infer'.")
+    parser.add_argument(
+        "--infer_result_dir",
+        type=str,
+        required=False,
+        default="../../data/infer_result",
+        help=
+        "cache dir of inference result. The default is '../data/infer_result'.")
+    arg = parser.parse_args()
+    return arg
+
+def process_img(img_file):
+    img0 = cv2.imread(img_file)
+    img = resize_i(img0, height=cfg.MODEL_HEIGHT, width=cfg.MODEL_WIDTH)
+    return img
+
+
+def resize_i(img, height=224, width=224):
+    """resize img"""
+    percent = float(height) / min(img.shape[0], img.shape[1])
+    resized_width = int(round(img.shape[1] * percent))
+    resized_height = int(round(img.shape[0] * percent))
+    img = cv2.resize(img, (resized_width, resized_height), interpolation=cv2.INTER_LANCZOS4)
+    shape = (224, 224)
+    resized = cv2.resize(img, shape, interpolation=cv2.INTER_LINEAR)
+    return resized
+
+def image_inference(pipeline_path, stream_name, data_dir, result_dir):
+    stream_manager_api = StreamManagerApi()
+    start_time = time.time()
+    sdk_api = SdkApi(pipeline_path)
+    if not sdk_api.init():
+        exit(-1)
+    print(stream_name)
+    if not os.path.exists(result_dir):
+        os.makedirs(result_dir)
+
+    img_data_plugin_id = 0
+
+    print("\nBegin to inference for {}.\n".format(data_dir))
+    TRAIN_LIST = "../data/Stanford_Online_Products/test_half.txt"
+    TRAIN_LISTS = open(TRAIN_LIST, "r").readlines()
+    max_len = 30003
+
+    # cal_acc
+    for _, item in enumerate(TRAIN_LISTS):
+        if _ >= max_len:
+            break
+        items = item.strip().split()
+        path = items[0]
+        father = path.split("/")[0]
+        father_path = os.path.join(result_dir, father)
+        if not os.path.exists(father_path):
+            os.makedirs(father_path)
+        file_path = os.path.join(data_dir, path)
+        save_bin_path = os.path.join(result_dir, "{}.bin".format(path.split(".")[0]))
+        img_np = process_img(file_path)
+        img_shape = img_np.shape
+        # SDK
+        sdk_api.send_img_input(stream_name,
+                               img_data_plugin_id, "appsrc0",
+                               img_np.tobytes(), img_shape)
+
+        result = sdk_api.get_result(stream_name)
+        with open(save_bin_path, "wb") as fp:
+            fp.write(result)
+            print(
+                "End-2end inference, file_name:", file_path,
+                "\n"
+            )
+    end_time = time.time()
+    print("cost: ", end_time-start_time, "s")
+    print("fps: ", 30003.0/(end_time-start_time), "imgs/sec")
+    stream_manager_api.DestroyAllStreams()
+
+if __name__ == "__main__":
+    args = parser_args()
+    image_inference(args.pipeline_path, cfg.STREAM_NAME.encode("utf-8"), args.img_path,
+                    args.infer_result_dir)
diff --git a/research/cv/metric_learn/infer/sdk/run.sh b/research/cv/metric_learn/infer/sdk/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..cfc5654cc0bff4f3a05dd0b4faa03f541c893b3c
--- /dev/null
+++ b/research/cv/metric_learn/infer/sdk/run.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+img_path=$1
+infer_result_dir=$2
+
+python3 main.py $img_path  $infer_result_dir
+exit 0
\ No newline at end of file
diff --git a/research/cv/metric_learn/infer/util/eval.py b/research/cv/metric_learn/infer/util/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e956f1eff2de3a557c50fbd8261c936c3676edba
--- /dev/null
+++ b/research/cv/metric_learn/infer/util/eval.py
@@ -0,0 +1,83 @@
+import os
+import argparse
+import multiprocessing as mp
+import numpy as np
+
+parser = argparse.ArgumentParser(description="metric_learn inference")
+parser.add_argument("--data_dir", type=str, required=True, help="data files directory.")
+parser.add_argument("--result_dir", type=str, required=True, help="result files directory.")
+args = parser.parse_args()
+
+def functtt(param):
+    """ fun """
+    sharedlist, s, e = param
+    fea, a, b = sharedlist
+    ab = np.dot(fea[s:e], fea.T)
+    d = a[s:e] + b - 2 * ab
+    for i in range(e - s):
+        d[i][s + i] += 1e8
+    sorted_index = np.argsort(d, 1)[:, :10]
+    return sorted_index
+
+
+def recall_topk_parallel(fea, lab, k):
+    """ recall_topk_parallel """
+    fea = np.array(fea)
+    fea = fea.reshape(fea.shape[0], -1)
+    n = np.sqrt(np.sum(fea ** 2, 1)).reshape(-1, 1)
+    fea = fea / n
+    a = np.sum(fea ** 2, 1).reshape(-1, 1)
+    b = a.T
+    sharedlist = mp.Manager().list()
+    sharedlist.append(fea)
+    sharedlist.append(a)
+    sharedlist.append(b)
+    N = 100
+    L = fea.shape[0] / N
+    params = []
+    for i in range(N):
+        if i == N - 1:
+            s, e = int(i * L), int(fea.shape[0])
+        else:
+            s, e = int(i * L), int((i + 1) * L)
+        params.append([sharedlist, s, e])
+    pool = mp.Pool(processes=4)
+    sorted_index_list = pool.map(functtt, params)
+    pool.close()
+    pool.join()
+    sorted_index = np.vstack(sorted_index_list)
+    res = 0
+    for i in range(len(fea)):
+        for j in range(k):
+            pred = lab[sorted_index[i][j]]
+            if lab[i] == pred:
+                res += 1.0
+                break
+    res = res / len(fea)
+    return res
+
+
+def eval_mxbase(data_dir, result_dir):
+    print("\nBegin to eval \n")
+    TRAIN_LIST = os.path.join(data_dir, "test_half.txt")
+    TRAIN_LISTS = open(TRAIN_LIST, "r").readlines()
+
+    # cal_acc
+    result_shape = (1, 2048)
+    f, l = [], []
+    for _, item in enumerate(TRAIN_LISTS):
+        items = item.strip().split()
+        path = items[0]
+        result_bin_path = os.path.join(result_dir, "{}.bin".format(path.split(".")[0]))
+        result = np.fromfile(result_bin_path, dtype=np.float32).reshape(result_shape)
+        gt = int(items[1]) - 1
+        f.append(result)
+        l.append(gt)
+    f = np.vstack(f)
+    l = np.hstack(l)
+    recall = recall_topk_parallel(f, l, k=1)
+    print("eval_recall:", recall)
+
+if __name__ == '__main__':
+    eval_mxbase(args.data_dir, args.result_dir)
+    
\ No newline at end of file
diff --git a/research/cv/metric_learn/modelart/train_start.py b/research/cv/metric_learn/modelart/train_start.py
new file mode 100644
index 0000000000000000000000000000000000000000..72d7e2b3dd10b7011ec3c9438a6cc3bea49bf9b1
--- /dev/null
+++ b/research/cv/metric_learn/modelart/train_start.py
@@ -0,0 +1,248 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""train resnet."""
+import os
+import time
+import argparse
+import ast
+import numpy as np
+from mindspore import context
+from mindspore import Tensor
+from mindspore.nn.optim.momentum import Momentum
+from mindspore.train.model import Model
+from mindspore.context import ParallelMode
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
+from mindspore.train.loss_scale_manager import FixedLossScaleManager
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore import export
+from mindspore.common import set_seed
+from mindspore.communication.management import init
+from mindspore.train.callback import Callback
+
+from src.loss import Softmaxloss
+from src.loss import Tripletloss
+from src.loss import Quadrupletloss
+from src.lr_generator import get_lr
+from src.resnet import resnet50
+from src.utility import GetDatasetGenerator_softmax, GetDatasetGenerator_triplet, GetDatasetGenerator_quadruplet
+
+set_seed(1)
+
+parser = argparse.ArgumentParser(description='Image classification')
+# modelarts parameter
+parser.add_argument('--train_url', type=str, default=None, help='Train output path')
+parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
+parser.add_argument('--ckpt_url', type=str, default=None, help='Pretrained ckpt path')
+parser.add_argument('--checkpoint_name', type=str, default='PreMetric.ckpt', help='Checkpoint file')
+parser.add_argument('--loss_name', type=str, default='softmax',
+                    help='loss name: softmax(pretrained) triplet quadruplet')
+
+# Ascend parameter
+parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
+parser.add_argument('--ckpt_path', type=str, default=None, help='ckpt path name')
+parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
+parser.add_argument('--device_id', type=int, default=0, help='Device id')
+parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run distribute')
+
+# export
+parser.add_argument('--export_batch_size', type=int, default=1, help="export batch size")
+parser.add_argument('--export_file_name', type=str, default='resnet50', help="export file name.")
+parser.add_argument('--export_width', type=int, default=224, help='export width')
+parser.add_argument('--export_height', type=int, default=224, help='export height')
+parser.add_argument('--export_file_format', type=str, choices=['AIR', 'ONNX', 'MINDIR'],
+                    default='AIR', help='export file format')
+args_opt = parser.parse_args()
+
+class Monitor(Callback):
+    """Monitor"""
+    def __init__(self, lr_init=None):
+        super(Monitor, self).__init__()
+        self.lr_init = lr_init
+        self.lr_init_len = len(lr_init)
+    def epoch_begin(self, run_context):
+        self.losses = []
+        self.epoch_time = time.time()
+        dataset_generator.__init__(data_dir=DATA_DIR, train_list=TRAIN_LIST)
+    def epoch_end(self, run_context):
+        cb_params = run_context.original_args()
+        epoch_mseconds = (time.time() - self.epoch_time) * 1000
+        per_step_mseconds = epoch_mseconds / cb_params.batch_num
+        print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:8.5f}"
+              .format(epoch_mseconds, per_step_mseconds, np.mean(self.losses)))
+        print('batch_size:', config.batch_size, 'epochs_size:', config.epoch_size,
+              'lr_model:', config.lr_decay_mode, 'lr:', config.lr_max, 'step_size:', step_size)
+    def step_begin(self, run_context):
+        self.step_time = time.time()
+    def step_end(self, run_context):
+        """step_end"""
+        cb_params = run_context.original_args()
+        step_mseconds = (time.time() - self.step_time) * 1000
+        step_loss = cb_params.net_outputs
+        if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
+            step_loss = step_loss[0]
+        if isinstance(step_loss, Tensor):
+            step_loss = np.mean(step_loss.asnumpy())
+        self.losses.append(step_loss)
+        cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
+        print("epochs:  [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:8.5f}/{:8.5f}], time:[{:5.3f}], lr:[{:8.5f}]".format(
+            cb_params.cur_epoch_num, config.epoch_size, cur_step_in_epoch, cb_params.batch_num, step_loss,
+            np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
+
+if __name__ == '__main__':
+    if args_opt.loss_name == 'softmax':
+        from src.config import config0 as config
+        from src.dataset import create_dataset0 as create_dataset
+    elif args_opt.loss_name == 'triplet':
+        from src.config import config1 as config
+        from src.dataset import create_dataset1 as create_dataset
+    elif args_opt.loss_name == 'quadruplet':
+        from src.config import config2 as config
+        from src.dataset import create_dataset1 as create_dataset
+    else:
+        print('loss no')
+    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
+    # init distributed
+    if args_opt.run_modelarts:
+        import moxing as mox
+        device_id = int(os.getenv('DEVICE_ID'))
+        device_num = int(os.getenv('RANK_SIZE'))
+        context.set_context(device_id=device_id)
+        local_data_url = '/cache/data'
+        local_ckpt_url = '/cache/ckpt'
+        local_train_url = '/cache/train'
+        if device_num > 1:
+            init()
+            context.set_auto_parallel_context(device_num=device_num,
+                                              parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True)
+            local_data_url = os.path.join(local_data_url, str(device_id))
+            local_ckpt_url = os.path.join(local_ckpt_url, str(device_id))
+        mox.file.copy_parallel(args_opt.data_url, local_data_url)
+        mox.file.copy_parallel(args_opt.ckpt_url, local_ckpt_url)
+        DATA_DIR = local_data_url + '/'
+    else:
+        if args_opt.run_distribute:
+            device_id = int(os.getenv('DEVICE_ID'))
+            device_num = int(os.getenv('RANK_SIZE'))
+            context.set_context(device_id=device_id)
+            init()
+            context.reset_auto_parallel_context()
+            context.set_auto_parallel_context(device_num=device_num,
+                                              parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True)
+        else:
+            context.set_context(device_id=args_opt.device_id)
+            device_num = 1
+            device_id = args_opt.device_id
+        DATA_DIR = args_opt.dataset_path + '/'
+
+    # create dataset
+    TRAIN_LIST = DATA_DIR + 'train_half.txt'
+    if args_opt.loss_name == 'softmax':
+        dataset_generator = GetDatasetGenerator_softmax(data_dir=DATA_DIR,
+                                                        train_list=TRAIN_LIST)
+    elif args_opt.loss_name == 'triplet':
+        dataset_generator = GetDatasetGenerator_triplet(data_dir=DATA_DIR,
+                                                        train_list=TRAIN_LIST)
+    elif args_opt.loss_name == 'quadruplet':
+        dataset_generator = GetDatasetGenerator_quadruplet(data_dir=DATA_DIR,
+                                                           train_list=TRAIN_LIST)
+    else:
+        print('loss no')
+    dataset = create_dataset(dataset_generator, do_train=True, batch_size=config.batch_size,
+                             device_num=device_num, rank_id=device_id)
+    step_size = dataset.get_dataset_size()
+
+    # define net
+    net = resnet50(class_num=config.class_num)
+
+    # init weight
+    if args_opt.run_modelarts:
+        checkpoint_path = os.path.join(local_ckpt_url, args_opt.checkpoint_name)
+    else:
+        checkpoint_path = args_opt.ckpt_path
+    param_dict = load_checkpoint(checkpoint_path)
+    load_param_into_net(net.backbone, param_dict)
+
+    # init lr
+    lr = Tensor(get_lr(lr_init=config.lr_init,
+                       lr_end=config.lr_end,
+                       lr_max=config.lr_max,
+                       warmup_epochs=config.warmup_epochs,
+                       total_epochs=config.epoch_size,
+                       steps_per_epoch=step_size,
+                       lr_decay_mode=config.lr_decay_mode))
+
+    # define opt
+    opt = Momentum(params=net.trainable_params(),
+                   learning_rate=lr,
+                   momentum=config.momentum,
+                   weight_decay=config.weight_decay,
+                   loss_scale=config.loss_scale)
+
+    # define loss, model
+    if args_opt.loss_name == 'softmax':
+        loss = Softmaxloss(sparse=True, smooth_factor=0.1, num_classes=config.class_num)
+    elif args_opt.loss_name == 'triplet':
+        loss = Tripletloss(margin=0.1)
+    elif args_opt.loss_name == 'quadruplet':
+        loss = Quadrupletloss(train_batch_size=config.batch_size, samples_each_class=2, margin=0.1)
+    else:
+        print('loss no')
+
+    loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
+
+    if args_opt.loss_name == 'softmax':
+        model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=None,
+                      amp_level='O3', keep_batchnorm_fp32=False)
+    else:
+        model = Model(net.backbone, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=None,
+                      amp_level='O3', keep_batchnorm_fp32=False)
+
+    #define callback
+    cb = []
+    ckpt_cb = None
+    if config.save_checkpoint and (device_num == 1 or device_id == 0):
+        config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
+                                     keep_checkpoint_max=config.keep_checkpoint_max)
+
+        check_name = 'ResNet50_' + args_opt.loss_name
+        if args_opt.run_modelarts:
+            ckpt_cb = ModelCheckpoint(prefix=check_name, directory=local_train_url, config=config_ck)
+        else:
+            save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_'+ str(device_id) +'/')
+            ckpt_cb = ModelCheckpoint(prefix=check_name, directory=save_ckpt_path, config=config_ck)
+
+        cb += [ckpt_cb]
+    cb += [Monitor(lr_init=lr.asnumpy())]
+
+    # train model
+    model.train(config.epoch_size - config.pretrain_epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
+    ckpt_name = ckpt_cb.latest_ckpt_file_name
+
+
+    # frozen to air file
+    net = resnet50(config.class_num)
+
+    param_dict = load_checkpoint(ckpt_name)
+    load_param_into_net(net.backbone, param_dict)
+
+    input_arr = Tensor(np.zeros([args_opt.export_batch_size, 3, args_opt.export_height, args_opt.export_width],
+                                np.float32))
+    export(net.backbone, input_arr, file_name='{0}/{1}'.format(local_train_url, args_opt.export_file_name),
+           file_format=args_opt.export_file_format)
+
+    if args_opt.run_modelarts and config.save_checkpoint and (device_num == 1 or device_id == 0):
+        mox.file.copy_parallel(src_url=local_train_url, dst_url=args_opt.train_url)
diff --git a/research/cv/metric_learn/scripts/docker_start.sh b/research/cv/metric_learn/scripts/docker_start.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6b452b3d3f4b1596501ed63d6047052717115c0f
--- /dev/null
+++ b/research/cv/metric_learn/scripts/docker_start.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+# Copyright (c) 2022. Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+docker_image=$1
+data_dir=$2
+model_dir=$3
+
+docker run -it -u root --ipc=host \
+               --device=/dev/davinci0 \
+               --device=/dev/davinci1 \
+               --device=/dev/davinci2 \
+               --device=/dev/davinci3 \
+               --device=/dev/davinci4 \
+               --device=/dev/davinci5 \
+               --device=/dev/davinci6 \
+               --device=/dev/davinci7 \
+               --device=/dev/davinci_manager \
+               --device=/dev/devmm_svm \
+               --device=/dev/hisi_hdc \
+               --privileged \
+               -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+               -v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons \
+               -v ${data_dir}:${data_dir} \
+               -v ${model_dir}:${model_dir} \
+               -v /root/ascend/log:/root/ascend/log ${docker_image} /bin/bash
\ No newline at end of file