diff --git a/research/cv/resnetv2/Dockerfile b/research/cv/resnetv2/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..360861ede17fb0ab697fbcac190acde7c1e29fef --- /dev/null +++ b/research/cv/resnetv2/Dockerfile @@ -0,0 +1,5 @@ +ARG FROM_IMAGE_NAME +FROM ${FROM_IMAGE_NAME} + +COPY requirements.txt . +RUN pip3.7 install -r requirements.txt diff --git a/research/cv/resnetv2/infer/resnetv2_101/convert/aipp.config b/research/cv/resnetv2/infer/resnetv2_101/convert/aipp.config new file mode 100644 index 0000000000000000000000000000000000000000..2b22e28521375fadfdb6e489a2d0432300770e84 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/convert/aipp.config @@ -0,0 +1,16 @@ +aipp_op { + aipp_mode : static + input_format : RGB888_U8 + csc_switch : false + rbuv_swap_switch : true + + mean_chn_0 : 0 + mean_chn_1 : 0 + mean_chn_2 : 0 + min_chn_0 : 123.675 + min_chn_1 : 116.28 + min_chn_2 : 103.53 + var_reci_chn_0 : 0.0171247538316637 + var_reci_chn_1 : 0.0175070028011204 + var_reci_chn_2 : 0.0174291938997821 +} diff --git a/research/cv/resnetv2/infer/resnetv2_101/convert/convert_om.sh b/research/cv/resnetv2/infer/resnetv2_101/convert/convert_om.sh new file mode 100644 index 0000000000000000000000000000000000000000..c6366c062f127bbdf1d1fa1b1d7067b00f9222c8 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/convert/convert_om.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_path=$1 +output_model_name=$2 + +/usr/local/Ascend/atc/bin/atc \ + --model=$model_path \ + --framework=1 \ + --output=$output_model_name \ + --input_format=NCHW --input_shape="actual_input_1:1,3,32,32" \ + --enable_small_channel=1 \ + --log=error \ + --soc_version=Ascend310 \ + --insert_op_conf=./aipp.config diff --git a/research/cv/resnetv2/infer/resnetv2_101/data/config/resnetv2.cfg b/research/cv/resnetv2/infer/resnetv2_101/data/config/resnetv2.cfg new file mode 100644 index 0000000000000000000000000000000000000000..b18718a1da5891d2b8a4498b06638a2b4c7619a9 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/data/config/resnetv2.cfg @@ -0,0 +1,3 @@ +CLASS_NUM=10 +SOFTMAX=false +TOP_K=5 diff --git a/research/cv/resnetv2/infer/resnetv2_101/data/config/resnetv2.pipeline b/research/cv/resnetv2/infer/resnetv2_101/data/config/resnetv2.pipeline new file mode 100644 index 0000000000000000000000000000000000000000..eb8b263f75f9461a0a1708cbce183d1b3b50a3e0 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/data/config/resnetv2.pipeline @@ -0,0 +1,64 @@ +{ + "im_resnetv2": { + "stream_config": { + "deviceId": "0" + }, + "appsrc1": { + "props": { + "blocksize": "409600" + }, + "factory": "appsrc", + "next": "mxpi_imagedecoder0" + }, + "mxpi_imagedecoder0": { + "props": { + "handleMethod": "opencv" + }, + "factory": "mxpi_imagedecoder", + "next": "mxpi_imageresize0" + }, + "mxpi_imageresize0": { + "props": { + "handleMethod": "opencv", + "resizeType": "Resizer_Stretch", + "resizeHeight": "32", + "resizeWidth": "32" + }, + "factory": "mxpi_imageresize", + "next": "mxpi_tensorinfer0" + }, + "mxpi_tensorinfer0": { + "props": { + "dataSource": "mxpi_imageresize0", + "modelPath": "../data/model/resnetv2.om", + "waitingTime": "1", + "outputDeviceId": "-1" + }, + "factory": "mxpi_tensorinfer", + "next": "mxpi_classpostprocessor0" + }, + "mxpi_classpostprocessor0": { + "props": { + "dataSource": "mxpi_tensorinfer0", + "postProcessConfigPath": "../data/config/resnetv2.cfg", + "labelPath": "../data/config/cifar10.names", + "postProcessLibPath": "/usr/local/sdk_home/mxManufacture/lib/modelpostprocessors/libresnet50postprocess.so" + }, + "factory": "mxpi_classpostprocessor", + "next": "mxpi_dataserialize0" + }, + "mxpi_dataserialize0": { + "props": { + "outputDataKeys": "mxpi_classpostprocessor0" + }, + "factory": "mxpi_dataserialize", + "next": "appsink0" + }, + "appsink0": { + "props": { + "blocksize": "4096000" + }, + "factory": "appsink" + } + } +} diff --git a/research/cv/resnetv2/infer/resnetv2_101/data/images/cifar10.py b/research/cv/resnetv2/infer/resnetv2_101/data/images/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..0564414ef2d853fa22a4a59ab6866aa37a8b6727 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/data/images/cifar10.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" cifar10.py """ + +import os +import cv2 +import numpy as np + +loc_1 = './train_cifar10/' +loc_2 = './test_cifar10/' + + +def unpickle(file_name): + import pickle + with open(file_name, 'rb') as fo: + dict_res = pickle.load(fo, encoding='bytes') + return dict_res + + +def convert_train_data(file_dir): + """ ./train_cifar10/ """ + if not os.path.exists(loc_1): + os.mkdir(loc_1) + for i in range(1, 6): + data_name = os.path.join(file_dir, 'data_batch_' + str(i)) + data_dict = unpickle(data_name) + print('{} is processing'.format(data_name)) + for j in range(10000): + img = np.reshape(data_dict[b'data'][j], (3, 32, 32)) + img = np.transpose(img, (1, 2, 0)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_name = "%s%s%s.jpg" % (loc_1, str(data_dict[b'labels'][j]), str((i) * 10000 + j)) + cv2.imwrite(img_name, img) + print('{} is done'.format(data_name)) + + +def convert_test_data(file_dir): + """ ./test_cifar10/ && test_label.txt """ + if not os.path.exists(loc_2): + os.mkdir(loc_2) + test_data_name = file_dir + '/test_batch' + print('{} is processing'.format(test_data_name)) + test_dict = unpickle(test_data_name) + for m in range(10000): + img = np.reshape(test_dict[b'data'][m], (3, 32, 32)) + img = np.transpose(img, (1, 2, 0)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_name = '%s%s%s%s' % (loc_2, str(test_dict[b'labels'][m]), str(10000 + m), '.jpg') + img_label = "%s%s.jpg" % (str(test_dict[b'labels'][m]), str(10000 + m)) + cv2.imwrite(img_name, img) + with open("test_label.txt", "a") as f: + f.write(img_label + " " * 10 + str(test_dict[b'labels'][m])) + f.write("\n") + print("{} is done".format(test_data_name)) + + +def cifar10_img(): + """ + from ./cifar-10-batches-py + to ./train_cifar10/ ./test_cifar10/ ./test_label.txt + """ + file_dir = './cifar-10-batches-py' + convert_train_data(file_dir) + convert_test_data(file_dir) + print('Finish transforming to image') + + +if __name__ == '__main__': + cifar10_img() diff --git a/research/cv/resnetv2/infer/resnetv2_101/docker_start_infer.sh b/research/cv/resnetv2/infer/resnetv2_101/docker_start_infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..2dcac5f542755d375b5fe97456606cef927659c6 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/docker_start_infer.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +docker_image=$1 +model_dir=$2 +data_dir=$3 + + +function show_help() { + echo "Usage: docker_start.sh docker_image data_dir" +} + +function param_check() { + if [ -z "${docker_image}" ]; then + echo "please input docker_image" + show_help + exit 1 + fi + + if [ -z "${model_dir}" ]; then + echo "please input model_dir" + show_help + exit 1 + fi + + if [ -z "${data_dir}" ]; then + echo "please input data_dir" + show_help + exit 1 + fi + +} + +param_check + +docker run -it \ + --device=/dev/davinci0 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm \ + --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v ${model_dir}:${model_dir} \ + -v ${data_dir}:${data_dir} \ + ${docker_image} \ + /bin/bash diff --git a/research/cv/resnetv2/infer/resnetv2_101/mxbase/CMakeLists.txt b/research/cv/resnetv2/infer/resnetv2_101/mxbase/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4d6334b04c4e5c65c442e08e6946acff8719c8d9 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/mxbase/CMakeLists.txt @@ -0,0 +1,53 @@ +cmake_minimum_required(VERSION 3.14.0) +project(resnetv2) +set(TARGET resnetv2) +add_definitions(-DENABLE_DVPP_INTERFACE) +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-Dgoogle=mindxsdk_private) +add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall) +add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie) + + +# Check environment variable +if(NOT DEFINED ENV{ASCEND_HOME}) + message(FATAL_ERROR "please define environment variable:ASCEND_HOME") +endif() +if(NOT DEFINED ENV{ASCEND_VERSION}) + message(WARNING "please define environment variable:ASCEND_VERSION") +endif() +if(NOT DEFINED ENV{ARCH_PATTERN}) + message(WARNING "please define environment variable:ARCH_PATTERN") +endif() +set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include) +set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64) + + +set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME}) + +set(MXBASE_INC ${MXBASE_ROOT_DIR}/include) +set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib) +set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors) +set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include) + +if(DEFINED ENV{MXSDK_OPENSOURCE_DIR}) + set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR}) +else() + set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource) +endif() + +include_directories(${ACL_INC_DIR}) +include_directories(${OPENSOURCE_DIR}/include) +include_directories(${OPENSOURCE_DIR}/include/opencv4) + +include_directories(${MXBASE_INC}) +include_directories(${MXBASE_POST_PROCESS_DIR}) +link_directories(${ACL_LIB_DIR}) +link_directories(${OPENSOURCE_DIR}/lib) +link_directories(${MXBASE_LIB_DIR}) +link_directories(${MXBASE_POST_LIB_DIR}) + +add_executable(${TARGET} src/main.cpp src/Resnetv2.cpp) + +target_link_libraries(${TARGET} glog cpprest mxbase resnet50postprocess opencv_world) + +install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/) diff --git a/research/cv/resnetv2/infer/resnetv2_101/mxbase/build.sh b/research/cv/resnetv2/infer/resnetv2_101/mxbase/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..efebf5e03be51ff93bdd746eb9bcc0daee4263cb --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/mxbase/build.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +path_cur=$(dirname $0) + +function check_env() +{ + # set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user + if [ ! "${ASCEND_HOME}" ]; then + export ASCEND_HOME=/usr/local/Ascend/ + echo "Set ASCEND_HOME to the default value: ${ASCEND_HOME}" + else + echo "ASCEND_HOME is set to ${ASCEND_HOME} by user" + fi + + if [ ! "${ASCEND_VERSION}" ]; then + export ASCEND_VERSION=nnrt/latest + echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}" + else + echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user" + fi + + if [ ! "${ARCH_PATTERN}" ]; then + # set ARCH_PATTERN to ./ when it was not specified by user + export ARCH_PATTERN=./ + echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}" + else + echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user" + + fi + + + +} + +function build() +{ + cd $path_cur + rm -rf build + mkdir -p build + cd build + cmake .. + make + ret=$? + if [ ${ret} -ne 0 ]; then + echo "Failed to build resnetv2." + exit ${ret} + fi + make install +} + +check_env +build diff --git a/research/cv/resnetv2/infer/resnetv2_101/mxbase/src/Resnetv2.cpp b/research/cv/resnetv2/infer/resnetv2_101/mxbase/src/Resnetv2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8a435aa558b5e1460fa025eaabbfe26040bd0785 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/mxbase/src/Resnetv2.cpp @@ -0,0 +1,221 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <map> +#include <memory> +#include <string> +#include <vector> +#include "Resnetv2.h" +#include "MxBase/DeviceManager/DeviceManager.h" +#include "MxBase/Log/Log.h" + +namespace { + const uint32_t YUV_BYTE_NU = 3; + const uint32_t YUV_BYTE_DE = 2; + const uint32_t VPC_H_ALIGN = 2; + const uint32_t MAX_LENGTH = 128; +} + +APP_ERROR Resnetv2::Init(const InitParam &initParam) { + deviceId_ = initParam.deviceId; + APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices(); + if (ret != APP_ERR_OK) { + LogError << "Init devices failed, ret = " << ret << "."; + return ret; + } + ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId); + if (ret != APP_ERR_OK) { + LogError << "Set context failed, ret = " << ret << "."; + return ret; + } + model_ = std::make_shared<MxBase::ModelInferenceProcessor>(); + ret = model_->Init(initParam.modelPath, modelDesc_); + if (ret != APP_ERR_OK) { + LogError << "ModelInferenceProcessor init failed, ret = " << ret << "."; + return ret; + } + MxBase::ConfigData configData; + const std::string softmax = initParam.softmax ? "true" : "false"; + const std::string checkTensor = initParam.checkTensor ? "true" : "false"; + + configData.SetJsonValue("CLASS_NUM", std::to_string(initParam.classNum)); + configData.SetJsonValue("TOP_K", std::to_string(initParam.topk)); + configData.SetJsonValue("SOFTMAX", softmax); + configData.SetJsonValue("CHECK_MODEL", checkTensor); + + auto jsonStr = configData.GetCfgJson().serialize(); + std::map<std::string, std::shared_ptr<void>> config; + config["postProcessConfigContent"] = std::make_shared<std::string>(jsonStr); + config["labelPath"] = std::make_shared<std::string>(initParam.labelPath); + + post_ = std::make_shared<MxBase::Resnet50PostProcess>(); + ret = post_->Init(config); + if (ret != APP_ERR_OK) { + LogError << "Resnetv2PostProcess init failed, ret = " << ret << "."; + return ret; + } + + return APP_ERR_OK; +} + +APP_ERROR Resnetv2::DeInit() { + model_->DeInit(); + post_->DeInit(); + MxBase::DeviceManager::GetInstance()->DestroyDevices(); + return APP_ERR_OK; +} + +APP_ERROR Resnetv2::ReadImage(const std::string &imgPath, cv::Mat *imageMat) { + *imageMat = cv::imread(imgPath, cv::IMREAD_COLOR); + LogInfo << "image size: " << imageMat->size(); + return APP_ERR_OK; +} + +APP_ERROR Resnetv2::ResizeImage(cv::Mat *imageMat) { + static constexpr uint32_t resizeHeight = 32; + static constexpr uint32_t resizeWidth = 32; + cv::resize(*imageMat, *imageMat, cv::Size(resizeWidth, resizeHeight)); + return APP_ERR_OK; +} + +APP_ERROR Resnetv2::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase *tensorBase) { + uint32_t batchSize = modelDesc_.inputTensors[0].tensorDims[0]; + const uint32_t dataSize = imageMat.cols * imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU * batchSize; + LogInfo << "image size after crop: [" << imageMat.cols << " x " << imageMat.rows << "]"; + MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_); + MxBase::MemoryData memoryDataSrc(imageMat.data, dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC); + + APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc); + if (ret != APP_ERR_OK) { + LogError << GetError(ret) << "Memory malloc failed."; + return ret; + } + std::vector<uint32_t> shape = {imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU * batchSize, + static_cast<uint32_t>(imageMat.cols)}; + *tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT8); + return APP_ERR_OK; +} + +APP_ERROR Resnetv2::Inference(const std::vector<MxBase::TensorBase> &inputs, + std::vector<MxBase::TensorBase> *outputs) { + auto dtypes = model_->GetOutputDataType(); + for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) { + std::vector<uint32_t> shape = {}; + for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) { + shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]); + } + MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_); + APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor); + if (ret != APP_ERR_OK) { + LogError << "TensorBaseMalloc failed, ret = " << ret << "."; + return ret; + } + outputs->push_back(tensor); + } + MxBase::DynamicInfo dynamicInfo = {}; + dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH; + auto startTime = std::chrono::high_resolution_clock::now(); + APP_ERROR ret = model_->ModelInference(inputs, *outputs, dynamicInfo); + auto endTime = std::chrono::high_resolution_clock::now(); + double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); + inferCostTimeMilliSec += costMs; + if (ret != APP_ERR_OK) { + LogError << "ModelInference failed, ret = " << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +APP_ERROR Resnetv2::PostProcess(const std::vector<MxBase::TensorBase> &inputs, + std::vector<std::vector<MxBase::ClassInfo>> *clsInfos) { + APP_ERROR ret = post_->Process(inputs, *clsInfos); + if (ret != APP_ERR_OK) { + LogError << "Process failed, ret = " << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +APP_ERROR Resnetv2::SaveResult(const std::string &imgPath, const std::string &resPath, + const std::vector<std::vector<MxBase::ClassInfo>> &batchClsInfos) { + LogInfo << "image path: " << imgPath; + std::string fileName = imgPath.substr(imgPath.find_last_of("/") + 1); + size_t dot = fileName.find_last_of("."); + std::string resFileName = resPath + "/" + fileName.substr(0, dot) + "_1.txt"; + LogInfo << "file path for saving result: " << resFileName; + std::ofstream outfile(resFileName); + if (outfile.fail()) { + LogError << "Failed to open result file: "; + return APP_ERR_COMM_FAILURE; + } + uint32_t batchIndex = 0; + for (auto clsInfos : batchClsInfos) { + std::string resultStr; + for (auto clsInfo : clsInfos) { + LogDebug << "className:" << clsInfo.className << " confidence:" << clsInfo.confidence << + " classIndex:" << clsInfo.classId; + resultStr += std::to_string(clsInfo.classId) + " "; + } + outfile << resultStr << std::endl; + batchIndex++; + } + outfile.close(); + return APP_ERR_OK; +} + +APP_ERROR Resnetv2::Process(const std::string &imgPath, const std::string &resPath) { + cv::Mat imageMat; + APP_ERROR ret = ReadImage(imgPath, &imageMat); + if (ret != APP_ERR_OK) { + LogError << "ReadImage failed, ret = " << ret << "."; + return ret; + } + ret = ResizeImage(&imageMat); + if (ret != APP_ERR_OK) { + LogError << "ResizeImage failed, ret = " << ret << "."; + return ret; + } + std::vector<MxBase::TensorBase> inputs = {}; + std::vector<MxBase::TensorBase> outputs = {}; + MxBase::TensorBase tensorBase; + ret = CVMatToTensorBase(imageMat, &tensorBase); + if (ret != APP_ERR_OK) { + LogError << "CVMatToTensorBase failed, ret = " << ret << "."; + return ret; + } + inputs.push_back(tensorBase); + auto startTime = std::chrono::high_resolution_clock::now(); + ret = Inference(inputs, &outputs); + auto endTime = std::chrono::high_resolution_clock::now(); + double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); + inferCostTimeMilliSec += costMs; + if (ret != APP_ERR_OK) { + LogError << "Inference failed, ret = " << ret << "."; + return ret; + } + std::vector<std::vector<MxBase::ClassInfo>> BatchClsInfos = {}; + ret = PostProcess(outputs, &BatchClsInfos); + if (ret != APP_ERR_OK) { + LogError << "PostProcess failed, ret = " << ret << "."; + return ret; + } + ret = SaveResult(imgPath, resPath, BatchClsInfos); + if (ret != APP_ERR_OK) { + LogError << "Save infer results into file failed. ret = " << ret << "."; + return ret; + } + return APP_ERR_OK; +} diff --git a/research/cv/resnetv2/infer/resnetv2_101/mxbase/src/Resnetv2.h b/research/cv/resnetv2/infer/resnetv2_101/mxbase/src/Resnetv2.h new file mode 100644 index 0000000000000000000000000000000000000000..f87a532a1d362d97c033e1aaec4188507e410218 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/mxbase/src/Resnetv2.h @@ -0,0 +1,67 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <string> +#include <vector> +#include <memory> +#ifndef MxBase_ALEXNET_H +#define MxBase_ALEXNET_H +#include <opencv2/opencv.hpp> + +#include "MxBase/DvppWrapper/DvppWrapper.h" +#include "MxBase/ModelInfer/ModelInferenceProcessor.h" +#include "MxBase/postprocess/include/ClassPostProcessors/Resnet50PostProcess.h" +#include "MxBase/Tensor/TensorContext/TensorContext.h" + + +struct InitParam { + uint32_t deviceId; + std::string labelPath; + uint32_t classNum; + uint32_t topk; + bool softmax; + bool checkTensor; + std::string modelPath; +}; + +struct ImageShape { + uint32_t width; + uint32_t height; +}; + +class Resnetv2 { + public: + APP_ERROR Init(const InitParam &initParam); + APP_ERROR DeInit(); + APP_ERROR ReadImage(const std::string &imgPath, cv::Mat *imageMat); + APP_ERROR ResizeImage(cv::Mat *imageMat); + APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase *tensorBase); + APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> *outputs); + APP_ERROR PostProcess(const std::vector<MxBase::TensorBase> &inputs, + std::vector<std::vector<MxBase::ClassInfo>> *clsInfos); + APP_ERROR Process(const std::string &imgPath, const std::string &resPath); + double GetInferCostMilliSec() const {return inferCostTimeMilliSec;} + private: + APP_ERROR SaveResult(const std::string &imgPath, const std::string &resPath, + const std::vector<std::vector<MxBase::ClassInfo>> &batchClsInfos); + private: + std::shared_ptr<MxBase::ModelInferenceProcessor> model_; + std::shared_ptr<MxBase::Resnet50PostProcess> post_; + MxBase::ModelDesc modelDesc_; + uint32_t deviceId_ = 0; + double inferCostTimeMilliSec = 0.0; +}; +#endif diff --git a/research/cv/resnetv2/infer/resnetv2_101/mxbase/src/main.cpp b/research/cv/resnetv2/infer/resnetv2_101/mxbase/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d398bfb13962c253ab4155775a1bef0710fc3c9 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/mxbase/src/main.cpp @@ -0,0 +1,91 @@ +/* + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <dirent.h> +#include "Resnetv2.h" +#include "MxBase/Log/Log.h" + + +namespace { + const uint32_t CLASS_NUM = 10; +} + +APP_ERROR ScanImages(const std::string &path, std::vector<std::string> *imgFiles) { + DIR *dirPtr = opendir(path.c_str()); + if (dirPtr == nullptr) { + LogError << "opendir failed. dir:" << path << path.c_str(); + return APP_ERR_INTERNAL_ERROR; + } + dirent *direntPtr = nullptr; + while ((direntPtr = readdir(dirPtr)) != nullptr) { + std::string fileName = direntPtr->d_name; + if (fileName == "." || fileName == "..") { + continue; + } + + imgFiles->emplace_back(path + "/" + fileName); + } + LogInfo << "opendir ok. dir:"; + closedir(dirPtr); + return APP_ERR_OK; +} + +int main(int argc, char* argv[]) { + if (argc <= 2) { + LogWarn << "Please input image path and result path, such as './resnetv2 image_dir res_dir'"; + return APP_ERR_OK; + } + + InitParam initParam = {}; + initParam.deviceId = 0; + initParam.classNum = CLASS_NUM; + initParam.labelPath = "../data/config/cifar10.names"; + initParam.topk = 5; + initParam.softmax = false; + initParam.checkTensor = true; + initParam.modelPath = "../data/model/resnetv2.om"; + auto resnetv2 = std::make_shared<Resnetv2>(); + APP_ERROR ret = resnetv2->Init(initParam); + if (ret != APP_ERR_OK) { + resnetv2->DeInit(); + LogError << "Resnetv2Classify init failed, ret = " << ret << "."; + return ret; + } + + std::string imgPath = argv[1]; + std::vector<std::string> imgFilePaths; + ret = ScanImages(imgPath, &imgFilePaths); + if (ret != APP_ERR_OK) { + return ret; + } + std::string resPath = argv[2]; + auto startTime = std::chrono::high_resolution_clock::now(); + for (auto &imgFile : imgFilePaths) { + ret = resnetv2->Process(imgFile, resPath); + if (ret != APP_ERR_OK) { + LogError << "Resnetv2Classify process failed, ret = " << ret << "."; + resnetv2->DeInit(); + return ret; + } + } + auto endTime = std::chrono::high_resolution_clock::now(); + resnetv2->DeInit(); + double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); + double fps = 1000.0 * imgFilePaths.size() / resnetv2->GetInferCostMilliSec(); + LogInfo << "[Process Delay] cost:" << costMilliSecs << " ms\tfps: " << fps << "imgs/sec"; + return APP_ERR_OK; +} + diff --git a/research/cv/resnetv2/infer/resnetv2_101/sdk/classification_task_metric.py b/research/cv/resnetv2/infer/resnetv2_101/sdk/classification_task_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..5dff3659f308508f375382bce553d819938fc498 --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/sdk/classification_task_metric.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" classification_task_metric.py """ + +import os +import sys +import json +import numpy as np + +np.set_printoptions(threshold=sys.maxsize) + +LABEL_FILE = "HiAI_label.json" + +def gen_file_name(img_name): + full_name = img_name.split('/')[-1] + return os.path.splitext(full_name) + +def cre_groundtruth_dict(gtfile_path): + """ + :param filename: file contains the imagename and label number + :return: dictionary key imagename, value is label number + """ + img_gt_dict = {} + for gtfile in os.listdir(gtfile_path): + if gtfile != LABEL_FILE: + with open(os.path.join(gtfile_path, gtfile), 'r') as f: + gt = json.load(f) + image_name = os.path.splitext(gtfile.split('/')[-1]) + img_gt_dict[image_name] = gt["image"]["annotations"][0]["category_id"] + return img_gt_dict + +def cre_groundtruth_dict_fromtxt(gtfile_path): + """ + :param filename: file contains the imagename and label number + :return: dictionary key imagename, value is label number + """ + img_gt_dict = {} + with open(gtfile_path, 'r')as f: + + for line in f.readlines(): + temp = line.strip().split(" ") + img_name = temp[0].split(".")[0] + img_lab = temp[-1] + img_gt_dict[img_name] = img_lab + return img_gt_dict + +def load_statistical_predict_result(filepath): + """ + function: + the prediction esult file data extraction + input: + result file:filepath + output: + n_label:numble of label + data_vec: the probabilitie of prediction in the 1000 + :return: probabilities, numble of label, in_type, color + """ + with open(filepath, 'r')as f: + temp = f.readline().strip().split(" ") + n_label = len(temp) + data_vec = np.zeros((len(temp)), dtype=np.float32) + if n_label != 0: + for ind, cls_ind in enumerate(temp): + data_vec[ind] = np.int(cls_ind) + return data_vec, n_label + +def create_visualization_statistical_result(prediction_file_path, + result_store_path, json_file_name, + img_gt_dict, topn=5): + """ create_visualization_statistical_result """ + writer = open(os.path.join(result_store_path, json_file_name), 'w') + table_dict = {} + table_dict["title"] = "Overall statistical evaluation" + table_dict["value"] = [] + count = 0 + res_cnt = 0 + n_labels = "" + count_hit = np.zeros(topn) + for tfile_name in os.listdir(prediction_file_path): + count += 1 + temp = tfile_name.split('.')[0] + index = temp.rfind('_') + img_name = temp[:index] + filepath = os.path.join(prediction_file_path, tfile_name) + + ret = load_statistical_predict_result(filepath) + prediction = ret[0] + n_labels = ret[1] + gt = img_gt_dict[img_name] + if n_labels == 1000: + real_label = int(gt) + elif n_labels == 1001: + real_label = int(gt) + 1 + else: + real_label = int(gt) + res_cnt = min(len(prediction), topn) + for i in range(res_cnt): + if str(real_label) == str(int(prediction[i])): + count_hit[i] += 1 + break + if 'value' not in table_dict.keys(): + print("the item value does not exist!") + else: + table_dict["value"].extend( + [{"key": "Number of images", "value": str(count)}, + {"key": "Number of classes", "value": str(n_labels)}]) + if count == 0: + accuracy = 0 + else: + accuracy = np.cumsum(count_hit) / count + for i in range(res_cnt): + table_dict["value"].append({"key": "Top" + str(i + 1) + " accuracy", + "value": str(round(accuracy[i] * 100, 2)) + '%'}) + json.dump(table_dict, writer, indent=4) + writer.close() + +def run(): + """ run """ + if len(sys.argv) == 5: + # Target file folder path + folder_davinci_target = sys.argv[1] + # annotation files path, "val_label.txt" + annotation_file_path = sys.argv[2] + # the path to store the results json path + result_json_path = sys.argv[3] + # result json file name + result_json_file_name = sys.argv[4] + else: + print("Please enter target file result folder | ground truth label file | result json file folder | " + "result json file name, such as ./result val_label.txt . result.json") + exit(1) + + if not os.path.exists(folder_davinci_target): + print("Target file folder does not exist.") + + if not os.path.exists(annotation_file_path): + print("Ground truth file does not exist.") + + if not os.path.exists(result_json_path): + print("Result folder doesn't exist.") + + img_label_dict = cre_groundtruth_dict_fromtxt(annotation_file_path) + create_visualization_statistical_result(folder_davinci_target, + result_json_path, result_json_file_name, + img_label_dict, topn=5) + +if __name__ == '__main__': + run() diff --git a/research/cv/resnetv2/infer/resnetv2_101/sdk/main.py b/research/cv/resnetv2/infer/resnetv2_101/sdk/main.py new file mode 100644 index 0000000000000000000000000000000000000000..52a2a7fa26fb0c46b2f2aa55b3dda40b4b8438fa --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/sdk/main.py @@ -0,0 +1,108 @@ +# coding=utf-8 + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" main.py """ + +import datetime +import json +import os +import sys +from StreamManagerApi import StreamManagerApi, MxDataInput + +def info(msg): + nowtime = datetime.datetime.now().isoformat() + print("[INFO][%s %d %s] %s" %(nowtime, os.getpid(), __file__, msg)) + +def warn(msg): + nowtime = datetime.datetime.now().isoformat() + print("\033[33m[WARN][%s %d %s] %s\033[0m" %(nowtime, os.getpid(), __file__, msg)) + +def err(msg): + nowtime = datetime.datetime.now().isoformat() + print("\033[31m[ERROR][%s %d %s] %s\033[0m" %(nowtime, os.getpid(), __file__, msg)) + +if __name__ == '__main__': + # init stream manager + stream_manager_api = StreamManagerApi() + ret = stream_manager_api.InitManager() + if ret != 0: + err("Failed to init Stream manager, ret=%s" % str(ret)) + exit() + + # create streams by pipeline config file + with open("../data/config/resnetv2.pipeline", 'rb') as f: + pipelineStr = f.read() + ret = stream_manager_api.CreateMultipleStreams(pipelineStr) + + if ret != 0: + err("Failed to create Stream, ret=%s" % str(ret)) + exit() + + # Construct the input of the stream + data_input = MxDataInput() + + dir_name = sys.argv[1] + res_dir_name = sys.argv[2] + file_list = os.listdir(dir_name) + if not os.path.exists(res_dir_name): + os.makedirs(res_dir_name) + + for file_name in file_list: + file_path = os.path.join(dir_name, file_name) + if not (file_name.lower().endswith(".jpg") + or file_name.lower().endswith(".jpeg")): + continue + + with open(file_path, 'rb') as f: + data_input.data = f.read() + info("Read data from %s" % file_path) + + empty_data = [] + stream_name = b'im_resnetv2' + in_plugin_id = 0 + + start_time = datetime.datetime.now() + + unique_id = stream_manager_api.SendData(stream_name, in_plugin_id, + data_input) + if unique_id < 0: + err("Failed to send data to stream.") + exit() + # Obtain the inference result by specifying streamName and uniqueId. + infer_result = stream_manager_api.GetResult(stream_name, unique_id) + end_time = datetime.datetime.now() + info('sdk run time: {}us'.format((end_time - start_time).microseconds)) + if infer_result.errorCode != 0: + err("GetResultWithUniqueId error. errorCode=%d, errorMsg=%s" % + (infer_result.errorCode, infer_result.data.decode())) + exit() + # print the infer result + infer_res = infer_result.data.decode() + info("process img: {}, infer result: {}".format(file_name, infer_res)) + load_dict = json.loads(infer_result.data.decode()) + if load_dict.get('MxpiClass') is None: + with open(res_dir_name + "/" + file_name.split('.')[0] + '.txt', + 'w') as f_write: + f_write.write("") + continue + res_vec = load_dict.get('MxpiClass') + with open(res_dir_name + "/" + file_name.split('.')[0] + '_1.txt', + 'w') as f_write: + res_list = [str(item.get("classId")) + " " for item in res_vec] + f_write.writelines(res_list) + f_write.write('\n') + + # destroy streams + stream_manager_api.DestroyAllStreams() diff --git a/research/cv/resnetv2/infer/resnetv2_101/sdk/run.sh b/research/cv/resnetv2/infer/resnetv2_101/sdk/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..01cf3158106ae8079f73d3e404a5a5bee213150d --- /dev/null +++ b/research/cv/resnetv2/infer/resnetv2_101/sdk/run.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +image_path=$1 +result_dir=$2 + +set -e + +# Simple log helper functions +info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; } +warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; } + +export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH} +export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner +export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins + +#to set PYTHONPATH, import the StreamManagerApi.py +export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python + +python3.7 main.py $image_path $result_dir +exit 0 diff --git a/research/cv/resnetv2/modelarts/resnetv2_101/train_start.py b/research/cv/resnetv2/modelarts/resnetv2_101/train_start.py new file mode 100644 index 0000000000000000000000000000000000000000..99ffcafc9796089a1c41e3f68f16c3f6c517c282 --- /dev/null +++ b/research/cv/resnetv2/modelarts/resnetv2_101/train_start.py @@ -0,0 +1,166 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 3.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train resnetv2.""" + +import argparse +import os +import numpy as np + +from mindspore.nn import Momentum +from mindspore import Model, Tensor, load_checkpoint, load_param_into_net, export, context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.nn import SoftmaxCrossEntropyWithLogits +from mindspore.common import set_seed +from mindspore.train.loss_scale_manager import FixedLossScaleManager + +# should find /src + +from src.lr_generator import get_lr +from src.CrossEntropySmooth import CrossEntropySmooth + +parser = argparse.ArgumentParser('mindspore resnetv2 training') + +parser.add_argument('--net', type=str, default='resnetv2_50', + help='Resnetv2 Model, resnetv2_50, resnetv2_101, resnetv2_152') +parser.add_argument('--dataset', type=str, default='cifar10', + help='Dataset, cifar10, imagenet2012') +parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--train_url', type=str, required=True, default='', + help='where training ckpts saved') +parser.add_argument('--data_url', type=str, required=True, default='', + help='path of dataset') + +# train +parser.add_argument('--pre_trained', type=str, default=None, help='pretrained checkpoint path') +parser.add_argument('--epoch_size', type=int, default=None, help='epochs') +parser.add_argument('--lr_init', type=float, default=None, help='base learning rate') + +# export +parser.add_argument('--width', type=int, default=32, help='input width') +parser.add_argument('--height', type=int, default=32, help='input height') +parser.add_argument('--file_name', type=str, default='resnetv2', help='output air file name') +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") + +args, _ = parser.parse_known_args() + +# import net +if args.net == "resnetv2_50": + from src.resnetv2 import PreActResNet50 as resnetv2 +elif args.net == 'resnetv2_101': + from src.resnetv2 import PreActResNet101 as resnetv2 +elif args.net == 'resnetv2_152': + from src.resnetv2 import PreActResNet152 as resnetv2 +else: + raise ValueError("network is not support.") + +# import dataset and config +if args.dataset == "cifar10": + from src.dataset import create_dataset1 as create_dataset + from src.config import config1 as config +elif args.dataset == "cifar100": + from src.dataset import create_dataset2 as create_dataset + from src.config import config2 as config +elif args.dataset == 'imagenet2012': + from src.dataset import create_dataset3 as create_dataset + from src.config import config3 as config +else: + raise ValueError("dataset is not support.") + +def _train(): + """ train """ + print("============== Starting Training ==============") + target = args.device_target + + # init context + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + + # create dataset + dataset = create_dataset(dataset_path=args.data_url, do_train=True, repeat_num=1, + batch_size=config.batch_size, target=target) + step_size = dataset.get_dataset_size() + + # define net + epoch_size = args.epoch_size if args.epoch_size else config.epoch_size + net = resnetv2(config.class_num, config.low_memory) + + # init weight + if args.pre_trained: + param_dict = load_checkpoint(args.pre_trained) + load_param_into_net(net, param_dict) + + # init lr + lr_init = args.lr_init if args.lr_init else config.lr_init + lr = get_lr(lr_init=lr_init, lr_end=config.lr_end, lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size, + lr_decay_mode=config.lr_decay_mode) + lr = Tensor(lr) + + # define loss, opt, model + if args.dataset == "imagenet2012": + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropySmooth(sparse=True, reduction="mean", + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + else: + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, + config.weight_decay, config.loss_scale) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) + + # define callbacks + time_cb = TimeMonitor(data_size=step_size) + loss_cb = LossMonitor() + + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_save_dir = args.train_url if args.train_url else config.save_checkpoint_path + ckpoint_cb = ModelCheckpoint(prefix=f"train_{args.net}_{args.dataset}", + directory=ckpt_save_dir, config=config_ck) + + # train + callbacks = [time_cb, loss_cb, ckpoint_cb] + model.train(epoch_size, dataset, callbacks=callbacks) + +def _get_last_ckpt(ckpt_dir): + """ get ckpt """ + ckpt_files = [(os.stat(os.path.join(ckpt_dir, ckpt_file)).st_ctime, ckpt_file) + for ckpt_file in os.listdir(ckpt_dir) + if ckpt_file.endswith('.ckpt')] + if not ckpt_files: + print("No ckpt file found.") + return None + + return os.path.join(ckpt_dir, max(ckpt_files)[1]) + +def _export_air(): + """ export air """ + print("============== Starting Exporting ==============") + ckpt_file = _get_last_ckpt(args.train_url) + if not ckpt_file: + return + + net = resnetv2(config.class_num) + param_dict = load_checkpoint(ckpt_file) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.zeros([config.batch_size, 3, args.height, args.width], np.float32)) + export(net, input_arr, file_name=os.path.join(args.train_url, args.file_name), file_format=args.file_format) + +if __name__ == '__main__': + set_seed(1) + _train() + _export_air() diff --git a/research/cv/resnetv2/scripts/docker_start.sh b/research/cv/resnetv2/scripts/docker_start.sh new file mode 100644 index 0000000000000000000000000000000000000000..e39553192b435f70dae6aa6b6adbfd6ffe901b40 --- /dev/null +++ b/research/cv/resnetv2/scripts/docker_start.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.mitations under the License. + +docker_image=$1 +data_dir=$2 +model_dir=$3 + +docker run -it --ipc=host \ + --device=/dev/davinci0 \ + --device=/dev/davinci1 \ + --device=/dev/davinci2 \ + --device=/dev/davinci3 \ + --device=/dev/davinci4 \ + --device=/dev/davinci5 \ + --device=/dev/davinci6 \ + --device=/dev/davinci7 \ + --device=/dev/davinci_manager \ + --device=/dev/devmm_svm --device=/dev/hisi_hdc \ + -v /usr/local/Ascend/driver:/usr/local/Ascend/driver \ + -v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons/ \ + -v ${model_dir}:${model_dir} \ + -v ${data_dir}:${data_dir} \ + -v /root/ascend/log:/root/ascend/log ${docker_image} /bin/bash \ No newline at end of file