diff --git a/research/cv/ProtoNet/default_config.yaml b/research/cv/ProtoNet/default_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5023b22655b3d5059e747c88e573f3a08e6c0cb4
--- /dev/null
+++ b/research/cv/ProtoNet/default_config.yaml
@@ -0,0 +1,38 @@
+# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
+enable_modelarts: True
+device_target: Ascend
+data_path: "/cache/data"
+data_url: ""
+train_url: ""
+output_path: "/cache/out"
+
+# ==============================================================================
+# Training options
+learning_rate: 0.001
+epoch_size: 2
+save_checkpoint_steps: 10
+keep_checkpoint_max: 5
+batch_size: 32
+image_height: 28
+image_width: 28
+air_name: 'protonet.air'
+
+
+# Model Description
+model_name: protonet
+file_name: 'protonet'
+file_format: 'AIR'
+
+
+---
+# Config description for each option
+enable_modelarts: 'Whether training on modelarts, default: False'
+data_url: 'Dataset url for obs'
+train_url: 'Training output url for obs'
+data_path: 'Dataset path for local'
+output_path: 'Training output path for local'
+
+device_target: 'Target device type'
+
+---
+device_target: ['Ascend', 'GPU', 'CPU']
\ No newline at end of file
diff --git a/research/cv/ProtoNet/infer/README.md b/research/cv/ProtoNet/infer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..10addbd0cb08715a63c1f7fdd9c2033d26ee59ed
--- /dev/null
+++ b/research/cv/ProtoNet/infer/README.md
@@ -0,0 +1,143 @@
+# Description
+
+This README file is to show how to inference Protonet by mxBase and mindX-SDK
+
+# Environment Preparation
+
+- (ALL Required) You should put `infer` folder into **Server Environment**  not must in subfolder of mxVision.
+- (Convert Required) You must configure the environment variables correctly like [here](https://support.huaweicloud.com/atctool-cann503alpha1infer/atlasatc_16_0004.html), if you use docker you may skip this step.
+- (mxBase mindX-SDK Required) You must config the environment parameter. for example:
+
+    ```bash
+    export MX_SDK_HOME="/home/data/xj_mindx/mxVision"
+    export ASCEND_HOME=/usr/local/Ascend
+    export ASCEND_VERSION=nnrt/latest
+    export ARCH_PATTERN=.
+    export MXSDK_OPENSOURCE_DIR=/home/data/xj_mindx/mxVision/opensource
+    export LD_LIBRARY_PATH="${MX_SDK_HOME}/lib/plugins:${MX_SDK_HOME}/opensource/lib64:${MX_SDK_HOME}/lib:${MX_SDK_HOME}/lib/modelpostprocessors:${MX_SDK_HOME}/opensource/lib:/usr/local/Ascend/nnae/latest/fwkacllib/lib64:${LD_LIBRARY_PATH}"
+    export ASCEND_OPP_PATH="/usr/local/Ascend/nnae/latest/opp"
+    export ASCEND_AICPU_PATH="/usr/local/Ascend/nnae/latest"
+    ```
+
+# Model Convert
+
+we offer a bash file `convert.sh`  that can help you to easy convert model from AIR to OM, it was placed in `convert` . for example:
+
+```bash
+bash convert.sh
+```
+
+If you want to see the help message of the bash file, you can use:
+
+```bash
+bash convert.sh --help
+```
+
+You will see the help and the default setting of the args.
+
+# Input image
+
+You must put the **Omniglot Dataset** into `infer/input/dataset` folder.and put the dataset after processed into `data/input`.
+
+e.g. **Original**
+
+```shell
+└─dataset
+    ├─raw
+    ├─spilts
+    │     vinyals
+    │         test.txt
+    │         train.txt
+    │         val.txt
+    │         trainval.txt
+    └─data
+           Alphabet_of_the_Magi
+           Angelic
+```
+
+**Procession:** we offer a bash file (`convert/dataprocess.sh`) to process the dataset:
+
+```bash
+bash dataprocess.sh
+```
+
+e.g. **Processed**
+
+```shell
+└─data
+    ├─dataset
+    ├─data_preprocess_Result
+    │     data_1.bin
+    │     data_2.bin
+    |         ···
+    │     data_100.bin
+    └─label_classes_preprocess_Result
+          label_1.bin
+          label_1.bin
+              ···
+          label_100.bin
+          classes_1.bin
+          classes_2.bin
+              ···
+          classes_100.bin
+
+```
+
+# Infer by mxBase
+
+You should put OM file into `data/model`, then you need build the project by `build.sh`, for example:
+
+```bash
+cd mxbase
+bash build.sh
+```
+
+if success, you should see a new file named `protonet`, then you can use command to infer:
+
+```bash
+./protonet
+```
+
+Inference result will store in folder `result`.
+
+# Infer by mindX-SDK
+
+if you want to infer by mindx-SDK, you should enter the folder `infer/sdk` and then use the shell command:
+
+```bash
+bash run.sh ../data/config/protonet.pipeline ../data/input/data_preprocess_Result/ ../data/input/label_classes_preprocess_Result
+```
+
+you will acquire the inference result in folder `result`.
+
+# Calculate Inference Precision
+
+We offer a python file to calculate the precision.
+
+```bash
+python postprocess.py --result_path=./infer/XXX/result
+                      --label_classes_path=./infer/data/input/label_classes_preprocess_Result
+```
+
+**note:**
+
+- XXX can be `mxbase` or `sdk`
+- `--label_classes_path` is the label and class data after preprocessed.
+
+# Self-Inspection Report
+
+- We have obtained the following result through mindX-SDK and mxBase inference:
+    || Accuracy|||  |   |
+    |:----:| :----:|:----:|:----:| :----: | :----: |
+    |mindX-SDK| 0.9943  |
+    |mxBase| 0.9943  |
+
+- The model precision in train:
+
+    | | Accuracy  |   |
+    | :----: | :----: | :----: |
+    | Train | 0.9954  |
+
+# ModelZoo Homepage
+
+ Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
diff --git a/research/cv/ProtoNet/infer/convert/convert.sh b/research/cv/ProtoNet/infer/convert/convert.sh
new file mode 100644
index 0000000000000000000000000000000000000000..11d235b7c94e2eb12c925f0f9741de4b76530b8f
--- /dev/null
+++ b/research/cv/ProtoNet/infer/convert/convert.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+air_path=$1
+om_path=$2
+
+# Help information. Don't edit it!
+if [[ $1 == --help || $1 == -h ]];then
+    echo"usage:bash ./ATC_AIR_2_OM.sh <args>"
+    echo " "
+    echo "parameter explain:
+    --model                  set model place, e.g. --model=/home/xj_mindx/lixiang/protonet.air
+    --output                 set the name and place of OM model, e.g. --output=/home/HwHiAiUser/fixmatch310_tune4
+    --soc_version            set the soc_version, default: --soc_version=Ascend310
+    --input_shape            set the input node and shape, default: --input_shape=\"x:1,1,28,28\"
+    --insert_op_conf         set the aipp config file, e.g. --insert_op_conf=aipp_opencv.cfg
+    -h/--help                show help message
+    "
+    exit 1
+fi
+
+
+
+rm -rf ../data/model
+mkdir -p ../data/model
+
+echo "Input AIR file path: ${air_path}"
+echo "Output OM file path: ${om_path}"
+    
+atc --input_format=NCHW --framework=1 --model="${air_path}" \
+    --input_shape="x:1,1,28,28" --output="${om_path}/protonet" \
+    --soc_version=Ascend310 --disable_reuse_memory=1
diff --git a/research/cv/ProtoNet/infer/convert/dataprocess.sh b/research/cv/ProtoNet/infer/convert/dataprocess.sh
new file mode 100644
index 0000000000000000000000000000000000000000..aa1dc2888f547f74f9c68efae2e04a6e0d804280
--- /dev/null
+++ b/research/cv/ProtoNet/infer/convert/dataprocess.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+rm -rf ../data/input/data_preprocess_Result
+rm -rf ../data/input/label_classes_preprocess_Result
+mkdir -p ../data/input/data_preprocess_Result
+mkdir -p ../data/input/label_classes_preprocess_Result
+
+python ../../preprocess.py --dataset_path=../data/input/dataset --data_output_path=../data/input/data_preprocess_Result --label_classses_output_path=./data/input/label_classes_preprocess_Result
\ No newline at end of file
diff --git a/research/cv/ProtoNet/infer/data/config/protonet.pipeline b/research/cv/ProtoNet/infer/data/config/protonet.pipeline
new file mode 100644
index 0000000000000000000000000000000000000000..1f17396c23c2efa980a14d3ecfc53b7dd172a3dc
--- /dev/null
+++ b/research/cv/ProtoNet/infer/data/config/protonet.pipeline
@@ -0,0 +1,39 @@
+
+{
+"protonet": {
+    "stream_config": {
+            "deviceId": "0"
+        },
+    "appsrc0": {
+        "props": {
+                "blocksize": "409600"
+            },
+        "factory": "appsrc",
+        "next": "tensorinfer0"
+    },
+    "tensorinfer0": {
+        "props": {
+        "modelPath": "../data/model/protonet.om",
+        "dataSource": "appsrc0",
+        "waitingTime": "2000",
+        "outputDeviceId": "-1"
+    },
+        "factory": "mxpi_tensorinfer",
+    "next": "dataserialize"
+    },
+    "dataserialize": {
+        "props": {
+            "outputDataKeys": "tensorinfer0"
+        },
+        "factory": "mxpi_dataserialize",
+        "next": "appsink0"
+    },
+    "appsink0": {
+               "props": {
+                "blocksize": "4096000"
+            },
+        "factory": "appsink"
+    }
+}
+}
+
diff --git a/research/cv/ProtoNet/infer/docker_start_infer.sh b/research/cv/ProtoNet/infer/docker_start_infer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..64cf90a2311bdfb21d68a4e90e08602670fdf632
--- /dev/null
+++ b/research/cv/ProtoNet/infer/docker_start_infer.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+docker_image=$1
+data_dir=$2
+
+function show_help() {
+    echo "Usage: docker_start.sh docker_image data_dir"
+}
+
+function param_check() {
+    if [ -z "${docker_image}" ]; then
+        echo "please input docker_image"
+        show_help
+        exit 1
+    fi
+
+    if [ -z "${data_dir}" ]; then
+        echo "please input data_dir"
+        show_help
+        exit 1
+    fi
+}
+
+param_check
+
+docker run -it \
+  --device=/dev/davinci0 \
+  --device=/dev/davinci_manager \
+  --device=/dev/devmm_svm \
+  --device=/dev/hisi_hdc \
+  -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+  -v ${data_dir}:${data_dir} \
+  ${docker_image} \
+  /bin/bash
diff --git a/research/cv/ProtoNet/infer/mxbase/CMakeLists.txt b/research/cv/ProtoNet/infer/mxbase/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9b58148ed90d8fc77bcfdbee85d2bf04a7464c32
--- /dev/null
+++ b/research/cv/ProtoNet/infer/mxbase/CMakeLists.txt
@@ -0,0 +1,48 @@
+cmake_minimum_required(VERSION 3.10.0)
+project(protonet)
+set(TARGET protonet)
+
+add_definitions(-DENABLE_DVPP_INTERFACE)
+add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
+add_definitions(-Dgoogle=mindxsdk_private)
+add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
+add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie)
+
+
+# Check environment variable
+if(NOT DEFINED ENV{ASCEND_HOME})
+    message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
+endif()
+if(NOT DEFINED ENV{ASCEND_VERSION})
+    message(WARNING "please define environment variable:ASCEND_VERSION")
+endif()
+if(NOT DEFINED ENV{ARCH_PATTERN})
+    message(WARNING "please define environment variable:ARCH_PATTERN")
+endif()
+set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
+set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
+set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
+set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
+set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
+set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
+set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
+
+if(NOT DEFINED ENV{MXSDK_OPENSOURCE_DIR})
+    message(WARNING "please define environment variable:MXSDK_OPENSOURCE_DIR")
+endif()
+set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
+
+include_directories(src)
+include_directories(${ACL_INC_DIR})
+include_directories(${OPENSOURCE_DIR}/include)
+include_directories(${OPENSOURCE_DIR}/include/opencv4)
+include_directories(${MXBASE_INC})
+include_directories(${MXBASE_POST_PROCESS_DIR})
+link_directories(${ACL_LIB_DIR})
+link_directories(${OPENSOURCE_DIR}/lib)
+link_directories(${MXBASE_LIB_DIR})
+link_directories(${MXBASE_POST_LIB_DIR})
+
+add_executable(${TARGET} src/main.cpp src/Protonet.cpp)
+target_link_libraries(${TARGET} glog cpprest mxbase opencv_world  stdc++fs)
+install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
diff --git a/research/cv/ProtoNet/infer/mxbase/build.sh b/research/cv/ProtoNet/infer/mxbase/build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..53c51a1120584075bd86bb337f0a9b473a2ad61e
--- /dev/null
+++ b/research/cv/ProtoNet/infer/mxbase/build.sh
@@ -0,0 +1,70 @@
+#!/bin/bash
+
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+export MX_SDK_HOME="/usr/local/sdk_home/mxManufacture"
+export ASCEND_HOME=/usr/local/Ascend
+export ASCEND_VERSION=nnrt/latest
+export ARCH_PATTERN=.
+export MXSDK_OPENSOURCE_DIR=/usr/local/sdk_home/mxManufacture/opensource
+export LD_LIBRARY_PATH="${MX_SDK_HOME}/lib/plugins:${MX_SDK_HOME}/opensource/lib64:${MX_SDK_HOME}/lib:${MX_SDK_HOME}/lib/modelpostprocessors:${MX_SDK_HOME}/opensource/lib:/usr/local/Ascend/nnae/latest/fwkacllib/lib64:${LD_LIBRARY_PATH}"
+export ASCEND_OPP_PATH="/usr/local/Ascend/nnae/latest/opp"
+export ASCEND_AICPU_PATH="/usr/local/Ascend/nnae/latest"
+
+
+function check_env()
+{
+    # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
+    if [ ! "${ASCEND_VERSION}" ]; then
+        export ASCEND_VERSION=nnrt/latest
+        echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
+    else
+        echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
+    fi
+
+    if [ ! "${ARCH_PATTERN}" ]; then
+        # set ARCH_PATTERN to ./ when it was not specified by user
+        export ARCH_PATTERN=.
+        echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
+    else
+        echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
+    fi
+}
+
+function build_protonet()
+{
+    cd .
+    rm -rf build
+    mkdir -p build
+    cd build
+    cmake ..
+    make
+    ret=$?
+    if [ ${ret} -ne 0 ]; then
+        echo "Failed to build protonet."
+        exit ${ret}
+    fi
+    make install
+}
+
+rm -rf ./result
+mkdir -p ./result
+
+check_env
+build_protonet
+
+
+
diff --git a/research/cv/ProtoNet/infer/mxbase/src/Protonet.cpp b/research/cv/ProtoNet/infer/mxbase/src/Protonet.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a472fac9d97e4881efc0d2e200360c991a50fd97
--- /dev/null
+++ b/research/cv/ProtoNet/infer/mxbase/src/Protonet.cpp
@@ -0,0 +1,179 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "Protonet.h"
+#include <unistd.h>
+#include <sys/stat.h>
+#include <map>
+#include <fstream>
+#include "MxBase/DeviceManager/DeviceManager.h"
+#include "MxBase/Log/Log.h"
+
+const uint32_t EACH_LABEL_LENGTH = 4;
+const uint32_t MAX_LENGTH = 313600;
+
+
+APP_ERROR Protonet::Init(const InitParam &initParam) {
+    deviceId_ = initParam.deviceId;
+    APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
+    if (ret != APP_ERR_OK) {
+        LogError << "Init devices failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
+    if (ret != APP_ERR_OK) {
+        LogError << "Set context failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
+    ret = model_->Init(initParam.modelPath, modelDesc_);
+    if (ret != APP_ERR_OK) {
+        LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
+        return ret;
+    }
+    return APP_ERR_OK;
+}
+
+
+APP_ERROR Protonet::DeInit() {
+    model_->DeInit();
+    MxBase::DeviceManager::GetInstance()->DestroyDevices();
+    return APP_ERR_OK;
+}
+
+
+APP_ERROR Protonet::ReadTensorFromFile(const std::string &file, uint32_t *data, uint32_t size) {
+    if (data == NULL || size < MAX_LENGTH) {
+        LogError << "input data is invalid.";
+        return APP_ERR_COMM_INVALID_POINTER;
+    }
+
+    std::ifstream infile;
+    // open data file
+    infile.open(file, std::ios_base::in | std::ios_base::binary);
+    // check data file validity
+    if (infile.fail()) {
+        LogError << "Failed to open data file: " << file << ".";
+        return APP_ERR_COMM_OPEN_FAIL;
+    }
+    infile.read(reinterpret_cast<char*>(data), sizeof(uint32_t) * MAX_LENGTH);
+    infile.close();
+    return APP_ERR_OK;
+}
+
+
+APP_ERROR Protonet::ReadInputTensor(const std::string &fileName, std::vector<MxBase::TensorBase> *inputs) {
+    uint32_t data[MAX_LENGTH] = {0};
+    APP_ERROR ret = ReadTensorFromFile(fileName, data, MAX_LENGTH);
+    if (ret != APP_ERR_OK) {
+        LogError << "ReadTensorFromFile failed.";
+        return ret;
+    }
+
+    const uint32_t dataSize = modelDesc_.inputTensors[0].tensorSize;
+    MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
+    MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(data), dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
+
+    ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
+    if (ret != APP_ERR_OK) {
+        LogError << GetError(ret) << "Memory malloc and copy failed.";
+        return ret;
+    }
+
+    std::vector<uint32_t> shape = {1, MAX_LENGTH};
+    inputs->push_back(MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT32));
+    return APP_ERR_OK;
+}
+
+
+APP_ERROR Protonet::Inference(const std::vector<MxBase::TensorBase> &inputs,
+                                 std::vector<MxBase::TensorBase> *outputs) {
+    auto dtypes = model_->GetOutputDataType();
+    for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
+        std::vector<uint32_t> shape = {};
+        for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
+            shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
+        }
+        MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
+        APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
+        if (ret != APP_ERR_OK) {
+            LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
+            return ret;
+        }
+        outputs->push_back(tensor);
+    }
+
+    MxBase::DynamicInfo dynamicInfo = {};
+    dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
+    auto startTime = std::chrono::high_resolution_clock::now();
+    APP_ERROR ret = model_->ModelInference(inputs, *outputs, dynamicInfo);
+    auto endTime = std::chrono::high_resolution_clock::now();
+    double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
+    g_inferCost.push_back(costMs);
+
+    if (ret != APP_ERR_OK) {
+        LogError << "ModelInference failed, ret=" << ret << ".";
+        return ret;
+    }
+    return APP_ERR_OK;
+}
+
+
+APP_ERROR  Protonet::WriteResult(const std::string &imageFile, std::vector<MxBase::TensorBase> *outputs) {
+    std::string infer_result_path;
+    infer_result_path = "./result/";
+    for (size_t i = 0; i < outputs.size(); ++i) {
+        APP_ERROR ret = outputs[i].ToHost();
+        if (ret != APP_ERR_OK) {
+        LogError << GetError(ret) << "tohost fail.";
+        return ret;
+        }
+        void *netOutput = outputs[i].GetBuffer();
+        std::vector<uint32_t> out_shape = outputs[i].GetShape();
+        std::string outFileName = infer_result_path + "/" + imageFile;
+        FILE *outputFile_ = fopen(outFileName.c_str(), "wb");
+        fwrite(netOutput, out_shape[0]*out_shape[1], sizeof(float), outputFile_);
+        fclose(outputFile_);
+        outputFile = nullptr;
+    }
+    return APP_ERR_OK;
+}
+
+
+APP_ERROR Protonet::Process(const std::string &inferPath, const std::string &fileName) {
+    std::vector<MxBase::TensorBase> inputs = {};
+    std::string inputIdsFile = inferPath + fileName;
+    APP_ERROR ret = ReadInputTensor(inputIdsFile, &inputs);
+    if (ret != APP_ERR_OK) {
+        LogError << "Read input ids failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    std::vector<MxBase::TensorBase> outputs = {};
+    ret = Inference(inputs, &outputs);
+    if (ret != APP_ERR_OK) {
+        LogError << "Inference failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    ret = WriteResult(fileName, outputs);
+    if (ret != APP_ERR_OK) {
+        LogError << "Write result failed, ret=" << ret << ".";
+        return ret;
+    }
+    return APP_ERR_OK;
+}
diff --git a/research/cv/ProtoNet/infer/mxbase/src/Protonet.h b/research/cv/ProtoNet/infer/mxbase/src/Protonet.h
new file mode 100644
index 0000000000000000000000000000000000000000..bcc7ad3c7e5343189702cf3a94bcbd530f8cfaed
--- /dev/null
+++ b/research/cv/ProtoNet/infer/mxbase/src/Protonet.h
@@ -0,0 +1,56 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MXBASE_BERTBASE_H
+#define MXBASE_BERTBASE_H
+
+#include <memory>
+#include <utility>
+#include <vector>
+#include <string>
+#include <map>
+#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
+#include "MxBase/Tensor/TensorContext/TensorContext.h"
+
+extern std::vector<double> g_inferCost;
+
+struct InitParam {
+    uint32_t deviceId;
+    std::string labelPath;
+    std::string modelPath;
+};
+
+
+class Protonet {
+ public:
+    APP_ERROR Init(const InitParam &initParam);
+    APP_ERROR DeInit();
+    APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> *outputs);
+    APP_ERROR Process(const std::string &inferPath, const std::string &fileName);
+ protected:
+    APP_ERROR ReadTensorFromFile(const std::string &file, uint32_t *data, uint32_t size);
+    APP_ERROR ReadInputTensor(const std::string &fileName, std::vector<MxBase::TensorBase> *inputs);
+    APP_ERROR WriteResult(const std::string &imageFile, std::vector<MxBase::TensorBase> *outputs);
+
+
+ private:
+    std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
+    MxBase::ModelDesc modelDesc_ = {};
+    std::vector<std::string> labelMap_ = {};
+    uint32_t deviceId_ = 0;
+    uint32_t classNum_ = 0;
+};
+#endif
diff --git a/research/cv/ProtoNet/infer/mxbase/src/main.cpp b/research/cv/ProtoNet/infer/mxbase/src/main.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f7cea546142af952e326089a7e86c3bb0a0cb064
--- /dev/null
+++ b/research/cv/ProtoNet/infer/mxbase/src/main.cpp
@@ -0,0 +1,92 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <unistd.h>
+#include <dirent.h>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include "Protonet.h"
+#include "MxBase/Log/Log.h"
+
+
+std::vector<double> g_inferCost;
+
+
+void InitProtonetParam(InitParam* initParam) {
+    initParam->deviceId = 0;
+    initParam->modelPath = "../data/model/protonet.om";
+}
+
+APP_ERROR ReadFilesFromPath(const std::string &path, std::vector<std::string> *files) {
+    DIR *dir = NULL;
+    struct dirent *ptr = NULL;
+
+    if ((dir = opendir(path.c_str())) == NULL) {
+        LogError << "Open dir error: " << path;
+        return APP_ERR_COMM_OPEN_FAIL;
+    }
+
+    while ((ptr=readdir(dir)) != NULL) {
+        if (ptr->d_type == 8) {
+            files->push_back(ptr->d_name);
+        }
+    }
+    closedir(dir);
+    return APP_ERR_OK;
+}
+
+
+int main(int argc, char* argv[]) {
+    InitParam initParam;
+    InitProtonetParam(&initParam);
+    auto protonet = std::make_shared<Protonet>();
+    APP_ERROR ret = protonet->Init(initParam);
+    if (ret != APP_ERR_OK) {
+        LogError << "Protonet init failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    std::string inferPath = "../data/input/data_preprocess_Result/";
+    std::vector<std::string> files;
+    ret = ReadFilesFromPath(inferPath, &files);
+    if (ret != APP_ERR_OK) {
+        LogError << "Read files from path failed, ret=" << ret << ".";
+        return ret;
+    }
+
+    // do infer
+    for (uint32_t i = 0; i < files.size(); i++) {
+        ret = protonet->Process(inferPath, files[i]);
+        if (ret != APP_ERR_OK) {
+            LogError << "Protonet process failed, ret=" << ret << ".";
+            protonet->DeInit();
+            return ret;
+        }
+    }
+
+    LogInfo << "infer succeed and write the result data with binary file !";
+
+    protonet->DeInit();
+    double costSum = 0;
+    for (uint32_t i = 0; i < g_inferCost.size(); i++) {
+        costSum += g_inferCost[i];
+    }
+    LogInfo << "Infer images sum " << g_inferCost.size() << ", cost total time: " << costSum << " ms.";
+    LogInfo << "The throughput: " << g_inferCost.size() * 1000 / costSum << " bin/sec.";
+    LogInfo << "\n == The infer result has been saved in ./result ==";
+    return APP_ERR_OK;
+}
diff --git a/research/cv/ProtoNet/infer/sdk/main.py b/research/cv/ProtoNet/infer/sdk/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..25a823f065c47fc796aadb89ea8cfb0808126c0a
--- /dev/null
+++ b/research/cv/ProtoNet/infer/sdk/main.py
@@ -0,0 +1,142 @@
+'''
+The scripts to execute sdk infer
+'''
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+
+import argparse
+import os
+import time
+
+import MxpiDataType_pb2 as MxpiDataType
+import numpy as np
+from StreamManagerApi import StreamManagerApi, InProtobufVector, \
+    MxProtobufIn, StringVector
+
+
+def parse_args():
+    """set and check parameters."""
+    parser = argparse.ArgumentParser(description="protonet process")
+    parser.add_argument("--pipeline", type=str, default="", help="SDK infer pipeline")
+    parser.add_argument("--data_dir", type=str, default="")
+    parser.add_argument("--infer_result_path", type=str, default="")
+    args_opt = parser.parse_args()
+    return args_opt
+
+
+def send_source_data(appsrc_id, tensor, stream_name, stream_manager):
+    """
+    Construct the input of the stream,
+    send inputs data to a specified stream based on streamName.
+
+    Returns:
+        bool: send data success or not
+    """
+    tensor_package_list = MxpiDataType.MxpiTensorPackageList()
+    tensor_package = tensor_package_list.tensorPackageVec.add()
+    array_bytes = tensor.tobytes()
+    tensor_vec = tensor_package.tensorVec.add()
+    tensor_vec.deviceId = 0
+    tensor_vec.memType = 0
+    for i in tensor.shape:
+        tensor_vec.tensorShape.append(i)
+    tensor_vec.dataStr = array_bytes
+    tensor_vec.tensorDataSize = len(array_bytes)
+    key = "appsrc{}".format(appsrc_id).encode('utf-8')
+    protobuf_vec = InProtobufVector()
+    protobuf = MxProtobufIn()
+    protobuf.key = key
+    protobuf.type = b'MxTools.MxpiTensorPackageList'
+    protobuf.protobuf = tensor_package_list.SerializeToString()
+    protobuf_vec.push_back(protobuf)
+
+    ret = stream_manager.SendProtobuf(stream_name, appsrc_id, protobuf_vec)
+    if ret < 0:
+        print("Failed to send data to stream.")
+        return False
+    return True
+
+
+def run():
+    """
+    read pipeline and do infer
+    """
+
+    args = parse_args()
+    BATCH_SIZE = 100
+
+    # init stream manager
+    stream_manager_api = StreamManagerApi()
+    ret = stream_manager_api.InitManager()
+    if ret != 0:
+        print("Failed to init Stream manager, ret=%s" % str(ret))
+        return
+
+    # create streams by pipeline config file
+    with open(os.path.realpath(args.pipeline), 'rb') as f:
+        pipeline_str = f.read()
+    ret = stream_manager_api.CreateMultipleStreams(pipeline_str)
+    if ret != 0:
+        print("Failed to create Stream, ret=%s" % str(ret))
+        return
+
+    stream_name = b'protonet'
+    infer_total_time = 0
+    file_list = os.listdir(args.data_dir)
+    infer_result_folder = args.infer_result_path
+    for file_name in file_list:
+        num = file_name.split('_')[1]
+        file_path = os.path.join(args.data_dir, file_name)
+        tensor = np.fromfile(file_path, dtype=np.float32)
+        tensor = np.resize(tensor, (100, 1, 28, 28))
+        array_list = []
+        for tensor0 in tensor:
+            tensor0 = tensor0.reshape((1, 1, 28, 28))
+            if not send_source_data(0, tensor0, stream_name, stream_manager_api):
+                return
+
+            # Obtain the inference result by specifying streamName and uniqueId.
+            key_vec = StringVector()
+            key_vec.push_back(b'tensorinfer0')
+            start_time = time.time()
+            infer_result = stream_manager_api.GetProtobuf(stream_name, 0, key_vec)
+            infer_total_time += time.time() - start_time
+            if infer_result.size() == 0:
+                print("inferResult is null")
+                return
+            if infer_result[0].errorCode != 0:
+                print("GetProtobuf error. errorCode=%d" % (infer_result[0].errorCode))
+                return
+            result = MxpiDataType.MxpiTensorPackageList()
+            result.ParseFromString(infer_result[0].messageBuf)
+            res = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype='<f4')
+            res = res.reshape((1, 64))
+            array_list.append(res)
+
+        tensor = np.vstack((array_list[0], array_list[1]))
+        for i in range(BATCH_SIZE - 2):
+            tensor = np.vstack((tensor, array_list[i+2]))
+        tensor.tofile(infer_result_folder + "/" + "data_" + num)
+    print("=======================================")
+    print("The total time of inference is {} s".format(infer_total_time))
+    print("=======================================")
+
+    # destroy streams
+    stream_manager_api.DestroyAllStreams()
+
+
+if __name__ == '__main__':
+    run()
diff --git a/research/cv/ProtoNet/infer/sdk/run.sh b/research/cv/ProtoNet/infer/sdk/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3964695c013fd78d877cfc1c17cb6590934556cf
--- /dev/null
+++ b/research/cv/ProtoNet/infer/sdk/run.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+
+# Simple log helper functions
+info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
+warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
+
+export pipeline_path=$1
+export source_data_path=$2
+
+rm -rf result
+mkdir -p result
+
+
+
+python3.7 main.py --pipeline=$pipeline_path \
+                  --data_dir=$source_data_path \
+                  --infer_result_path=./result
+echo " == The infer result has been saved in ./result =="
+
+exit 0
diff --git a/research/cv/ProtoNet/modelarts/train_start.py b/research/cv/ProtoNet/modelarts/train_start.py
new file mode 100644
index 0000000000000000000000000000000000000000..425114cff076626e5f1ba040f948c9de1c9dcccb
--- /dev/null
+++ b/research/cv/ProtoNet/modelarts/train_start.py
@@ -0,0 +1,224 @@
+'''
+The boot script to train model
+'''
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+import argparse
+import sys
+import time
+import datetime
+import numpy as np
+import moxing as mox
+
+
+import mindspore.nn as nn
+from mindspore.communication.management import init
+from mindspore import context
+from mindspore import export
+from mindspore import Tensor
+from mindspore import dataset as ds
+from mindspore.train import Model
+from mindspore.context import ParallelMode
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.common import set_seed
+
+from src.protonet import ProtoNet
+from src.PrototypicalLoss import PrototypicalLoss
+from src.protonet import WithLossCell
+from src.EvalCallBack import EvalCallBack
+
+from model_init import init_dataloader
+
+
+sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../'))
+
+parser = argparse.ArgumentParser(description='Image classification')
+parser.add_argument("--enable_modelarts", type=bool, default=True, help="")
+parser.add_argument('--device_target', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
+                    help="Device target, support Ascend, GPU and CPU.")
+parser.add_argument("--data_path", type=str, default="/cache/data", help="path to dataset on modelarts")
+parser.add_argument("--data_url", type=str, default="", help="path to dataset on obs")
+parser.add_argument("--train_url", type=str, default="", help="path to training output on obs")
+parser.add_argument("--output_path", type=str, default="/cache/out", help="path to training output on modelarts")
+
+parser.add_argument("--learning_rate", type=float, default=0.001, help="")
+parser.add_argument("--epoch_size", type=int, default=1, help="")
+parser.add_argument("--save_checkpoint_steps", type=int, default=10, help="")
+parser.add_argument("--keep_checkpoint_max", type=int, default=5, help="")
+
+parser.add_argument("--batch_size", type=int, default=100, help="")
+parser.add_argument("--image_height", type=int, default=28, help="")
+parser.add_argument("--image_width", type=int, default=28, help="")
+parser.add_argument("--file_name", type=str, default="protonet", help="the name of air file ")
+
+parser.add_argument('-cTr', '--classes_per_it_tr',
+                    type=int,
+                    help='number of random classes per episode for training, default=60',
+                    default=20)
+parser.add_argument('-nsTr', '--num_support_tr',
+                    type=int,
+                    help='number of samples per class to use as support for training, default=5',
+                    default=5)
+parser.add_argument('-nqTr', '--num_query_tr',
+                    type=int,
+                    help='number of samples per class to use as query for training, default=5',
+                    default=5)
+parser.add_argument('-cVa', '--classes_per_it_val',
+                    type=int,
+                    help='number of random classes per episode for validation, default=5',
+                    default=20)
+parser.add_argument('-nsVa', '--num_support_val',
+                    type=int,
+                    help='number of samples per class to use as support for validation, default=5',
+                    default=5)
+parser.add_argument('-nqVa', '--num_query_val',
+                    type=int,
+                    help='number of samples per class to use as query for validation, default=15',
+                    default=15)
+parser.add_argument('-its', '--iterations',
+                    type=int,
+                    help='number of episodes per epoch, default=100',
+                    default=100)
+
+
+config = parser.parse_args()
+
+set_seed(1)
+_global_sync_count = 0
+
+def frozen_to_air(net, args):
+    param_dict = load_checkpoint(args.get("ckpt_file"))
+    load_param_into_net(net, param_dict)
+    input_arr = Tensor(np.zeros([args.get("batch_size"), 1,
+                                 args.get("image_height"), args.get("image_width")], np.float32))
+    export(net, input_arr, file_name=args.get("file_name"), file_format=args.get("file_format"))
+
+
+def sync_data(from_path, to_path):
+    """
+    Download data from remote obs to local directory if the first url is remote url and the second one is local path
+    Upload data from local directory to remote obs in contrast.
+    """
+    global _global_sync_count
+    sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
+    _global_sync_count += 1
+
+    # Each server contains 8 devices as most.
+    if os.getenv('DEVICE_ID', '0') % min(os.getenv('RANK_SIZE', '1'), 8) == 0 and not os.path.exists(sync_lock):
+        print("from path: ", from_path)
+        print("to path: ", to_path)
+        mox.file.copy_parallel(from_path, to_path)
+        print("===finish data synchronization===")
+        try:
+            os.mknod(sync_lock)
+        except IOError:
+            print("Failed to create directory")
+        print("===save flag===")
+
+    while True:
+        if os.path.exists(sync_lock):
+            break
+        time.sleep(1)
+
+    print("Finish sync data from {} to {}.".format(from_path, to_path))
+
+
+def wrapped_func(config_name):
+    """
+    Transfer data and file from obs to modelarts
+    """
+    if config_name.enable_modelarts:
+        if config_name.data_url:
+            if not os.path.isdir(config_name.data_path):
+                os.makedirs(config_name.data_path)
+                sync_data(config_name.data_url, config_name.data_path)
+                print("Dataset downloaded: ", os.listdir(config.data_path))
+            if config_name.train_url:
+                if not os.path.isdir(config_name.output_path):
+                    os.makedirs(config_name.output_path)
+                sync_data(config_name.train_url, config_name.output_path)
+                print("Workspace downloaded: ", os.listdir(config_name.output_path))
+
+def train_protonet_model():
+    '''
+    train model
+    '''
+    print(config)
+    print('device id:', os.getenv('DEVICE_ID', '0'))
+    print('device num:', os.getenv('RANK_SIZE', '1'))
+
+    context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
+    context.set_context(save_graphs=False)
+
+    device_target = config.device_target
+    if device_target == "GPU":
+        context.set_context(enable_graph_kernel=True)
+        context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
+
+    device_num = os.getenv('RANK_SIZE', '1')
+    if device_num > 1:
+        if device_target == "Ascend":
+            init()
+        elif device_target == "GPU":
+            init()
+        context.reset_auto_parallel_context()
+        context.set_auto_parallel_context(device_num=device_num,
+                                          parallel_mode=ParallelMode.DATA_PARALLEL,
+                                          gradients_mean=True)
+    context.set_context(device_id=os.getenv('DEVICE_ID', '0'))
+
+    tr_dataloader = init_dataloader(config, 'train', config.data_path)
+    val_dataloader = init_dataloader(config, 'val', config.data_path)
+
+    Net = ProtoNet()
+    loss_fn = PrototypicalLoss(config.num_support_tr, config.num_query_tr, config.classes_per_it_tr)
+    eval_loss_fn = PrototypicalLoss(config.num_support_tr, config.num_query_tr,
+                                    config.classes_per_it_val, is_train=False)
+    my_loss_cell = WithLossCell(Net, loss_fn)
+    my_acc_cell = WithLossCell(Net, eval_loss_fn)
+    optim = nn.Adam(params=Net.trainable_params(), learning_rate=config.learning_rate)
+    model = Model(my_loss_cell, optimizer=optim)
+
+    train_data = ds.GeneratorDataset(tr_dataloader, column_names=['data', 'label', 'classes'])
+    eval_data = ds.GeneratorDataset(val_dataloader, column_names=['data', 'label', 'classes'])
+    eval_cb = EvalCallBack(config, my_acc_cell, eval_data, config.output_path)
+
+    config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
+                                 keep_checkpoint_max=config.keep_checkpoint_max,
+                                 saved_network=Net)
+    ckpoint_cb = ModelCheckpoint(prefix='protonet_ckpt', directory=config.output_path, config=config_ck)
+
+    print("============== Starting Training ==============")
+    starttime = datetime.datetime.now()
+    model.train(config.epoch_size, train_data, callbacks=[ckpoint_cb, eval_cb, TimeMonitor()],)
+    endtime = datetime.datetime.now()
+    print('epoch time: ', (endtime - starttime).seconds / 10, 'per step time:', (endtime - starttime).seconds / 1000)
+
+    frozen_to_air_args = {"ckpt_file": config.output_path + "/" + "best_ck.ckpt",
+                          "batch_size": config.batch_size,
+                          "image_height": config.image_height,
+                          "image_width": config.image_width,
+                          "file_name": config.output_path + "/" + config.file_name,
+                          "file_format": "AIR"}
+
+    frozen_to_air(Net, frozen_to_air_args)
+    mox.file.copy_parallel(config.output_path, config.train_url)
+
+if __name__ == "__main__":
+    wrapped_func(config)
+    train_protonet_model()
diff --git a/research/cv/ProtoNet/scripts/docker_start.sh b/research/cv/ProtoNet/scripts/docker_start.sh
new file mode 100644
index 0000000000000000000000000000000000000000..298b3d1d6968983d21387dc95c50852f2ce50b0e
--- /dev/null
+++ b/research/cv/ProtoNet/scripts/docker_start.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+docker_image=$1
+data_dir=$2
+model_dir=$3
+
+docker run -it --ipc=host \
+               --device=/dev/davinci0 \
+               --device=/dev/davinci1 \
+               --device=/dev/davinci2 \
+               --device=/dev/davinci3 \
+               --device=/dev/davinci4 \
+               --device=/dev/davinci5 \
+               --device=/dev/davinci6 \
+               --device=/dev/davinci7 \
+               --device=/dev/davinci_manager \
+               --device=/dev/devmm_svm --device=/dev/hisi_hdc \
+               -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+               -v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons/ \
+               -v ${model_dir}:${model_dir} \
+               -v ${data_dir}:${data_dir}  \
+               -v ~/ascend/log/npu/conf/slog/slog.conf:/var/log/npu/conf/slog/slog.conf \
+               -v ~/ascend/log/npu/slog/:/var/log/npu/slog -v ~/ascend/log/npu/profiling/:/var/log/npu/profiling \
+               -v ~/ascend/log/npu/dump/:/var/log/npu/dump -v ~/ascend/log/npu/:/usr/slog ${docker_image} \
+               /bin/bash
\ No newline at end of file