diff --git a/research/cv/dem/infer/convert/convert.sh b/research/cv/dem/infer/convert/convert.sh
new file mode 100644
index 0000000000000000000000000000000000000000..340ae947ca095c27c6660e252642da7d922bf7ee
--- /dev/null
+++ b/research/cv/dem/infer/convert/convert.sh
@@ -0,0 +1,61 @@
+#!/bin/bash
+
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if [ $# -ne 3 ]
+then
+  echo "Wrong parameter format."
+  echo "Usage:"
+  echo "         bash $0 [INPUT_AIR_PATH] [AIPP_PATH] [OUTPUT_OM_PATH_NAME]"
+  echo "Example: "
+  echo "         bash convert_om.sh  xxx.air xx_name(with no suffix) dataset"
+
+  exit 1
+fi
+
+input_air_path=$1
+output_om_path=$2
+dataset="$3"
+
+if [ "$dataset" = "AwA" ]
+then
+    input_shape="1:85"
+else
+    input_shape="1:312"
+fi
+
+export install_path=/usr/local/Ascend/
+export ASCEND_ATC_PATH=${install_path}/atc
+export LD_LIBRARY_PATH=${install_path}/atc/lib64:$LD_LIBRARY_PATH
+export PATH=/usr/local/python3.7.5/bin:${install_path}/atc/ccec_compiler/bin:${install_path}/atc/bin:$PATH
+export PYTHONPATH=${install_path}/atc/python/site-packages:${install_path}/latest/atc/python/site-packages/auto_tune.egg/auto_tune:${install_path}/atc/python/site-packages/schedule_search.egg:${PYTHONPATH}
+export ASCEND_OPP_PATH=${install_path}/opp
+
+export ASCEND_SLOG_PRINT_TO_STDOUT=1
+
+echo "Input AIR file path: ${input_air_path}"
+echo "Output OM file path: ${output_om_path}"
+echo "Dataset: ${dataset}"
+echo "Input Shape: ${input_shape}"
+
+atc  --input_format=NCHW \
+--framework=1 \
+--input_shape=${input_shape} \
+--model=${input_air_path} \
+--output=${output_om_path} \
+--soc_version=Ascend310 \
+--disable_reuse_memory=0 \
+--precision_mode=allow_fp32_to_fp16  \
+--op_select_implmode=high_precision
\ No newline at end of file
diff --git a/research/cv/dem/infer/docker_start_infer.sh b/research/cv/dem/infer/docker_start_infer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..60a78321df0c3d9323a5bb36223f5b08d5572b9d
--- /dev/null
+++ b/research/cv/dem/infer/docker_start_infer.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+docker_image=$1
+model_dir=$2
+
+if [ -z "${docker_image}" ]; then
+    echo "please input docker_image"
+    exit 1
+fi
+
+if [ ! -d "${model_dir}" ]; then
+    echo "please input model_dir"
+    exit 1
+fi
+
+docker run -it \
+           --device=/dev/davinci0 \
+           --device=/dev/davinci_manager \
+           --device=/dev/devmm_svm \
+           --device=/dev/hisi_hdc \
+           -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
+           -v ${model_dir}:${model_dir} \
+           ${docker_image} \
+           /bin/bash
\ No newline at end of file
diff --git a/research/cv/dem/infer/mxbase/CMakeLists.txt b/research/cv/dem/infer/mxbase/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..083d158522de72aa733fcd15fcdb07376617c5e0
--- /dev/null
+++ b/research/cv/dem/infer/mxbase/CMakeLists.txt
@@ -0,0 +1,51 @@
+cmake_minimum_required(VERSION 3.10.0)
+project(dem)
+
+set(TARGET dem)
+
+add_definitions(-DENABLE_DVPP_INTERFACE)
+add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
+add_definitions(-Dgoogle=mindxsdk_private)
+add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
+add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie)
+
+# Check environment variable
+if(NOT DEFINED ENV{ASCEND_HOME})
+    message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
+endif()
+if(NOT DEFINED ENV{ASCEND_VERSION})
+    message(WARNING "please define environment variable:ASCEND_VERSION")
+endif()
+if(NOT DEFINED ENV{ARCH_PATTERN})
+    message(WARNING "please define environment variable:ARCH_PATTERN")
+endif()
+set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
+set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
+
+set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
+set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
+set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
+set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
+set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
+if(DEFINED ENV{MXSDK_OPENSOURCE_DIR})
+    set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
+else()
+    set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource)
+endif()
+
+include_directories(${ACL_INC_DIR})
+include_directories(${OPENSOURCE_DIR}/include)
+include_directories(${OPENSOURCE_DIR}/include/opencv4)
+
+include_directories(${MXBASE_INC})
+include_directories(${MXBASE_POST_PROCESS_DIR})
+
+link_directories(${ACL_LIB_DIR})
+link_directories(${OPENSOURCE_DIR}/lib)
+link_directories(${MXBASE_LIB_DIR})
+link_directories(${MXBASE_POST_LIB_DIR})
+
+add_executable(${TARGET} src/main.cpp src/DEM.cpp)
+target_link_libraries(${TARGET} glog cpprest mxbase opencv_world stdc++fs)
+
+install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
\ No newline at end of file
diff --git a/research/cv/dem/infer/mxbase/build.sh b/research/cv/dem/infer/mxbase/build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a3b40fd8cf8e1391e5893ab5f863b8b8010e2a78
--- /dev/null
+++ b/research/cv/dem/infer/mxbase/build.sh
@@ -0,0 +1,72 @@
+#!/bin/bash
+
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+export ASCEND_HOME=/usr/local/Ascend
+export ASCEND_VERSION=nnrt/latest
+export ARCH_PATTERN=.
+export MXSDK_OPENSOURCE_DIR=/usr/local/sdk_home/mxManufacture/opensource
+export LD_LIBRARY_PATH="${MX_SDK_HOME}/lib/plugins:${MX_SDK_HOME}/opensource/lib64:${MX_SDK_HOME}/lib:${MX_SDK_HOME}/lib/modelpostprocessors:${MX_SDK_HOME}/opensource/lib:/usr/local/Ascend/nnae/latest/fwkacllib/lib64:${LD_LIBRARY_PATH}"
+export ASCEND_OPP_PATH="/usr/local/Ascend/nnae/latest/opp"
+export ASCEND_AICPU_PATH="/usr/local/Ascend/nnae/latest"
+
+function check_env()
+{
+    # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
+    if [ ! "${ASCEND_VERSION}" ]; then
+        export ASCEND_VERSION=ascend-toolkit/latest
+        echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
+    else
+        echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
+    fi
+
+    if [ ! "${ARCH_PATTERN}" ]; then
+        # set ARCH_PATTERN to ./ when it was not specified by user
+        export ARCH_PATTERN=./
+        echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
+    else
+        echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
+    fi
+}
+
+function build_dem() {
+    cd .
+    rm -f dem
+    rm -rf build
+    mkdir -p build
+    cd build
+    if ! cmake ..;
+    then
+      echo "cmake failed."
+      return 1
+    fi
+
+    if ! (make);
+    then
+      echo "make failed."
+      return 1
+    fi
+    ret=$?
+    if [ ${ret} -ne 0 ]; then
+        echo "Failed to build dem."
+        exit ${ret}
+    fi
+    make install
+}
+
+check_env
+build_dem
+
+echo "build finish"
\ No newline at end of file
diff --git a/research/cv/dem/infer/mxbase/src/DEM.cpp b/research/cv/dem/infer/mxbase/src/DEM.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ccd2e63a93a25929d3a88ef8ef9c1908a4256ef5
--- /dev/null
+++ b/research/cv/dem/infer/mxbase/src/DEM.cpp
@@ -0,0 +1,190 @@
+/*
+* Copyright 2022 Huawei Technologies Co., Ltd.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+#include "DEM.h"
+#include <unistd.h>
+#include <sys/stat.h>
+#include <map>
+#include <fstream>
+#include <typeinfo>
+#include <iomanip>
+#include <iostream>
+#include "MxBase/DeviceManager/DeviceManager.h"
+#include "MxBase/Log/Log.h"
+
+APP_ERROR DEM::Init(const InitParam &initParam) {
+    deviceId_ = initParam.deviceId;
+    APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
+    if (ret != APP_ERR_OK) {
+        LogError << "Init devices failed, ret=" << ret << ".";
+        return ret;
+    }
+    ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
+    if (ret != APP_ERR_OK) {
+        LogError << "Set context failed, ret=" << ret << ".";
+        return ret;
+    }
+    model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
+    ret = model_->Init(initParam.modelPath, modelDesc_);
+    if (ret != APP_ERR_OK) {
+        LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
+        return ret;
+    }
+    return APP_ERR_OK;
+}
+
+APP_ERROR DEM::DeInit() {
+    model_->DeInit();
+    MxBase::DeviceManager::GetInstance()->DestroyDevices();
+    return APP_ERR_OK;
+}
+
+APP_ERROR DEM::ReadTensorFromFile(const std::string &file, float *data, uint32_t size) {
+    if (data == NULL) {
+        LogError << "input data is invalid.";
+        return APP_ERR_COMM_INVALID_POINTER;
+    }
+    std::ifstream fin(file);
+    if (fin.fail()) {
+        LogError << "Failed to open file: " << file << ".";
+        return APP_ERR_COMM_OPEN_FAIL;
+    }
+    for (uint32_t i = 0; i < size; ++i) {
+        fin >> data[i];
+    }
+    fin.close();
+    return APP_ERR_OK;
+}
+
+APP_ERROR DEM::ReadInputTensor(const std::string &fileName, uint32_t index,
+                               std::vector<MxBase::TensorBase> *inputs, uint32_t size,
+                               MxBase::TensorDataType type) {
+    float *data = new float[size];
+    APP_ERROR ret = ReadTensorFromFile(fileName, data, size);
+    if (ret != APP_ERR_OK) {
+        LogError << "Read Tensor From File failed.";
+        return ret;
+    }
+    const uint32_t dataSize = modelDesc_.inputTensors[index].tensorSize;
+    LogInfo << "dataSize:" << dataSize;
+    MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
+    MxBase::MemoryData memoryDataSrc(reinterpret_cast<void*>(data), dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
+    ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
+    if (ret != APP_ERR_OK) {
+        LogError << GetError(ret) << "Memory malloc and copy failed.";
+        return ret;
+    }
+    std::vector<uint32_t> shape = {1, size};
+    inputs->push_back(MxBase::TensorBase(memoryDataDst, false, shape, type));
+    return APP_ERR_OK;
+}
+
+APP_ERROR DEM::Inference(const std::vector<MxBase::TensorBase> &inputs,
+                         std::vector<MxBase::TensorBase> *outputs) {
+    auto dtypes = model_->GetOutputDataType();
+    for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
+        std::vector<uint32_t> shape = {};
+        for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
+            shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
+        }
+        MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
+        APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
+        if (ret != APP_ERR_OK) {
+            LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
+            return ret;
+        }
+        outputs->push_back(tensor);
+    }
+    MxBase::DynamicInfo dynamicInfo = {};
+    dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
+    auto startTime = std::chrono::high_resolution_clock::now();
+    APP_ERROR ret = model_->ModelInference(inputs, *outputs, dynamicInfo);
+    auto endTime = std::chrono::high_resolution_clock::now();
+    double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
+    g_inferCost.push_back(costMs);
+    if (ret != APP_ERR_OK) {
+        LogError << "ModelInference failed, ret=" << ret << ".";
+        return ret;
+    }
+    return APP_ERR_OK;
+}
+
+APP_ERROR DEM::PostProcess(std::vector<MxBase::TensorBase> *outputs, std::vector<float> *result) {
+    LogInfo << "Outputs size:" << outputs->size();
+    MxBase::TensorBase &tensor = outputs->at(0);
+    APP_ERROR ret = tensor.ToHost();
+    if (ret != APP_ERR_OK) {
+        LogError << GetError(ret) << "Tensor deploy to host failed.";
+        return ret;
+    }
+    // check tensor is available
+    auto outputShape = tensor.GetShape();
+    uint32_t length = outputShape[0];
+    uint32_t classNum = outputShape[1];
+    LogInfo << "output shape is: " << outputShape[0] << " "<< outputShape[1] << std::endl;
+
+    void* data = tensor.GetBuffer();
+    for (uint32_t i = 0; i < length; i++) {
+        for (uint32_t j = 0; j < classNum; j++) {
+            // get real data by index, the variable 'data' is address
+            float value = *(reinterpret_cast<float*>(data) + i * classNum + j);
+            // LogInfo << "value " << value;
+            result->push_back(value);
+        }
+    }
+    return APP_ERR_OK;
+}
+
+APP_ERROR DEM::SaveResult(std::vector<float > *result) {
+    std::ofstream outfile("res", std::ofstream::app);
+    if (outfile.fail()) {
+        LogError << "Failed to open result file: ";
+        return APP_ERR_COMM_FAILURE;
+    }
+    for (uint32_t i = 0; i < 1024; ++i) {
+        outfile << std::setiosflags(std::ios::fixed) << std::setprecision(6) << result->at(i) << " ";
+    }
+    outfile << std::endl;
+    outfile.close();
+    return APP_ERR_OK;
+}
+
+APP_ERROR DEM::Process(const std::string &inferPath, uint32_t size) {
+    std::vector<MxBase::TensorBase> inputs = {};
+    std::string dataPath = inferPath;
+    APP_ERROR ret = ReadInputTensor(dataPath, 0, &inputs, size, MxBase::TENSOR_DTYPE_FLOAT32);
+    if (ret != APP_ERR_OK) {
+        LogError << "Read input data failed, ret= " << ret << ".";
+    }
+    std::vector<MxBase::TensorBase> outputs = {};
+    ret = Inference(inputs, &outputs);
+    if (ret != APP_ERR_OK) {
+        LogError << "Inference failed, ret=" << ret << ".";
+        return ret;
+    }
+    std::vector<float> result;
+    ret = PostProcess(&outputs, &result);
+    if (ret != APP_ERR_OK) {
+        LogError << "PostProcess failed, ret=" << ret << ".";
+        return ret;
+    }
+    ret = SaveResult(&result);
+    if (ret != APP_ERR_OK) {
+        LogError << "CalcF1Score read label failed, ret=" << ret << ".";
+        return ret;
+    }
+    return APP_ERR_OK;
+}
diff --git a/research/cv/dem/infer/mxbase/src/DEM.h b/research/cv/dem/infer/mxbase/src/DEM.h
new file mode 100644
index 0000000000000000000000000000000000000000..e03bb9eccc426027ddb6d480fbaec0020b115218
--- /dev/null
+++ b/research/cv/dem/infer/mxbase/src/DEM.h
@@ -0,0 +1,55 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MXBASE_DEM_H
+#define MXBASE_DEM_H
+
+#include <memory>
+#include <utility>
+#include <vector>
+#include <string>
+#include <map>
+#include <opencv2/opencv.hpp>
+#include "MxBase/DvppWrapper/DvppWrapper.h"
+#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
+#include "MxBase/Tensor/TensorContext/TensorContext.h"
+
+extern std::vector<double> g_inferCost;
+
+struct InitParam {
+    uint32_t deviceId;
+    std::string datasetPath;
+    std::string modelPath;
+};
+
+class DEM{
+ public:
+    APP_ERROR Init(const InitParam &initParam);
+    APP_ERROR DeInit();
+    APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> *outputs);
+    APP_ERROR Process(const std::string &inferPath, uint32_t size);
+    APP_ERROR PostProcess(std::vector<MxBase::TensorBase> *outputs, std::vector<float> *result);
+ protected:
+    APP_ERROR ReadTensorFromFile(const std::string &file, float *data, uint32_t size);
+    APP_ERROR ReadInputTensor(const std::string &fileName, uint32_t index, std::vector<MxBase::TensorBase> *inputs,
+                              uint32_t size, MxBase::TensorDataType type);
+    APP_ERROR SaveResult(std::vector<float> *result);
+ private:
+    std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
+    MxBase::ModelDesc modelDesc_ = {};
+    uint32_t deviceId_ = 0;
+};
+#endif
diff --git a/research/cv/dem/infer/mxbase/src/main.cpp b/research/cv/dem/infer/mxbase/src/main.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2dedde053acb157ca118f5f116b8db791bf2f2fe
--- /dev/null
+++ b/research/cv/dem/infer/mxbase/src/main.cpp
@@ -0,0 +1,73 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <stdio.h>
+#include <unistd.h>
+#include <dirent.h>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include "DEM.h"
+#include "MxBase/Log/Log.h"
+
+std::vector<double> g_inferCost;
+
+void InitDEMParam(InitParam* initParam) {
+    initParam->deviceId = 0;
+    initParam->modelPath = "../convert/dem.om";
+}
+
+int main(int argc, char* argv[]) {
+    if (argc <= 2) {
+        LogWarn << "Please input dataset and path e.g. ./dem [dataset] [data_path]";
+        return APP_ERR_OK;
+    }
+
+    InitParam initParam;
+    InitDEMParam(&initParam);
+    auto demBase = std::make_shared<DEM>();
+    APP_ERROR ret = demBase->Init(initParam);
+    if (ret != APP_ERR_OK) {
+        LogError << "DEMBase init failed, ret=" << ret << ".";
+        return ret;
+    }
+    std::string dataset = argv[1];
+    std::string inferPath = argv[2];
+    uint32_t size = (dataset == "CUB") ? 312 : 85;
+    LogInfo << "Infer path :" << inferPath;
+    int len = (dataset == "CUB") ? 50 : 10;
+    for (int i = 0; i < len; ++i) {
+        char file[1024];
+        snprintf(file, sizeof(file), "test_att_%d", i);
+        LogInfo << "reading file name:" << file;
+        ret = demBase->Process(inferPath + file, size);
+        if (ret != APP_ERR_OK) {
+            LogError << "DEMBase process failed, ret=" << ret << ".";
+            demBase->DeInit();
+            return ret;
+        }
+        LogInfo << "Finish " << i << " file";
+    }
+    LogInfo << "======== Inference finished ========";
+    demBase->DeInit();
+    double costSum = 0;
+    for (uint32_t i = 0; i < g_inferCost.size(); i++) {
+        costSum += g_inferCost[i];
+    }
+    LogInfo << "Infer sum " << g_inferCost.size() << ", cost total time: " << costSum << " ms.";
+    LogInfo << "The throughput: " << g_inferCost.size() * 1000 / costSum << " bin/sec.";
+    return APP_ERR_OK;
+}
diff --git a/research/cv/dem/infer/sdk/pipeline/dem.pipeline b/research/cv/dem/infer/sdk/pipeline/dem.pipeline
new file mode 100644
index 0000000000000000000000000000000000000000..2a901b952f63b242f6b7c985413af7c3b8085bb8
--- /dev/null
+++ b/research/cv/dem/infer/sdk/pipeline/dem.pipeline
@@ -0,0 +1,36 @@
+{
+  "dem": {
+    "stream_config": {
+      "deviceId": "0"
+    },
+    "appsrc0": {
+      "props": {
+        "blocksize": "409600"
+      },
+      "factory": "appsrc",
+      "next": "mxpi_tensorinfer0"
+    },
+    "mxpi_tensorinfer0": {
+      "props": {
+        "dataSource": "appsrc0",
+        "modelPath": "../../convert/dem.om",
+        "outputDeviceId": "0"
+      },
+      "factory": "mxpi_tensorinfer",
+      "next": "mxpi_dataserialize0"
+    },
+    "mxpi_dataserialize0": {
+      "props": {
+        "outputDataKeys": "mxpi_classpostprocessor1"
+      },
+      "factory": "mxpi_dataserialize",
+      "next": "appsink0"
+    },
+    "appsink0": {
+      "props": {
+        "blocksize": "4096000"
+      },
+      "factory": "appsink"
+    }
+  }
+}
\ No newline at end of file
diff --git a/research/cv/dem/infer/sdk/python_DEM/SdkApi.py b/research/cv/dem/infer/sdk/python_DEM/SdkApi.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7efb324597dc037a14a59846b46d0000cb44804
--- /dev/null
+++ b/research/cv/dem/infer/sdk/python_DEM/SdkApi.py
@@ -0,0 +1,125 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" Model Infer """
+import json
+import logging
+import numpy as np
+import MxpiDataType_pb2 as MxpiDataType
+from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, MxProtobufIn, StringVector
+
+
+class SdkApi:
+    """ Class SdkApi """
+    INFER_TIMEOUT = 100000
+    STREAM_NAME = "dem"
+
+    def __init__(self, pipeline_cfg):
+        self.pipeline_cfg = pipeline_cfg
+        self._stream_api = None
+        self._data_input = None
+        self._device_id = None
+
+    def init(self):
+        """ Initialize Stream """
+        with open(self.pipeline_cfg, 'r') as fp:
+            self._device_id = int(
+                json.loads(fp.read())[self.STREAM_NAME]["stream_config"]
+                ["deviceId"])
+            print(f"The device id: {self._device_id}.")
+
+        # create api
+        self._stream_api = StreamManagerApi()
+
+        # init stream mgr
+        ret = self._stream_api.InitManager()
+        if ret != 0:
+            print(f"Failed to init stream manager, ret={ret}.")
+            return False
+
+        # create streams
+        with open(self.pipeline_cfg, 'rb') as fp:
+            pipe_line = fp.read()
+
+        ret = self._stream_api.CreateMultipleStreams(pipe_line)
+        if ret != 0:
+            print(f"Failed to create stream, ret={ret}.")
+            return False
+
+        self._data_input = MxDataInput()
+        return True
+
+    def __del__(self):
+        if not self._stream_api:
+            return
+
+        self._stream_api.DestroyAllStreams()
+
+    def _send_protobuf(self, stream_name, plugin_id, element_name, buf_type,
+                       pkg_list):
+        """ Send Stream """
+        protobuf = MxProtobufIn()
+        protobuf.key = element_name
+        protobuf.type = buf_type
+        protobuf.protobuf = pkg_list.SerializeToString()
+        protobuf_vec = InProtobufVector()
+        protobuf_vec.push_back(protobuf)
+        err_code = self._stream_api.SendProtobuf(stream_name, plugin_id,
+                                                 protobuf_vec)
+        if err_code != 0:
+            logging.error(
+                "Failed to send data to stream, stream_name(%s), plugin_id(%s), element_name(%s), "
+                "buf_type(%s), err_code(%s).", stream_name, plugin_id,
+                element_name, buf_type, err_code)
+            return False
+        return True
+
+    def send_tensor_input(self, stream_name, plugin_id, element_name,
+                          input_data, input_shape, data_type):
+        """ Send Tensor """
+        tensor_list = MxpiDataType.MxpiTensorPackageList()
+
+        data = np.expand_dims(input_data, 0)
+        tensor_pkg = tensor_list.tensorPackageVec.add()
+        # init tensor vector
+        tensor_vec = tensor_pkg.tensorVec.add()
+        tensor_vec.deviceId = self._device_id
+        tensor_vec.memType = 0
+        tensor_vec.tensorShape.extend(data.shape)
+        tensor_vec.tensorDataType = data_type
+        tensor_vec.dataStr = data.tobytes()
+        tensor_vec.tensorDataSize = int(data.shape[1] * 4)
+        print(type(tensor_list))
+        buf_type = b"MxTools.MxpiTensorPackageList"
+        return self._send_protobuf(stream_name, plugin_id, element_name,
+                                   buf_type, tensor_list)
+
+    def get_result(self, stream_name, out_plugin_id=0):
+        """ Get Result """
+        keys = [b"mxpi_tensorinfer0"]
+        keyVec = StringVector()
+        for key in keys:
+            keyVec.push_back(key)
+        infer_result = self._stream_api.GetProtobuf(stream_name, 0, keyVec)
+        if infer_result.size() == 0:
+            print("infer_result is null")
+            exit()
+
+        if infer_result[0].errorCode != 0:
+            print("GetProtobuf error. errorCode=%d" % (
+                infer_result[0].errorCode))
+            exit()
+        TensorList = MxpiDataType.MxpiTensorPackageList()
+        TensorList.ParseFromString(infer_result[0].messageBuf)
+        return TensorList
diff --git a/research/cv/dem/infer/sdk/python_DEM/dem_run.sh b/research/cv/dem/infer/sdk/python_DEM/dem_run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e3e31402418db4bf8b30de0822bc314210f5a05d
--- /dev/null
+++ b/research/cv/dem/infer/sdk/python_DEM/dem_run.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+pipeline_path=$1
+data_dir=$2
+dataset=$3
+
+set -e
+
+# Simple log helper functions
+info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
+warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
+
+python3.7 main.py --pipeline_path=$pipeline_path --data_dir=$data_dir --dataset=$dataset
+exit 0
\ No newline at end of file
diff --git a/research/cv/dem/infer/sdk/python_DEM/main.py b/research/cv/dem/infer/sdk/python_DEM/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..85d35afc17811ff43a7b89ac2f4f1d869cc18cdb
--- /dev/null
+++ b/research/cv/dem/infer/sdk/python_DEM/main.py
@@ -0,0 +1,85 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""main process for sdk infer"""
+import argparse
+import time
+import scipy.io as sio
+import numpy as np
+from SdkApi import SdkApi
+
+STREAM_NAME = b'dem'
+TENSOR_DTYPE_FLOAT16 = 1
+TENSOR_DTYPE_FLOAT32 = 0
+
+
+def parse_args():
+    """set and check parameters."""
+    parser = argparse.ArgumentParser(description="dem process")
+    parser.add_argument("--pipeline_path", type=str, default="../pipeline/dem.pipeline", help="SDK infer pipeline")
+    parser.add_argument("--data_dir", type=str, default="/dataset/DEM_data", help="path where the dataset is saved")
+    parser.add_argument("--dataset", type=str, default="AwA", choices=['AwA', 'CUB'],
+                        help="dataset which is chosen to use")
+    args_opt = parser.parse_args()
+    return args_opt
+
+
+def inference():
+    """infer process function"""
+    args = parse_args()
+
+    # init stream manager
+    sdk_api = SdkApi(args.pipeline_path)
+    if not sdk_api.init():
+        exit(-1)
+
+    start_time = time.time()
+    if args.dataset == 'AwA':
+        input_tensor = dataset_AwA(args.data_dir)
+    elif args.dataset == 'CUB':
+        input_tensor = dataset_cub(args.data_dir)
+    print("================> Input shape:", input_tensor.shape)
+    res_list = np.empty((input_tensor.shape[0], 1, 1024), dtype=np.float32)
+    for i in range(input_tensor.shape[0]):
+        input_data = input_tensor[i]
+        sdk_api.send_tensor_input(STREAM_NAME, 0, b'appsrc0', input_data, input_tensor.shape, TENSOR_DTYPE_FLOAT32)
+        result = sdk_api.get_result(STREAM_NAME)
+        pred = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype=np.float32)
+        res_list[i] = pred
+    end_time = time.time() - start_time
+    print(f"The inference time is {end_time}")
+    res = np.squeeze(res_list, axis=1)
+    # save result
+    np.savetxt('res', res, fmt="%f")
+
+
+def dataset_cub(data_path):
+    """input:*.mat, output:array"""
+    f = sio.loadmat(data_path+'/CUB_data/test_proto.mat')
+    test_att_0 = np.array(f['test_proto'])
+    test_att_0 = test_att_0.astype("float32")
+
+    return test_att_0
+
+
+def dataset_AwA(data_path):
+    """input:*.mat, output:array"""
+    f = sio.loadmat(data_path+'/AwA_data/attribute/pca_te_con_10x85.mat')
+    test_att_0 = np.array(f['pca_te_con_10x85'])
+    test_att_0 = test_att_0.astype("float32")
+
+    return test_att_0
+
+
+if __name__ == '__main__':
+    inference()
diff --git a/research/cv/dem/infer/sdk/task_metric.py b/research/cv/dem/infer/sdk/task_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e3ea52ade8a2430666788e0c67d389325f6282
--- /dev/null
+++ b/research/cv/dem/infer/sdk/task_metric.py
@@ -0,0 +1,152 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""calculate infer result accuracy"""
+
+import argparse
+import numpy as np
+import scipy.io as sio
+
+
+def parse_args():
+    """set and check parameters."""
+    parser = argparse.ArgumentParser(description="bert process")
+    parser.add_argument("--res_path", type=str, default="./python_DEM/res", help="result numpy path")
+    parser.add_argument("--data_dir", type=str, default="/home/dataset/DEM_data",
+                        help="path where the dataset is saved")
+    parser.add_argument("--dataset", type=str, default="AwA", choices=['AwA', 'CUB'],
+                        help="dataset which is chosen to use")
+    args_opt = parser.parse_args()
+    return args_opt
+
+
+def kNNClassify(newInput, dataSet, labels, k):
+    """classify using kNN"""
+    numSamples = dataSet.shape[0]
+    diff = np.tile(newInput, (numSamples, 1)) - dataSet
+    squaredDiff = diff ** 2
+    squaredDist = np.sum(squaredDiff, axis=1)
+    distance = squaredDist ** 0.5
+    sortedDistIndices = np.argsort(distance)
+    classCount = {}
+    for i in range(k):
+        voteLabel = labels[sortedDistIndices[i]]
+        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
+    maxCount = 0
+    for key, value in classCount.items():
+        if value > maxCount:
+            maxCount = value
+            maxIndex = key
+    return maxIndex
+    #return sortedDistIndices
+
+
+def compute_accuracy_att(att_pred_0, pred_len, test_att_0, test_visual_0, test_id_0, test_label_0):
+    """calculate accuracy using infer result"""
+    outpred = [0] * pred_len
+    test_label_0 = test_label_0.astype("float32")
+    for i in range(pred_len):
+        outputLabel = kNNClassify(test_visual_0[i, :], att_pred_0, test_id_0, 1)
+        outpred[i] = outputLabel
+    outpred = np.array(outpred)
+    acc_0 = np.equal(outpred, test_label_0).mean()
+    return acc_0
+
+
+def dataset_CUB(data_path):
+    """input:*.mat, output:array"""
+    f = sio.loadmat(data_path+'/CUB_data/train_attr.mat')
+    train_att_0 = np.array(f['train_attr'])
+    # print('train attr:', train_att.shape)
+
+    f = sio.loadmat(data_path+'/CUB_data/train_cub_googlenet_bn.mat')
+    train_x_0 = np.array(f['train_cub_googlenet_bn'])
+    # print('train x:', train_x.shape)
+
+    f = sio.loadmat(data_path+'/CUB_data/test_cub_googlenet_bn.mat')
+    test_x_0 = np.array(f['test_cub_googlenet_bn'])
+    # print('test x:', test_x.shape)
+
+    f = sio.loadmat(data_path+'/CUB_data/test_proto.mat')
+    test_att_0 = np.array(f['test_proto'])
+    test_att_0 = test_att_0.astype("float16")
+    # test_att_0 = Tensor(test_att_0, mindspore.float32)
+    # print('test att:', test_att.shape)
+
+    f = sio.loadmat(data_path+'/CUB_data/test_labels_cub.mat')
+    test_label_0 = np.squeeze(np.array(f['test_labels_cub']))
+    # print('test x2label:', test_x2label)
+
+    f = sio.loadmat(data_path+'/CUB_data/testclasses_id.mat')
+    test_id_0 = np.squeeze(np.array(f['testclasses_id']))
+    # print('test att2label:', test_att2label)
+
+    return train_att_0, train_x_0, test_x_0, test_att_0, test_label_0, test_id_0
+
+
+def dataset_AwA(data_path):
+    """input:*.mat, output:array"""
+    f = sio.loadmat(data_path+'/AwA_data/train_googlenet_bn.mat')
+    train_x_0 = np.array(f['train_googlenet_bn'])
+
+    # useless data
+    train_att_0 = np.empty(1)
+
+    f = sio.loadmat(data_path+'/AwA_data/wordvector/train_word.mat')
+    train_word_0 = np.array(f['train_word'])
+
+    f = sio.loadmat(data_path+'/AwA_data/test_googlenet_bn.mat')
+    test_x_0 = np.array(f['test_googlenet_bn'])
+
+    f = sio.loadmat(data_path+'/AwA_data/attribute/pca_te_con_10x85.mat')
+    test_att_0 = np.array(f['pca_te_con_10x85'])
+    test_att_0 = test_att_0.astype("float16")
+
+    f = sio.loadmat(data_path+'/AwA_data/wordvector/test_vectors.mat')
+    test_word_0 = np.array(f['test_vectors'])
+    test_word_0 = test_word_0.astype("float16")
+
+    f = sio.loadmat(data_path+'/AwA_data/test_labels.mat')
+    test_label_0 = np.squeeze(np.array(f['test_labels']))
+
+    f = sio.loadmat(data_path+'/AwA_data/testclasses_id.mat')
+    test_id_0 = np.squeeze(np.array(f['testclasses_id']))
+
+    return train_x_0, train_att_0, train_word_0, test_x_0, \
+        test_att_0, test_word_0, test_label_0, test_id_0
+
+
+def read_res(res_path):
+    """load result"""
+    return np.loadtxt(res_path, dtype=np.float32)
+
+
+def test_result(dir_name, res_path, dataset):
+    """calculate"""
+    if dataset == 'AwA':
+        pred_len = 6180
+        _, _, _, test_x, test_att, _, test_label, test_id = dataset_AwA(dir_name)
+    elif dataset == 'CUB':
+        pred_len = 2933
+        _, _, test_x, test_att, test_label, test_id = dataset_CUB(dir_name)
+    att_pred_res = read_res(res_path)
+
+    acc = compute_accuracy_att(att_pred_res, pred_len, test_att, test_x, test_id, test_label)
+    return acc
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    print("dataset:", args.dataset)
+    final_acc = test_result(args.data_dir, args.res_path, args.dataset)
+    print('accuracy :', final_acc)
diff --git a/research/cv/dem/modelarts/train_modelarts.py b/research/cv/dem/modelarts/train_modelarts.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae37704cf19089fcf1aff5d658f1cef9d9014e72
--- /dev/null
+++ b/research/cv/dem/modelarts/train_modelarts.py
@@ -0,0 +1,212 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+######################## train DEM ########################
+train DEM
+python train.py --data_path = /YourDataPath \
+                --dataset = AwA or CUB \
+                --train_mode = att, word or fusion
+"""
+import os
+import time
+import sys
+import numpy as np
+import moxing as mox
+
+import mindspore
+import mindspore.nn as nn
+from mindspore import context
+from mindspore import save_checkpoint
+from mindspore import dataset as ds
+from mindspore import Model
+from mindspore import set_seed
+from mindspore import export
+from mindspore import load_checkpoint
+from mindspore import Tensor
+from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
+from mindspore.communication.management import init, get_rank, get_group_size
+
+from src.dataset import dataset_AwA, dataset_CUB, SingleDataIterable, DoubleDataIterable
+from src.demnet import MyTrainOneStepCell
+from src.set_parser import set_parser
+from src.utils import acc_cfg, backbone_cfg, param_cfg, withlosscell_cfg
+from src.accuracy import compute_accuracy_att, compute_accuracy_word, compute_accuracy_fusion
+
+if __name__ == "__main__":
+    # Set graph mode, device id
+    set_seed(1000)
+    args = set_parser()
+
+    local_data_path = "/cache/dataset/"
+    model_path = "/cache/model/"
+    if not os.path.exists(local_data_path):
+        os.makedirs(local_data_path, exist_ok=True)
+    if not os.path.exists(model_path):
+        os.makedirs(model_path, exist_ok=True)
+    mox.file.copy_parallel(args.data_path, local_data_path)
+    ckpt_path = os.path.join(model_path, "train.ckpt")
+
+    context.set_context(mode=context.GRAPH_MODE,
+                        device_target=args.device_target)
+    if args.distribute:
+        if args.device_target == "Ascend":
+            context.set_context(device_id=args.device_id)
+
+        init()
+        args.device_num = get_group_size()
+        rank_id = get_rank()
+        context.reset_auto_parallel_context()
+        context.set_auto_parallel_context(
+            parallel_mode=context.ParallelMode.DATA_PARALLEL,
+            gradients_mean=True,
+            device_num=args.device_num
+        )
+    else:
+        rank_id = 0
+    # Initialize parameters
+    pred_len = acc_cfg(args)
+    lr, weight_decay, clip_param = param_cfg(args)
+    if np.equal(args.distribute, True):
+        lr = lr * 5
+
+    # Loading datasets and iterators
+    if args.dataset == 'AwA':
+        train_x, train_att, train_word, \
+        test_x, test_att, test_word, \
+        test_label, test_id = dataset_AwA(local_data_path)
+        if args.train_mode == 'att':
+            custom_data = ds.GeneratorDataset(SingleDataIterable(train_att, train_x),
+                                              ['label', 'data'],
+                                              num_shards=args.device_num,
+                                              shard_id=rank_id,
+                                              shuffle=True)
+        elif args.train_mode == 'word':
+            custom_data = ds.GeneratorDataset(SingleDataIterable(train_word, train_x),
+                                              ['label', 'data'],
+                                              num_shards=args.device_num,
+                                              shard_id=rank_id,
+                                              shuffle=True)
+        elif args.train_mode == 'fusion':
+            custom_data = ds.GeneratorDataset(DoubleDataIterable(train_att, train_word, train_x),
+                                              ['label1', 'label2', 'data'],
+                                              num_shards=args.device_num,
+                                              shard_id=rank_id,
+                                              shuffle=True)
+    elif args.dataset == 'CUB':
+        train_att, train_x, \
+        test_x, test_att, \
+        test_label, test_id = dataset_CUB(local_data_path)
+        if args.train_mode == 'att':
+            custom_data = ds.GeneratorDataset(SingleDataIterable(train_att, train_x),
+                                              ['label', 'data'],
+                                              num_shards=args.device_num,
+                                              shard_id=rank_id,
+                                              shuffle=True)
+        elif args.train_mode == 'word':
+            print("Warning: Do not support word vector mode training in CUB dataset.")
+            print("Only attribute mode is supported in this dataset.")
+            sys.exit(0)
+        elif args.train_mode == 'fusion':
+            print("Warning: Do not support fusion mode training in CUB dataset.")
+            print("Only attribute mode is supported in this dataset.")
+            sys.exit(0)
+    # Note: Must set "drop_remainder = True" in parallel mode.
+    batch_size = args.batch_size
+    custom_data = custom_data.batch(batch_size, drop_remainder=True)
+
+    # Build network
+    net = backbone_cfg(args)
+    loss_fn = nn.MSELoss(reduction='mean')
+    optim = nn.Adam(net.trainable_params(), lr, weight_decay)
+    MyWithLossCell = withlosscell_cfg(args)
+    loss_net = MyWithLossCell(net, loss_fn)
+    train_net = MyTrainOneStepCell(loss_net, optim)
+    model = Model(train_net)
+
+    # Train
+    start = time.time()
+    acc_max = 0
+    save_min_acc = 0
+    save_ckpt = model_path
+    epoch_size = args.epoch_size
+    interval_step = args.interval_step
+    if os.path.exists(ckpt_path):
+        print("============== Starting Loading ==============")
+        load_checkpoint(ckpt_path, net)
+    else:
+        print("============== Starting Training ==============")
+        if np.equal(args.distribute, True):
+            now = time.localtime()
+            nowt = time.strftime("%Y-%m-%d-%H:%M:%S", now)
+            print(nowt)
+            loss_cb = LossMonitor(interval_step)
+            if args.device_target == "Ascend":
+                ckpt_config = CheckpointConfig(save_checkpoint_steps=interval_step)
+                ckpt_callback = ModelCheckpoint(prefix='auto_parallel', config=ckpt_config)
+            t1 = time.time()
+
+            if args.device_target == "Ascend":
+                model.train(
+                    epoch_size,
+                    train_dataset=custom_data,
+                    callbacks=[loss_cb, ckpt_callback],
+                    dataset_sink_mode=True
+                )
+            elif args.device_target == "GPU":
+                model.train(epoch_size, train_dataset=custom_data, callbacks=[loss_cb], dataset_sink_mode=False)
+                ckpt_file_name = save_ckpt + f'/train_{rank_id}.ckpt'
+                save_checkpoint(net, ckpt_file_name)
+
+            end = time.time()
+
+            t3 = 1000 * (end - t1) / (88 * epoch_size)
+            print('total time:', end - start)
+            print('speed_8p = %.3f ms/step'%t3)
+            now = time.localtime()
+            nowt = time.strftime("%Y-%m-%d-%H:%M:%S", now)
+            print(nowt)
+        else:
+            for i in range(epoch_size):
+                t1 = time.time()
+                loss_cb = LossMonitor(interval_step)
+                model.train(1, train_dataset=custom_data, callbacks=loss_cb, dataset_sink_mode=True)
+                t2 = time.time()
+                t3 = 1000 * (t2 - t1) / 88
+                if args.train_mode == 'att':
+                    acc = compute_accuracy_att(net, pred_len, test_att, test_x, test_id, test_label)
+                elif args.train_mode == 'word':
+                    acc = compute_accuracy_word(net, pred_len, test_word, test_x, test_id, test_label)
+                else:
+                    acc = compute_accuracy_fusion(net, pred_len, test_att, test_word, test_x, test_id, test_label)
+                if acc > acc_max:
+                    acc_max = acc
+                    if acc_max > save_min_acc:
+                        save_checkpoint(net, ckpt_path)
+                print('epoch:', i + 1, 'accuracy = %.5f'%acc, 'speed = %.3f ms/step'%t3)
+            end = time.time()
+            print("total time:", end - start)
+
+    acc = compute_accuracy_att(net, pred_len, test_att, test_x, test_id, test_label)
+    print("current accuracy:", acc)
+    print("============== Starting Exporting ==============")
+    if args.train_mode == 'att':
+        if args.dataset == 'AwA':
+            input0 = Tensor(np.zeros([1, 85]), mindspore.float32)
+        elif args.dataset == 'CUB':
+            input0 = Tensor(np.zeros([1, 312]), mindspore.float32)
+        save_ckpt = save_ckpt + '/train'
+        export(net, input0, file_name=save_ckpt, file_format=args.file_format)
+        print("Successfully convert to", args.file_format)
+    mox.file.copy_parallel(model_path, args.save_ckpt)