diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index f2541575ca2e812913c4340723a526cc6e536ced..ff04e56592ca61df411959de28ab0f9becc5636a 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -15,4 +15,8 @@ "models/research/cv/FaceQualityAssessment/infer/mxbase/src/FQA.cpp" "runtime/references" "models/research/cv/FaceQualityAssessment/infer/util/plugins/MxpiTransposePlugin.h" "runtime/references" "models/research/cv/FaceQualityAssessment/infer/util/plugins/MxpiTransposePlugin.h" "build/namespaces" -"models/research/cv/FaceQualityAssessment/infer/util/plugins/MxpiTransposePlugin.cpp" "runtime/references" \ No newline at end of file +"models/research/cv/FaceQualityAssessment/infer/util/plugins/MxpiTransposePlugin.cpp" "runtime/references" + +"models/research/cv/stgcn/infer/mxbase/src/stgcnUtil.h" "runtime/references" +"models/research/cv/stgcn/infer/mxbase/src/stgcnUtil.cpp" "runtime/references" +"models/research/cv/stgcn/infer/mxbase/src/main.cpp" "runtime/references" diff --git a/research/cv/stgcn/infer/Dockerfile b/research/cv/stgcn/infer/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..360861ede17fb0ab697fbcac190acde7c1e29fef --- /dev/null +++ b/research/cv/stgcn/infer/Dockerfile @@ -0,0 +1,5 @@ +ARG FROM_IMAGE_NAME +FROM ${FROM_IMAGE_NAME} + +COPY requirements.txt . +RUN pip3.7 install -r requirements.txt diff --git a/research/cv/stgcn/infer/convert/convert_om.sh b/research/cv/stgcn/infer/convert/convert_om.sh new file mode 100644 index 0000000000000000000000000000000000000000..bab1e987fca8e6cc78c064d85b0d79dee2fc7263 --- /dev/null +++ b/research/cv/stgcn/infer/convert/convert_om.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# -ne 2 ] +then + echo "Wrong parameter format." + echo "Usage:" + echo " bash $0 [INPUT_AIR_PATH] [OUTPUT_OM_PATH_NAME]" + echo "Example: " + echo " bash convert_om.sh xxx.air xx" + + exit 1 +fi + +input_air_path=$1 +output_om_path=$2 + +export install_path=/usr/local/Ascend/ + +export ASCEND_ATC_PATH=${install_path}/atc +export LD_LIBRARY_PATH=${install_path}/atc/lib64:$LD_LIBRARY_PATH +export PATH=/usr/local/python3.7.5/bin:${install_path}/atc/ccec_compiler/bin:${install_path}/atc/bin:$PATH +export PYTHONPATH=${install_path}/atc/python/site-packages:${install_path}/latest/atc/python/site-packages/auto_tune.egg/auto_tune:${install_path}/atc/python/site-packages/schedule_search.egg + +echo "Input AIR file path: ${input_air_path}" +echo "Output OM file path: ${output_om_path}" + +atc --framework=1 \ + --model="${input_air_path}" \ + --input_shape="actual_input_1:1,3,128,64" \ + --output="${output_om_path}" \ + --enable_small_channel=1 \ + --log=error \ + --soc_version=Ascend310 \ + --op_select_implmode=high_precision \ + --output_type=FP32 diff --git a/research/cv/stgcn/infer/docker_start_infer.sh b/research/cv/stgcn/infer/docker_start_infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..64cf90a2311bdfb21d68a4e90e08602670fdf632 --- /dev/null +++ b/research/cv/stgcn/infer/docker_start_infer.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +docker_image=$1 +data_dir=$2 + +function show_help() { + echo "Usage: docker_start.sh docker_image data_dir" +} + +function param_check() { + if [ -z "${docker_image}" ]; then + echo "please input docker_image" + show_help + exit 1 + fi + + if [ -z "${data_dir}" ]; then + echo "please input data_dir" + show_help + exit 1 + fi +} + +param_check + +docker run -it \ + --device=/dev/davinci0 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v ${data_dir}:${data_dir} \ + ${docker_image} \ + /bin/bash diff --git a/research/cv/stgcn/infer/mxbase/CMakeLists.txt b/research/cv/stgcn/infer/mxbase/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..abc356e8d8f199362ad2ce9084a9e6888ae5e196 --- /dev/null +++ b/research/cv/stgcn/infer/mxbase/CMakeLists.txt @@ -0,0 +1,56 @@ +cmake_minimum_required(VERSION 3.10.0) +project(stgcn) + +set(TARGET stgcn) + +add_definitions(-DENABLE_DVPP_INTERFACE) +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-Dgoogle=mindxsdk_private) +add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall) +add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie) + +# Check environment variable +if(NOT DEFINED ENV{ASCEND_HOME}) + message(FATAL_ERROR "please define environment variable:ASCEND_HOME") +endif() +if(NOT DEFINED ENV{ASCEND_VERSION}) + message(WARNING "please define environment variable:ASCEND_VERSION") +endif() +if(NOT DEFINED ENV{ARCH_PATTERN}) + message(WARNING "please define environment variable:ARCH_PATTERN") +endif() + +set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include) +set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64) + +set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME}) +set(MXBASE_INC ${MXBASE_ROOT_DIR}/include) +set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib) +set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors) +set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include) + +if(DEFINED ENV{MXSDK_OPENSOURCE_DIR}) + set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR}) +else() + set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource) +endif() + + +include_directories(${ACL_INC_DIR}) +include_directories(${OPENSOURCE_DIR}/include) +include_directories(${OPENSOURCE_DIR}/include/opencv4) + +include_directories(${MXBASE_INC}) +include_directories(${MXBASE_POST_PROCESS_DIR}) + +link_directories(${ACL_LIB_DIR}) +link_directories(${OPENSOURCE_DIR}/lib) +link_directories(${MXBASE_LIB_DIR}) +link_directories(${MXBASE_POST_LIB_DIR}) + + +add_executable(${TARGET} src/main.cpp src/stgcnUtil.cpp) + +target_link_libraries(${TARGET} glog cpprest mxbase resnet50postprocess opencv_world) + +install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/) diff --git a/research/cv/stgcn/infer/mxbase/build.sh b/research/cv/stgcn/infer/mxbase/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..bbfd526b59c709107440cc404c953778d1edbc05 --- /dev/null +++ b/research/cv/stgcn/infer/mxbase/build.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +path_cur=$(dirname $0) + +function check_env() +{ + # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user + if [ ! "${ASCEND_HOME}" ]; then + export ASCEND_HOME=/usr/local/Ascend/ + echo "Set ASCEND_HOME to the default value: ${ASCEND_HOME}" + else + echo "ASCEND_HOME is set to ${ASCEND_HOME} by user" + fi + + if [ ! "${ASCEND_VERSION}" ]; then + export ASCEND_VERSION=nnrt/latest + echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}" + else + echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user" + fi + + if [ ! "${ARCH_PATTERN}" ]; then + # set ARCH_PATTERN to ./ when it was not specified by user + export ARCH_PATTERN=./ + echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}" + else + echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user" + + fi + + + +} + +function build_stgcn() +{ + cd $path_cur + rm -rf build + mkdir -p build + cd build + cmake .. + make + ret=$? + if [ ${ret} -ne 0 ]; then + echo "Failed to build resnet50." + exit ${ret} + fi + make install +} + +check_env +build_stgcn diff --git a/research/cv/stgcn/infer/mxbase/src/main.cpp b/research/cv/stgcn/infer/mxbase/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c6851d21da2cee90561ab87725ca9ed5f9bc3b56 --- /dev/null +++ b/research/cv/stgcn/infer/mxbase/src/main.cpp @@ -0,0 +1,164 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <dirent.h> +#include <fstream> +#include <string> +#include <sstream> +#include <cstdlib> +#include <vector> +#include <cmath> +#include <iostream> +#include "stgcnUtil.h" +#include "MxBase/Log/Log.h" + +APP_ERROR ReadCsv(const std::string &path, std::vector<std::vector<float>> &dataset) { + std::ifstream fp(path); + std::string line; + while (std::getline(fp, line)) { + std::vector<float> data_line; + std::string number; + std::istringstream readstr(line); + + for (int j = 0; j < 228; j++) { + std::getline(readstr, number, ','); + data_line.push_back(atof(number.c_str())); + } + dataset.push_back(data_line); + } + return APP_ERR_OK; +} + +APP_ERROR transform(std::vector<std::vector<float>>& dataset, + const std::vector<float>& mean, const std::vector<float>& stdd) { + for (uint32_t i = 0; i < dataset.size(); ++i) { + for (uint32_t j = 0; j < dataset[0].size(); ++j) { + dataset[i][j] = (dataset[i][j]-mean[j])/sqrt(stdd[j]); + } + } + return APP_ERR_OK; +} + +APP_ERROR getMeanStd(std::vector<std::vector<float>> dataset, std::vector<float>& mean, std::vector<float>& stdd) { + for (uint32_t j = 0; j < dataset[0].size(); ++j) { + float m = 0.0; + float var = 0.0; + for (uint32_t i = 0; i < dataset.size(); ++i) { + m += dataset[i][j]; + } + m /= dataset.size(); + for (uint32_t i = 0; i < dataset.size(); ++i) { + var += (dataset[i][j] - m)*(dataset[i][j] - m); + } + var /= (dataset.size()); + + mean.emplace_back(m); + stdd.emplace_back(var); + } + return APP_ERR_OK; +} + +int main(int argc, char* argv[]) { + if (argc <= 2) { + LogWarn << "Please input dataset path and n_pred, such as './data/vel/csv 9'"; + return APP_ERR_OK; + } + + InitParam initParam = {}; + initParam.deviceId = 0; + initParam.checkTensor = true; + initParam.modelPath = "../data/models/stgcn.om"; + auto stgcn = std::make_shared<STGCN>(); + APP_ERROR ret = stgcn->Init(initParam); + if (ret != APP_ERR_OK) { + stgcn->DeInit(); + LogError << "stgcn init failed, ret=" << ret << "."; + return ret; + } + + std::string imgPath = argv[1]; + int n_pred = atoi(argv[2]); + + std::vector<std::vector<float>> dataset; + + ret = ReadCsv(imgPath, dataset); + if (ret != APP_ERR_OK) { + stgcn->DeInit(); + LogError << "read dataset failed, ret=" << ret << "."; + return ret; + } + + float val_and_test_rate = 0.15; + int data_row = dataset.size(); + int data_col = data_row > 0 ? dataset[0].size():0; + + if (data_col == 0) { + LogError << "stgcn process dataset failed, data_col=" << data_col << "."; + stgcn->DeInit(); + return -1; + } + + int len_val = static_cast<int>(floor(data_row * val_and_test_rate)); + int len_test = static_cast<int>(floor(data_row * val_and_test_rate)); + int len_train = static_cast<int>(data_row - len_val - len_test); + + std::vector<std::vector<float>> dataset_train; + std::vector<std::vector<float>> dataset_test; + + for (int i = 0; i < data_row; ++i) { + if (i < len_train) { + dataset_train.emplace_back(dataset[i]); + } else if (i >= (len_train + len_val)) { + dataset_test.emplace_back(dataset[i]); + } else { + continue; + } + } + + ret = getMeanStd(dataset_train, initParam.MEAN, initParam.STD); + if (ret != APP_ERR_OK) { + LogError << "get mean and std of train dataset failed, ret=" << ret << "."; + return ret; + } + + // Norlize test dataset + ret = transform(dataset_test, initParam.MEAN, initParam.STD); + if (ret != APP_ERR_OK) { + LogError << "transform test dataset failed, ret=" << ret << "."; + return ret; + } + + int n_his = 12; + int num = dataset_test.size() - n_his - n_pred; + + for (int i=0; i < num; ++i) { + std::vector<std::vector<float>> data; + for (int t = i; t < i + n_his; ++t) { + data.emplace_back(dataset_test[t]); + } + + ret = stgcn->Process(data, initParam); + if (ret != APP_ERR_OK) { + LogError << "stgcn process failed, ret=" << ret << "."; + stgcn->DeInit(); + return ret; + } + } + + stgcn->DeInit(); + return APP_ERR_OK; +} + diff --git a/research/cv/stgcn/infer/mxbase/src/stgcnUtil.cpp b/research/cv/stgcn/infer/mxbase/src/stgcnUtil.cpp new file mode 100644 index 0000000000000000000000000000000000000000..004a2e043af4bcc7a3830efb00bd3bbd2fc09579 --- /dev/null +++ b/research/cv/stgcn/infer/mxbase/src/stgcnUtil.cpp @@ -0,0 +1,219 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <unistd.h> +#include <sys/stat.h> +#include <memory> +#include <cmath> +#include <string> +#include <fstream> +#include <vector> +#include "stgcnUtil.h" +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/Log/Log.h" + +APP_ERROR STGCN::Init(const InitParam &initParam) { + deviceId_ = initParam.deviceId; + APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices(); + if (ret != APP_ERR_OK) { + LogError << "Init devices failed, ret=" << ret << "."; + return ret; + } + ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId); + if (ret != APP_ERR_OK) { + LogError << "Set context failed, ret=" << ret << "."; + return ret; + } + dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>(); + ret = dvppWrapper_->Init(); + if (ret != APP_ERR_OK) { + LogError << "DvppWrapper init failed, ret=" << ret << "."; + return ret; + } + model_ = std::make_shared<MxBase::ModelInferenceProcessor>(); + ret = model_->Init(initParam.modelPath, modelDesc_); + if (ret != APP_ERR_OK) { + LogError << "ModelInferenceProcessor init failed, ret=" << ret << "."; + return ret; + } + + return APP_ERR_OK; +} + +APP_ERROR STGCN::DeInit() { + dvppWrapper_->DeInit(); + model_->DeInit(); + MxBase::DeviceManager::GetInstance()->DestroyDevices(); + return APP_ERR_OK; +} + +APP_ERROR STGCN::VectorToTensorBase(const std::vector<std::vector<float>> &input_x, MxBase::TensorBase &tensorBase) { + uint32_t dataSize = 1; + for (size_t i = 0; i < modelDesc_.inputTensors.size(); ++i) { + for (size_t j = 0; j < modelDesc_.inputTensors[i].tensorDims.size(); ++j) { + dataSize *= (uint32_t)modelDesc_.inputTensors[i].tensorDims[j]; + } + } + + float *metaFeatureData = new float[dataSize]; + uint32_t idx = 0; + for (size_t bs = 0; bs < input_x.size(); bs++) { + for (size_t c = 0; c < input_x[bs].size(); c++) { + metaFeatureData[idx++] = input_x[bs][c]; + } + } + + MxBase::MemoryData memoryDataDst(dataSize * sizeof(float), MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + MxBase::MemoryData memoryDataSrc(reinterpret_cast<void *>(metaFeatureData), + dataSize * 4, MxBase::MemoryData::MEMORY_HOST_MALLOC); + + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + + std::vector<uint32_t> shape = {1, 1, 12, 228}; + tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32); + return APP_ERR_OK; +} + +APP_ERROR STGCN::Inference(const std::vector<MxBase::TensorBase> &inputs, + std::vector<MxBase::TensorBase> &outputs) { + auto dtypes = model_->GetOutputDataType(); + for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) { + std::vector<uint32_t> shape = {}; + for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) { + shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]); + } + MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_); + APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor); + if (ret != APP_ERR_OK) { + LogError << "TensorBaseMalloc failed, ret=" << ret << "."; + return ret; + } + outputs.push_back(tensor); + } + MxBase::DynamicInfo dynamicInfo = {}; + dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH; + auto startTime = std::chrono::high_resolution_clock::now(); + APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo); + auto endTime = std::chrono::high_resolution_clock::now(); + double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); + inferCostTimeMilliSec += costMs; + if (ret != APP_ERR_OK) { + LogError << "ModelInference failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +APP_ERROR STGCN::PostProcess(std::vector<float> &inputs, const InitParam &initParam) { + for (uint32_t i = 0; i < inputs.size(); ++i) { + inputs[i] = inputs[i] * sqrt(initParam.STD[i]) + initParam.MEAN[i]; + } + return APP_ERR_OK; +} + +APP_ERROR STGCN::SaveInferResult(std::vector<float> &batchFeaturePaths, const std::vector<MxBase::TensorBase> &inputs) { + LogInfo << "Infer results before postprocess:\n"; + for (auto retTensor : inputs) { + LogInfo << "Tensor description:\n" << retTensor.GetDesc(); + std::vector<uint32_t> shape = retTensor.GetShape(); + uint32_t N = shape[0]; + uint32_t C = shape[1]; + uint32_t H = shape[2]; + uint32_t W = shape[3]; + if (!retTensor.IsHost()) { + LogInfo << "this tensor is not in host. Now deploy it to host"; + retTensor.ToHost(); + } + void* data = retTensor.GetBuffer(); + + for (uint32_t i = 0; i < N; i++) { + for (uint32_t j = 0; j < C; j++) { + for (uint32_t k = 0; k < H; k++) { + for (uint32_t l = 0; l < W; l++) { + float value = *(reinterpret_cast<float*>(data) + i * C + j * H + k * W + l); + batchFeaturePaths.emplace_back(value); + } + } + } + } + } + return APP_ERR_OK; +} + +APP_ERROR STGCN::WriteResult(const std::vector<float> &outputs) { + std::string resultPathName = "./result.txt"; + std::ofstream outfile(resultPathName, std::ios::app); + if (outfile.fail()) { + LogError << "Failed to open result file: "; + return APP_ERR_COMM_FAILURE; + } + + std::string tmp; + for (auto x : outputs) { + tmp += std::to_string(x) + " "; + } + tmp = tmp.substr(0, tmp.size()-1); + outfile << tmp << std::endl; + outfile.close(); + return APP_ERR_OK; +} + +APP_ERROR STGCN::Process(const std::vector<std::vector<float>> &input_x, const InitParam &initParam) { + std::vector<MxBase::TensorBase> inputs = {}; + std::vector<MxBase::TensorBase> outputs = {}; + std::vector<float> batchFeaturePaths; + MxBase::TensorBase tensorBase; + auto ret = VectorToTensorBase(input_x, tensorBase); + if (ret != APP_ERR_OK) { + LogError << "ToTensorBase failed, ret=" << ret << "."; + return ret; + } + + inputs.push_back(tensorBase); + + ret = Inference(inputs, outputs); + + + + + if (ret != APP_ERR_OK) { + LogError << "Inference failed, ret=" << ret << "."; + return ret; + } + + ret = SaveInferResult(batchFeaturePaths, outputs); + if (ret != APP_ERR_OK) { + LogError << "Save model infer results into file failed. ret = " << ret << "."; + return ret; + } + + ret = PostProcess(batchFeaturePaths, initParam); + if (ret != APP_ERR_OK) { + LogError << "PostProcess failed, ret=" << ret << "."; + return ret; + } + + ret = WriteResult(batchFeaturePaths); + if (ret != APP_ERR_OK) { + LogError << "WriteResult failed, ret=" << ret << "."; + return ret; + } + return APP_ERR_OK; +} diff --git a/research/cv/stgcn/infer/mxbase/src/stgcnUtil.h b/research/cv/stgcn/infer/mxbase/src/stgcnUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..77398a615e749acb139ada18a60c0e901f3fab60 --- /dev/null +++ b/research/cv/stgcn/infer/mxbase/src/stgcnUtil.h @@ -0,0 +1,52 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MxBase_STGCN_H +#define MxBase_STGCN_H +#include <string> +#include <vector> +#include <memory> +#include "MxBase/DvppWrapper/DvppWrapper.h" +#include "MxBase/ModelInfer/ModelInferenceProcessor.h" +#include "MxBase/Tensor/TensorContext/TensorContext.h" + +struct InitParam { + uint32_t deviceId; + bool checkTensor; + std::string modelPath; + std::vector<float> MEAN; + std::vector<float> STD; +}; + +class STGCN { + public: + APP_ERROR Init(const InitParam &initParam); + APP_ERROR DeInit(); + APP_ERROR VectorToTensorBase(const std::vector<std::vector<float>> &input_x, MxBase::TensorBase &tensorBase); + APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs); + APP_ERROR PostProcess(std::vector<float> &inputs, const InitParam &initParam); + APP_ERROR Process(const std::vector<std::vector<float>> &input_x, const InitParam &initParam); + APP_ERROR SaveInferResult(std::vector<float> &batchFeaturePaths, const std::vector<MxBase::TensorBase> &inputs); + APP_ERROR WriteResult(const std::vector<float> &outputs); + double GetInferCostMilliSec() const {return inferCostTimeMilliSec;} + private: + std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_; + std::shared_ptr<MxBase::ModelInferenceProcessor> model_; + MxBase::ModelDesc modelDesc_; + uint32_t deviceId_ = 0; + double inferCostTimeMilliSec = 0.0; +}; +#endif diff --git a/research/cv/stgcn/infer/requirements.txt b/research/cv/stgcn/infer/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7f15d3dbfe87dd4f9479e11090394850ad9b6767 --- /dev/null +++ b/research/cv/stgcn/infer/requirements.txt @@ -0,0 +1,3 @@ +numpy +pandas +sklearn \ No newline at end of file diff --git a/research/cv/stgcn/infer/sdk/main.py b/research/cv/stgcn/infer/sdk/main.py new file mode 100644 index 0000000000000000000000000000000000000000..a1941bcd438a6e45b719b824834f3e836c14365a --- /dev/null +++ b/research/cv/stgcn/infer/sdk/main.py @@ -0,0 +1,155 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""run sdk""" +import sys +import math +import datetime +import pandas as pd +import numpy as np +from sklearn import preprocessing + +import MxpiDataType_pb2 as MxpiDataType +from StreamManagerApi import StreamManagerApi, InProtobufVector, MxProtobufIn, StringVector + + +def send_source_data(appsrc_id, tensor, stream_name, stream_manager): + """ + Construct the input of the stream, + send inputs data to a specified stream based on streamName. + """ + tensor_package_list = MxpiDataType.MxpiTensorPackageList() + tensor_package = tensor_package_list.tensorPackageVec.add() + array_bytes = tensor.tobytes() + tensor_vec = tensor_package.tensorVec.add() + tensor_vec.deviceId = 0 + tensor_vec.memType = 0 + for i in tensor.shape: + tensor_vec.tensorShape.append(i) + tensor_vec.dataStr = array_bytes + tensor_vec.tensorDataSize = len(array_bytes) + key = "appsrc{}".format(appsrc_id).encode('utf-8') + protobuf_vec = InProtobufVector() + protobuf = MxProtobufIn() + protobuf.key = key + protobuf.type = b'MxTools.MxpiTensorPackageList' + protobuf.protobuf = tensor_package_list.SerializeToString() + protobuf_vec.push_back(protobuf) + + ret = stream_manager.SendProtobuf(stream_name, appsrc_id, protobuf_vec) + return ret + +def run(): + """ + read pipeline and do infer + """ + if len(sys.argv) == 4: + dir_name = sys.argv[1] + res_dir_name = sys.argv[2] + n_pred = int(sys.argv[3]) + else: + print("Please enter Dataset path| Inference result path " + "such as ../data ./result 9") + exit(1) + stream_manager_api = StreamManagerApi() + ret = stream_manager_api.InitManager() + if ret != 0: + print("Failed to init Stream manager, ret=%s" % str(ret)) + return + + # create streams by pipeline config file + with open("./pipeline/stgcn.pipeline", 'rb') as f: + pipeline_str = f.read() + ret = stream_manager_api.CreateMultipleStreams(pipeline_str) + + if ret != 0: + print("Failed to create Stream, ret=%s" % str(ret)) + return + + # Construct the input of the stream + + n_his = 12 + zscore = preprocessing.StandardScaler() + + df = pd.read_csv(dir_name, header=None) + data_col = df.shape[0] + val_and_test_rate = 0.15 + + len_val = int(math.floor(data_col * val_and_test_rate)) + len_test = int(math.floor(data_col * val_and_test_rate)) + len_train = int(data_col - len_val - len_test) + + dataset = df[len_train + len_val:] + + zscore.fit(df[: len_train]) + dataset = zscore.transform(dataset) + + n_vertex = dataset.shape[1] + len_record = len(dataset) + num = len_record - n_his - n_pred + + x = np.zeros([num, 1, n_his, n_vertex], np.float32) + y = np.zeros([num, n_vertex], np.float32) + + for i in range(num): + head = i + tail = i + n_his + x[i, :, :, :] = dataset[head: tail].reshape(1, n_his, n_vertex) + y[i] = dataset[tail + n_pred - 1] + + labels = [] + predcitions = [] + stream_name = b'im_stgcn' + #start infer + for i in range(num): + inPluginId = 0 + tensor = np.expand_dims(x[i], axis=0) + uniqueId = send_source_data(0, tensor, stream_name, stream_manager_api) + if uniqueId < 0: + print("Failed to send data to stream.") + return + + # Obtain the inference result by specifying stream_name and uniqueId. + start_time = datetime.datetime.now() + + keyVec = StringVector() + keyVec.push_back(b'mxpi_tensorinfer0') + infer_result = stream_manager_api.GetProtobuf(stream_name, inPluginId, keyVec) + + end_time = datetime.datetime.now() + print('sdk run time: {}'.format((end_time - start_time).microseconds)) + + if infer_result.size() == 0: + print("inferResult is null") + return + if infer_result[0].errorCode != 0: + print("GetProtobuf error. errorCode=%d" % (infer_result[0].errorCode)) + return + # get infer result + result = MxpiDataType.MxpiTensorPackageList() + result.ParseFromString(infer_result[0].messageBuf) + # convert the inference result to Numpy array + res = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype=np.float32) + + labels.append(zscore.inverse_transform(np.expand_dims(y[i], axis=0)).reshape(-1)) + predcitions.append(zscore.inverse_transform(np.expand_dims(res, axis=0)).reshape(-1)) + + np.savetxt(res_dir_name+'labels.txt', np.array(labels)) + np.savetxt(res_dir_name+'predcitions.txt', np.array(predcitions)) + + # destroy streams + stream_manager_api.DestroyAllStreams() + +if __name__ == '__main__': + run() diff --git a/research/cv/stgcn/infer/sdk/pipeline/stgcn.pipeline b/research/cv/stgcn/infer/sdk/pipeline/stgcn.pipeline new file mode 100644 index 0000000000000000000000000000000000000000..e0712792989ad7cf47411105e5b882a3a1cb591a --- /dev/null +++ b/research/cv/stgcn/infer/sdk/pipeline/stgcn.pipeline @@ -0,0 +1,34 @@ +{ + "im_stgcn": { + "stream_config": { + "deviceId": "0" + }, + "appsrc0": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_tensorinfer0" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource": "appsrc0", + "modelPath": "../data/models/stgcn.om", + "outputDeviceId": "-1" + }, + "factory": "mxpi_tensorinfer", + "next": "mxpi_dataserialize0" + }, + "mxpi_dataserialize0": { + "props": { + "outputDataKeys": "mxpi_tensorinfer0" + }, + "factory": "mxpi_dataserialize", + "next": "appsink0" + }, + "appsink0": { + "factory": "appsink" + } + } +} + diff --git a/research/cv/stgcn/infer/sdk/run_sdk_infer.sh b/research/cv/stgcn/infer/sdk/run_sdk_infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..31180ddd2d8f17418f05497e02e7e50bfabdaf74 --- /dev/null +++ b/research/cv/stgcn/infer/sdk/run_sdk_infer.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +image_path=$1 +result_dir=$2 +n_pred=$3 + +set -e + + + +# Simple log helper functions +info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } +warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } + +export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/driver/lib64/:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH} +export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner +export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins + +#to set PYTHONPATH, import the StreamManagerApi.py +export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python + +python3.7 main.py $image_path $result_dir $n_pred +exit 0 diff --git a/research/cv/stgcn/infer/sdk/stgcn_metric.py b/research/cv/stgcn/infer/sdk/stgcn_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d4197c11b8a510ec1096512095573448b0a1a847 --- /dev/null +++ b/research/cv/stgcn/infer/sdk/stgcn_metric.py @@ -0,0 +1,52 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""compute accuracy""" +import os +import sys +import numpy as np + +def run(): + """compute acc""" + if len(sys.argv) == 3: + # the path to store the results path + result_file = sys.argv[1] + # the path to store the label path + label_file = sys.argv[2] + else: + print("Please enter target file result folder | ground truth label file | result json file folder | " + "result json file name, such as ./result val_label.txt . result.json") + exit(1) + if not os.path.exists(result_file): + print("Target file folder does not exist.") + + if not os.path.exists(label_file): + print("Label file does not exist.") + + predcitions = np.loadtxt(result_file) + labels = np.loadtxt(label_file) + mae, mape, mse = [], [], [] + for predcition, label in zip(predcitions, labels): + d = np.abs(predcition - label) + mae += d.tolist() + mape += (d / label).tolist() + mse += (d ** 2).tolist() + + MAE = np.array(mae).mean() + MAPE = np.array(mape).mean() + RMSE = np.sqrt(np.array(mse).mean()) + print(f'MAE {MAE:.2f} | MAPE {MAPE*100:.2f} | RMSE {RMSE:.2f}') + +if __name__ == '__main__': + run() diff --git a/research/cv/stgcn/modelarts/start_train.py b/research/cv/stgcn/modelarts/start_train.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed4aa7d570f11e9b58c2ef91da00237f158d19b --- /dev/null +++ b/research/cv/stgcn/modelarts/start_train.py @@ -0,0 +1,227 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +train network. +""" + +import os +import glob +import argparse +import pandas as pd +import numpy as np +import moxing as mox +from sklearn import preprocessing + +from mindspore.common import dtype as mstype +import mindspore.nn as nn + +from mindspore import Tensor +from mindspore.common import set_seed +from mindspore.communication.management import init +from mindspore.train.model import Model, ParallelMode +from mindspore import context, load_checkpoint, load_param_into_net, export +from mindspore.train.callback import CheckpointConfig, LossMonitor, ModelCheckpoint, TimeMonitor + +from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg,\ + stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg +from src import dataloader, utility +from src.model import models, metric + +set_seed(1) + +def export_stgcn(config, vertex, checkpoint_path, s_prefix, file_name, file_format): + """ export_stgcn """ + # load checkpoint + net_export = models.STGCN_Conv(config.Kt, config.Ks, blocks, config.n_his, vertex, \ + config.gated_act_func, config.graph_conv_type, conv_matrix, config.drop_rate) + prob_ckpt_list = os.path.join(checkpoint_path, "{}*.ckpt".format(s_prefix)) + ckpt_list = glob.glob(prob_ckpt_list) + if not ckpt_list: + print('Freezing model failed!') + print("can not find ckpt files. ") + else: + ckpt_list.sort(key=os.path.getmtime) + ckpt_name = ckpt_list[-1] + print("checkpoint file name", ckpt_name) + param_dict = load_checkpoint(ckpt_name) + load_param_into_net(net_export, param_dict) + + input_x = Tensor(np.zeros([1, 1, 12, 228]), mstype.float32) + + if not os.path.exists(checkpoint_path): + os.makedirs(checkpoint_path, exist_ok=True) + file_name = os.path.join(checkpoint_path, file_name) + export(net_export, input_x, file_name=file_name, file_format=file_format) + print('Freezing model success!') + return 0 + +parser = argparse.ArgumentParser('mindspore stgcn training') +# The way of training +parser.add_argument('--device_target', type=str, default='Ascend', \ + help='device where the code will be implemented. (Default: Ascend)') +parser.add_argument('--save_check_point', type=bool, default=True, help='Whether save checkpoint') + +# Parameter +parser.add_argument('--epochs', type=int, default=2, help='Whether save checkpoint') +parser.add_argument('--batch_size', type=int, default=8, help='Whether save checkpoint') + +# Path for data and checkpoint +parser.add_argument('--data_url', type=str, required=True, help='Train dataset directory.') +parser.add_argument('--train_url', type=str, required=True, help='Save checkpoint directory.') +parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.') +parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.') + +# Super parameters for training +parser.add_argument('--n_pred', type=int, default=3, help='The number of time interval for predcition, default as 3') +parser.add_argument('--opt', type=str, default='AdamW', help='optimizer, default as AdamW') + +#network +parser.add_argument('--graph_conv_type', type=str, default="gcnconv", help='Grapg convolution type') + +parser.add_argument("--file_name", type=str, default="stgcn", help="output file name.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") + +args, _ = parser.parse_known_args() + +if args.graph_conv_type == "chebconv": + if args.n_pred == 9: + cfg = stgcn_chebconv_45min_cfg + elif args.n_pred == 6: + cfg = stgcn_chebconv_30min_cfg + elif args.n_pred == 3: + cfg = stgcn_chebconv_15min_cfg + else: + raise ValueError("Unsupported n_pred.") +elif args.graph_conv_type == "gcnconv": + if args.n_pred == 9: + cfg = stgcn_gcnconv_45min_cfg + elif args.n_pred == 6: + cfg = stgcn_gcnconv_30min_cfg + elif args.n_pred == 3: + cfg = stgcn_gcnconv_15min_cfg + else: + raise ValueError("Unsupported pred.") +else: + raise ValueError("Unsupported graph_conv_type.") + +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False) + +if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0): + raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.') + +Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num + +if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"): + raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.') + +if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2): + cfg.Ks = 2 + +# blocks: settings of channel size in st_conv_blocks and output layer, +# using the bottleneck design in st_conv_blocks +blocks = [] +blocks.append([1]) +for l in range(cfg.stblock_num): + blocks.append([64, 16, 64]) +if Ko == 0: + blocks.append([128]) +elif Ko > 0: + blocks.append([128, 128]) +blocks.append([1]) + +day_slot = int(24 * 60 / cfg.time_intvl) + +time_pred = cfg.n_pred * cfg.time_intvl +time_pred_str = str(time_pred) + '_mins' + +device_id = int(os.getenv('DEVICE_ID')) +device_num = int(os.getenv('RANK_SIZE')) + +context.set_context(device_id=device_id) +local_data_url = '/cache/data' +local_train_url = '/cache/train' +mox.file.copy_parallel(args.data_url, local_data_url) +if device_num > 1: + init() + #context.set_auto_parallel_context(parameter_broadcast=True) + context.set_auto_parallel_context(device_num=device_num, \ + parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) +data_dir = local_data_url + '/' + +adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path) + +n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1] +n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1] +if n_vertex_vel == n_vertex_adj: + n_vertex = n_vertex_vel +else: + raise ValueError(f"ERROR: number of vertices in dataset is not equal to number \ + of vertices in weighted adjacency matrix.") + +mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type) +conv_matrix = Tensor(Tensor.from_numpy(mat), mstype.float32) +if cfg.graph_conv_type == "chebconv": + if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"): + raise ValueError(f'ERROR: {cfg.mat_type} is wrong.') +elif cfg.graph_conv_type == "gcnconv": + if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"): + raise ValueError(f'ERROR: {cfg.mat_type} is wrong.') + +stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \ + cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate) +net = stgcn_conv + +if __name__ == "__main__": + #start training + + zscore = preprocessing.StandardScaler() + dataset = dataloader.create_dataset(data_dir+args.data_path, args.batch_size, cfg.n_his, \ + cfg.n_pred, zscore, False, device_num, device_id, mode=0) + data_len = dataset.get_dataset_size() + + learning_rate = nn.exponential_decay_lr(learning_rate=cfg.learning_rate, decay_rate=cfg.gamma, \ + total_step=data_len*args.epochs, step_per_epoch=data_len, decay_epoch=cfg.decay_epoch) + if args.opt == "RMSProp": + optimizer = nn.RMSProp(net.trainable_params(), learning_rate=learning_rate) + elif args.opt == "Adam": + optimizer = nn.Adam(net.trainable_params(), learning_rate=learning_rate, \ + weight_decay=cfg.weight_decay_rate) + elif args.opt == "AdamW": + optimizer = nn.AdamWeightDecay(net.trainable_params(), learning_rate=learning_rate, \ + weight_decay=cfg.weight_decay_rate) + else: + raise ValueError(f'ERROR: optimizer {args.opt} is undefined.') + + loss_cb = LossMonitor() + time_cb = TimeMonitor(data_size=data_len) + callbacks = [time_cb, loss_cb] + prefix = "" + #save training results + if args.save_check_point and (device_num == 1 or device_id == 0): + config_ck = CheckpointConfig( + save_checkpoint_steps=data_len*args.epochs, keep_checkpoint_max=args.epochs) + prefix = 'STGCN' + cfg.graph_conv_type + str(cfg.n_pred) + '-' + ckpoint_cb = ModelCheckpoint(prefix=prefix, directory=local_train_url, config=config_ck) + callbacks += [ckpoint_cb] + + network = metric.LossCellWithNetwork(net) + model = Model(network, optimizer=optimizer, amp_level='O3') + + model.train(args.epochs, dataset, callbacks=callbacks) + print("train success") + export_stgcn(cfg, n_vertex, local_train_url, prefix, args.file_name, args.file_format) + # export + + mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url) diff --git a/research/cv/stgcn/scripts/docker_start.sh b/research/cv/stgcn/scripts/docker_start.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0fe5720341c4b29264a5a0c6a9f97fbf54d1211 --- /dev/null +++ b/research/cv/stgcn/scripts/docker_start.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.mitations under the License. + +docker_image=$1 +data_dir=$2 +model_dir=$3 + +docker run -it --ipc=host \ + --device=/dev/davinci0 \ + --device=/dev/davinci1 \ + --device=/dev/davinci2 \ + --device=/dev/davinci3 \ + --device=/dev/davinci4 \ + --device=/dev/davinci5 \ + --device=/dev/davinci6 \ + --device=/dev/davinci7 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons/ \ + -v ${model_dir}:${model_dir} \ + -v ${data_dir}:${data_dir} \ + -v /var/log/npu/conf/slog/slog.conf:/var/log/npu/conf/slog/slog.conf \ + -v /var/log/npu/slog/:/var/log/npu/slog -v /var/log/npu/profiling/:/var/log/npu/profiling \ + -v /var/log/npu/dump/:/var/log/npu/dump -v /var/log/npu/:/usr/slog ${docker_image} \ + /bin/bash \ No newline at end of file