Skip to content
Snippets Groups Projects
Unverified Commit 598eee4b authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2452 [浙江大学][高校贡献][Mindspore][TGCN]-高性能预训练模型提交+功能

Merge pull request !2452 from windhxs/tgcn
parents ae6c158f 979ff05f
No related branches found
No related tags found
No related merge requests found
Showing with 1131 additions and 0 deletions
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
model_path=$1
output_path=$2
atc --model=$model_path \
--framework=1 \
--output=$output_path \
--log=error \
--soc_version=Ascend310
\ No newline at end of file
{
"im_tgcn": {
"stream_config": {
"deviceId": "0"
},
"appsrc0": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "appsrc0",
"modelPath": "../data/models/tgcn_sztaxi.om",
"outputDeviceId": "-1"
},
"factory": "mxpi_tensorinfer",
"next": "mxpi_dataserialize0"
},
"mxpi_dataserialize0": {
"props": {
"outputDataKeys": "mxpi_tensorinfer0"
},
"factory": "mxpi_dataserialize",
"next": "appsink0"
},
"appsink0": {
"factory": "appsink"
}
}
}
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
share_dir=$2
data_dir=$3
echo "$1"
echo "$2"
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
exit 1
fi
if [ ! -d "${share_dir}" ]; then
echo "please input share directory that contains dataset, models and codes"
exit 1
fi
docker run -it \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
--privileged \
-v //usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${data_dir}:${data_dir} \
-v ${share_dir}:${share_dir} \
-u root \
${docker_image} \
/bin/bash
cmake_minimum_required(VERSION 3.10.0)
project(tgcn)
set(TARGET tgcn)
add_definitions(-DENABLE_DVPP_INTERFACE)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-Dgoogle=mindxsdk_private)
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie)
# Check environment variable
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_VERSION})
message(WARNING "please define environment variable:ASCEND_VERSION")
endif()
if(NOT DEFINED ENV{ARCH_PATTERN})
message(WARNING "please define environment variable:ARCH_PATTERN")
endif()
set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
if(DEFINED ENV{MXSDK_OPENSOURCE_DIR})
set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
else()
set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource)
endif()
include_directories(${ACL_INC_DIR})
include_directories(${OPENSOURCE_DIR}/include)
include_directories(${OPENSOURCE_DIR}/include/opencv4)
include_directories(${MXBASE_INC})
include_directories(${MXBASE_POST_PROCESS_DIR})
link_directories(${ACL_LIB_DIR})
link_directories(${OPENSOURCE_DIR}/lib)
link_directories(${MXBASE_LIB_DIR})
link_directories(${MXBASE_POST_LIB_DIR})
add_executable(${TARGET} src/main.cpp src/tgcn.cpp)
target_link_libraries(${TARGET} glog cpprest mxbase opencv_world)
install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
path_cur=$(dirname $0)
function check_env()
{
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_HOME}" ]; then
export ASCEND_HOME=/usr/local/Ascend/
echo "Set ASCEND_HOME to the default value: ${ASCEND_HOME}"
else
echo "ASCEND_HOME is set to ${ASCEND_HOME} by user"
fi
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to ./ when it was not specified by user
export ARCH_PATTERN=./
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
fi
}
function build_tgcn()
{
cd $path_cur
rm -rf build
mkdir -p build
cd build
cmake ..
make
ret=$?
if [ ${ret} -ne 0 ]; then
echo "Failed to build tgcn."
exit ${ret}
fi
make install
}
check_env
build_tgcn
/*
* Copyright 2022 Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <vector>
#include <memory>
#ifndef MxBase_TGCN_H
#define MxBase_TGCN_H
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
struct InitParam {
uint32_t deviceId;
bool checkTensor;
std::string modelPath;
std::string dataset;
};
class TGCN {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR VectorToTensorBase(const std::string &dataset, const std::vector<std::vector<float>> &input_x,
MxBase::TensorBase *tensorBase);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> *outputs);
APP_ERROR Process(const std::string &dataset, const std::vector<std::vector<float>> &input_x,
const InitParam &initParam, std::vector<float> *output);
APP_ERROR SaveInferResult(std::vector<float> *batchFeaturePaths,
const std::vector<MxBase::TensorBase> &inputs);
private:
// std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_;
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
double inferCostTimeMilliSec = 0.0;
};
#endif
/*
* Copyright 2022 Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <dirent.h>
#include <math.h>
#include <fstream>
#include <string>
#include <sstream>
#include <cstdlib>
#include <vector>
#include <cmath>
#include <cstdio>
#include "Tgcn.h"
#include "MxBase/Log/Log.h"
APP_ERROR ReadAdj(const std::string &dataset, const std::string &adj_path, std::vector<std::vector<float>> *adj) {
std::ifstream fp(adj_path);
std::string line;
int num = dataset == "SZ-taxi" ? 156 : 207;
while (std::getline(fp, line)) {
std::vector<float> data_adj;
std::string number;
std::istringstream readstr(line);
for (int j = 0; j < num; j++) {
std::getline(readstr, number, ' ');
data_adj.emplace_back(atof(number.c_str()));
}
adj->emplace_back(data_adj);
}
return APP_ERR_OK;
}
APP_ERROR ReadFeat(const std::string &dataset, const std::string &feat_path, const int seq_len, const int pre_len,
float &max_val, std::vector<std::vector<float>> *feat_input,
std::vector<std::vector<float>> *feat_target) {
std::ifstream fp(feat_path);
std::string line;
std::vector<std::vector<float>> feat;
int num = dataset == "SZ-taxi" ? 156 : 207;
int flag = 0;
while (std::getline(fp, line)) {
std::vector<float> data_feat;
std::string number;
std::istringstream readstr(line);
for (int j = 0; j < num; j++) {
std::getline(readstr, number, ',');
float tmp = atof(number.c_str());
max_val = std::max(max_val, tmp);
data_feat.emplace_back(tmp);
}
if (!flag) {
max_val = -1e9;
flag = 1;
continue;
}
feat.emplace_back(data_feat);
}
size_t time_len = feat.size();
for (size_t i = 0 ; i < time_len - seq_len - pre_len; i++) {
for (size_t j = i ; j < i + seq_len ; j++)
feat_input->emplace_back(feat[j]);
for (size_t j = i + seq_len ; j < i + seq_len + pre_len; j++)
feat_target->emplace_back(feat[j]);
}
return APP_ERR_OK;
}
float Rmse(const std::vector<float> &output, const std::vector<float> &target) {
float res = 0;
size_t len = output.size();
for (size_t i = 0; i < len; i++) {
res += (output[i] - target[i]) * (output[i] - target[i]);
}
res = res / len;
return sqrt(res);
}
float Mae(const std::vector<float> &output, const std::vector<float> &target) {
float res = 0;
size_t len = output.size();
for (size_t i = 0; i < len; i++) {
res += abs(output[i] - target[i]);
}
return res /= len;
}
float Acc(const std::vector<float> &output, const std::vector<float> &target) {
float diff_norm = 0, targe_norm = 0;
size_t len = output.size();
for (size_t i = 0; i < len; i++) {
diff_norm += (output[i] - target[i]) * (output[i] - target[i]);
targe_norm += target[i] * target[i];
}
diff_norm = sqrt(diff_norm);
targe_norm = sqrt(targe_norm);
return 1 - diff_norm / targe_norm;
}
float R2(const std::vector<float> &output, const std::vector<float> &target) {
float output_mean = 0, rsum = 0, rsumt = 0;
size_t len = output.size();
for (size_t i = 0; i < len; i++) {
output_mean += output[i];
rsum += (output[i] - target[i]) * (output[i] - target[i]);
}
output_mean /= len;
for (size_t i = 0; i < len; i++)
rsumt += (output_mean - target[i]) * (output_mean - target[i]);
return 1 - rsum / rsumt;
}
float Var(const std::vector<float> &output, const std::vector<float> &target) {
float diff_var = 0, target_var = 0, diff_mean = 0, target_mean = 0;
std::vector<float> diff;
size_t len = output.size();
for (size_t i = 0; i < len; i++) {
diff.emplace_back(target[i] - output[i]);
diff_mean += target[i] - output[i];
target_mean += target[i];
}
for (size_t i = 0; i < len; i++) {
diff_var += (diff[i] - diff_mean) * (diff[i] - diff_mean);
target_var += (target[i] - target_mean) * (target[i] - target_mean);
}
diff_var /= len;
target_var /= len;
return 1 - diff_var / target_var;
}
int main(int argc, char* argv[]) {
InitParam initParam = {};
initParam.deviceId = 0;
initParam.checkTensor = true;
std::string dataset = argv[1];
initParam.dataset = dataset;
std::string adj_path, feat_path;
int seq_len, pre_len;
if (dataset == "SZ-taxi") {
initParam.modelPath = "../data/models/tgcn_sztaxi.om";
adj_path = "../data/input/SZ-taxi/adj.csv";
feat_path = "../data/input/SZ-taxi/feature.csv";
seq_len = 4;
pre_len = 1;
}
auto tgcn = std::make_shared<TGCN>();
printf("Start running\n");
APP_ERROR ret = tgcn->Init(initParam);
if (ret != APP_ERR_OK) {
tgcn->DeInit();
LogError << "tgcn init failed, ret=" << ret << ".";
return ret;
}
float max_val = -1e9;
float rmse = 0, mae = 0, acc = 0, r2 = 0, var = 0;
std::vector<std::vector<float>> adj_data, feat_input, feat_target;
ret = ReadAdj(dataset, adj_path, &adj_data);
if (ret != APP_ERR_OK) {
tgcn->DeInit();
LogError << "read ajd failed, ret=" << ret << ".";
return ret;
}
ret = ReadFeat(dataset, feat_path, seq_len, pre_len, max_val, &feat_input, &feat_target);
if (ret != APP_ERR_OK) {
tgcn->DeInit();
LogError << "read feat failed, ret=" << ret << ".";
return ret;
}
LogInfo << "Read data done.";
// 输入归一化
int feat_input_len = feat_input.size();
int feat_input_content_len = feat_input[0].size();
int feat_targe_len = feat_target.size();
int feat_target_content_len = feat_target[0].size();
for (int i = 0 ; i < feat_input_len ; i++)
for (int j = 0 ; j < feat_input_content_len ; j++)
feat_input[i][j] /= max_val;
for (int i = 0 ; i < feat_targe_len ; i++)
for (int j = 0 ; j < feat_target_content_len ; j++)
feat_target[i][j] /= max_val;
const int data_num = feat_input.size() / seq_len;
std::vector<std::vector<float>> tot_output;
int tot_num = 0;
for (int i = data_num * 0.8 + 4; i < data_num ; i++) {
tot_num++;
std::vector<std::vector<float>> data;
std::vector<float> target;
std::vector<float> output;
for (int j = 0 ; j < seq_len ; j++)
data.push_back(feat_input[i * seq_len + j]);
target.insert(target.end(), feat_target[i].begin(), feat_target[i].end());
ret = tgcn->Process(dataset, data, initParam, &output);
if (ret !=APP_ERR_OK) {
LogError << "tgcn process failed, ret=" << ret << ".";
tgcn->DeInit();
return ret;
}
tot_output.push_back(output);
rmse += (Rmse(output, target) * max_val);
mae += (Mae(output, target) * max_val);
acc += Acc(output, target);
r2 += R2(output, target);
var += Var(output, target);
}
LogInfo << "totla " << " rmse: " << rmse / (tot_num)
<< " mae: " << mae / (tot_num)
<< " acc: " << acc / (tot_num)
<< " r2: " << r2 / (tot_num)
<< " var: " << var / (tot_num);
std::string resultPathName = "./result.txt";
std::ofstream outfile(resultPathName, std::ios::out);
if (outfile.fail()) {
LogError << "Failed to open result file: ";
return APP_ERR_COMM_FAILURE;
}
for (auto u : tot_output) {
std::string tmp;
for (auto x : u) {
tmp += std::to_string(x) + " ";
}
tmp = tmp.substr(0, tmp.size()-1);
outfile << tmp << std::endl;
}
outfile.close();
// tgcn->DeInit();
return APP_ERR_OK;
}
/*
* Copyright 2022 Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "Tgcn.h"
#include <unistd.h>
#include <sys/stat.h>
#include <memory>
#include <string>
#include <fstream>
#include <algorithm>
#include <vector>
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
APP_ERROR TGCN::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
LogInfo << "Init done.";
return APP_ERR_OK;
}
APP_ERROR TGCN::DeInit() {
// dvppWrapper_->DeInit();
model_->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR TGCN::VectorToTensorBase(const std::string &dataset, const std::vector<std::vector<float>> &input_x,
MxBase::TensorBase *tensorBase) {
uint32_t dataSize = 1;
for (size_t i = 0; i < modelDesc_.inputTensors.size(); i++) {
std::vector<uint32_t> shapes = {};
for (size_t j = 0; j < modelDesc_.inputTensors[i].tensorDims.size(); j++) {
shapes.push_back((uint32_t)modelDesc_.inputTensors[i].tensorDims[j]);
}
for (uint32_t s = 0; s < shapes.size(); ++s) {
dataSize *= shapes[s];
}
}
float *metaFeatureData = new float[dataSize];
uint32_t idx = 0;
for (size_t bs = 0; bs < input_x.size(); bs++)
for (size_t c = 0; c < input_x[bs].size(); c++)
metaFeatureData[idx++] = input_x[bs][c];
MxBase::MemoryData memoryDataDst(dataSize * 4, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void *>(metaFeatureData),
dataSize * 4, MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape;
if (dataset == "SZ-taxi")
shape.assign({64, 4, 156});
else
shape.assign({64, 12, 207});
*tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32);
return APP_ERR_OK;
}
APP_ERROR TGCN::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> *outputs) {
auto dtypes = model_->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); i++) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); j++) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
(*outputs).push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret = model_->ModelInference(inputs, *outputs, dynamicInfo);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR TGCN::SaveInferResult(std::vector<float> *batchFeaturePaths, const std::vector<MxBase::TensorBase> &inputs) {
for (auto retTensor : inputs) {
std::vector<uint32_t> shape = retTensor.GetShape();
uint32_t N = shape[0];
uint32_t C = shape[1];
if (!retTensor.IsHost()) {
retTensor.ToHost();
}
void* data = retTensor.GetBuffer();
for (uint32_t i = 0; i < N; i++) {
for (uint32_t j = 0; j < C; j++) {
float value = *(reinterpret_cast<float*>(data) + i * C + j);
batchFeaturePaths->emplace_back(value);
}
}
}
return APP_ERR_OK;
}
APP_ERROR TGCN::Process(const std::string &dataset, const std::vector<std::vector<float>> &input_x,
const InitParam &initParam, std::vector<float> *output) {
std::vector<MxBase::TensorBase> inputs = {};
std::vector<MxBase::TensorBase> infer_outputs;
MxBase::TensorBase tensorBase;
auto ret = VectorToTensorBase(dataset, input_x, &tensorBase);
if (ret != APP_ERR_OK) {
LogError << "ToTensorBase failed, ret=" << ret << ".";
return ret;
}
inputs.push_back(tensorBase);
auto startTime = std::chrono::high_resolution_clock::now();
ret = Inference(inputs, &infer_outputs);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
ret = SaveInferResult(output, infer_outputs);
if (ret != APP_ERR_OK) {
LogError << "Save model infer results into file failed. ret = " << ret << ".";
return ret;
}
return APP_ERR_OK;
}
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import datetime
import numpy as np
from sklearn import metrics
import MxpiDataType_pb2 as MxpiDataType
from StreamManagerApi import StreamManagerApi, InProtobufVector, MxProtobufIn, StringVector, MxDataInput
def accuracy(preds, targets):
"""
Calculate the accuracy between predictions and targets
Args:
preds(Tensor): predictions
targets(Tensor): ground truth
Returns:
accuracy: defined as 1 - (norm(targets - preds) / norm(targets))
"""
return 1 - np.linalg.norm(targets - preds) / np.linalg.norm(targets)
def r2(preds, targets):
"""
Calculate R square between predictions and targets
Args:
preds(Tensor): predictions
targets(Tensor): ground truth
Returns:
R square: coefficient of determination
"""
return 1 - np.sum((targets - preds) ** 2) / np.sum((targets - np.sum(preds)) ** 2)
def explained_variance(preds, targets):
"""
Calculate the explained variance between predictions and targets
Args:
preds(Tensor): predictions
targets(Tensor): ground truth
Returns:
Var: explained variance
"""
return 1 - (targets - preds).var() / targets.var()
def load_feat_matrix(path):
feat = np.loadtxt(path, delimiter=',', skiprows=1)
tmp_max_val = np.max(feat)
return feat, tmp_max_val
def generate_dataset_np(feat, seq_len, pre_len, normalize=True):
time_len = feat.shape[0]
if normalize:
tmp_max_val = np.max(feat)
feat = feat / tmp_max_val
train_size = int(time_len * 0.8)
train_data = feat[0:train_size]
eval_data = feat[train_size:time_len]
train_inputs, train_targets, tmp_eval_inputs, tmp_eval_targets = list(), list(), list(), list()
for i in range(len(train_data) - seq_len - pre_len):
train_inputs.append(np.array(train_data[i: i + seq_len]))
train_targets.append(np.array(train_data[i + seq_len: i + seq_len + pre_len]))
for i in range(len(eval_data) - seq_len - pre_len):
tmp_eval_inputs.append(np.array(eval_data[i: i + seq_len]))
tmp_eval_targets.append(np.array(eval_data[i + seq_len: i + seq_len + pre_len]))
return np.array(train_inputs), np.array(train_targets), np.array(tmp_eval_inputs), np.array(tmp_eval_targets)
def get_dataset(path):
feat, tmp_max_val = load_feat_matrix(path)
_, _, tmp_eval_inputs, tmp_eval_targets = generate_dataset_np(feat, 4, 1)
return tmp_max_val, tmp_eval_inputs, tmp_eval_targets
if __name__ == '__main__':
# init stream manager
stream_manager_api = StreamManagerApi()
ret = stream_manager_api.InitManager()
if ret != 0:
print("Failed to init Stream manager, ret=%s" % str(ret))
exit()
# create streams by pipeline config file
with open("../data/config/tgcn.pipeline", 'rb') as f:
pipelineStr = f.read()
ret = stream_manager_api.CreateMultipleStreams(pipelineStr)
if ret != 0:
print("Failed to create Stream, ret=%s" % str(ret))
exit()
# Construct the input of the stream
infer_total_time = 0
data_path = '../data/input/SZ-taxi/feature.csv'
# all_test_data = np.loadtxt(data_path)
# all_test_label = np.loadtxt(label_path)
# all_test_label = all_test_label.astype(np.int32)
stream_name = b'im_tgcn'
max_val, eval_inputs, eval_targets = get_dataset(data_path)
print("eval input shape")
print(eval_inputs.shape)
print("eval target shape")
print(eval_targets.shape)
num = eval_inputs.shape[0]
dataset = np.zeros([num, 1, 4, 156], np.float32)
for idx in range(num):
dataset[idx, :, :, :] = eval_inputs[idx].reshape(4, 156)
bs = 64
tot_output = []
tot_rmse = 0
tot_mae = 0
tot_acc = 0
tot_r2 = 0
tot_var = 0
for idx in range(num):
tensor = dataset[idx]
tensor_bytes = tensor.tobytes()
in_plugin_id = 0
tensorPackageList = MxpiDataType.MxpiTensorPackageList()
tensorPackage = tensorPackageList.tensorPackageVec.add()
dataInput = MxDataInput()
dataInput.data = tensor_bytes
tensorVec = tensorPackage.tensorVec.add()
tensorVec.deviceId = 0
tensorVec.memType = 0
for t in tensor.shape:
tensorVec.tensorShape.append(t)
tensorVec.dataStr = dataInput.data
tensorVec.tensorDataSize = len(tensor_bytes)
# add feature data end
key = "appsrc{}".format(in_plugin_id).encode('utf-8')
protobufVec = InProtobufVector()
protobuf = MxProtobufIn()
protobuf.key = key
protobuf.type = b'MxTools.MxpiTensorPackageList'
protobuf.protobuf = tensorPackageList.SerializeToString()
protobufVec.push_back(protobuf)
unique_id = stream_manager_api.SendProtobuf(stream_name, in_plugin_id, protobufVec)
if unique_id < 0:
print("Failed to send data to stream.")
exit()
# Obtain the inference result by specifying streamName and uniqueId.
start_time = datetime.datetime.now()
keyVec = StringVector()
keyVec.push_back(b'mxpi_tensorinfer0')
infer_result = stream_manager_api.GetProtobuf(stream_name, in_plugin_id, keyVec)
if infer_result.size() == 0:
print("inferResult is null")
exit()
if infer_result[0].errorCode != 0:
print("GetProtobuf error. errorCode=%d" % (
infer_result[0].errorCode))
exit()
# get infer result
result = MxpiDataType.MxpiTensorPackageList()
result.ParseFromString(infer_result[0].messageBuf)
# convert the inference result to Numpy array
output = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype=np.float32)
tmp_target = np.squeeze(eval_targets[idx], axis=0)
tot_rmse += np.sqrt(metrics.mean_squared_error(tmp_target, output))
tot_mae += metrics.mean_absolute_error(tmp_target, output)
tot_acc += accuracy(tmp_target, output)
tot_r2 += r2(tmp_target, output)
tot_var += explained_variance(tmp_target, output)
tot_output.append(output)
with open('res.txt', 'w') as f:
for output in tot_output:
for x in output:
f.write(('%.6f'%x) + ' ')
f.write('\n')
print("=====Evaluation Results=====")
print('RMSE:', '{:.6f}'.format(tot_rmse * max_val / num))
print('MAE:', '{:.6f}'.format(tot_mae * max_val / num))
print('Accuracy:', '{:.6f}'.format(tot_acc / num))
print('R2:', '{:.6f}'.format(tot_r2 / num))
print('Var:', '{:.6f}'.format(tot_var / num))
print("============================")
# destroy streams
stream_manager_api.DestroyAllStreams()
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
set -e
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
python3 main.py
exit 0
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Training script
"""
import os
import argparse
import time
import numpy as np
from mindspore.communication import init, get_rank
from mindspore.context import ParallelMode
from mindspore import dtype as mstype
from mindspore import export, set_seed, nn, context, Model, load_checkpoint, load_param_into_net, Tensor
from mindspore.train.callback import LossMonitor, TimeMonitor, Callback
from mindspore import save_checkpoint
from mindspore.dataset.core.validator_helpers import INT32_MAX
from src.config import ConfigTGCN
from src.dataprocess import load_adj_matrix, load_feat_matrix, generate_dataset_ms, generate_dataset_ms_distributed
from src.task import SupervisedForecastTask
from src.model.loss import TGCNLoss
from src.callback import RMSE
class SaveCallback(Callback):
"""
Save the best checkpoint (minimum RMSE) during training
"""
def __init__(self, eval_model, ds_eval, config):
super(SaveCallback, self).__init__()
self.model = eval_model
self.ds_eval = ds_eval
self.rmse = INT32_MAX
self.config = config
def epoch_end(self, run_context):
"""Evaluate the network and save the best checkpoint (minimum RMSE)"""
if not os.path.exists('checkpoints'):
os.mkdir('checkpoints')
cb_params = run_context.original_args()
file_name = self.config.dataset + '_' + str(self.config.pre_len) + '.ckpt'
if self.config.save_best:
result = self.model.eval(self.ds_eval)
print('Eval RMSE:', '{:.6f}'.format(result['RMSE']))
if result['RMSE'] < self.rmse:
self.rmse = result['RMSE']
save_checkpoint(save_obj=cb_params.train_network,
ckpt_file_name=os.path.join(self.config.train_url, file_name))
print("Best checkpoint saved!")
else:
save_checkpoint(save_obj=cb_params.train_network,
ckpt_file_name=os.path.join(self.config.train_url, file_name))
def _export(config):
context.set_context(mode=context.GRAPH_MODE, device_target=config.device)
# Create network
adj = load_adj_matrix(config.dataset, config.data_url)
net = SupervisedForecastTask(adj, config.hidden_dim, config.pre_len)
# Load parameters from checkpoint into network
ckpt_file = config.dataset + "_" + str(config.pre_len) + ".ckpt"
param_dict = load_checkpoint(os.path.join(config.train_url, ckpt_file))
print(os.path.join(config.train_url, ckpt_file))
print(param_dict)
load_param_into_net(net, param_dict)
# Initialize dummy inputs
inputs = np.random.uniform(0.0, 1.0, size=[config.batch_size, config.seq_len, adj.shape[0]]).astype(np.float32)
# Export network into MINDIR model file
if not os.path.exists('outputs'):
os.mkdir('outputs')
file_name = config.dataset + "_" + str(config.pre_len)
path = os.path.join(config.train_url, file_name)
# export(net, Tensor(inputs), file_name=path, file_format='ONNX')
export(net, Tensor(inputs), file_name=path, file_format='AIR')
print("==========================================")
# print(file_name + ".onnx exported successfully!")
print(file_name + ".air exported successfully!")
print("==========================================")
def merge_cfg(t_args, config):
for arg in vars(t_args):
setattr(config, arg, getattr(t_args, arg))
return config
def run_train(args):
"""
Run training
"""
# Config initialization
config = ConfigTGCN()
# Set global seed for MindSpore and NumPy
set_seed(config.seed)
# ModelArts runtime, datasets and network initialization
config = merge_cfg(args, config)
print(config.data_url)
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id)
if config.distributed:
device_num = int(os.getenv('RANK_SIZE'))
init()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
training_set = generate_dataset_ms_distributed(config, training=True, abs_path=config.data_url)
eval_set = generate_dataset_ms_distributed(config, training=False, abs_path=config.data_url)
_, max_val = load_feat_matrix(config.dataset, config.data_url)
net = SupervisedForecastTask(load_adj_matrix(config.dataset, config.data_url),
config.hidden_dim, config.pre_len)
else:
training_set = generate_dataset_ms(config, training=True, abs_path=config.data_url)
eval_set = generate_dataset_ms(config, training=False, abs_path=config.data_url)
_, max_val = load_feat_matrix(config.dataset, abs_path=config.data_url)
net = SupervisedForecastTask(load_adj_matrix(config.dataset, abs_path=config.data_url),
config.hidden_dim, config.pre_len)
# Mixed precision
net.tgcn.tgcn_cell.graph_conv1.matmul.to_float(mstype.float16)
net.tgcn.tgcn_cell.graph_conv2.matmul.to_float(mstype.float16)
# Loss function
loss_fn = TGCNLoss()
# Optimizer
optimizer = nn.Adam(net.trainable_params(), config.learning_rate, weight_decay=config.weight_decay)
# Create model
model = Model(net, loss_fn, optimizer, {'RMSE': RMSE(max_val)})
# Training
time_start = time.time()
callbacks = [LossMonitor(), TimeMonitor()]
if config.distributed:
print("==========Distributed Training Start==========")
save_callback = SaveCallback(model, eval_set, config)
if get_rank() == 0:
callbacks = [LossMonitor(), TimeMonitor(), save_callback]
elif config.save_best:
callbacks = [LossMonitor()]
else:
print("==========Training Start==========")
save_callback = SaveCallback(model, eval_set, config)
callbacks.append(save_callback)
model.train(config.epochs, training_set,
callbacks=callbacks,
dataset_sink_mode=config.data_sink)
time_end = time.time()
if config.distributed:
print("==========Distributed Training End==========")
else:
print("==========Training End==========")
print("Training time in total:", '{:.6f}'.format(time_end - time_start), "s")
_export(config)
if __name__ == '__main__':
# Set universal arguments
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', help="DEVICE_ID", type=int, default=0)
parser.add_argument('--distributed', help="distributed training", type=bool, default=False)
# Set ModelArts arguments
parser.add_argument('--data_url', help='ModelArts location of data', type=str, default=None)
parser.add_argument('--train_url', help='ModelArts location of training outputs', type=str, default=None)
parser.add_argument('--dataset', type=str, default='SZ-taxi')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=3000)
parser.add_argument('--pre_len', type=int, default=1)
run_args = parser.parse_args()
# Training
run_train(run_args)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
data_dir=$2
model_dir=$3
docker run -it --ipc=host \
--device=/dev/davinci0 \
--device=/dev/davinci1 \
--device=/dev/davinci2 \
--device=/dev/davinci3 \
--device=/dev/davinci4 \
--device=/dev/davinci5 \
--device=/dev/davinci6 \
--device=/dev/davinci7 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm --device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons/ \
-v ${model_dir}:${model_dir} \
-v ${data_dir}:${data_dir} \
-v ~/ascend/log/npu/conf/slog/slog.conf:/var/log/npu/conf/slog/slog.conf \
-v ~/ascend/log/npu/slog/:/var/log/npu/slog -v ~/ascend/log/npu/profiling/:/var/log/npu/profiling \
-v ~/ascend/log/npu/dump/:/var/log/npu/dump -v ~/ascend/log/npu/:/usr/slog ${docker_image} \
-u root \
/bin/bash
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment