Skip to content
Snippets Groups Projects
Unverified Commit 57c90cd8 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3827 [西安交通大学][高校贡献][Mindspore][PvNet]-高性能预训练模型提交+modelarts/SDK/Mxbase

Merge pull request !3827 from Whishing/master
parents b135448e 6f8b0461
Branches
No related tags found
No related merge requests found
Showing
with 1281 additions and 0 deletions
ARG FROM_IMAGE_NAME
FROM ${FROM_IMAGE_NAME}
COPY requirements.txt .
RUN pip3.7 install -r requirements.txt
\ No newline at end of file
aipp_op {
aipp_mode: static
input_format : RGB888_U8
rbuv_swap_switch : false
src_image_size_w: 640
src_image_size_h: 480
mean_chn_0 : 124
mean_chn_1 : 117
mean_chn_2 : 104
var_reci_chn_0 : 0.017124754
var_reci_chn_1 : 0.017507003
var_reci_chn_2 : 0.017429194
}
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import shutil
import numpy as np
from model_utils.config import config as cfg
from model_utils.data_file_utils import read_pickle
if __name__ == '__main__':
re_dir = os.path.join("./data/", cfg.cls_name)
img_dir = os.path.join(re_dir, 'images')
pose_dir = os.path.join(re_dir, 'poses')
# 6D pose estimation input
real_pkl = os.path.join(cfg.eval_dataset, cfg.dataset_name, 'posedb', '{}_real.pkl'.format(cfg.cls_name))
real_set = read_pickle(real_pkl)
data_root_dir = os.path.join(cfg.dataset_dir, cfg.dataset_name, cfg.cls_name)
test_fn = os.path.join(data_root_dir, 'test.txt')
val_fn = os.path.join(data_root_dir, 'val.txt')
with open(test_fn, 'r') as f:
test_fns = [line.strip().split('/')[-1] for line in f.readlines()]
with open(val_fn, 'r') as f:
val_fns = [line.strip().split('/')[-1] for line in f.readlines()]
test_real_set = []
val_real_set = []
for data in real_set:
if data['rgb_pth'].split('/')[-1] in test_fns:
if data['rgb_pth'].split('/')[-1] in val_fns:
val_real_set.append(data)
else:
test_real_set.append(data)
test_db = []
test_db += test_real_set
test_db += val_real_set
if not os.path.exists(img_dir):
os.makedirs(img_dir)
if not os.path.exists(pose_dir):
os.makedirs(pose_dir)
for idx, _ in enumerate(test_db):
rgb_path = os.path.join(cfg.eval_dataset, cfg.dataset_name, test_db[idx]['rgb_pth'])
rgb_name = test_db[idx]['rgb_pth'].strip().split('/')[-1]
shutil.copyfile(rgb_path, os.path.join(img_dir, rgb_name))
pose = test_db[idx]['RT'].copy()
np.savetxt(os.path.join(pose_dir, rgb_name.split('.')[0]+'.txt'), pose)
np.savetxt(os.path.join(re_dir, "test.txt"), test_fns, fmt='%s')
print('preprocess success!')
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
model_path=$1
output_model_name=$2
atc \
--model=$model_path \
--framework=1 \
--output=$output_model_name \
--input_format=NCHW --input_shape="actual_input_1:1,3,480,640" \
--enable_small_channel=1 \
--log=error \
--soc_version=Ascend310 \
--insert_op_conf=./aipp.config
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
share_dir=$2
data_dir=$3
echo "$1"
echo "$2"
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
exit 1
fi
if [ ! -d "${share_dir}" ]; then
echo "please input share directory that contains dataset, models and codes"
exit 1
fi
docker run -it -u root \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
--privileged \
-v //usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${data_dir}:${data_dir} \
-v ${share_dir}:${share_dir} \
${docker_image} \
/bin/bash
cmake_minimum_required(VERSION 3.14.0)
project(pvnet)
set(TARGET pvnet)
add_definitions(-DENABLE_DVPP_INTERFACE)
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall
-Dgoogle=mindxsdk_private -D_GLIBCXX_USE_CXX11_ABI=0)
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -pie)
# Check environment variable
if(NOT DEFINED ENV{MX_SDK_HOME})
message(FATAL_ERROR "please define environment variable:MX_SDK_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_VERSION})
message(WARNING "please define environment variable:ASCEND_VERSION")
endif()
if(NOT DEFINED ENV{ARCH_PATTERN})
message(WARNING "please define environment variable:ARCH_PATTERN")
endif()
set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
if(DEFINED ENV{MXSDK_OPENSOURCE_DIR})
set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
else()
set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource)
endif()
include_directories(${ACL_INC_DIR})
include_directories(${OPENSOURCE_DIR}/include)
include_directories(${OPENSOURCE_DIR}/include/opencv4)
include_directories(${MXBASE_INC})
include_directories(${MXBASE_POST_PROCESS_DIR})
link_directories(${ACL_LIB_DIR})
link_directories(${OPENSOURCE_DIR}/lib)
link_directories(${MXBASE_LIB_DIR})
link_directories(${MXBASE_POST_LIB_DIR})
add_executable(${TARGET} main_pvnet.cpp PVNet.cpp)
target_link_libraries(${TARGET} glog cpprest mxbase opencv_world stdc++fs)
install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
/*
* Copyright (c) 2022. Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <map>
#include "PVNet.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
APP_ERROR PVNet::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR PVNet::DeInit() {
model_->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR PVNet::ReadImage(const std::string &imgPath, cv::Mat &imageMat) {
imageMat = cv::imread(imgPath, cv::IMREAD_COLOR);
cv::cvtColor(imageMat, imageMat, cv::COLOR_RGB2BGR);
return APP_ERR_OK;
}
APP_ERROR PVNet::ResizeImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat) {
static constexpr uint32_t resizeHeight = 480;
static constexpr uint32_t resizeWidth = 640;
cv::resize(srcImageMat, dstImageMat, cv::Size(resizeWidth, resizeHeight));
return APP_ERR_OK;
}
APP_ERROR PVNet::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase) {
const uint32_t dataSize = imageMat.cols * imageMat.rows * 3;
MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(imageMat.data, dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape = {1, static_cast<uint32_t>(imageMat.rows), static_cast<uint32_t>(imageMat.cols), 3};
tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT8);
return APP_ERR_OK;
}
APP_ERROR PVNet::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> &outputs) {
auto dtypes = model_->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs.push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR PVNet::SaveResult(const std::string &imgPath, std::vector<MxBase::TensorBase> &outputs) {
// LogInfo << "result path" << imgPath;
std::string fileName = imgPath.substr(imgPath.find_last_of("/") + 1);
size_t dot = fileName.find_last_of(".");
std::string segFileName = "result/seg_pred/" + fileName.substr(0, dot) + ".bin";
std::string verFileName = "result/ver_pred/" + fileName.substr(0, dot) + ".bin";
LogInfo << "file path for saving seg_pred:" << segFileName;
LogInfo << "file path for saving ver_pred:" << verFileName;
// save seg_pred
APP_ERROR ret0 = outputs[0].ToHost();
if (ret0 != APP_ERR_OK) {
LogError << GetError(ret0) << "tohost fail.";
return ret0;
}
void *segOutput = outputs[0].GetBuffer();
std::vector<uint32_t> segShape = outputs[0].GetShape();
FILE *segOutputFile = fopen(segFileName.c_str(), "wb");
fwrite(segOutput, segShape[0]*segShape[1]*segShape[2]*segShape[3], sizeof(float), segOutputFile);
fclose(segOutputFile);
segOutputFile = nullptr;
// save ver_pred
APP_ERROR ret1 = outputs[1].ToHost();
if (ret1 != APP_ERR_OK) {
LogError << GetError(ret1) << "tohost fail.";
return ret1;
}
void *verOutput = outputs[1].GetBuffer();
std::vector<uint32_t> verShape = outputs[1].GetShape();
FILE *verOutputFile = fopen(verFileName.c_str(), "wb");
fwrite(verOutput, verShape[0]*verShape[1]*verShape[2]*verShape[3], sizeof(float), verOutputFile);
fclose(verOutputFile);
verOutputFile = nullptr;
return APP_ERR_OK;
}
APP_ERROR PVNet::Process(const std::string &imgPath) {
cv::Mat imageMat;
APP_ERROR ret = ReadImage(imgPath, imageMat);
if (ret != APP_ERR_OK) {
LogError << "ReadImage failed, ret=" << ret << ".";
return ret;
}
std::vector<MxBase::TensorBase> inputs = {};
std::vector<MxBase::TensorBase> outputs = {};
MxBase::TensorBase tensorBase;
ret = CVMatToTensorBase(imageMat, tensorBase);
if (ret != APP_ERR_OK) {
LogError << "CVMatToTensorBase failed, ret=" << ret << ".";
return ret;
}
inputs.push_back(tensorBase);
auto startTime = std::chrono::high_resolution_clock::now();
ret = Inference(inputs, outputs);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
ret = SaveResult(imgPath, outputs);
if (ret != APP_ERR_OK) {
LogError << "Save infer results into file failed. ret = " << ret << ".";
return ret;
}
return APP_ERR_OK;
}
/*
* Copyright (c) 2022. Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_PVNET_H
#define MXBASE_PVNET_H
#include <string>
#include <vector>
#include <memory>
#include <opencv2/opencv.hpp>
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
struct InitParam {
uint32_t deviceId;
std::string modelPath;
};
class PVNet {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR ReadImage(const std::string &imgPath, cv::Mat &imageMat);
APP_ERROR ResizeImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat);
APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs);
APP_ERROR Process(const std::string &imgPath);
// get infer time
double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
private:
APP_ERROR SaveResult(const std::string &resultPath,
std::vector<MxBase::TensorBase> &outputs);
private:
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
// infer time
double inferCostTimeMilliSec = 0.0;
};
#endif
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
path_cur=$(dirname $0)
function check_env()
{
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to ./ when it was not specified by user
export ARCH_PATTERN=./
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
fi
}
function build_pvnet()
{
cd $path_cur
rm -rf build
mkdir -p build
cd build
cmake ..
make
ret=$?
if [ ${ret} -ne 0 ]; then
echo "Failed to build pvnet."
exit ${ret}
fi
make install
}
check_env
build_pvnet
/*
* Copyright (c) 2022. Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <dirent.h>
#include <fstream>
#include <string>
#include <iostream>
#include "PVNet.h"
#include "MxBase/Log/Log.h"
APP_ERROR ScanImages(const std::string &path, std::vector<std::string> &imgFiles) {
DIR *dirPtr = opendir(path.c_str());
if (dirPtr == nullptr) {
LogError << "opendir failed. dir:" << path;
return APP_ERR_INTERNAL_ERROR;
}
dirent *direntPtr = nullptr;
while ((direntPtr = readdir(dirPtr)) != nullptr) {
std::string fileName = direntPtr->d_name;
if (fileName == "." || fileName == "..") {
continue;
}
imgFiles.emplace_back(path + "/" + fileName);
}
closedir(dirPtr);
return APP_ERR_OK;
}
APP_ERROR ReadImgFiles(const std::string &txtPath, const std::string &datasetPath, std::vector<std::string> &imgFiles) {
std::ifstream testTxt(txtPath.c_str());
std::string fileName;
if (testTxt) {
while (getline(testTxt, fileName)) {
imgFiles.emplace_back(datasetPath + "/" + fileName);
}
} else {
LogError << "Open file failed. file:" << txtPath;
return APP_ERR_INTERNAL_ERROR;
}
return APP_ERR_OK;
}
int main(int argc, char* argv[]) {
if (argc <= 2) {
LogWarn << "Please inputs test.txt and dataset path, such as '../data/cat/test.txt ../data/cat/images'.";
return APP_ERR_OK;
}
InitParam initParam = {};
initParam.deviceId = 0;
initParam.modelPath = "./data/models/pvnet.om";
auto pvnet = std::make_shared<PVNet>();
APP_ERROR ret = pvnet->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "PVNet init failed, ret=" << ret << ".";
return ret;
}
std::string txtPath = argv[1];
std::string datasetPath = argv[2];
std::vector<std::string> imgFilePaths;
ret = ReadImgFiles(txtPath, datasetPath, imgFilePaths);
if (ret != APP_ERR_OK) {
return ret;
}
auto startTime = std::chrono::high_resolution_clock::now();
for (auto &imgFile : imgFilePaths) {
ret = pvnet->Process(imgFile);
if (ret != APP_ERR_OK) {
LogError << "PVNet process failed, ret=" << ret << ".";
pvnet->DeInit();
return ret;
}
}
auto endTime = std::chrono::high_resolution_clock::now();
pvnet->DeInit();
double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
double fps = 1000.0 * imgFilePaths.size() / pvnet->GetInferCostMilliSec();
LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec";
return APP_ERR_OK;
}
opencv-python
tqdm
pycocotools
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Model Infer """
import json
import logging
import MxpiDataType_pb2 as MxpiDataType
from StreamManagerApi import StreamManagerApi, MxDataInput, InProtobufVector, MxProtobufIn, StringVector
from config import config as cfg
class SdkApi:
""" Class SdkApi """
INFER_TIMEOUT = cfg.INFER_TIMEOUT
STREAM_NAME = cfg.STREAM_NAME
def __init__(self, pipeline_cfg):
self.pipeline_cfg = pipeline_cfg
self._stream_api = None
self._data_input = None
self._device_id = None
def init(self):
""" Initialize Stream """
with open(self.pipeline_cfg, 'r') as fp:
self._device_id = int(
json.loads(fp.read())[self.STREAM_NAME]["stream_config"]
["deviceId"])
print(f"The device id: {self._device_id}.")
# create api
self._stream_api = StreamManagerApi()
# init stream mgr
ret = self._stream_api.InitManager()
if ret != 0:
print(f"Failed to init stream manager, ret={ret}.")
return False
# create streams
with open(self.pipeline_cfg, 'rb') as fp:
pipe_line = fp.read()
ret = self._stream_api.CreateMultipleStreams(pipe_line)
if ret != 0:
print(f"Failed to create stream, ret={ret}.")
return False
self._data_input = MxDataInput()
return True
def __del__(self):
if not self._stream_api:
return
self._stream_api.DestroyAllStreams()
def _send_protobuf(self, stream_name, plugin_id, element_name, buf_type,
pkg_list):
""" Send Stream """
protobuf = MxProtobufIn()
protobuf.key = element_name.encode("utf-8")
protobuf.type = buf_type
protobuf.protobuf = pkg_list.SerializeToString()
protobuf_vec = InProtobufVector()
protobuf_vec.push_back(protobuf)
err_code = self._stream_api.SendProtobuf(stream_name, plugin_id,
protobuf_vec)
if err_code != 0:
logging.error(
"Failed to send data to stream, stream_name(%s), plugin_id(%s), element_name(%s), "
"buf_type(%s), err_code(%s).", stream_name, plugin_id,
element_name, buf_type, err_code)
return False
return True
def send_tensor_input(self, stream_name, plugin_id, element_name,
input_data, input_shape, data_type):
""" Send Tensor """
tensor_list = MxpiDataType.MxpiTensorPackageList()
tensor_pkg = tensor_list.tensorPackageVec.add()
# init tensor vector
tensor_vec = tensor_pkg.tensorVec.add()
tensor_vec.deviceId = self._device_id
tensor_vec.memType = 0
tensor_vec.tensorShape.extend(input_shape)
tensor_vec.tensorDataType = data_type
tensor_vec.dataStr = input_data
tensor_vec.tensorDataSize = len(input_data)
buf_type = b"MxTools.MxpiTensorPackageList"
return self._send_protobuf(stream_name, plugin_id, element_name,
buf_type, tensor_list)
def get_result(self, stream_name, out_plugin_id=0):
""" Get Result """
keys = [b"mxpi_tensorinfer0"]
keyVec = StringVector()
for key in keys:
keyVec.push_back(key)
infer_result = self._stream_api.GetProtobuf(stream_name, 0, keyVec)
print(infer_result)
if infer_result.size() == 0:
print("infer_result is null")
exit()
if infer_result[0].errorCode != 0:
print("GetProtobuf error. errorCode=%d" % (
infer_result[0].errorCode))
exit()
TensorList = MxpiDataType.MxpiTensorPackageList()
TensorList.ParseFromString(infer_result[0].messageBuf)
return TensorList
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Model Config """
STREAM_NAME = "pvnet"
INFER_TIMEOUT = 100000
TENSOR_DTYPE_FLOAT32 = 0
TENSOR_DTYPE_FLOAT16 = 1
TENSOR_DTYPE_INT8 = 2
{
"pvnet": {
"stream_config": {
"deviceId": "0"
},
"appsrc0": {
"factory": "appsrc",
"next": "mxpi_tensorinfer0"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "appsrc0",
"modelPath": "../data/models/pvnet.om",
"waitingTime": "2000"
},
"factory": "mxpi_tensorinfer",
"next": "appsink0"
},
"appsink0": {
"factory": "appsink"
}
}
}
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import datetime
import libransac_voting as ransac_vote
from model_utils.config import config as cfg
from model_utils.data_file_utils import read_rgb_np
import numpy as np
from api.infer import SdkApi
from config import config as stream_cfg
def vote(seg_pred, ver_pred):
"""
save infer result to the file, Write format:
Object detected num is 5
#Obj: 1, box: 453 369 473 391, confidence: 0.3, label: person, id: 0
...
:param result_dir is the dir of save result
:param result content bbox and class_id of all object
"""
data = np.concatenate([seg_pred, ver_pred], 1)[0]
channel = cfg.vote_num * 2 + 2
ransac_vote.init_voting(cfg.img_height, cfg.img_width, channel, 2, cfg.vote_num)
print('vote init success!----------------------------------------------------------')
corner_pred = np.zeros((cfg.vote_num, 2), dtype=np.float32)
ransac_vote.do_voting(data, corner_pred)
print('do voting success!----------------------------------------------------------')
return corner_pred
if __name__ == '__main__':
# init stream manager
pipeline_path = "./config/pvnet.pipeline"
sdk_api = SdkApi(pipeline_path)
if not sdk_api.init():
exit(-1)
# Construct the input of the stream
img_data_plugin_id = 0
re_dir = os.path.join("./data/", cfg.cls_name)
image_dir = os.path.join(re_dir, 'images')
test_fn = os.path.join(re_dir, 'test.txt')
res_dir_name = './result'
stream_name = b'pvnet'
TENSOR_DTYPE_FLOAT32 = 0
if not os.path.exists(os.path.join(res_dir_name, 'seg_pred')):
os.makedirs(os.path.join(res_dir_name, 'seg_pred'))
if not os.path.exists(os.path.join(res_dir_name, 'ver_pred')):
os.makedirs(os.path.join(res_dir_name, 'ver_pred'))
test_fns = np.loadtxt(test_fn, dtype=str)
corner_preds = []
poses = []
for _, img_fn in enumerate(test_fns):
rgb_path = os.path.join(image_dir, img_fn)
rgb = read_rgb_np(rgb_path).reshape(1, 480, 640, 3)
sdk_api.send_tensor_input(stream_name, img_data_plugin_id, "appsrc0",
rgb.tobytes(), rgb.shape, stream_cfg.TENSOR_DTYPE_FLOAT32)
start_time = datetime.datetime.now()
result = sdk_api.get_result(stream_name)
end_time = datetime.datetime.now()
print('sdk run time: {}'.format((end_time - start_time).microseconds))
seg_result = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr,
dtype=np.float32).reshape(1, -1, 480, 640)
ver_result = np.frombuffer(result.tensorPackageVec[0].tensorVec[1].dataStr,
dtype=np.float32).reshape(1, -1, 480, 640)
seg_result.tofile(os.path.join(res_dir_name, 'seg_pred', img_fn.split('.')[0]+'.bin'))
ver_result.tofile(os.path.join(res_dir_name, 'ver_pred', img_fn.split('.')[0]+'.bin'))
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/driver/lib64/:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
python3 main.py
exit 0
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval utils"""
import numpy as np
import cv2
from src.evaluation_dataset import get_pts_3d
def pnp(points_3d, points_2d, camera_matrix, method=cv2.SOLVEPNP_ITERATIVE):
"""pnp"""
dist_coeffs = np.zeros(shape=[8, 1], dtype='float64')
assert points_3d.shape[0] == points_2d.shape[0], 'points 3D and points 2D must have same number of vertices'
if method == cv2.SOLVEPNP_EPNP:
points_3d = np.expand_dims(points_3d, 0)
points_2d = np.expand_dims(points_2d, 0)
points_2d = np.ascontiguousarray(points_2d.astype(np.float64))
points_3d = np.ascontiguousarray(points_3d.astype(np.float64))
camera_matrix = camera_matrix.astype(np.float64)
_, R_exp, t = cv2.solvePnP(points_3d,
points_2d,
camera_matrix,
dist_coeffs,
flags=method)
R, _ = cv2.Rodrigues(R_exp)
return np.concatenate([R, t], axis=-1)
def evaluate(points_2d, class_type):
"""evaluate"""
points_3d = get_pts_3d(class_type)
k_matrix = np.array([[572.41140, 0., 325.26110],
[0., 573.57043, 242.04899],
[0., 0., 1.]], np.float32)
pose_pred = pnp(points_3d, points_2d, k_matrix)
return pose_pred
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import libransac_voting as ransac_vote
from model_utils.config import config as cfg
from src.evaluation_utils import Evaluator
def acc(test_fn, pose_dir, result_dir):
channel = cfg.vote_num * 2 + 2
ransac_vote.init_voting(cfg.img_height, cfg.img_width, channel, 2, cfg.vote_num)
test_fns = np.loadtxt(test_fn, dtype=str)
evaluator = Evaluator()
for _, img_fn in enumerate(test_fns):
seg_fn = os.path.join(result_dir, 'seg_pred', img_fn.split('.')[0] + '.bin')
ver_fn = os.path.join(result_dir, 'ver_pred', img_fn.split('.')[0] + '.bin')
seg_pred = np.fromfile(seg_fn, dtype=np.float32).reshape(1, -1, 480, 640)
ver_pred = np.fromfile(ver_fn, dtype=np.float32).reshape(1, -1, 480, 640)
pose = np.loadtxt(os.path.join(pose_dir, img_fn.split('.')[0] + '.txt')).reshape(3, 4)
data = np.concatenate([seg_pred, ver_pred], 1)[0]
corner_pred = np.zeros((cfg.vote_num, 2), dtype=np.float32)
ransac_vote.do_voting(data, corner_pred)
pose_pred = evaluator.evaluate(corner_pred, pose, cfg.cls_name)
np.savetxt(os.path.join(result_dir, "pred_pose", img_fn.split('.')[0] + '.txt'), pose_pred)
proj_err, add, _ = evaluator.average_precision(False)
print('Processing object:{}, 2D error:{}, ADD:{}'.format(cfg.cls_name, proj_err, add))
if __name__ == '__main__':
test_fn_out = os.path.join('./data', cfg.cls_name, 'test.txt')
pose_dir_out = os.path.join('./data', cfg.cls_name, 'poses')
result_dir_out = './result'
if not os.path.exists(os.path.join(result_dir_out, 'pred_pose')):
os.makedirs(os.path.join(result_dir_out, 'pred_pose'))
acc(test_fn=test_fn_out, pose_dir=pose_dir_out, result_dir=result_dir_out)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train"""
import os
import time
import glob
import argparse
import ast
import numpy as np
import mindspore
import mindspore.context as context
from mindspore import Tensor
from mindspore import nn
from mindspore.communication import get_rank, init, get_group_size
from mindspore.nn import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, _InternalCallbackParam, RunContext
# from model_utils.config import config as cfg
from src.dataset import create_dataset
from src.loss_scale import TrainOneStepWithLossScaleCell
from src.model_reposity import Resnet18_8s, NetworkWithLossCell
from src.net_utils import AverageMeter, adjust_learning_rate
import moxing as mox
from model_utils.config import config as cfg
loss_rec = AverageMeter()
recs = [loss_rec]
print(os.system('env'))
def export_AIR(args_opt):
"""start modelarts export"""
ckpt_list = glob.glob(os.path.join(args_opt.modelarts_result_dir, args_opt.cls_name, "train*.ckpt"))
if not ckpt_list:
print("ckpt file not generated.")
ckpt_list.sort(key=os.path.getmtime)
ckpt_model = ckpt_list[-1]
print("checkpoint path", ckpt_model)
# if args.device_target == "Ascend":
# context.set_context(device_id=args.rank)
net = Resnet18_8s(ver_dim=args.vote_num * 2)
param_dict = mindspore.load_checkpoint(ckpt_model)
mindspore.load_param_into_net(net, param_dict)
net.set_train(False)
input_data = Tensor(np.zeros([1, 3, args.img_height, args.img_width]), mindspore.float32)
mindspore.export(net, input_data, file_name=args.file_name, file_format=args.file_format)
class Train:
"""PVNet Train class"""
def __init__(self, arg):
"""__init__"""
self.cls_num = 1 + len(arg.cls_name.split(','))
self.arg = arg
self.dataset = create_dataset(
cls_list=arg.cls_name,
batch_size=arg.batch_size,
workers=arg.workers_num,
devices=arg.group_size,
rank=arg.rank
)
self.current_dir = os.path.dirname(os.path.abspath(__file__))
if arg.pretrained_path is None:
self.pretrained_path = None
else:
self.pretrained_path = os.path.join(self.current_dir, arg.pretrained_path)
self.step_per_epoch = self.dataset.get_dataset_size()
self.dataset = self.dataset.create_tuple_iterator(output_numpy=True, do_copy=False)
print("cls:{}, device_num:{}, rank:{}, data_size:{}".format(arg.cls_name, arg.group_size, arg.rank,
self.step_per_epoch))
def _build_net(self):
""" build pvnet network"""
lr = mindspore.Tensor(adjust_learning_rate(global_step=0,
lr_init=self.arg.lr,
lr_decay_rate=self.arg.learning_rate_decay_rate,
lr_decay_epoch=self.arg.learning_rate_decay_epoch,
total_epochs=self.arg.epoch_size,
steps_per_epoch=self.step_per_epoch))
net = Resnet18_8s(ver_dim=self.arg.vote_num * 2, pretrained_path=self.pretrained_path)
self.opt = nn.Adam(net.trainable_params(), learning_rate=lr)
self.net = NetworkWithLossCell(net, cls_num=self.cls_num)
scale_manager = DynamicLossScaleUpdateCell(loss_scale_value=self.arg.loss_scale_value,
scale_factor=self.arg.scale_factor,
scale_window=self.arg.scale_window)
self.net = TrainOneStepWithLossScaleCell(self.net, self.opt, scale_sense=scale_manager)
self.net.set_train()
def train_net(self):
""" train pvnet network"""
self._build_net()
if self.arg.rank == 0:
self._train_begin()
for i in range(self.arg.epoch_size):
start = time.time()
iter_start = time.time()
for idx, data in enumerate(self.dataset):
for rec in recs:
rec.reset()
cost_time = time.time() - iter_start
image, mask, vertex, vertex_weight = data
image = Tensor.from_numpy(image)
mask = Tensor.from_numpy(mask)
vertex = Tensor.from_numpy(vertex)
vertex_weight = Tensor.from_numpy(vertex_weight)
total_loss = self.net(image, mask, vertex, vertex_weight)
for rec, val in zip(recs, total_loss):
rec.update(val)
if idx % 80 == 0:
log_str = "Rank:{}/{}, Epoch:[{}/{}], Step[{}/{}] cost:{}.s total:{}".format(
self.arg.rank, self.arg.group_size, i + 1, self.arg.epoch_size, idx, self.step_per_epoch,
cost_time,
recs[0].avg)
print(log_str)
iter_start = time.time()
if self.arg.rank == 0:
self._cb_params.output = total_loss
self._cb_params.cur_step_num += 1
self._ckpt_saver.step_end(self._run_context)
print('Epoch Cost:{}'.format(time.time() - start), "seconds.")
if self.arg.rank == 0:
self._cb_params.cur_epoch_num += 1
def _train_begin(self):
""" the step before training """
begin_epoch = 0
cb_params = _InternalCallbackParam()
cb_params.epoch_num = self.arg.epoch_size
cb_params.batch_num = self.step_per_epoch
cb_params.cur_epoch_num = begin_epoch
cb_params.cur_step_num = begin_epoch * self.step_per_epoch
cb_params.train_network = self.net
self._cb_params = cb_params
self._run_context = RunContext(cb_params)
ckpt_config = CheckpointConfig(save_checkpoint_steps=self.step_per_epoch,
keep_checkpoint_max=self.arg.keep_checkpoint_max)
self._ckpt_saver = ModelCheckpoint(
prefix="train",
directory=os.path.join(self.arg.modelarts_result_dir, self.arg.cls_name),
config=ckpt_config
)
self._ckpt_saver.begin(self._run_context)
def network_init(argvs):
""" init distribute training """
context.set_context(mode=context.GRAPH_MODE,
device_target=argvs.device_target,
save_graphs=False,
device_id=int(os.getenv('DEVICE_ID', '0')),
reserve_class_name_in_scope=False)
# Init distributed
if argvs.distribute:
init()
argvs.rank = get_rank()
argvs.group_size = get_group_size()
context.reset_auto_parallel_context()
parallel_mode = context.ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=argvs.group_size)
def parse_args():
parser = argparse.ArgumentParser('PVNet')
parser.add_argument("--train_url", type=str, default="./output")
parser.add_argument("--data_url", type=str, default="./dataset")
parser.add_argument("--modelarts_data_dir", type=str, default="/cache/dataset")
parser.add_argument("--modelarts_result_dir", type=str,
default="/cache/result") # modelarts train result: /cache/result
parser.add_argument('--random_seed', type=int, default=0, help='random_seed')
parser.add_argument('--cls_name', type=str, default="cat",
help='Sub-Dataset to train, for example, cat,ape,cam ')
parser.add_argument('--epoch_size', type=int, default=1, help='epoch_size')
parser.add_argument('--lr', type=float, default=0.0005, help='learning rate')
parser.add_argument('--pretrain_epoch_size', type=int, default=0, help='pretrain_epoch_size, use with pre_trained')
parser.add_argument('--batch_size', type=int, default=16, help='batch_size')
parser.add_argument('--learning_rate_decay_rate', type=float, default=0.5, help='learning_rate_decay_rate')
parser.add_argument('--learning_rate_decay_epoch', type=int, default=20, help='learning_rate_decay_epoch')
parser.add_argument('--vote_num', type=int, default=9, help='vote num')
parser.add_argument('--workers_num', type=int, default=16, help='workers_num')
parser.add_argument('--group_size', type=int, default=1, help='group_size')
parser.add_argument('--rank', type=int, default=0, help='rank')
parser.add_argument('--loss_scale_value', type=int, default=1024, help='loss_scale_value')
parser.add_argument('--scale_factor', type=int, default=2, help='scale_factor')
parser.add_argument('--scale_window', type=int, default=1000, help='scale_window')
parser.add_argument('--keep_checkpoint_max', type=int, default=10, help='keep_checkpoint_max')
# Do not change the following two hyper-parameter, it conflicts with pvnet_linemod_config.yaml
parser.add_argument('--img_height', type=int, default=480, help='img_height')
parser.add_argument('--img_width', type=int, default=640, help='img_width')
parser.add_argument('--distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--device_target', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
help="Device target, support Ascend, GPU and CPU.")
parser.add_argument('--pretrained_path', type=str, default="./resnet18-5c106cde.ckpt",
help='Pretrained checkpoint path')
parser.add_argument('--file_name', type=str, default='pvnet', help='output air file name')
parser.add_argument('--file_format', type=str, default='AIR', help='file_format')
return parser.parse_args()
# _CACHE_DATA_URL = "/cache/data_url"
# _CACHE_TRAIN_URL = "/cache/train_url"
if __name__ == '__main__':
args = parse_args()
cfg.data_url = args.data_url
mindspore.set_seed(args.random_seed)
network_init(args)
## copy dataset from obs to modelarts
os.makedirs(args.modelarts_data_dir, exist_ok=True)
os.makedirs(args.modelarts_result_dir, exist_ok=True)
mox.file.copy_parallel(args.data_url, args.modelarts_data_dir)
train = Train(args)
train.train_net()
## start export air
export_AIR(args)
## copy result from modelarts to obs
mox.file.copy_parallel(args.modelarts_result_dir, args.train_url)
air_file = args.file_name + ".air"
mox.file.copy(src_url=air_file, dst_url=os.path.join(args.train_url, air_file))
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.mitations under the License.
docker_image=$1
data_dir=$2
model_dir=$3
docker run -it -u root --ipc=host \
--device=/dev/davinci0 \
--device=/dev/davinci1 \
--device=/dev/davinci2 \
--device=/dev/davinci3 \
--device=/dev/davinci4 \
--device=/dev/davinci5 \
--device=/dev/davinci6 \
--device=/dev/davinci7 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm --device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons/ \
-v ${model_dir}:${model_dir} \
-v ${data_dir}:${data_dir} \
-v /root/ascend/log:/root/ascend/log ${docker_image} /bin/bash
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment