Skip to content
Snippets Groups Projects
Commit dbcfa81f authored by dengjian's avatar dengjian
Browse files

dengjian

parent f3810dd8
No related branches found
No related tags found
No related merge requests found
Showing
with 1395 additions and 1 deletion
......@@ -51,6 +51,9 @@
"models/research/cv/ibnnet/infer/mxbase/src/IbnnetOpencv.h" "runtime/references"
"models/official/cv/nasnet/infer/mxbase/NASNet_A_MobileClassifyOpencv.h" "runtime/references"
"models/official/cv/nasnet/infer/mxbase/main_opencv.cpp" "runtime/references"
"models/official/cv/shufflenetv2/infer/mxbase/ShuffleNetV2ClassifyOpencv.h" "runtime/references"
"models/official/cv/shufflenetv2/infer/mxbase/main_opencv.cpp" "runtime/references"
......
ARG FROM_IMAGE_NAME
FROM ${FROM_IMAGE_NAME}
COPY requirements.txt .
RUN pip3.7 install -r requirements.txt
......@@ -14,6 +14,7 @@
# ============================================================================
"""export checkpoint file into AIR MINDIR ONNX models"""
import argparse
import ast
import numpy as np
import mindspore as ms
......@@ -33,7 +34,13 @@ if __name__ == '__main__':
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"],
help="device where the code will be implemented (default: Ascend)")
parser.add_argument('--overwrite_config', type=ast.literal_eval, default=False,
help='whether to overwrite the config according to the arguments')
parser.add_argument('--num_classes', type=int, default=1000, help='number of classes')
args = parser.parse_args()
if args.overwrite_config:
cfg.num_classes = args.num_classes
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend" or args.device_target == "GPU":
......
#coding = utf-8
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import json
import numpy as np
np.set_printoptions(threshold=sys.maxsize)
LABEL_FILE = "HiAI_label.json"
def gen_file_name(img_name):
full_name = img_name.split('/')[-1]
return os.path.splitext(full_name)
def cre_groundtruth_dict(gtfile_path):
"""
:param filename: file contains the imagename and label number
:return: dictionary key imagename, value is label number
"""
img_gt_dict = {}
for gtfile in os.listdir(gtfile_path):
if gtfile != LABEL_FILE:
with open(os.path.join(gtfile_path, gtfile), 'r') as f:
gt = json.load(f)
ret = gt["image"]["annotations"][0]["category_id"]
img_gt_dict[gen_file_name(gtfile)] = ret
return img_gt_dict
def cre_groundtruth_dict_fromtxt(gtfile_path):
"""
:param filename: file contains the imagename and label number
:return: dictionary key imagename, value is label number
"""
img_gt_dict = {}
with open(gtfile_path, 'r')as f:
for line in f.readlines():
temp = line.strip().split(" ")
img_name = temp[0].split(".")[0]
img_lab = temp[1]
img_gt_dict[img_name] = img_lab
return img_gt_dict
def load_statistical_predict_result(filepath):
"""
function:
the prediction esult file data extraction
input:
result file:filepath
output:
n_label:numble of label
data_vec: the probabilitie of prediction in the 1000
:return: probabilities, numble of label, in_type, color
"""
with open(filepath, 'r')as f:
data = f.readline()
temp = data.strip().split(" ")
n_label = len(temp)
data_vec = np.zeros((n_label), dtype=np.float32)
in_type = ''
color = ''
if n_label == 0:
in_type = f.readline()
color = f.readline()
else:
for ind, cls_ind in enumerate(temp):
data_vec[ind] = np.int_(cls_ind)
return data_vec, n_label, in_type, color
def create_visualization_statistical_result(prediction_file_path,
result_store_path, json_file_name,
img_gt_dict, topn=5):
"""
:param prediction_file_path:
:param result_store_path:
:param json_file_name:
:param img_gt_dict:
:param topn:
:return:
"""
writer = open(os.path.join(result_store_path, json_file_name), 'w')
table_dict = {}
table_dict["title"] = "Overall statistical evaluation"
table_dict["value"] = []
count = 0
res_cnt = 0
n_labels = ""
count_hit = np.zeros(topn)
for tfile_name in os.listdir(prediction_file_path):
count += 1
temp = tfile_name.split('.')[0]
index = temp.rfind('_')
img_name = temp[:index]
filepath = os.path.join(prediction_file_path, tfile_name)
ret = load_statistical_predict_result(filepath)
prediction = ret[0]
n_labels = ret[1]
gt = img_gt_dict[img_name]
if n_labels == 1000:
real_label = int(gt)
elif n_labels == 1001:
real_label = int(gt) + 1
else:
real_label = int(gt)
res_cnt = min(len(prediction), topn)
for i in range(res_cnt):
if str(real_label) == str(int(prediction[i])):
count_hit[i] += 1
break
if 'value' not in table_dict.keys():
print("the item value does not exist!")
else:
table_dict["value"].extend(
[{"key": "Number of images", "value": str(count)},
{"key": "Number of classes", "value": str(n_labels)}])
if count == 0:
accuracy = 0
else:
accuracy = np.cumsum(count_hit) / count
for i in range(res_cnt):
table_dict["value"].append({"key": "Top" + str(i + 1) + " accuracy",
"value": str(
round(accuracy[i] * 100, 2)) + '%'})
json.dump(table_dict, writer)
writer.close()
if __name__ == '__main__':
try:
# txt file path
folder_davinci_target = sys.argv[1]
# annotation files path, "val_label.txt"
annotation_file_path = sys.argv[2]
# the path to store the results json path
result_json_path = sys.argv[3]
except IndexError:
print("Please enter target file result folder | ground truth label file | result json file folder | "
"result json file name, such as ./result val_label.txt . result.json")
exit(1)
if not os.path.exists(folder_davinci_target):
print("Target file folder does not exist.")
if not os.path.exists(annotation_file_path):
print("Ground truth file does not exist.")
if not os.path.exists(result_json_path):
print("Result folder doesn't exist.")
img_label_dict = cre_groundtruth_dict_fromtxt(annotation_file_path)
create_visualization_statistical_result(folder_davinci_target,
result_json_path,
sys.argv[4], # json_file_name,
img_label_dict, topn=5)
aipp_op {
aipp_mode: static
input_format : RGB888_U8
rbuv_swap_switch : true
mean_chn_0 : 0
mean_chn_1 : 0
mean_chn_2 : 0
min_chn_0 : 127.5
min_chn_1 : 127.5
min_chn_2 : 127.5
var_reci_chn_0 : 0.00784
var_reci_chn_1 : 0.00784
var_reci_chn_2 : 0.00784
}
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 2 ]
then
echo "Usage: sh air2om.sh [INPUT_MODEL_FILE] [OUTPUT_MODEL_NAME]"
exit 1
fi
# check the INPUT_MODEL_FILE
if [ ! -f $1 ]
then
echo "error: INPUT_MODEL_FILE=$1 is not a file"
exit 1
fi
input_model_file=$1
output_model_name=$2
/usr/local/Ascend/atc/bin/atc \
--model=$input_model_file \
--framework=1 \
--output=$output_model_name \
--input_format=NCHW --input_shape="actual_input_1:1,3,224,224" \
--disable_reuse_memory=0 \
--enable_small_channel=0 \
--log=error \
--soc_version=Ascend310 \
--insert_op_conf=./aipp.config
#!/usr/bin/env bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
data_dir=$2
function show_help() {
echo "Usage: docker_start.sh docker_image data_dir"
}
function param_check() {
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
show_help
exit 1
fi
if [ -z "${data_dir}" ]; then
echo "please input data_dir"
show_help
exit 1
fi
}
param_check
docker run -it \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${data_dir}:${data_dir} \
${docker_image} \
/bin/bash
cmake_minimum_required(VERSION 3.14.0)
project(nasnet_a_mobile)
set(TARGET nasnet_a_mobile)
add_definitions(-DENABLE_DVPP_INTERFACE)
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall
-Dgoogle=mindxsdk_private -D_GLIBCXX_USE_CXX11_ABI=0)
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -pie)
# Check environment variable
if(NOT DEFINED ENV{MX_SDK_HOME})
message(FATAL_ERROR "please define environment variable:MX_SDK_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_VERSION})
message(WARNING "please define environment variable:ASCEND_VERSION")
endif()
if(NOT DEFINED ENV{ARCH_PATTERN})
message(WARNING "please define environment variable:ARCH_PATTERN")
endif()
set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
if(DEFINED ENV{MXSDK_OPENSOURCE_DIR})
set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
else()
set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource)
endif()
include_directories(${ACL_INC_DIR})
include_directories(${OPENSOURCE_DIR}/include)
include_directories(${OPENSOURCE_DIR}/include/opencv4)
include_directories(${MXBASE_INC})
include_directories(${MXBASE_POST_PROCESS_DIR})
link_directories(${ACL_LIB_DIR})
link_directories(${OPENSOURCE_DIR}/lib)
link_directories(${MXBASE_LIB_DIR})
link_directories(${MXBASE_POST_LIB_DIR})
add_executable(${TARGET} main_opencv.cpp NASNet_A_MobileClassifyOpencv.cpp)
target_link_libraries(${TARGET} glog cpprest mxbase resnet50postprocess opencv_world stdc++fs)
install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
/*
* Copyright (c) 2021. Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include "NASNet_A_MobileClassifyOpencv.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
APP_ERROR NASNet_A_MobileClassifyOpencv::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
MxBase::ConfigData configData;
const std::string softmax = initParam.softmax ? "true" : "false";
const std::string checkTensor = initParam.checkTensor ? "true" : "false";
configData.SetJsonValue("CLASS_NUM", std::to_string(initParam.classNum));
configData.SetJsonValue("TOP_K", std::to_string(initParam.topk));
configData.SetJsonValue("SOFTMAX", softmax);
configData.SetJsonValue("CHECK_MODEL", checkTensor);
auto jsonStr = configData.GetCfgJson().serialize();
std::map<std::string, std::shared_ptr<void>> config;
config["postProcessConfigContent"] = std::make_shared<std::string>(jsonStr);
config["labelPath"] = std::make_shared<std::string>(initParam.labelPath);
post_ = std::make_shared<MxBase::Resnet50PostProcess>();
ret = post_->Init(config);
if (ret != APP_ERR_OK) {
LogError << "Resnet50PostProcess init failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR NASNet_A_MobileClassifyOpencv::DeInit() {
model_->DeInit();
post_->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR NASNet_A_MobileClassifyOpencv::ReadImage(const std::string &imgPath, cv::Mat &imageMat) {
imageMat = cv::imread(imgPath, cv::IMREAD_COLOR);
return APP_ERR_OK;
}
APP_ERROR NASNet_A_MobileClassifyOpencv::ResizeImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat) {
static constexpr uint32_t resizeHeight = 304;
static constexpr uint32_t resizeWidth = 304;
cv::resize(srcImageMat, dstImageMat, cv::Size(resizeWidth, resizeHeight));
return APP_ERR_OK;
}
APP_ERROR NASNet_A_MobileClassifyOpencv::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase) {
const uint32_t dataSize = imageMat.cols * imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU;
LogInfo << "image size after crop" << imageMat.cols << " " << imageMat.rows;
MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(imageMat.data, dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape = {imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU, static_cast<uint32_t>(imageMat.cols)};
tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT8);
return APP_ERR_OK;
}
APP_ERROR NASNet_A_MobileClassifyOpencv::Crop(const cv::Mat &srcImageMat, cv::Mat &dstImageMat) {
static cv::Rect rectOfImg(40, 40, 224, 224);
dstImageMat = srcImageMat(rectOfImg).clone();
return APP_ERR_OK;
}
APP_ERROR NASNet_A_MobileClassifyOpencv::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> &outputs) {
auto dtypes = model_->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs.push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR NASNet_A_MobileClassifyOpencv::PostProcess(const std::vector<MxBase::TensorBase> &inputs,
std::vector<std::vector<MxBase::ClassInfo>> &clsInfos) {
APP_ERROR ret = post_->Process(inputs, clsInfos);
if (ret != APP_ERR_OK) {
LogError << "Process failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR NASNet_A_MobileClassifyOpencv::SaveResult(const std::string &imgPath,
const std::vector<std::vector<MxBase::ClassInfo>> &batchClsInfos) {
LogInfo << "image path" << imgPath;
std::string fileName = imgPath.substr(imgPath.find_last_of("/") + 1);
size_t dot = fileName.find_last_of(".");
std::string resFileName = "result/" + fileName.substr(0, dot) + "_1.txt";
LogInfo << "file path for saving result" << resFileName;
std::ofstream outfile(resFileName);
if (outfile.fail()) {
LogError << "Failed to open result file: ";
return APP_ERR_COMM_FAILURE;
}
uint32_t batchIndex = 0;
for (auto clsInfos : batchClsInfos) {
std::string resultStr;
for (auto clsInfo : clsInfos) {
LogDebug << " className:" << clsInfo.className << " confidence:" << clsInfo.confidence <<
" classIndex:" << clsInfo.classId;
resultStr += std::to_string(clsInfo.classId) + " ";
}
outfile << resultStr << std::endl;
batchIndex++;
}
outfile.close();
return APP_ERR_OK;
}
APP_ERROR NASNet_A_MobileClassifyOpencv::Process(const std::string &imgPath) {
cv::Mat imageMat;
APP_ERROR ret = ReadImage(imgPath, imageMat);
if (ret != APP_ERR_OK) {
LogError << "ReadImage failed, ret=" << ret << ".";
return ret;
}
cv::Mat resizeImage;
ret = ResizeImage(imageMat, resizeImage);
if (ret != APP_ERR_OK) {
LogError << "Resize failed, ret=" << ret << ".";
return ret;
}
cv::Mat cropImage;
ret = Crop(resizeImage, cropImage);
if (ret != APP_ERR_OK) {
LogError << "Crop failed, ret=" << ret << ".";
return ret;
}
std::vector<MxBase::TensorBase> inputs = {};
std::vector<MxBase::TensorBase> outputs = {};
MxBase::TensorBase tensorBase;
ret = CVMatToTensorBase(cropImage, tensorBase);
if (ret != APP_ERR_OK) {
LogError << "CVMatToTensorBase failed, ret=" << ret << ".";
return ret;
}
inputs.push_back(tensorBase);
auto startTime = std::chrono::high_resolution_clock::now();
ret = Inference(inputs, outputs);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
std::vector<std::vector<MxBase::ClassInfo>> BatchClsInfos = {};
ret = PostProcess(outputs, BatchClsInfos);
if (ret != APP_ERR_OK) {
LogError << "PostProcess failed, ret=" << ret << ".";
return ret;
}
ret = SaveResult(imgPath, BatchClsInfos);
if (ret != APP_ERR_OK) {
LogError << "Save infer results into file failed. ret = " << ret << ".";
return ret;
}
return APP_ERR_OK;
}
/*
* Copyright (c) 2021. Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_NASNET_A_MOBILECLASSIFYOPENCV_H
#define MXBASE_NASNET_A_MOBILECLASSIFYOPENCV_H
#include <string>
#include <vector>
#include <memory>
#include <opencv2/opencv.hpp>
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
#include "ClassPostProcessors/Resnet50PostProcess.h"
struct InitParam {
uint32_t deviceId;
std::string labelPath;
uint32_t classNum;
uint32_t topk;
bool softmax;
bool checkTensor;
std::string modelPath;
};
class NASNet_A_MobileClassifyOpencv {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR ReadImage(const std::string &imgPath, cv::Mat &imageMat);
APP_ERROR ResizeImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat);
APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase);
APP_ERROR Crop(const cv::Mat &srcImageMat, cv::Mat &dstImageMat);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs);
APP_ERROR PostProcess(const std::vector<MxBase::TensorBase> &inputs,
std::vector<std::vector<MxBase::ClassInfo>> &clsInfos);
APP_ERROR Process(const std::string &imgPath);
// get infer time
double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
private:
APP_ERROR SaveResult(const std::string &imgPath,
const std::vector<std::vector<MxBase::ClassInfo>> &batchClsInfos);
private:
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
std::shared_ptr<MxBase::Resnet50PostProcess> post_;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
// infer time
double inferCostTimeMilliSec = 0.0;
};
#endif
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
path_cur=$(dirname $0)
function check_env()
{
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to ./ when it was not specified by user
export ARCH_PATTERN=./
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
fi
}
function build_nasnet_a_mobile()
{
cd $path_cur
rm -rf build
mkdir -p build
cd build
cmake ..
make
ret=$?
if [ ${ret} -ne 0 ]; then
echo "Failed to build nasnet_a_mobile."
exit ${ret}
fi
make install
}
check_env
build_nasnet_a_mobile
\ No newline at end of file
/*
* Copyright (c) 2021. Huawei Technologies Co., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <dirent.h>
#include "NASNet_A_MobileClassifyOpencv.h"
#include "MxBase/Log/Log.h"
namespace {
const uint32_t CLASS_NUM = 1000;
} // namespace
APP_ERROR ScanImages(const std::string &path, std::vector<std::string> &imgFiles) {
DIR *dirPtr = opendir(path.c_str());
if (dirPtr == nullptr) {
LogError << "opendir failed. dir:" << path;
return APP_ERR_INTERNAL_ERROR;
}
dirent *direntPtr = nullptr;
while ((direntPtr = readdir(dirPtr)) != nullptr) {
std::string fileName = direntPtr->d_name;
if (fileName == "." || fileName == "..") {
continue;
}
imgFiles.emplace_back(path + "/" + fileName);
}
closedir(dirPtr);
return APP_ERR_OK;
}
int main(int argc, char* argv[]) {
if (argc <= 1) {
LogWarn << "Please input image path, such as './imagenet/val'.";
return APP_ERR_OK;
}
InitParam initParam = {};
initParam.deviceId = 0;
initParam.classNum = CLASS_NUM;
initParam.labelPath = "../imagenet1000_clsidx_to_labels.names";
initParam.topk = 5;
initParam.softmax = false;
initParam.checkTensor = true;
initParam.modelPath = "../nasnet_a_mobile.om";
auto nasnet_a_mobile = std::make_shared<NASNet_A_MobileClassifyOpencv>();
APP_ERROR ret = nasnet_a_mobile->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "NASNet_A_MobileClassify init failed, ret=" << ret << ".";
return ret;
}
std::string imgPath = argv[1];
std::vector<std::string> imgFilePaths;
ret = ScanImages(imgPath, imgFilePaths);
if (ret != APP_ERR_OK) {
return ret;
}
auto startTime = std::chrono::high_resolution_clock::now();
for (auto &imgFile : imgFilePaths) {
ret = nasnet_a_mobile->Process(imgFile);
if (ret != APP_ERR_OK) {
LogError << "NASNet_A_MobileClassify process failed, ret=" << ret << ".";
nasnet_a_mobile->DeInit();
return ret;
}
}
auto endTime = std::chrono::high_resolution_clock::now();
nasnet_a_mobile->DeInit();
double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
double fps = 1000.0 * imgFilePaths.size() / nasnet_a_mobile->GetInferCostMilliSec();
LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec";
return APP_ERR_OK;
}
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 1 ]
then
echo "Usage: sh run.sh [DATASET_VAL_PATH]"
exit 1
fi
# check the DATASET_VAL_PATH
if [ ! -d $1 ]
then
echo "error: DATASET_VAL_PATH=$1 is not a path"
exit 1
fi
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/lib/modelpostprocessors:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
# run
./nasnet_a_mobile $1
# coding=utf-8
"""
Copyright (c) 2021. Huawei Technologies Co., Ltd.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import datetime
import json
import os
import sys
from StreamManagerApi import StreamManagerApi
from StreamManagerApi import MxDataInput
def run():
# init stream manager
stream_manager_api = StreamManagerApi()
ret = stream_manager_api.InitManager()
if ret != 0:
print("Failed to init Stream manager, ret=%s" % str(ret))
return
# create streams by pipeline config file
with open("nasnet_a_mobile.pipeline", 'rb') as f:
pipelineStr = f.read()
ret = stream_manager_api.CreateMultipleStreams(pipelineStr)
if ret != 0:
print("Failed to create Stream, ret=%s" % str(ret))
return
# Construct the input of the stream
data_input = MxDataInput()
dir_name = sys.argv[1]
res_dir_name = sys.argv[2]
file_list = os.listdir(dir_name)
if not os.path.exists(res_dir_name):
os.makedirs(res_dir_name)
for file_name in file_list:
file_path = os.path.join(dir_name, file_name)
if not (file_name.lower().endswith(".jpg") or file_name.lower().endswith(".jpeg")):
continue
with open(file_path, 'rb') as f:
data_input.data = f.read()
stream_name = b'im_nasnet'
in_plugin_id = 0
unique_id = stream_manager_api.SendData(stream_name, in_plugin_id, data_input)
if unique_id < 0:
print("Failed to send data to stream.")
return
# Obtain the inference result by specifying streamName and uniqueId.
start_time = datetime.datetime.now()
infer_result = stream_manager_api.GetResult(stream_name, unique_id)
end_time = datetime.datetime.now()
print('sdk run time: {}'.format((end_time - start_time).microseconds))
if infer_result.errorCode != 0:
print("GetResultWithUniqueId error. errorCode=%d, errorMsg=%s" % (
infer_result.errorCode, infer_result.data.decode()))
return
# print the infer result
infer_res = infer_result.data.decode()
print("process img: {}, infer result: {}".format(file_name, infer_res))
load_dict = json.loads(infer_result.data.decode())
if load_dict.get('MxpiClass') is None:
with open(res_dir_name + "/" + file_name[:-5] + '.txt', 'w') as f_write:
f_write.write("")
continue
res_vec = load_dict.get('MxpiClass')
with open(res_dir_name + "/" + file_name[:-5] + '_1.txt', 'w') as f_write:
res_list = [str(item.get("classId")) + " " for item in res_vec]
f_write.writelines(res_list)
f_write.write('\n')
# destroy streams
stream_manager_api.DestroyAllStreams()
if __name__ == '__main__':
run()
CLASS_NUM=1000
SOFTMAX=false
TOP_K=5
{
"im_nasnet": {
"stream_config": {
"deviceId": "0"
},
"appsrc1": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_imagedecoder0"
},
"mxpi_imagedecoder0": {
"props": {
"handleMethod": "opencv"
},
"factory": "mxpi_imagedecoder",
"next": "mxpi_imageresize0"
},
"mxpi_imageresize0": {
"props": {
"handleMethod": "opencv",
"resizeType": "Resizer_Stretch",
"resizeHeight": "304",
"resizeWidth": "304"
},
"factory": "mxpi_imageresize",
"next": "mxpi_opencvcentercrop0"
},
"mxpi_opencvcentercrop0": {
"props": {
"dataSource": "mxpi_imageresize0",
"cropHeight": "224",
"cropWidth": "224"
},
"factory": "mxpi_opencvcentercrop",
"next": "mxpi_tensorinfer0"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "mxpi_opencvcentercrop0",
"modelPath": "../nasnet_a_mobile.om",
"waitingTime": "2000",
"outputDeviceId": "-1"
},
"factory": "mxpi_tensorinfer",
"next": "mxpi_classpostprocessor0"
},
"mxpi_classpostprocessor0": {
"props": {
"dataSource": "mxpi_tensorinfer0",
"postProcessConfigPath": "nasnet_a_mobile.cfg",
"labelPath": "../imagenet1000_clsidx_to_labels.names",
"postProcessLibPath": "libresnet50postprocess.so"
},
"factory": "mxpi_classpostprocessor",
"next": "mxpi_dataserialize0"
},
"mxpi_dataserialize0": {
"props": {
"outputDataKeys": "mxpi_classpostprocessor0"
},
"factory": "mxpi_dataserialize",
"next": "appsink0"
},
"appsink0": {
"props": {
"blocksize": "4096000"
},
"factory": "appsink"
}
}
}
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# != 2 ]
then
echo "Usage: sh run.sh [IMAGE_PATH] [RESULT_DIR]"
exit 1
fi
# check the DATASET_VAL_PATH
if [ ! -d $1 ]
then
echo "error: IMAGE_PATH=$1 is not a path"
exit 1
fi
image_path=$1
result_dir=$2
set -e
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
python3.7 main.py $image_path $result_dir
exit 0
\ No newline at end of file
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train imagenet."""
import argparse
import ast
import os
import time
from collections import OrderedDict
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.rmsprop import RMSProp
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore.common import dtype as mstype
from mindspore import export
from src.config import nasnet_a_mobile_config_gpu, nasnet_a_mobile_config_ascend
from src.dataset import create_dataset
from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobile
from src.lr_generator import get_lr
def export_models(checkpoint_path):
net = NASNetAMobile(num_classes=config.num_classes, is_training=False)
file_list = []
for root, _, files in os.walk(checkpoint_path):
for file in files:
if os.path.splitext(file)[1] == '.ckpt':
file_list.append(os.path.join(root, file))
file_list.sort(key=os.path.getmtime, reverse=True)
exported_count = 0
for checkpoint in file_list:
ckpt_dict = load_checkpoint(checkpoint)
parameter_dict = OrderedDict()
for name in ckpt_dict:
new_name = name
if new_name.startswith("network."):
new_name = new_name.replace("network.", "")
parameter_dict[new_name] = ckpt_dict[name]
load_param_into_net(net, parameter_dict)
output_file = checkpoint.replace('.ckpt', '')
input_data = Tensor(np.zeros([1, 3, 224, 224]), mstype.float32)
if args_opt.export_mindir_model:
export(net, input_data, file_name=output_file, file_format="MINDIR")
if args_opt.export_air_model and context.get_context("device_target") == "Ascend":
export(net, input_data, file_name=output_file, file_format="AIR")
if args_opt.export_onnx_model:
export(net, input_data, file_name=output_file, file_format="ONNX")
print(checkpoint, 'is exported')
exported_count += 1
if exported_count >= args_opt.export_checkpoint_count:
print('exported checkpoint count =', exported_count)
break
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
for key in list(origin_dict.keys()):
for name in param_filter:
if name in key:
print("Delete parameter from checkpoint: ", key)
del origin_dict[key]
break
if __name__ == '__main__':
start_time = time.time()
parser = argparse.ArgumentParser(description='image classification training')
parser.add_argument('--dataset_path', type=str, default='../imagenet', help='Dataset path')
parser.add_argument('--resume', type=str, default='',
help='resume training with existed checkpoint')
parser.add_argument('--resume_epoch', type=int, default=1, help='Resume from which epoch')
parser.add_argument('--is_distributed', type=ast.literal_eval, default=False,
help='distributed training')
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'),
help='run platform')
parser.add_argument('--device_id', type=int, default=0, help='device id(Default:0)')
parser.add_argument('--is_modelarts', type=ast.literal_eval, default=False)
parser.add_argument('--data_url', type=str, default=None, help='Dataset path for modelarts')
parser.add_argument('--train_url', type=str, default=None, help='Output path for modelarts')
parser.add_argument('--use_pynative_mode', type=ast.literal_eval, default=False,
help='whether to use pynative mode for device(Default: False)')
parser.add_argument('--amp_level', type=str, default='O0', help='level for mixed precision training')
parser.add_argument('--remove_classifier_parameter', type=ast.literal_eval, default=False,
help='whether to filter the classifier parameter in the checkpoint (Default: False)')
parser.add_argument('--export_mindir_model', type=ast.literal_eval, default=True,
help='whether to export MINDIR model (Default: True)')
parser.add_argument('--export_air_model', type=ast.literal_eval, default=True,
help='whether to export AIR model on Ascend 910 (Default: True)')
parser.add_argument('--export_onnx_model', type=ast.literal_eval, default=False,
help='whether to export ONNX model (Default: False)')
parser.add_argument('--export_checkpoint_count', type=int, default=1,
help='export how many checkpoints reversed from the last epoch (Default: 1)')
parser.add_argument('--overwrite_config', type=ast.literal_eval, default=False,
help='whether to overwrite the config according to the arguments')
#when the overwrite_config == True , the following argument will be written to config
parser.add_argument('--epoch_size', type=int, default=600,
help='Epoches for trainning(default:600)')
parser.add_argument('--num_classes', type=int, default=1000, help='number of classes')
parser.add_argument('--cutout', type=ast.literal_eval, default=False,
help='whether to cutout the data for trainning(Default: False)')
parser.add_argument('--train_batch_size', type=int, default=32, help='batch size for training')
parser.add_argument('--lr_init', type=float, default=0.32, help='learning rate for training')
args_opt = parser.parse_args()
is_modelarts = args_opt.is_modelarts
if args_opt.platform == 'GPU':
config = nasnet_a_mobile_config_gpu
drop_remainder = True
else:
config = nasnet_a_mobile_config_ascend
drop_remainder = False
if args_opt.overwrite_config:
config.epoch_size = args_opt.epoch_size
config.num_classes = args_opt.num_classes
config.cutout = args_opt.cutout
config.train_batch_size = args_opt.train_batch_size
config.lr_init = args_opt.lr_init
print('epoch_size = ', config.epoch_size, ' num_classes = ', config.num_classes)
print('train_batch_size = ', config.train_batch_size, ' lr_init = ', config.lr_init)
print('cutout = ', config.cutout, ' cutout_length =', config.cutout_length)
set_seed(config.random_seed)
if args_opt.use_pynative_mode:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args_opt.platform)
else:
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
# init distributed
if args_opt.is_distributed:
init()
if args_opt.is_modelarts:
device_id = get_rank()
config.group_size = get_group_size()
else:
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID', default='0'))
config.group_size = int(os.getenv('DEVICE_NUM', default='1'))
else:
device_id = get_rank()
config.group_size = get_group_size()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=config.group_size,
gradients_mean=True)
else:
device_id = args_opt.device_id
config.group_size = 1
context.set_context(device_id=device_id)
rank_id = device_id
config.rank = rank_id
print('rank_id = ', rank_id, ' group_size = ', config.group_size)
resume = args_opt.resume
if args_opt.is_modelarts:
# download dataset from obs to cache
import moxing
dataset_path = '/cache/dataset'
if args_opt.data_url.find('/train/') > 0:
dataset_path += '/train/'
moxing.file.copy_parallel(src_url=args_opt.data_url, dst_url=dataset_path)
# download the checkpoint from obs to cache
if resume != '':
base_name = os.path.basename(resume)
dst_url = '/cache/checkpoint/' + base_name
moxing.file.copy_parallel(src_url=resume, dst_url=dst_url)
resume = dst_url
# the path for the output of training
save_checkpoint_path = '/cache/train_output/' + str(device_id) + '/'
else:
dataset_path = args_opt.dataset_path
save_checkpoint_path = os.path.join(config.ckpt_path, 'ckpt_' + str(config.rank) + '/')
log_filename = os.path.join(save_checkpoint_path, 'log_' + str(device_id) + '.txt')
# dataloader
if dataset_path.find('/train') > 0:
dataset_train_path = dataset_path
else:
dataset_train_path = os.path.join(dataset_path, 'train')
if not os.path.exists(dataset_train_path):
dataset_train_path = dataset_path
train_dataset = create_dataset(dataset_train_path, True, config.rank, config.group_size,
num_parallel_workers=config.work_nums,
batch_size=config.train_batch_size,
drop_remainder=drop_remainder, shuffle=True,
cutout=config.cutout, cutout_length=config.cutout_length,
image_size=config.image_size)
batches_per_epoch = train_dataset.get_dataset_size()
# network
net_with_loss = NASNetAMobileWithLoss(config)
if resume != '':
ckpt = load_checkpoint(resume)
print('remove_classifier_parameter = ', args_opt.remove_classifier_parameter)
if args_opt.remove_classifier_parameter:
filter_list = [x.name for x in net_with_loss.network.classifier.get_parameters()]
filter_checkpoint_parameter_by_list(ckpt, filter_list)
filter_list = [x.name for x in net_with_loss.network.aux_logits.fc.get_parameters()]
filter_checkpoint_parameter_by_list(ckpt, filter_list)
load_param_into_net(net_with_loss, ckpt)
print(resume, ' is loaded')
# learning rate schedule
lr = get_lr(lr_init=config.lr_init, lr_decay_rate=config.lr_decay_rate,
num_epoch_per_decay=config.num_epoch_per_decay, total_epochs=config.epoch_size,
steps_per_epoch=batches_per_epoch, is_stair=True)
if resume:
resume_epoch = args_opt.resume_epoch
step_num_in_epoch = train_dataset.get_dataset_size()
lr = lr[step_num_in_epoch * resume_epoch:]
# adjust the epoch_size in config so that the source code for model.train will be simplified.
config.epoch_size = config.epoch_size - resume_epoch
print('Effective epoch_size = ', config.epoch_size)
lr = Tensor(lr, mstype.float32)
# optimizer
decayed_params = []
no_decayed_params = []
for param in net_with_loss.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net_with_loss.trainable_params()}]
optimizer = RMSProp(group_params, lr, decay=config.rmsprop_decay, weight_decay=config.weight_decay,
momentum=config.momentum, epsilon=config.opt_eps, loss_scale=config.loss_scale)
# high performance
net_with_loss.set_train()
print('amp_level = ', args_opt.amp_level)
model = Model(net_with_loss, optimizer=optimizer, amp_level=args_opt.amp_level)
print("============== Starting Training ==============")
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
time_cb = TimeMonitor(data_size=batches_per_epoch)
callbacks = [loss_cb, time_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=f"nasnet-a-mobile-rank{config.rank}",
directory=save_checkpoint_path, config=config_ck)
if args_opt.is_distributed and config.is_save_on_master == 1:
if config.rank == 0:
callbacks.append(ckpoint_cb)
else:
callbacks.append(ckpoint_cb)
try:
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
except KeyboardInterrupt:
print("!!!!!!!!!!!!!! Train Failed !!!!!!!!!!!!!!!!!!!")
else:
print("============== Train Success ==================")
export_models(save_checkpoint_path)
print("data_url = ", args_opt.data_url)
print("cutout = ", config.cutout, " cutout_length = ", config.cutout_length)
print("epoch_size = ", config.epoch_size, " train_batch_size = ", config.train_batch_size,
" lr_init = ", config.lr_init, " weight_decay = ", config.weight_decay)
print("time: ", (time.time() - start_time) / 3600, " hours")
fp = open(log_filename, 'at+')
print("data_url = ", args_opt.data_url, file=fp)
print("cutout = ", config.cutout, " cutout_length = ", config.cutout_length, file=fp)
print("epoch_size = ", config.epoch_size, " train_batch_size = ", config.train_batch_size,
" lr_init = ", config.lr_init, " weight_decay = ", config.weight_decay, file=fp)
print("time: ", (time.time() - start_time) / 3600, file=fp)
fp.close()
if args_opt.is_modelarts:
if os.path.exists('/cache/train_output'):
moxing.file.copy_parallel(src_url='/cache/train_output', dst_url=args_opt.train_url)
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
data_dir=$2
model_dir=$3
docker run -it --ipc=host \
--device=/dev/davinci0 \
--device=/dev/davinci1 \
--device=/dev/davinci2 \
--device=/dev/davinci3 \
--device=/dev/davinci4 \
--device=/dev/davinci5 \
--device=/dev/davinci6 \
--device=/dev/davinci7 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm --device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons/ \
-v ${model_dir}:${model_dir} \
-v ${data_dir}:${data_dir} \
-v /root/ascend/log:/root/ascend/log ${docker_image} /bin/bash
\ No newline at end of file
......@@ -65,7 +65,7 @@ if __name__ == '__main__':
help='Epoches for trainning(default:600)')
parser.add_argument('--num_classes', type=int, default=1000, help='number of classes')
parser.add_argument('--cutout', type=ast.literal_eval, default=False,
help='whether to cutout the data for trainning(Default: True)')
help='whether to cutout the data for trainning(Default: False)')
args_opt = parser.parse_args()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment