Skip to content
Snippets Groups Projects
Commit 3538a3ff authored by wanglin's avatar wanglin
Browse files

sdk&mxbase

parent 60f10baa
No related branches found
No related tags found
No related merge requests found
Showing
with 1348 additions and 5 deletions
......@@ -26,4 +26,8 @@
"models/research/cv/FaceAttribute/infer/mxbase/faceattribute/FaceAttribute.cpp" "runtime/references"
"models/official/cv/lenet/infer/mxbase/LenetOpencv.h" "runtime/references"
"models/official/cv/lenet/infer/mxbase/main_opencv.cpp" "runtime/references"
\ No newline at end of file
"models/official/cv/lenet/infer/mxbase/main_opencv.cpp" "runtime/references"
"models/research/cv/squeezenet1_1/infer/mxbase/Squeezenet1_1ClassifyOpencv.h" "runtime/references"
"models/research/cv/squeezenet1_1/infer/mxbase/main_opencv.cpp" "runtime/references"
"models/research/cv/squeezenet1_1/infer/mxbase/Squeezenet1_1ClassifyOpencv.cpp" "runtime/references"
\ No newline at end of file
......@@ -24,7 +24,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.CrossEntropySmooth import CrossEntropySmooth
from src.squeezenet import SqueezeNet as squeezenet
from src.dataset import create_dataset_imagenet as create_dataset
from src.config import config
from src.config import config_imagenet as config
local_data_url = '/cache/data'
local_ckpt_url = '/cache/ckpt.ckpt'
......
aipp_op {
aipp_mode: static
input_format : RGB888_U8
rbuv_swap_switch : true
mean_chn_0 : 0
mean_chn_1 : 0
mean_chn_2 : 0
min_chn_0 : 123.675
min_chn_1 : 116.28
min_chn_2 : 103.53
var_reci_chn_0 : 0.0171247538316637
var_reci_chn_1 : 0.0175070028011204
var_reci_chn_2 : 0.0174291938997821
}
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
model_path=$1
output_model_name=$2
aipp_cfg=$3
/usr/local/Ascend/atc/bin/atc \
--model=$model_path \
--framework=1 \
--output=$output_model_name \
--input_format=NCHW --input_shape="actual_input_1:1,3,227,227" \
--enable_small_channel=1 \
--log=error \
--soc_version=Ascend310 \
--insert_op_conf=$aipp_cfg \
--output_type=FP32
\ No newline at end of file
CLASS_NUM=1000
SOFTMAX=false
TOP_K=5
#!/usr/bin/env bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
model_dir=$2
data_dir=$3
function show_help() {
echo "Usage: docker_start.sh docker_image model_dir data_dir"
}
function param_check() {
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
show_help
exit 1
fi
if [ -z "${model_dir}" ]; then
echo "please input model_dir"
show_help
exit 1
fi
if [ -z "${data_dir}" ]; then
echo "please input data_dir"
show_help
exit 1
fi
}
param_check
docker run -it \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${model_dir}:${model_dir} \
-v ${data_dir}:${data_dir} \
${docker_image} \
/bin/bash
cmake_minimum_required(VERSION 3.14.0)
project(squeezenet1_1)
set(TARGET squeezenet1_1)
add_definitions(-DENABLE_DVPP_INTERFACE)
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall
-Dgoogle=mindxsdk_private -D_GLIBCXX_USE_CXX11_ABI=0)
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -pie)
# Check environment variable
if(NOT DEFINED ENV{MX_SDK_HOME})
message(FATAL_ERROR "please define environment variable:MX_SDK_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_VERSION})
message(WARNING "please define environment variable:ASCEND_VERSION")
endif()
if(NOT DEFINED ENV{ARCH_PATTERN})
message(WARNING "please define environment variable:ARCH_PATTERN")
endif()
set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
if(DEFINED ENV{MXSDK_OPENSOURCE_DIR})
set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
else()
set(OPENSOURCE_DIR ${MXBASE_ROOT_DIR}/opensource)
endif()
include_directories(${ACL_INC_DIR})
include_directories(${OPENSOURCE_DIR}/include)
include_directories(${OPENSOURCE_DIR}/include/opencv4)
include_directories(${MXBASE_INC})
include_directories(${MXBASE_POST_PROCESS_DIR})
link_directories(${ACL_LIB_DIR})
link_directories(${OPENSOURCE_DIR}/lib)
link_directories(${MXBASE_LIB_DIR})
link_directories(${MXBASE_POST_LIB_DIR})
add_executable(${TARGET} main_opencv.cpp Squeezenet1_1ClassifyOpencv.cpp)
target_link_libraries(${TARGET} glog cpprest mxbase resnet50postprocess opencv_world stdc++fs)
install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
/*
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <map>
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
#include "Squeezenet1_1ClassifyOpencv.h"
APP_ERROR Squeezenet1_1ClassifyOpencv::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
dvppWrapper_ = std::make_shared<MxBase::DvppWrapper>();
ret = dvppWrapper_->Init();
if (ret != APP_ERR_OK) {
LogError << "DvppWrapper init failed, ret=" << ret << ".";
return ret;
}
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
MxBase::ConfigData configData;
const std::string softmax = initParam.softmax ? "true" : "false";
const std::string checkTensor = initParam.checkTensor ? "true" : "false";
configData.SetJsonValue("CLASS_NUM", std::to_string(initParam.classNum));
configData.SetJsonValue("TOP_K", std::to_string(initParam.topk));
configData.SetJsonValue("SOFTMAX", softmax);
configData.SetJsonValue("CHECK_MODEL", checkTensor);
auto jsonStr = configData.GetCfgJson().serialize();
std::map<std::string, std::shared_ptr<void>> config;
config["postProcessConfigContent"] = std::make_shared<std::string>(jsonStr);
config["labelPath"] = std::make_shared<std::string>(initParam.labelPath);
post_ = std::make_shared<MxBase::Resnet50PostProcess>();
ret = post_->Init(config);
if (ret != APP_ERR_OK) {
LogError << "Squeezenet1_1PostProcess init failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Squeezenet1_1ClassifyOpencv::DeInit() {
dvppWrapper_->DeInit();
model_->DeInit();
post_->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR Squeezenet1_1ClassifyOpencv::ReadImage(const std::string &imgPath, cv::Mat &imageMat) {
imageMat = cv::imread(imgPath, cv::IMREAD_COLOR);
return APP_ERR_OK;
}
APP_ERROR Squeezenet1_1ClassifyOpencv::ResizeImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat) {
static constexpr uint32_t resizeHeight = 256;
static constexpr uint32_t resizeWidth = 256;
cv::resize(srcImageMat, dstImageMat, cv::Size(resizeWidth, resizeHeight));
return APP_ERR_OK;
}
APP_ERROR Squeezenet1_1ClassifyOpencv::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase) {
const uint32_t dataSize = imageMat.cols * imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU;
LogInfo << "image size after crop" << imageMat.cols << " " << imageMat.rows;
MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(imageMat.data, dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape = {imageMat.rows * MxBase::YUV444_RGB_WIDTH_NU, static_cast<uint32_t>(imageMat.cols)};
tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_UINT8);
return APP_ERR_OK;
}
APP_ERROR Squeezenet1_1ClassifyOpencv::Crop(const cv::Mat &srcImageMat, cv::Mat &dstImageMat) {
static cv::Rect rectOfImg(14.5, 14.5, 227, 227);
dstImageMat = srcImageMat(rectOfImg).clone();
return APP_ERR_OK;
}
APP_ERROR Squeezenet1_1ClassifyOpencv::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> &outputs) {
uint32_t first = inputs[0].GetShape()[MxBase::VECTOR_FIRST_INDEX];
uint32_t second = inputs[0].GetShape()[MxBase::VECTOR_SECOND_INDEX];
uint32_t third = inputs[0].GetShape()[MxBase::VECTOR_THIRD_INDEX];
uint32_t fourth = inputs[0].GetShape()[MxBase::VECTOR_FOURTH_INDEX];
std::cout << "++ inputs: " << inputs.size() << " " << first << " "
<< second << " " << third << " " << fourth << std::endl;
auto dtypes = model_->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs.push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Squeezenet1_1ClassifyOpencv::PostProcess(const std::vector<MxBase::TensorBase> &inputs,
std::vector<std::vector<MxBase::ClassInfo>> &clsInfos) {
APP_ERROR ret = post_->Process(inputs, clsInfos);
if (ret != APP_ERR_OK) {
LogError << "Process failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Squeezenet1_1ClassifyOpencv::SaveResult(const std::string &imgPath,
const std::vector<std::vector<MxBase::ClassInfo>> &BatchClsInfos) {
LogInfo << "image path" << imgPath;
std::string fileName = imgPath.substr(imgPath.find_last_of("/") + 1);
size_t dot = fileName.find_last_of(".");
std::string resFileName = "result/" + fileName.substr(0, dot) + "_1.txt";
LogInfo << "file path for saving result" << resFileName;
std::ofstream outfile(resFileName);
if (outfile.fail()) {
LogError << "Failed to open result file: ";
return APP_ERR_COMM_FAILURE;
}
uint32_t batchIndex = 0;
for (auto clsInfos : BatchClsInfos) {
std::string resultStr;
for (auto clsInfo : clsInfos) {
LogDebug << " className:" << clsInfo.className << " confidence:" << clsInfo.confidence <<
" classIndex:" << clsInfo.classId;
resultStr += std::to_string(clsInfo.classId) + " ";
}
outfile << resultStr << std::endl;
batchIndex++;
}
outfile.close();
return APP_ERR_OK;
}
APP_ERROR Squeezenet1_1ClassifyOpencv::Process(const std::string &imgPath) {
cv::Mat imageMat;
APP_ERROR ret = ReadImage(imgPath, imageMat);
if (ret != APP_ERR_OK) {
LogError << "ReadImage failed, ret=" << ret << ".";
return ret;
}
cv::Mat resizeImage;
ret = ResizeImage(imageMat, resizeImage);
if (ret != APP_ERR_OK) {
LogError << "Resize failed, ret=" << ret << ".";
return ret;
}
cv::Mat cropImage;
ret = Crop(resizeImage, cropImage);
if (ret != APP_ERR_OK) {
LogError << "Crop failed, ret=" << ret << ".";
return ret;
}
std::vector<MxBase::TensorBase> inputs = {};
std::vector<MxBase::TensorBase> outputs = {};
MxBase::TensorBase tensorBase;
ret = CVMatToTensorBase(cropImage, tensorBase);
if (ret != APP_ERR_OK) {
LogError << "CVMatToTensorBase failed, ret=" << ret << ".";
return ret;
}
inputs.push_back(tensorBase);
auto startTime = std::chrono::high_resolution_clock::now();
ret = Inference(inputs, outputs);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count(); // save time
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
std::vector<std::vector<MxBase::ClassInfo>> BatchClsInfos = {};
ret = PostProcess(outputs, BatchClsInfos);
if (ret != APP_ERR_OK) {
LogError << "PostProcess failed, ret=" << ret << ".";
return ret;
}
ret = SaveResult(imgPath, BatchClsInfos);
if (ret != APP_ERR_OK) {
LogError << "Save infer results into file failed. ret = " << ret << ".";
return ret;
}
return APP_ERR_OK;
}
/*
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_Squeezenet1_1CLASSIFYOPENCV_H
#define MXBASE_Squeezenet1_1CLASSIFYOPENCV_H
#include <string>
#include <vector>
#include <memory>
#include <opencv2/opencv.hpp>
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
#include "ClassPostProcessors/Resnet50PostProcess.h"
struct InitParam {
uint32_t deviceId;
std::string labelPath;
uint32_t classNum;
uint32_t topk;
bool softmax;
bool checkTensor;
std::string modelPath;
};
class Squeezenet1_1ClassifyOpencv {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR ReadImage(const std::string &imgPath, cv::Mat &imageMat);
APP_ERROR ResizeImage(const cv::Mat &srcImageMat, cv::Mat &dstImageMat);
APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase);
APP_ERROR Crop(const cv::Mat &srcImageMat, cv::Mat &dstImageMat);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs);
APP_ERROR PostProcess(const std::vector<MxBase::TensorBase> &inputs,
std::vector<std::vector<MxBase::ClassInfo>> &clsInfos);
APP_ERROR Process(const std::string &imgPath);
// get infer time
double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
private:
APP_ERROR SaveResult(const std::string &imgPath,
const std::vector<std::vector<MxBase::ClassInfo>> &batchClsInfos);
private:
std::shared_ptr<MxBase::DvppWrapper> dvppWrapper_;
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
std::shared_ptr<MxBase::Resnet50PostProcess> post_;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
// infer time
double inferCostTimeMilliSec = 0.0;
};
#endif
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
path_cur=$(dirname $0)
function check_env()
{
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to ./ when it was not specified by user
export ARCH_PATTERN=./
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
fi
}
function build_squeezenet1_1()
{
cd $path_cur
rm -rf build
mkdir -p build
cd build
cmake ..
make
ret=$?
if [ ${ret} -ne 0 ]; then
echo "Failed to build squeezenet1_1."
exit ${ret}
fi
}
check_env
build_squeezenet1_1
\ No newline at end of file
/*
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "Squeezenet1_1ClassifyOpencv.h"
#include <dirent.h>
#include "MxBase/Log/Log.h"
namespace {
const uint32_t CLASS_NUM = 1000;
} // namespace
APP_ERROR ScanImages(const std::string &path, std::vector<std::string> &imgFiles) {
DIR *dirPtr = opendir(path.c_str());
if (dirPtr == nullptr) {
LogError << "opendir failed. dir:" << path;
return APP_ERR_INTERNAL_ERROR;
}
dirent *direntPtr = nullptr;
while ((direntPtr = readdir(dirPtr)) != nullptr) {
std::string fileName = direntPtr->d_name;
if (fileName == "." || fileName == "..") {
continue;
}
imgFiles.emplace_back(path + "/" + fileName);
}
closedir(dirPtr);
return APP_ERR_OK;
}
int main(int argc, char* argv[]) {
if (argc <= 1) {
LogWarn << "Please input image path, such as './squeezenet1_1 image_dir'.";
return APP_ERR_OK;
}
InitParam initParam = {};
initParam.deviceId = 0;
initParam.classNum = CLASS_NUM;
initParam.labelPath = "../data/config/imagenet1000_clsidx_to_labels.names";
initParam.topk = 5;
initParam.softmax = false;
initParam.checkTensor = true;
initParam.modelPath = "../data/models/squeezenet.om";
auto squeezenet1_1 = std::make_shared<Squeezenet1_1ClassifyOpencv>();
APP_ERROR ret = squeezenet1_1->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "Squeezenet1_1Classify init failed, ret=" << ret << ".";
return ret;
}
std::string imgPath = argv[1];
std::vector<std::string> imgFilePaths;
ret = ScanImages(imgPath, imgFilePaths);
if (ret != APP_ERR_OK) {
squeezenet1_1->DeInit();
return ret;
}
auto startTime = std::chrono::high_resolution_clock::now();
for (auto &imgFile : imgFilePaths) {
ret = squeezenet1_1->Process(imgFile);
if (ret != APP_ERR_OK) {
LogError << "Squeezenet1_1Classify process failed, ret=" << ret << ".";
squeezenet1_1->DeInit();
return ret;
}
}
auto endTime = std::chrono::high_resolution_clock::now();
squeezenet1_1->DeInit();
double costMilliSecs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
double fps = 1000.0 * imgFilePaths.size() / squeezenet1_1->GetInferCostMilliSec();
LogInfo << "[Process Delay] cost: " << costMilliSecs << " ms\tfps: " << fps << " imgs/sec";
return APP_ERR_OK;
}
# coding = utf-8
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""calculation accuracy"""
import os
import sys
import json
import numpy as np
np.set_printoptions(threshold=sys.maxsize)
LABEL_FILE = "HiAI_label.json"
def gen_file_name(img_name):
full_name = img_name.split('/')[-1]
return os.path.splitext(full_name)
def cre_groundtruth_dict(gtfile_path):
"""
:param filename: file contains the imagename and label number
:return: dictionary key imagename, value is label number
"""
img_gt_dict = {}
for gtfile in os.listdir(gtfile_path):
if gtfile != LABEL_FILE:
with open(os.path.join(gtfile_path, gtfile), 'r') as f:
gt = json.load(f)
ret = gt["image"]["annotations"][0]["category_id"]
img_gt_dict[gen_file_name(gtfile)] = ret
return img_gt_dict
def cre_groundtruth_dict_fromtxt(gtfile_path):
"""
:param filename: file contains the imagename and label number
:return: dictionary key imagename, value is label number
"""
img_gt_dict = {}
with open(gtfile_path, 'r')as f:
for line in f.readlines():
temp = line.strip().split(" ")
img_name = temp[0].split(".")[0]
img_lab = temp[1]
img_gt_dict[img_name] = img_lab
return img_gt_dict
def load_statistical_predict_result(filepath):
"""
function:
the prediction esult file data extraction
input:
result file:filepath
output:
n_label:numble of label
data_vec: the probabilitie of prediction in the 1000
:return: probabilities, numble of label, in_type, color
"""
with open(filepath, 'r')as f:
data = f.readline()
temp = data.strip().split(" ")
n_label = len(temp)
data_vec = np.zeros((n_label), dtype=np.float32)
in_type = ''
color = ''
if n_label == 0:
in_type = f.readline()
color = f.readline()
else:
for ind, cls_ind in enumerate(temp):
data_vec[ind] = np.int32(cls_ind)
return data_vec, n_label, in_type, color
def create_visualization_statistical_result(prediction_file_path,
result_store_path, _json_file_name,
img_gt_dict, topn=5):
"""
:param prediction_file_path:
:param result_store_path:
:param json_file_name:
:param img_gt_dict:
:param topn:
:return:
"""
writer = open(os.path.join(result_store_path, json_file_name), 'w')
table_dict = {}
table_dict["title"] = "Overall statistical evaluation"
table_dict["value"] = []
count = 0
res_cnt = 0
n_labels = ""
count_hit = np.zeros(topn)
for tfile_name in os.listdir(prediction_file_path):
count += 1
temp = tfile_name.split('.')[0]
index = temp.rfind('_')
img_name = temp[:index]
filepath = os.path.join(prediction_file_path, tfile_name)
ret = load_statistical_predict_result(filepath)
prediction = ret[0]
n_labels = ret[1]
gt = img_gt_dict[img_name]
if n_labels == 1000:
real_label = int(gt)
elif n_labels == 1001:
real_label = int(gt) + 1
else:
real_label = int(gt)
res_cnt = min(len(prediction), topn)
for i in range(res_cnt):
if str(real_label) == str(int(prediction[i])):
count_hit[i] += 1
break
if 'value' not in table_dict.keys():
print("the item value does not exist!")
else:
table_dict["value"].extend(
[{"key": "Number of images", "value": str(count)},
{"key": "Number of classes", "value": str(n_labels)}])
if count == 0:
accuracy = 0
else:
accuracy = np.cumsum(count_hit) / count
for i in range(res_cnt):
table_dict["value"].append({"key": "Top" + str(i + 1) + " accuracy",
"value": str(
round(accuracy[i] * 100, 2)) + '%'})
json.dump(table_dict, writer)
writer.close()
if __name__ == '__main__':
try:
# txt file path
folder_davinci_target = sys.argv[1]
# annotation files path, "val_label.txt"
annotation_file_path = sys.argv[2]
# the path to store the results json path
result_json_path = sys.argv[3]
# result json file name
json_file_name = sys.argv[4]
except IndexError:
print("Please enter target file result folder | ground truth label file | result json file folder | "
"result json file name, such as ./result val_label.txt . result.json")
exit(1)
if not os.path.exists(folder_davinci_target):
print("Target file folder does not exist.")
if not os.path.exists(annotation_file_path):
print("Ground truth file does not exist.")
if not os.path.exists(result_json_path):
print("Result folder doesn't exist.")
img_label_dict = cre_groundtruth_dict_fromtxt(annotation_file_path)
create_visualization_statistical_result(folder_davinci_target,
result_json_path, json_file_name,
img_label_dict, topn=5)
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""infer squeezenet"""
import os
import sys
import json
import datetime
from StreamManagerApi import StreamManagerApi, MxDataInput
if __name__ == '__main__':
# init stream manager
stream_manager_api = StreamManagerApi()
ret = stream_manager_api.InitManager()
if ret != 0:
print("Failed to init Stream manager, ret=%s" % str(ret))
exit()
# create streams by pipeline config file
with open("squeezenet.pipeline", 'rb') as f:
pipeline_str = f.read()
ret = stream_manager_api.CreateMultipleStreams(pipeline_str)
if ret != 0:
print("Failed to create Stream, ret=%s" % str(ret))
exit()
# Construct the input of the stream
data_input = MxDataInput()
dir_name = sys.argv[1]
res_dir_name = sys.argv[2]
file_list = os.listdir(dir_name)
file_list.sort()
if not os.path.exists(res_dir_name):
os.makedirs(res_dir_name)
for file_name in file_list:
print(file_name)
file_path = dir_name + file_name
if file_name.lower().endswith((".JPEG", ".jpeg", "JPG", "jpg")):
portion = os.path.splitext(file_name)
with open(file_path, 'rb') as f:
data_input.data = f.read()
else:
continue
empty_data = []
stream_name = b'im_squeezenet1_1'
in_plugin_id = 0
uniqueId = stream_manager_api.SendData(stream_name, in_plugin_id, data_input)
if uniqueId < 0:
print("Failed to send data to stream.")
exit()
# Obtain the inference result by specifying stream_name and uniqueId.
start_time = datetime.datetime.now()
infer_result = stream_manager_api.GetResult(stream_name, uniqueId)
end_time = datetime.datetime.now()
print('sdk run time: {}'.format((end_time - start_time).microseconds))
if infer_result.errorCode != 0:
print("GetResultWithUniqueId error. errorCode=%d, errorMsg=%s" % (
infer_result.errorCode, infer_result.data.decode()))
exit()
# print the infer result
print(infer_result.data.decode())
load_dict = json.loads(infer_result.data.decode())
print(load_dict)
if load_dict.get('MxpiClass') is None:
with open(res_dir_name + "/" + file_name[:-5] + '.txt', 'w') as f_write:
f_write.write("")
continue
res_vec = load_dict['MxpiClass']
with open(res_dir_name + "/" + file_name[:-5] + '_1.txt', 'w') as f_write:
list1 = [str(item.get("classId")) + " " for item in res_vec]
f_write.writelines(list1)
f_write.write('\n')
# destroy streams
stream_manager_api.DestroyAllStreams()
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
image_path=$1
result_dir=$2
set -e
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
python3.7 main_squeezenet.py $image_path $result_dir
exit 0
{
"im_squeezenet1_1": {
"stream_config": {
"deviceId": "0"
},
"appsrc1": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_imagedecoder0"
},
"mxpi_imagedecoder0": {
"props": {
"handleMethod": "opencv"
},
"factory": "mxpi_imagedecoder",
"next": "mxpi_imageresize0"
},
"mxpi_imageresize0": {
"props": {
"handleMethod": "opencv",
"resizeType": "Resizer_Stretch",
"resizeHeight": "256",
"resizeWidth": "256"
},
"factory": "mxpi_imageresize",
"next": "mxpi_opencvcentercrop0"
},
"mxpi_opencvcentercrop0": {
"props": {
"dataSource": "mxpi_imageresize0",
"cropHeight": "227",
"cropWidth": "227"
},
"factory": "mxpi_opencvcentercrop",
"next": "mxpi_tensorinfer0"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "mxpi_opencvcentercrop0",
"modelPath": "../data/models/squeezenet.om",
"waitingTime": "2000",
"outputDeviceId": "-1"
},
"factory": "mxpi_tensorinfer",
"next": "mxpi_classpostprocessor0"
},
"mxpi_classpostprocessor0": {
"props": {
"dataSource": "mxpi_tensorinfer0",
"postProcessConfigPath": "../data/config/squeezenet.cfg",
"labelPath": "../data/config/imagenet1000_clsidx_to_labels.names",
"postProcessLibPath": "libresnet50postprocess.so"
},
"factory": "mxpi_classpostprocessor",
"next": "mxpi_dataserialize0"
},
"mxpi_dataserialize0": {
"props": {
"outputDataKeys": "mxpi_classpostprocessor0"
},
"factory": "mxpi_dataserialize",
"next": "appsink0"
},
"appsink0": {
"props": {
"blocksize": "4096000"
},
"factory": "appsink"
}
}
}
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train squeezenet."""
import ast
import os
import argparse
import glob
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore import export
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore.nn.metrics import Accuracy
from mindspore.communication.management import init
from src.lr_generator import get_lr
from src.CrossEntropySmooth import CrossEntropySmooth
from src.squeezenet import SqueezeNet as squeezenet
parser = argparse.ArgumentParser(description='SqueezeNet1_1')
parser.add_argument('--net', type=str, default='squeezenet', help='Model.')
parser.add_argument('--dataset', type=str, default='imagenet', help='Dataset.')
parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=False,
help='Whether it is running on CloudBrain platform.')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default="None", help='Pretrained checkpoint path')
parser.add_argument('--data_url', type=str, default="None", help='Datapath')
parser.add_argument('--train_url', type=str, default="None", help='Train output path')
parser.add_argument('--num_classes', type=int, default="1000", help="classes")
parser.add_argument('--epoch_size', type=int, default="200", help="epoch_size")
parser.add_argument('--batch_size', type=int, default="32", help="batch_size")
args_opt = parser.parse_args()
local_data_url = '/cache/data'
local_train_url = '/cache/ckpt'
local_pretrain_url = '/cache/preckpt.ckpt'
set_seed(1)
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
"""remove useless parameters according to filter_list"""
for key in list(origin_dict.keys()):
for name in param_filter:
if name in key:
print("Delete parameter from checkpoint: ", key)
del origin_dict[key]
break
def frozen_to_air(network, args):
paramdict = load_checkpoint(args.get("ckpt_file"))
load_param_into_net(network, paramdict)
input_arr = Tensor(np.zeros([args.get("batch_size"), 3, args.get("height"), args.get("width")], np.float32))
export(network, input_arr, file_name=args.get("file_name"), file_format=args.get("file_format"))
if __name__ == '__main__':
target = args_opt.device_target
if args_opt.device_target != "Ascend":
raise ValueError("Unsupported device target.")
# init context
if args_opt.run_distribute:
device_num = int(os.getenv("RANK_SIZE"))
device_id = int(os.getenv("DEVICE_ID"))
context.set_context(mode=context.GRAPH_MODE,
device_target=target)
context.set_context(device_id=device_id,
enable_auto_mixed_precision=True)
context.set_auto_parallel_context(
device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
local_data_url = os.path.join(local_data_url, str(device_id))
else:
device_id = 0
context.set_context(mode=context.GRAPH_MODE,
device_target=target)
# create dataset
if args_opt.dataset == "cifar10":
from src.config import config_cifar as config
from src.dataset import create_dataset_cifar as create_dataset
else:
from src.config import config_imagenet as config
from src.dataset import create_dataset_imagenet as create_dataset
if args_opt.run_cloudbrain:
import moxing as mox
mox.file.copy_parallel(args_opt.data_url, local_data_url)
dataset = create_dataset(dataset_path=local_data_url,
do_train=True,
repeat_num=1,
batch_size=args_opt.batch_size,
target=target,
run_distribute=args_opt.run_distribute)
step_size = dataset.get_dataset_size()
# define net
net = squeezenet(num_classes=args_opt.num_classes)
# load checkpoint
if args_opt.pre_trained != "None":
if args_opt.run_cloudbrain:
dir_path = os.path.dirname(os.path.abspath(__file__))
ckpt_name = args_opt.pre_trained[2:]
ckpt_path = os.path.join(dir_path, ckpt_name)
print(ckpt_path)
param_dict = load_checkpoint(ckpt_path)
filter_list = [x.name for x in net.final_conv.get_parameters()]
filter_checkpoint_parameter_by_list(param_dict, filter_list)
load_param_into_net(net, param_dict)
# init lr
lr = get_lr(lr_init=config.lr_init,
lr_end=config.lr_end,
lr_max=config.lr_max,
total_epochs=args_opt.epoch_size,
warmup_epochs=config.warmup_epochs,
pretrain_epochs=config.pretrain_epoch_size,
steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
lr = Tensor(lr)
# define loss
if args_opt.dataset == "imagenet":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True,
reduction='mean',
smooth_factor=config.label_smooth_factor,
num_classes=args_opt.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define opt, model
loss_scale = FixedLossScaleManager(config.loss_scale,
drop_overflow_update=False)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
lr,
config.momentum,
config.weight_decay,
config.loss_scale,
use_nesterov=True)
model = Model(net,
loss_fn=loss,
optimizer=opt,
loss_scale_manager=loss_scale,
metrics={'acc': Accuracy()},
amp_level="O2",
keep_batchnorm_fp32=False)
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint and device_id == 0:
config_ck = CheckpointConfig(
save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=args_opt.net,
directory=local_train_url,
config=config_ck)
cb += [ckpt_cb]
# train model
model.train(args_opt.epoch_size - config.pretrain_epoch_size,
dataset,
callbacks=cb)
if device_id == 0:
ckpt_list = glob.glob("/cache/ckpt/squeezenet*.ckpt")
if not ckpt_list:
print("ckpt file not generated.")
ckpt_list.sort(key=os.path.getmtime)
ckpt_model = ckpt_list[-1]
print("checkpoint path", ckpt_model)
net = squeezenet(args_opt.num_classes)
frozen_to_air_args = {'ckpt_file': ckpt_model,
'batch_size': 1,
'height': 227,
'width': 227,
'file_name': '/cache/ckpt/squeezenet',
'file_format': 'AIR'}
frozen_to_air(net, frozen_to_air_args)
if args_opt.run_cloudbrain:
mox.file.copy_parallel(local_train_url, args_opt.train_url)
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.mitations under the License.
docker_image=$1
data_dir=$2
model_dir=$3
docker run -it --ipc=host \
--device=/dev/davinci0 \
--device=/dev/davinci1 \
--device=/dev/davinci2 \
--device=/dev/davinci3 \
--device=/dev/davinci4 \
--device=/dev/davinci5 \
--device=/dev/davinci6 \
--device=/dev/davinci7 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm --device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons/ \
-v ${model_dir}:${model_dir} \
-v ${data_dir}:${data_dir} \
-v /root/ascend/log:/root/ascend/log ${docker_image} /bin/bash
......@@ -18,7 +18,7 @@ network config setting, will be used in train.py and eval.py
from easydict import EasyDict as ed
# config for squeezenet, imagenet
config = ed({
config_imagenet = ed({
"class_num": 1000,
"batch_size": 32,
"loss_scale": 1024,
......@@ -38,3 +38,23 @@ config = ed({
"lr_end": 0,
"lr_max": 0.01
})
# config for squeezenet, cifar10
config_cifar = ed({
"class_num": 10,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 120,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr_decay_mode": "poly",
"lr_init": 0,
"lr_end": 0,
"lr_max": 0.01
})
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
create train or eval dataset of imagenet and cifar10.
"""
import os
import mindspore.common.dtype as mstype
......@@ -102,3 +102,78 @@ def create_dataset_imagenet(dataset_path,
data_set = data_set.repeat(repeat_num)
return data_set
def create_dataset_cifar(dataset_path,
do_train,
repeat_num=1,
batch_size=32,
target="Ascend",
run_distribute=False):
"""
create a train or evaluate cifar10 dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
if run_distribute:
device_num = int(os.getenv("RANK_SIZE"))
device_id = int(os.getenv("DEVICE_ID"))
else:
device_num = 1
else:
raise ValueError("Unsupported device target.")
if device_num == 1:
data_set = ds.Cifar10Dataset(dataset_path,
num_parallel_workers=8,
shuffle=True)
else:
data_set = ds.Cifar10Dataset(dataset_path,
num_parallel_workers=8,
shuffle=True,
num_shards=device_num,
shard_id=device_id)
# define map operations
if do_train:
trans = [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
C.Resize((227, 227)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.CutOut(112),
C.HWC2CHW()
]
else:
trans = [
C.Resize((227, 227)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=type_cast_op,
input_columns="label",
num_parallel_workers=8)
data_set = data_set.map(operations=trans,
input_columns="image",
num_parallel_workers=8)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
......@@ -31,7 +31,7 @@ from mindspore.communication.management import init, get_rank
from src.lr_generator import get_lr
from src.CrossEntropySmooth import CrossEntropySmooth
from src.squeezenet import SqueezeNet as squeezenet
from src.config import config
from src.config import config_imagenet as config
from src.dataset import create_dataset_imagenet as create_dataset
parser = argparse.ArgumentParser(description='SqueezeNet1_1')
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment