diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index 16e06fff9edbf10dc6ec8ab50f634d456512b47b..3a26d4dd08cd8c7bc7cae0e74ffe01b4420284f1 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -122,6 +122,10 @@ "models/research/recommend/autodis/infer/mxbase/src/Autodis.h" "runtime/references" "models/research/recommend/autodis/infer/mxbase/src/Autodis.cpp" "runtime/references" +"models/research/recommend/mmoe/infer/mxbase/src/main.cpp" "runtime/references" +"models/research/recommend/mmoe/infer/mxbase/src/MMoE.h" "runtime/references" +"models/research/recommend/mmoe/infer/mxbase/src/MMoE.cpp" "runtime/references" + "models/research/cv/textfusenet/infer/mxbase/src/Textfusenet.h" "runtime/references" "models/research/cv/textfusenet/infer/mxbase/src/Textfusenet.cpp" "runtime/references" "models/research/cv/textfusenet/infer/mxbase/src/PostProcess/TextfusenetMindsporePost.h" "runtime/references" diff --git a/research/recommend/mmoe/infer/convert/convert.sh b/research/recommend/mmoe/infer/convert/convert.sh new file mode 100644 index 0000000000000000000000000000000000000000..0dceb01ccab182bea3d9ce182a2fad2387242958 --- /dev/null +++ b/research/recommend/mmoe/infer/convert/convert.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) 2022. Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -ne 2 ] +then + echo "Need two parameters: one for air model input file path, another for om model output dir path!" + exit 1 +fi + +model=$1 +output=$2 + +atc --model="${model}" \ + --framework=1 \ + --output="${output}" \ + --soc_version=Ascend310 \ + --input_shape="data:1,499" \ + --output_type=FP16 \ No newline at end of file diff --git a/research/recommend/mmoe/infer/data/config/MMoE.pipeline b/research/recommend/mmoe/infer/data/config/MMoE.pipeline new file mode 100644 index 0000000000000000000000000000000000000000..2924dd3b5e59fba308a7d2e41060da971859feea --- /dev/null +++ b/research/recommend/mmoe/infer/data/config/MMoE.pipeline @@ -0,0 +1,35 @@ +{ + "MMoE": { + "stream_config": { + "deviceId": "0" + }, + "appsrc0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0:0" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource":"appsrc0", + "modelPath": "../data/model/MMoE.om" + }, + "factory": "mxpi_tensorinfer", + "next": "mxpi_dataserialize0" + }, + "mxpi_dataserialize0": { + "props": { + "outputDataKeys": "mxpi_tensorinfer0" + }, + "factory": "mxpi_dataserialize", + "next": "appsink0" + }, + "appsink0": { + "props": { + "blocksize": "4096000" + }, + "factory": "appsink" + } + } +} \ No newline at end of file diff --git a/research/recommend/mmoe/infer/docker_start_infer.sh b/research/recommend/mmoe/infer/docker_start_infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..b650ad2b7b760d89fbf363d1dac8eed715c00cc9 --- /dev/null +++ b/research/recommend/mmoe/infer/docker_start_infer.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash + +# Copyright (c) 2022. Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +docker_image=$1 +model_dir=$2 + + +function show_help() { + echo "Usage: docker_start.sh docker_image model_dir data_dir" +} + +function param_check() { + if [ -z "${docker_image}" ]; then + echo "please input docker_image" + show_help + exit 1 + fi + + if [ -z "${model_dir}" ]; then + echo "please input model_dir" + show_help + exit 1 + fi +} + +param_check + +docker run -it \ + --device=/dev/davinci0 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v ${model_dir}:${model_dir} \ + ${docker_image} \ + /bin/bash diff --git a/research/recommend/mmoe/infer/mxbase/CMakeLists.txt b/research/recommend/mmoe/infer/mxbase/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..eb1ea072337ad3f5bfc48dcbfed0dbd53951d6b7 --- /dev/null +++ b/research/recommend/mmoe/infer/mxbase/CMakeLists.txt @@ -0,0 +1,57 @@ +cmake_minimum_required(VERSION 3.10.0) +project(mmoe) + +set(TARGET mmoe) + +add_definitions(-DENABLE_DVPP_INTERFACE) +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-Dgoogle=mindxsdk_private) +add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall) +add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie) + +# Check environment variable +if(NOT DEFINED ENV{ASCEND_HOME}) + message(FATAL_ERROR "please define environment variable:ASCEND_HOME") +endif() +if(NOT DEFINED ENV{ASCEND_VERSION}) + message(WARNING "please define environment variable:ASCEND_VERSION") +endif() +if(NOT DEFINED ENV{ARCH_PATTERN}) + message(WARNING "please define environment variable:ARCH_PATTERN") +endif() + +set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include) +set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64) + +set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME}) +set(MXBASE_INC ${MXBASE_ROOT_DIR}/include) +set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib) +set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors) +set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include) + +if(DEFINED ENV{MXSDK_OPENSOURCE_DIR}) + set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR}) +else() + set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource) +endif() + + +include_directories(${ACL_INC_DIR}) +include_directories($ENV{ASCEND_HOME}/ascend-toolkit/5.0.4/x86_64-linux/runtime/include) +include_directories(${OPENSOURCE_DIR}/include) +include_directories(${OPENSOURCE_DIR}/include/opencv4) + +include_directories(${MXBASE_INC}) +include_directories(${MXBASE_POST_PROCESS_DIR}) + +link_directories(${ACL_LIB_DIR}) +link_directories(${OPENSOURCE_DIR}/lib) +link_directories(${MXBASE_LIB_DIR}) +link_directories(${MXBASE_POST_LIB_DIR}) + + +add_executable(${TARGET} src/main.cpp src/MMoE.cpp) + +target_link_libraries(${TARGET} glog cpprest mxbase opencv_world) + +install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/) \ No newline at end of file diff --git a/research/recommend/mmoe/infer/mxbase/build.sh b/research/recommend/mmoe/infer/mxbase/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..ef0b0732534aa66f7b07c6133f93e5a37b03b408 --- /dev/null +++ b/research/recommend/mmoe/infer/mxbase/build.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +# Copyright (c) 2022. Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +path_cur=$(dirname $0) + +function check_env() +{ + # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user + if [ ! "${ASCEND_HOME}" ]; then + export ASCEND_HOME=/usr/local/Ascend/ + echo "Set ASCEND_HOME to the default value: ${ASCEND_HOME}" + else + echo "ASCEND_HOME is set to ${ASCEND_HOME} by user" + fi + + if [ ! "${ASCEND_VERSION}" ]; then + export ASCEND_VERSION=nnrt/latest + echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}" + else + echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user" + fi + + if [ ! "${ARCH_PATTERN}" ]; then + # set ARCH_PATTERN to ./ when it was not specified by user + export ARCH_PATTERN=./ + echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}" + else + echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user" + fi +} + +function build_mmoe() +{ + cd $path_cur + rm -rf build + mkdir -p build + cd build + cmake .. + make + ret=$? + if [ ${ret} -ne 0 ]; then + echo "Failed to build mmoe." + exit ${ret} + fi + make install +} + +check_env +build_mmoe \ No newline at end of file diff --git a/research/recommend/mmoe/infer/mxbase/src/MMoE.cpp b/research/recommend/mmoe/infer/mxbase/src/MMoE.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3338b9b674e269f62862c764f9ce287f43005e77 --- /dev/null +++ b/research/recommend/mmoe/infer/mxbase/src/MMoE.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2022. Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <unistd.h> +#include <sys/stat.h> +#include <math.h> +#include <memory> +#include <string> +#include <fstream> +#include <algorithm> +#include <vector> +#include "MMoE.h" +#include "half.hpp" +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/Log/Log.h" + +using half_float::half; + +APP_ERROR MMoE::Init(const InitParam &initParam) { + deviceId_ = initParam.deviceId; + APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices(); + if (ret != APP_ERR_OK) { + LogError << "Init devices failed, ret=" << ret << "."; + return ret; + } + ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId); + if (ret != APP_ERR_OK) { + LogError << "Set context failed, ret=" << ret << "."; + return ret; + } + model_ = std::make_shared<MxBase::ModelInferenceProcessor>(); + ret = model_->Init(initParam.modelPath, modelDesc_); + if (ret != APP_ERR_OK) { + LogError << "ModelInferenceProcessor init failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +APP_ERROR MMoE::DeInit() { + model_->DeInit(); + MxBase::DeviceManager::GetInstance()->DestroyDevices(); + return APP_ERR_OK; +} + +template<class dtype> +APP_ERROR MMoE::VectorToTensorBase(const std::vector<std::vector<dtype>> &input, uint32_t inputId + , MxBase::TensorBase &tensorBase) { + uint32_t dataSize = modelDesc_.inputTensors[inputId].tensorDims[1]; + dtype *metaFeatureData = new dtype[dataSize]; + uint32_t idx = 0; + for (size_t bs = 0; bs < input.size(); bs++) { + for (size_t c = 0; c < input[bs].size(); c++) { + metaFeatureData[idx++] = input[bs][c]; + } + } + + MxBase::MemoryData memoryDataDst(dataSize * 2, MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + MxBase::MemoryData memoryDataSrc(reinterpret_cast<void *>(metaFeatureData), dataSize * 2 + , MxBase::MemoryData::MEMORY_HOST_MALLOC); + + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + + std::vector<uint32_t> shape = {1, dataSize}; + if (typeid(dtype) == typeid(half)) { + tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT16); + } else { + tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_INT32); + } + return APP_ERR_OK; +} + +APP_ERROR MMoE::Inference(const std::vector<MxBase::TensorBase> &inputs, + std::vector<MxBase::TensorBase> &outputs) { + auto dtypes = model_->GetOutputDataType(); + for (size_t i = 0; i < modelDesc_.outputTensors.size(); i++) { + std::vector<uint32_t> shape = {}; + for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); j++) { + shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]); + } + MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_); + APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor); + if (ret != APP_ERR_OK) { + LogError << "TensorBaseMalloc failed, ret=" << ret << "."; + return ret; + } + outputs.push_back(tensor); + } + MxBase::DynamicInfo dynamicInfo = {}; + dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH; + auto startTime = std::chrono::high_resolution_clock::now(); + APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo); + auto endTime = std::chrono::high_resolution_clock::now(); + double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); + inferCostTimeMilliSec += costMs; + if (ret != APP_ERR_OK) { + LogError << "ModelInference failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +APP_ERROR MMoE::PostProcess(std::vector<std::vector<half>> &income_preds + , std::vector<std::vector<half>> &married_preds + , const std::vector<MxBase::TensorBase> &inputs) { + size_t index = 0; + for (auto retTensor : inputs) { + std::vector<uint32_t> shape = retTensor.GetShape(); + uint32_t N = shape[0]; + uint32_t C = shape[1]; + if (!retTensor.IsHost()) { + retTensor.ToHost(); + } + void* data = retTensor.GetBuffer(); + std::vector<half> temp; + for (uint32_t i = 0; i < N; i++) { + for (uint32_t j = 0; j < C; j++) { + half value = *(reinterpret_cast<half*>(data) + i * C + j); + temp.push_back(value); + } + } + if (index == 0) { + income_preds.emplace_back(temp); + } else { + married_preds.emplace_back(temp); + } + index++; + } + return APP_ERR_OK; +} + +APP_ERROR MMoE::PrintInputInfo(std::vector<MxBase::TensorBase> inputs) { + LogInfo << "input size: " << inputs.size(); + for (size_t i = 0; i < inputs.size(); i++) { + // check tensor is available + MxBase::TensorBase &tensor_input = inputs[i]; + auto inputShape = tensor_input.GetShape(); + uint32_t inputDataType = tensor_input.GetDataType(); + LogInfo << "input_" + std::to_string(i) + "_shape is: " << inputShape[0] + << " " << inputShape[1] << " " << inputShape.size(); + LogInfo << "input_" + std::to_string(i) + "_dtype is: " << inputDataType; + } + return APP_ERR_OK; +} + +APP_ERROR MMoE::Process(const std::vector<std::vector<half>> &data, const InitParam &initParam + , std::vector<std::vector<half>> &income_preds + , std::vector<std::vector<half>> &married_preds) { + std::vector<MxBase::TensorBase> inputs = {}; + std::vector<MxBase::TensorBase> outputs = {}; + APP_ERROR ret; + MxBase::TensorBase tensorBase; + ret = VectorToTensorBase(data, 0, tensorBase); + if (ret != APP_ERR_OK) { + LogError << "ToTensorBase failed, ret=" << ret << "."; + return ret; + } + inputs.push_back(tensorBase); + + // run inference + ret = Inference(inputs, outputs); + if (ret != APP_ERR_OK) { + LogError << "Inference failed, ret=" << ret << "."; + return ret; + } + + ret = PostProcess(income_preds, married_preds, outputs); + if (ret != APP_ERR_OK) { + LogError << "Save model infer results into file failed. ret = " << ret << "."; + return ret; + } +} diff --git a/research/recommend/mmoe/infer/mxbase/src/MMoE.h b/research/recommend/mmoe/infer/mxbase/src/MMoE.h new file mode 100644 index 0000000000000000000000000000000000000000..352a49f2d145e7378eb9ec30266d42e111bce14a --- /dev/null +++ b/research/recommend/mmoe/infer/mxbase/src/MMoE.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2022. Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <string> +#include <vector> +#include <memory> +#include "half.hpp" +#ifndef MxBase_MMoE_H +#define MxBase_MMoE_H + +#include "MxBase/DvppWrapper/DvppWrapper.h" +#include "MxBase/ModelInfer/ModelInferenceProcessor.h" +#include "MxBase/Tensor/TensorContext/TensorContext.h" + +using half_float::half; + +struct InitParam { + uint32_t deviceId; + bool checkTensor; + std::string modelPath; +}; + +class MMoE { + public: + APP_ERROR Init(const InitParam &initParam); + APP_ERROR DeInit(); + template<class dtype> + APP_ERROR VectorToTensorBase(const std::vector<std::vector<dtype>> &input_x, uint32_t inputId + , MxBase::TensorBase &tensorBase); + APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs); + APP_ERROR Process(const std::vector<std::vector<half>> &data, const InitParam &initParam + , std::vector<std::vector<half>> &income_preds + , std::vector<std::vector<half>> &married_preds); + APP_ERROR PostProcess(std::vector<std::vector<half>> &income_preds, std::vector<std::vector<half>> &married_preds, + const std::vector<MxBase::TensorBase> &inputs); + APP_ERROR PrintInputInfo(std::vector<MxBase::TensorBase> inputs); + double GetInferCostMilliSec() const {return inferCostTimeMilliSec;} + private: + std::shared_ptr<MxBase::ModelInferenceProcessor> model_; + MxBase::ModelDesc modelDesc_; + uint32_t deviceId_ = 0; + double inferCostTimeMilliSec = 0.0; +}; +#endif diff --git a/research/recommend/mmoe/infer/mxbase/src/main.cpp b/research/recommend/mmoe/infer/mxbase/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98ebef21a6403888b7a8dac6f154818b96508957 --- /dev/null +++ b/research/recommend/mmoe/infer/mxbase/src/main.cpp @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2022. Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <dirent.h> +#include <unistd.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <fstream> +#include <string> +#include <sstream> +#include <cstdlib> +#include <vector> +#include <cmath> +#include <cstdio> +#include "MMoE.h" +#include "half.hpp" +#include "MxBase/Log/Log.h" + +using half_float::half; + +const char mode[] = "eval"; +template<class dtype> +APP_ERROR ReadTxt(const std::string &path, std::vector<std::vector<dtype>> &dataset) { + std::ifstream fp(path); + std::string line; + while (std::getline(fp, line)) { + std::vector<dtype> data_line; + std::string number; + std::istringstream readstr(line); + + while (std::getline(readstr, number, '\t')) { + data_line.push_back(half(atof(number.c_str()))); + } + dataset.push_back(data_line); + } + return APP_ERR_OK; +} + +APP_ERROR WriteResult(const std::string &output_dir, const std::string &filename, + const std::vector<std::vector<half>> &result) { + std::string output_path = output_dir + "/" + filename; + if (access(output_dir.c_str(), F_OK) == -1) { + mkdir(output_dir.c_str(), S_IRWXO|S_IRWXG|S_IRWXU); + } + std::ofstream outfile(output_path, std::ios::out | std::ios::trunc);\ + if (outfile.fail()) { + LogError << "Failed to open result file: "; + return APP_ERR_COMM_FAILURE; + } + for (size_t i = 0; i < result.size(); i ++) { + std::string temp = std::to_string(result[i][0]) + "\t" +std::to_string(result[i][1]) + "\n"; + outfile << temp; + } + outfile.close(); + return APP_ERR_OK; +} + +float get_auc(const std::vector<std::vector<half>> &preds, const std::vector<std::vector<half>> &labels + , size_t n_bins = 1000000) { + std::vector<half> flatten_preds; + std::vector<half> flatten_labels; + int rows = preds.size(); + for (size_t i = 0; i < rows; i ++) { + flatten_preds.push_back(preds[i][0]); + flatten_preds.push_back(preds[i][1]); + flatten_labels.push_back(labels[i][0]); + flatten_labels.push_back(labels[i][1]); + } + size_t positive_len = 0; + for (size_t i = 0; i < flatten_labels.size(); i++) { + positive_len += static_cast<int>(flatten_labels[i]); + } + size_t negative_len = flatten_labels.size()-positive_len; + if (positive_len == 0 || negative_len == 0) { + return 0.0; + } + uint64_t total_case = positive_len*negative_len; + std::vector<size_t> pos_histogram(n_bins+1, 0); + std::vector<size_t> neg_histogram(n_bins+1, 0); + float bin_width = 1.0/n_bins; + for (size_t i = 0; i < flatten_preds.size(); i ++) { + size_t nth_bin = static_cast<int>(flatten_preds[i]/bin_width); + if (static_cast<int>(flatten_labels[i]) == 1) { + pos_histogram[nth_bin] += 1; + } else { + neg_histogram[nth_bin] += 1; + } + } + size_t accumulated_neg = 0; + float satisfied_pair = 0; + for (size_t i = 0; i < n_bins+1; i ++) { + satisfied_pair += (pos_histogram[i]*accumulated_neg + pos_histogram[i]*neg_histogram[i]*0.5); + accumulated_neg += neg_histogram[i]; + } + return satisfied_pair/total_case; +} + +int main(int argc, char* argv[]) { + InitParam initParam = {}; + initParam.deviceId = 0; + initParam.checkTensor = true; + initParam.modelPath = "../data/model/MMoE.om"; + auto mmoe = std::make_shared<MMoE>(); + printf("Start running\n"); + APP_ERROR ret = mmoe->Init(initParam); + if (ret != APP_ERR_OK) { + mmoe->DeInit(); + LogError << "mmoe init failed, ret=" << ret << "."; + return ret; + } + + // read data from txt + std::string data_path = "../data/input/data_" + std::string(mode) + std::string(".txt"); + std::string income_path = "../data/input/income_labels_" + std::string(mode) +std::string(".txt"); + std::string married_path = "../data/input/married_labels_" + std::string(mode) +std::string(".txt"); + std::vector<std::vector<half>> data; + std::vector<std::vector<half>> income; + std::vector<std::vector<half>> married; + ret = ReadTxt(data_path, data); + if (ret != APP_ERR_OK) { + LogError << "read ids failed, ret=" << ret << "."; + return ret; + } + + ret = ReadTxt(income_path, income); + if (ret != APP_ERR_OK) { + LogError << "read wts failed, ret=" << ret << "."; + return ret; + } + + ret = ReadTxt(married_path, married); + if (ret != APP_ERR_OK) { + LogError << "read label failed, ret=" << ret << "."; + return ret; + } + + int data_rows = data.size(); + int income_rows = income.size(); + int married_rows = married.size(); + if (data_rows != income_rows || income_rows != married_rows) { + LogError << "size of data, income and married are not equal"; + return -1; + } + int rows = data_rows; + std::vector<std::vector<half>> income_preds; + std::vector<std::vector<half>> married_preds; + + for (int i = 0; i < rows; i++) { + std::vector<std::vector<half>> data_batch; + data_batch.emplace_back(data[i]); + ret = mmoe->Process(data_batch, initParam, income_preds, married_preds); + if (ret !=APP_ERR_OK) { + LogError << "mmoe process failed, ret=" << ret << "."; + mmoe->DeInit(); + return ret; + } + } + + // write results + std::string output_dir = "./output"; + std::string filename = "income_preds_" + std::string(mode) + std::string(".txt"); + WriteResult(output_dir, filename, income_preds); + filename = "income_labels_" + std::string(mode) +std::string(".txt"); + WriteResult(output_dir, filename, income); + filename = "married_preds_" + std::string(mode) + std::string(".txt"); + WriteResult(output_dir, filename, married_preds); + filename = "married_labels_" + std::string(mode) +std::string(".txt"); + WriteResult(output_dir, filename, married); + + float infer_total_time = mmoe->GetInferCostMilliSec()/1000; + float income_auc = get_auc(income_preds, income); + float married_auc = get_auc(married_preds, married); + LogInfo << "<<==========Infer Metric==========>>"; + LogInfo << "Number of samples:" + std::to_string(rows); + LogInfo << "Total inference time:" + std::to_string(infer_total_time); + LogInfo << "Average infer time:" + std::to_string(infer_total_time/rows); + LogInfo << "Income infer auc:"+ std::to_string(income_auc); + LogInfo << "Married infer auc:"+ std::to_string(married_auc); + LogInfo << "<<================================>>"; + + mmoe->DeInit(); + return APP_ERR_OK; +} diff --git a/research/recommend/mmoe/infer/sdk/main.py b/research/recommend/mmoe/infer/sdk/main.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f6834af5cd2f1f1f610cf5e8289a9587fbe062 --- /dev/null +++ b/research/recommend/mmoe/infer/sdk/main.py @@ -0,0 +1,209 @@ +# Copyright (c) 2022. Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +sample script of MMoE infer using SDK run in docker +""" + +import argparse +import os +import time +import numpy as np +from sklearn.metrics import roc_auc_score + +import MxpiDataType_pb2 as MxpiDataType +from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, \ + MxProtobufIn, StringVector + +def parse_args(): + """set and check parameters.""" + parser = argparse.ArgumentParser(description='MMoE process') + parser.add_argument('--data_dir', type=str, default='../data/input', help='Data path') + parser.add_argument('--data_file', type=str, default='data_{}.npy') + parser.add_argument('--income_file', type=str, default='income_labels_{}.npy') + parser.add_argument('--married_file', type=str, default='married_labels_{}.npy') + parser.add_argument('--mode', type=str, default='eval') + parser.add_argument('--num_features', type=int, default=499, help='dim of feature') + parser.add_argument('--num_labels', type=int, default=2, help='dim of label') + parser.add_argument('--output_dir', type=str, default='./output', help='Data path') + parser.add_argument('--pipeline', type=str, default='../data/config/MMoE.pipeline', help='SDK infer pipeline') + args_opt = parser.parse_args() + return args_opt + +args = parse_args() + +def send_source_data(appsrc_id, file_name, file_data, stream_name, stream_manager, shape, tp): + """ + Construct the input of the stream, + send inputs data to a specified stream based on streamName. + + Returns: + bool: send data success or not + """ + tensors = np.array(file_data, dtype=tp).reshape(shape) + tensor_package_list = MxpiDataType.MxpiTensorPackageList() + tensor_package = tensor_package_list.tensorPackageVec.add() + data_input = MxDataInput() + tensor_vec = tensor_package.tensorVec.add() + tensor_vec.deviceId = 0 + tensor_vec.memType = 0 + for i in tensors.shape: + tensor_vec.tensorShape.append(i) + array_bytes = tensors.tobytes() + data_input.data = array_bytes + tensor_vec.dataStr = data_input.data + tensor_vec.tensorDataSize = len(array_bytes) + key = "appsrc{}".format(appsrc_id).encode('utf-8') + protobuf_vec = InProtobufVector() + protobuf = MxProtobufIn() + protobuf.key = key + protobuf.type = b'MxTools.MxpiTensorPackageList' + protobuf.protobuf = tensor_package_list.SerializeToString() + protobuf_vec.push_back(protobuf) + ret = stream_manager.SendProtobuf(stream_name, appsrc_id, protobuf_vec) + if ret < 0: + print("Failed to send data to stream.") + return False + print("Send successfully!") + return True + +def send_appsrc_data(appsrc_id, file_name, file_data, stream_name, stream_manager, shape, tp): + """ + send three stream to infer model, include input ids, input mask and token type_id. + + Returns: + bool: send data success or not + """ + if not send_source_data(appsrc_id, file_name, file_data, stream_name, stream_manager, shape, tp): + return False + return True + +def post_process(infer_result): + """ + process the result of infer tensor to Visualization results. + Args: + infer_result: get logit from infer result + """ + result = MxpiDataType.MxpiTensorPackageList() + result.ParseFromString(infer_result[0].messageBuf) + income_pred = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype=np.float16) + income_pred = income_pred.reshape((-1, 2)) + married_pred = np.frombuffer(result.tensorPackageVec[0].tensorVec[1].dataStr, dtype=np.float16) + married_pred = married_pred.reshape((-1, 2)) + return income_pred, married_pred + +def get_auc(labels, preds): + labels = labels.flatten().tolist() + preds = preds.flatten().tolist() + return roc_auc_score(labels, preds) + +def run(): + """ + read pipeline and do infer + """ + # init stream manager + stream_manager_api = StreamManagerApi() + ret = stream_manager_api.InitManager() + if ret != 0: + print("Failed to init Stream manager, ret=%s" % str(ret)) + exit() + + # create streams by pipeline config file + with open(os.path.realpath(args.pipeline), 'rb') as f: + pipeline_str = f.read() + ret = stream_manager_api.CreateMultipleStreams(pipeline_str) + if ret != 0: + print("Failed to create Stream, ret=%s" % str(ret)) + exit() + + # prepare data + data = np.load(os.path.join(args.data_dir, args.data_file.format(args.mode))).astype(np.float16) + income = np.load(os.path.join(args.data_dir, args.income_file.format(args.mode))).astype(np.float16) + married = np.load(os.path.join(args.data_dir, args.married_file.format(args.mode))).astype(np.float16) + + if(data.shape[0] != income.shape[0] or income.shape[0] != married.shape[0]): + print("number of input data not completely equal") + exit() + rows = data.shape[0] + + # statistical variable + income_labels = [] + married_labels = [] + income_preds = [] + married_preds = [] + infer_total_time = 0 + + # write predict results + if not os.path.exists(args.output_dir): + os.mkdir(args.output_dir) + for i in range(rows): + # fetch data + data_batch = data[i] + income_batch = income[i] + married_batch = married[i] + + # data shape + data_shape = (-1, args.num_features) + + # data type + data_type = np.float16 + + # send data + stream_name = b'MMoE' + if not send_appsrc_data(0, 'data', data_batch, stream_name, stream_manager_api, data_shape, data_type): + return + + # Obtain the inference result by specifying streamName and uniqueId. + key_vec = StringVector() + key_vec.push_back(b'mxpi_tensorinfer0') + start_time = time.time() + infer_result = stream_manager_api.GetProtobuf(stream_name, 0, key_vec) + infer_total_time += time.time() - start_time + if infer_result.size() == 0: + print("inferResult is null") + return + if infer_result[0].errorCode != 0: + print("GetProtobuf error. errorCode=%d" % (infer_result[0].errorCode)) + return + + # updata variable + income_pred, married_pred = post_process(infer_result) + income_preds.extend(income_pred) + married_preds.extend(married_pred) + income_labels.extend(income_batch) + married_labels.extend(married_batch) + + income_preds = np.array(income_preds) + married_preds = np.array(married_preds) + income_labels = np.array(income_labels) + married_labels = np.array(married_labels) + np.save(os.path.join(args.output_dir, 'income_preds_{}.npy'.format(args.mode)), income_preds) + np.save(os.path.join(args.output_dir, 'married_preds_{}.npy'.format(args.mode)), married_preds) + np.save(os.path.join(args.output_dir, 'income_labels_{}.npy').format(args.mode), income_labels) + np.save(os.path.join(args.output_dir, 'married_labels_{}.npy'.format(args.mode)), married_labels) + income_auc = get_auc(income_labels, income_preds) + married_auc = get_auc(married_labels, married_preds) + print('<<======== Infer Metric ========>>') + print('Mode: {}'.format(args.mode)) + print('Number of samples: {}'.format(rows)) + print('Total inference time: {}'.format(infer_total_time)) + print('Average inference time: {}'.format(infer_total_time/rows)) + print('Income auc: {}'.format(income_auc)) + print('Married auc: {}'.format(married_auc)) + print('<<===============================>>') + stream_manager_api.DestroyAllStreams() + +if __name__ == '__main__': + run() diff --git a/research/recommend/mmoe/infer/sdk/prec/calc_metric.py b/research/recommend/mmoe/infer/sdk/prec/calc_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..2b98b5677daf8605ce6d56ae036d2f9ce7dc464b --- /dev/null +++ b/research/recommend/mmoe/infer/sdk/prec/calc_metric.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022. Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +sample script of mmoe calculating metric +""" + +import os +import argparse +import numpy as np +from sklearn.metrics import roc_auc_score + +def parse_args(): + """set and check parameters.""" + parser = argparse.ArgumentParser(description='calc metric') + parser.add_argument('--data_dir', type=str, default='../output') + parser.add_argument('--income_preds', type=str, default='income_preds_{}.npy') + parser.add_argument('--married_preds', type=str, default='married_preds_{}.npy') + parser.add_argument('--income_labels', type=str, default='income_labels_{}.npy') + parser.add_argument('--married_labels', type=str, default='married_labels_{}.npy') + parser.add_argument('--mode', type=str, default='eval') + parser.add_argument('--metric_file', type=str, default='./metric.txt') + args_opt = parser.parse_args() + return args_opt + +def get_auc(labels, preds): + return roc_auc_score(labels, preds) + +def run(): + """calc metric""" + args = parse_args() + income_preds = np.load(os.path.join(args.data_dir, args.income_preds.format(args.mode))) + income_preds = income_preds.flatten().tolist() + married_preds = np.load(os.path.join(args.data_dir, args.married_preds.format(args.mode))) + married_preds = married_preds.flatten().tolist() + + income_labels = np.load(os.path.join(args.data_dir, args.income_labels.format(args.mode))) + income_labels = income_labels.flatten().tolist() + married_labels = np.load(os.path.join(args.data_dir, args.married_labels.format(args.mode))) + married_labels = married_labels.flatten().tolist() + + income_auc = get_auc(income_labels, income_preds) + married_auc = get_auc(married_labels, married_preds) + print('<<======== Infer Metric ========>>') + print('Mode: {}'.format(args.mode)) + print('Income auc: {}'.format(income_auc)) + print('Married auc: {}'.format(married_auc)) + print('<<===============================>>') + fo = open(args.metric_file, "w") + fo.write('Mode: {}\n'.format(args.mode)) + fo.write('Income auc: {}\n'.format(income_auc)) + fo.write('Married auc: {}\n'.format(married_auc)) + fo.close() + +if __name__ == '__main__': + run() diff --git a/research/recommend/mmoe/infer/sdk/run.sh b/research/recommend/mmoe/infer/sdk/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..f19f7cf325ac99edc13025b08fc470ca213dd478 --- /dev/null +++ b/research/recommend/mmoe/infer/sdk/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Copyright (c) 2022. Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +# Simple log helper functions +info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } +warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } + +#export MX_SDK_HOME=/home/data/cz/app/mxManufacture +export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH} +export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner +export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins + +#to set PYTHONPATH, import the StreamManagerApi.py +export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python + +python3 main.py +exit 0 \ No newline at end of file diff --git a/research/recommend/mmoe/infer/utils/npy2txt.py b/research/recommend/mmoe/infer/utils/npy2txt.py new file mode 100644 index 0000000000000000000000000000000000000000..f69afd772e190eef58a3a3b14cf0e50289494199 --- /dev/null +++ b/research/recommend/mmoe/infer/utils/npy2txt.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022. Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import argparse +import numpy as np + +def parse_args(): + """set and check parameters.""" + parser = argparse.ArgumentParser(description="prepare txt") + parser.add_argument('--data_dir', type=str, default='../data/input/') + parser.add_argument('--data_file', type=str, default='data_{}.npy') + parser.add_argument('--income_file', type=str, default='income_labels_{}.npy') + parser.add_argument('--married_file', type=str, default='married_labels_{}.npy') + parser.add_argument('--mode', type=str, default='eval') + args_opt = parser.parse_args() + return args_opt + +def run(): + """prepare txt data""" + args = parse_args() + # load npy data + data = np.load(os.path.join(args.data_dir, args.data_file.format(args.mode))) + income = np.load(os.path.join(args.data_dir, args.income_file.format(args.mode))) + married = np.load(os.path.join(args.data_dir, args.married_file.format(args.mode))) + + np.savetxt(os.path.join(args.data_dir, args.data_file.split('.')[0].format(args.mode)+'.txt'), data, delimiter='\t') + np.savetxt(os.path.join(args.data_dir, args.income_file.split('.')[0].format(args.mode)+'.txt'), \ + income, delimiter='\t') + np.savetxt(os.path.join(args.data_dir, args.married_file.split('.')[0].format(args.mode)+'.txt'), \ + married, delimiter='\t') + +if __name__ == '__main__': + run() diff --git a/research/recommend/mmoe/modelarts/train_start.py b/research/recommend/mmoe/modelarts/train_start.py new file mode 100644 index 0000000000000000000000000000000000000000..60e0005519d72f309b5e37eff27a37fa245d8e05 --- /dev/null +++ b/research/recommend/mmoe/modelarts/train_start.py @@ -0,0 +1,165 @@ +# Copyright (c) 2022. Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +#################train MMoE example on census-income data######################## +python train.py +""" + +import os +import datetime +import numpy as np +import mindspore as ms +from mindspore import context, load_checkpoint, load_param_into_net, Tensor, export +from mindspore.communication.management import init +from mindspore.context import ParallelMode +from mindspore.nn.optim import Adam +from mindspore.train.model import Model +from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.common import set_seed + +from src.mmoe import TrainStepWrap +from src.model_utils.moxing_adapter import moxing_wrapper +from src.load_dataset import create_dataset +from src.mmoe import MMoE_Layer, MMoE +from src.model_utils.config import config +from src.mmoe import LossForMultiLabel, NetWithLossClass +from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +from src.get_lr import get_lr + +set_seed(1) + + +def get_latest_ckpt(): + """get latest ckpt""" + ckpt_path = config.ckpt_path + ckpt_files = [ckpt_file for ckpt_file in os.listdir(ckpt_path) if ckpt_file.endswith(".ckpt")] + if not ckpt_files: + return None + latest_ckpt_file = sorted(ckpt_files)[-1] + return latest_ckpt_file + + +def modelarts_process(): + pass + + +@moxing_wrapper(pre_process=modelarts_process) +def export_mmoe(): + """export MMoE""" + latest_ckpt_file = get_latest_ckpt() + if not latest_ckpt_file: + print("Not found ckpt file") + return + config.ckpt_file_path = os.path.join(config.ckpt_path, latest_ckpt_file) + config.file_name = os.path.join(config.ckpt_path, config.file_name) + net = MMoE(num_features=config.num_features, num_experts=config.num_experts, units=config.units) + param_dict = load_checkpoint(config.ckpt_file_path) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.zeros([1, 499]), ms.float16) + config.file_format = "AIR" + export(net, input_arr, file_name=config.file_name, file_format=config.file_format) + + +def modelarts_pre_process(): + pass + + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_train(): + """train function""" + print('device id:', get_device_id()) + print('device num:', get_device_num()) + print('rank id:', get_rank_id()) + print('job id:', get_job_id()) + + device_target = config.device_target + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + context.set_context(save_graphs=False) + if config.device_target == "GPU": + context.set_context(enable_graph_kernel=True) + context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul") + + device_num = get_device_num() + + if config.run_distribute: + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, + parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) + if device_target == "Ascend": + context.set_context(device_id=get_device_id()) + init() + elif device_target == "GPU": + init() + else: + context.set_context(device_id=get_device_id()) + print("init finished.") + + config.data_path = config.data_url + ds_train = create_dataset(config.data_path, config.batch_size, training=True, \ + target=config.device_target, run_distribute=config.run_distribute) + + if ds_train.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size.") + print("create dataset finished.") + + net = MMoE_Layer(input_size=config.num_features, num_experts=config.num_experts, units=config.units) + print("model created.") + loss = LossForMultiLabel() + loss_net = NetWithLossClass(net, loss) + + step_per_size = ds_train.get_dataset_size() + print("train dataset size:", step_per_size) + + if config.run_distribute: + learning_rate = get_lr(0.0005, config.epoch_size, step_per_size, step_per_size * 2) + else: + learning_rate = get_lr(0.001, config.epoch_size, step_per_size, step_per_size * 5) + opt = Adam(net.trainable_params(), + learning_rate=learning_rate, + beta1=0.9, + beta2=0.999, + eps=1e-7, + weight_decay=0.0, + loss_scale=1.0) + scale_update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 12, + scale_factor=2, + scale_window=1000) + train_net = TrainStepWrap(loss_net, opt, scale_update_cell) + train_net.set_train() + model = Model(train_net) + + time_cb = TimeMonitor() + loss_cb = LossMonitor(step_per_size) + config_ck = CheckpointConfig(save_checkpoint_steps=step_per_size, keep_checkpoint_max=100) + callbacks_list = [time_cb, loss_cb] + if get_rank_id() == 0: + config.ckpt_path = config.train_url + config.ckpt_path = os.path.join(config.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_%H_%M_%S')) + ckpoint_cb = ModelCheckpoint(prefix='MMoE_train', directory=config.ckpt_path, config=config_ck) + callbacks_list.append(ckpoint_cb) + + print("train start!") + model.train(epoch=config.epoch_size, + train_dataset=ds_train, + callbacks=callbacks_list, + dataset_sink_mode=config.dataset_sink_mode) + + +if __name__ == '__main__': + run_train() + export_mmoe() diff --git a/research/recommend/mmoe/scripts/docker_start.sh b/research/recommend/mmoe/scripts/docker_start.sh new file mode 100644 index 0000000000000000000000000000000000000000..e4dd74bee5489998a832481bc63d8b4889e441f7 --- /dev/null +++ b/research/recommend/mmoe/scripts/docker_start.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright (c) 2022. Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +docker_image=$1 +data_dir=$2 +model_dir=$3 + +docker run -it --ipc=host \ + --device=/dev/davinci0 \ + --device=/dev/davinci1 \ + --device=/dev/davinci2 \ + --device=/dev/davinci3 \ + --device=/dev/davinci4 \ + --device=/dev/davinci5 \ + --device=/dev/davinci6 \ + --device=/dev/davinci7 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + --privileged \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons \ + -v ${data_dir}:${data_dir} \ + -v ${model_dir}:${model_dir} \ + -v /root/ascend/log:/root/ascend/log ${docker_image} /bin/bash \ No newline at end of file