Skip to content
Snippets Groups Projects
Unverified Commit c21c88e0 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2622 [哈尔滨工业大学威海][高校贡献][Mindspore][E-NET]-高性能预训练模型提交+语义分割

Merge pull request !2622 from 孙大千/code_E-NET
parents 6185be67 97a9d59c
No related branches found
No related tags found
No related merge requests found
Showing
with 1474 additions and 0 deletions
#!/bin/bash
# Copyright(C) 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
data_dir=$2
model_dir=$3
docker run -it --ipc=host \
--device=/dev/davinci0 \
--device=/dev/davinci1 \
--device=/dev/davinci2 \
--device=/dev/davinci3 \
--device=/dev/davinci4 \
--device=/dev/davinci5 \
--device=/dev/davinci6 \
--device=/dev/davinci7 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm --device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons/ \
-v ${model_dir}:${model_dir} \
-v ${data_dir}:${data_dir} \
-v ~/ascend/log/npu/conf/slog/slog.conf:/var/log/npu/conf/slog/slog.conf \
-v ~/ascend/log/npu/slog/:/var/log/npu/slog -v ~/ascend/log/npu/profiling/:/var/log/npu/profiling \
-v ~/ascend/log/npu/dump/:/var/log/npu/dump -v ~/ascend/log/npu/:/usr/slog ${docker_image} \
/bin/bash
#!/bin/bash
# Copyright(C) 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
model_dir=$2
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
exit 1
fi
if [ ! -d "${model_dir}" ]; then
echo "please input model_dir"
exit 1
fi
docker run -it \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${model_dir}:${model_dir} \
${docker_image} \
/bin/bash
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
for para in "$@"
do
if [[ $para == --model* ]];then
model=`echo ${para#*=}`
elif [[ $para == --output* ]];then
output=`echo ${para#*=}`
elif [[ $para == --soc_version* ]];then
soc_version=`echo ${para#*=}`
fi
done
echo "Input AIR file path: ${model}"
echo "Input aipp file path: ${output}"
soc_version=Ascend310
atc --input_format=NCHW \
--model=${model} \
--output=${output} \
--soc_version=${soc_version} \
--framework=1
{
"enet": {
"appsrc0": {
"factory": "appsrc",
"next": "modelInfer"
},
"modelInfer": {
"props": {
"modelPath": "../../out/Enet.om",
"dataSource": "appsrc0"
},
"factory": "mxpi_tensorinfer",
"next": "dataserialize"
},
"dataserialize": {
"props": {
"outputDataKeys": "modelInfer"
},
"factory": "mxpi_dataserialize",
"next": "appsink0"
},
"appsink0": {
"factory": "appsink"
}
}
}
cmake_minimum_required(VERSION 3.14.0)
project(enet)
set(TARGET enet)
add_definitions(-DENABLE_DVPP_INTERFACE)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-Dgoogle=mindxsdk_private)
add_compile_options(-std=c++11 -fPIE -fstack-protector-all -fPIC -Wall)
add_link_options(-Wl,-z,relro,-z,now,-z,noexecstack -s -pie)
# Check environment variable
if(NOT DEFINED ENV{ASCEND_HOME})
message(FATAL_ERROR "please define environment variable:ASCEND_HOME")
endif()
if(NOT DEFINED ENV{ASCEND_VERSION})
message(WARNING "please define environment variable:ASCEND_VERSION")
endif()
if(NOT DEFINED ENV{ARCH_PATTERN})
message(WARNING "please define environment variable:ARCH_PATTERN")
endif()
set(ACL_INC_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/include)
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/$ENV{ASCEND_VERSION}/$ENV{ARCH_PATTERN}/acllib/lib64)
set(MXBASE_ROOT_DIR $ENV{MX_SDK_HOME})
set(MXBASE_INC ${MXBASE_ROOT_DIR}/include)
set(MXBASE_LIB_DIR ${MXBASE_ROOT_DIR}/lib)
set(MXBASE_POST_LIB_DIR ${MXBASE_ROOT_DIR}/lib/modelpostprocessors)
set(MXBASE_POST_PROCESS_DIR ${MXBASE_ROOT_DIR}/include/MxBase/postprocess/include)
if(NOT DEFINED ENV{MXSDK_OPENSOURCE_DIR})
message(WARNING "please define environment variable:MXSDK_OPENSOURCE_DIR")
endif()
set(OPENSOURCE_DIR $ENV{MXSDK_OPENSOURCE_DIR})
include_directories(src)
include_directories(${ACL_INC_DIR})
include_directories(${OPENSOURCE_DIR}/include)
include_directories(${OPENSOURCE_DIR}/include/opencv4)
include_directories(/usr/local/Ascend/ascend-toolkit/5.0.4/x86_64-linux/x86_64-linux/include)
include_directories(${MXBASE_INC})
include_directories(${MXBASE_POST_PROCESS_DIR})
link_directories(${ACL_LIB_DIR})
link_directories(${OPENSOURCE_DIR}/lib)
link_directories(${MXBASE_LIB_DIR})
link_directories(${MXBASE_POST_LIB_DIR})
add_executable(${TARGET} src/main.cpp src/Enet.cpp)
target_link_libraries(${TARGET} glog cpprest mxbase opencv_world stdc++fs)
install(TARGETS ${TARGET} RUNTIME DESTINATION ${PROJECT_SOURCE_DIR}/)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export ASCEND_HOME=/usr/local/Ascend
export ASCEND_VERSION=nnrt/latest
export ARCH_PATTERN=.
export MXSDK_OPENSOURCE_DIR=/usr/local/sdk_home/mxManufacture/opensource
export LD_LIBRARY_PATH="${MX_SDK_HOME}/lib/plugins:${MX_SDK_HOME}/opensource/lib64:${MX_SDK_HOME}/lib:${MX_SDK_HOME}/lib/modelpostprocessors:${MX_SDK_HOME}/opensource/lib:/usr/local/Ascend/nnae/latest/fwkacllib/lib64:${LD_LIBRARY_PATH}"
export ASCEND_OPP_PATH="/usr/local/Ascend/nnae/latest/opp"
export ASCEND_AICPU_PATH="/usr/local/Ascend/nnae/latest"
function check_env()
{
# set ASCEND_VERSION to ascend-toolkit/latest when it was not specified by user
if [ ! "${ASCEND_VERSION}" ]; then
export ASCEND_VERSION=ascend-toolkit/latest
echo "Set ASCEND_VERSION to the default value: ${ASCEND_VERSION}"
else
echo "ASCEND_VERSION is set to ${ASCEND_VERSION} by user"
fi
if [ ! "${ARCH_PATTERN}" ]; then
# set ARCH_PATTERN to ./ when it was not specified by user
export ARCH_PATTERN=./
echo "ARCH_PATTERN is set to the default value: ${ARCH_PATTERN}"
else
echo "ARCH_PATTERN is set to ${ARCH_PATTERN} by user"
fi
}
function build_enet()
{
cd .
rm -rf build
mkdir -p build
cd build
cmake ..
make
ret=$?
if [ ${ret} -ne 0 ]; then
echo "Failed to build brdnet."
exit ${ret}
fi
make install
}
rm -rf ./result
mkdir -p ./result
check_env
build_enet
'''
The scripts to execute sdk infer
'''
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import numpy as np
import PIL.Image as Image
def parse_args():
"""set and check parameters."""
parser = argparse.ArgumentParser(description="ENET process")
parser.add_argument("--image_path", type=str, default=None, help="root path of image")
parser.add_argument('--image_width', default=1024, type=int, help='image width')
parser.add_argument('--image_height', default=512, type=int, help='image height')
parser.add_argument('--output_path', default='./bin', type=str, help='bin file path')
args_opt = parser.parse_args()
return args_opt
def _get_city_pairs(folder, split='train'):
"""_get_city_pairs"""
def get_path_pairs(img_folder, mask_folder):
img_paths = []
mask_paths = []
for root, _, files in os.walk(img_folder):
for filename in files:
if filename.startswith('._'):
continue
if filename.endswith('.png'):
imgpath = os.path.join(root, filename)
foldername = os.path.basename(os.path.dirname(imgpath))
maskname = filename.replace('leftImg8bit', 'gtFine_labelIds')
maskpath = os.path.join(mask_folder, foldername, maskname)
if os.path.isfile(imgpath) and os.path.isfile(maskpath):
img_paths.append(imgpath)
mask_paths.append(maskpath)
else:
print('cannot find the mask or image:', imgpath, maskpath)
print('Found {} images in the folder {}'.format(len(img_paths), img_folder))
return img_paths, mask_paths
if split in ('train', 'val'):
img_folder = os.path.join(folder, 'leftImg8bit' + os.sep + split) # os.sep:/
mask_folder = os.path.join(folder, 'gtFine' + os.sep + split)
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
return img_paths, mask_paths
assert split == 'trainval'
print('trainval set')
train_img_folder = os.path.join(folder, 'leftImg8bit' + os.sep + 'train')
train_mask_folder = os.path.join(folder, 'gtFine' + os.sep + 'train')
val_img_folder = os.path.join(folder, 'leftImg8bit' + os.sep + 'val')
val_mask_folder = os.path.join(folder, 'gtFine' + os.sep + 'val')
train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder)
val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder)
img_paths = train_img_paths + val_img_paths
mask_paths = train_mask_paths + val_mask_paths
return img_paths, mask_paths
def _val_sync_transform(outsize, img):
"""_val_sync_transform"""
short_size = min(outsize)
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
w, h = img.size
x1 = int(round((w - outsize[1]) / 2.))
y1 = int(round((h - outsize[0]) / 2.))
img = img.crop((x1, y1, x1 + outsize[1], y1 + outsize[0]))
img = np.array(img)
return img
def main():
args = parse_args()
images, mask_paths = _get_city_pairs(args.image_path, 'val')
assert len(images) == len(mask_paths)
if not images:
raise RuntimeError("Found 0 images in subfolders of:" + args.image_path + "\n")
for index in range(len(images)):
image_name = images[index].split(os.sep)[-1].split(".")[0] # get the name of image file
print("Processing ---> ", image_name)
img = Image.open(images[index]).convert('RGB')
img = _val_sync_transform((args.image_height, args.image_width), img)
img = img.astype(np.float32)
img = img.transpose((2, 0, 1)) # HWC->CHW(H:height W:width C:channel)
for channel, _ in enumerate(img):
img[channel] /= 255
img = np.expand_dims(img, 0) # NCHW
# save bin file
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
data = img
dataname = image_name + ".bin"
data.tofile(args.output_path + '/' + dataname)
if __name__ == '__main__':
main()
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
image_width=1024
image_height=512
output_path=./bin
# help message
if [[ $1 == --help || $1 == -h ]];then
echo "usage:bash ./preprocess.sh <args>"
echo "parameter explain:
--image_path root path of processed images, e.g. --image_path=../data/
--image_width set the image width, default: --image_width=1024
--image_height set the image height, default: --image_height=512
--output_path bin file path, default: --output_path=./bin
-h/--help show help message
"
exit 1
fi
for para in "$@"
do
if [[ $para == --image_path* ]];then
image_path=`echo ${para#*=}`
elif [[ $para == --image_width* ]];then
image_width=`echo ${para#*=}`
elif [[ $para == --image_height* ]];then
image_height=`echo ${para#*=}`
elif [[ $para == --output_path* ]];then
output_path=`echo ${para#*=}`
fi
done
if [[ $image_path == "" ]];then
echo "[Error] para \"image_path \" must be config"
exit 1
fi
python3 main.py --image_path=$image_path \
--image_width=$image_width \
--image_height=$image_height \
--output_path=$output_path
exit 0
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "Enet.h"
#include <unistd.h>
#include <sys/stat.h>
#include <map>
#include <fstream>
#include <vector>
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/Log/Log.h"
std::vector<cv::Vec3b> cityspallete {
{128, 64, 128},
{244, 35, 232},
{70, 70, 70},
{102, 102, 156},
{190, 153, 153},
{153, 153, 153},
{0, 130, 180},
{220, 220, 0},
{107, 142, 35},
{152, 251, 152},
{250, 170, 30},
{220, 20, 60},
{0, 0, 230},
{119, 11, 32},
{0, 0, 70},
{0, 60, 100},
{0, 80, 100},
{255, 0, 0},
{0, 0, 142},
{0, 0, 0}};
APP_ERROR Enet::Init(const InitParam &initParam) {
this->deviceId_ = initParam.deviceId;
this->outputDataPath_ = initParam.outputDataPath;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
this->model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = this->model_->Init(initParam.modelPath, this->modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
uint32_t input_data_size = 1;
for (size_t j = 0; j < this->modelDesc_.inputTensors[0].tensorDims.size(); ++j) {
this->inputDataShape_[j] = (uint32_t)this->modelDesc_.inputTensors[0].tensorDims[j];
input_data_size *= this->inputDataShape_[j];
}
this->inputDataSize_ = input_data_size;
return APP_ERR_OK;
}
APP_ERROR Enet::DeInit() {
this->model_->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR Enet::ReadTensorFromFile(const std::string &file, float *data) {
if (data == NULL) {
LogError << "input data is invalid.";
return APP_ERR_COMM_INVALID_POINTER;
}
std::ifstream infile;
// open data file
infile.open(file, std::ios_base::in | std::ios_base::binary);
// check data file validity
if (infile.fail()) {
LogError << "Failed to open data file: " << file << ".";
return APP_ERR_COMM_OPEN_FAIL;
}
infile.read(reinterpret_cast<char *>(data), sizeof(float) * this->inputDataSize_);
infile.close();
return APP_ERR_OK;
}
APP_ERROR Enet::ReadInputTensor(const std::string &fileName, std::vector<MxBase::TensorBase> *inputs) {
float data[this->inputDataSize_] = {0};
APP_ERROR ret = ReadTensorFromFile(fileName, data);
if (ret != APP_ERR_OK) {
LogError << "ReadTensorFromFile failed.";
return ret;
}
const uint32_t dataSize = modelDesc_.inputTensors[0].tensorSize;
MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, this->deviceId_);
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void *>(data), dataSize, MxBase::MemoryData::MEMORY_HOST_MALLOC);
ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc and copy failed.";
return ret;
}
inputs->push_back(MxBase::TensorBase(memoryDataDst, false, this->inputDataShape_, MxBase::TENSOR_DTYPE_FLOAT32));
return APP_ERR_OK;
}
APP_ERROR Enet::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> *outputs) {
auto dtypes = this->model_->GetOutputDataType();
for (size_t i = 0; i < this->modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)this->modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, this->deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs->push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret = this->model_->ModelInference(inputs, *outputs, dynamicInfo);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
g_inferCost.push_back(costMs);
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Enet::PostProcess(std::vector<MxBase::TensorBase> &inputs, cv::Mat &output) {
MxBase::TensorBase &tensor = inputs[0];
int channel = tensor.GetShape()[MxBase::VECTOR_SECOND_INDEX];
int outputModelHeight = tensor.GetShape()[MxBase::VECTOR_THIRD_INDEX];
int outputModelWidth = tensor.GetShape()[MxBase::VECTOR_FOURTH_INDEX];
// argmax
for (int h = 0; h < outputModelHeight; h++) {
for (int w = 0; w < outputModelWidth; w++) {
float max;
int index = 0;
std::vector<int> index_ori = {0, index, h, w};
tensor.GetValue(max, index_ori);
for (int c = 1; c < channel; c++) {
float num_c;
std::vector<int> index_cur = {0, c, h, w};
tensor.GetValue(num_c, index_cur);
if (num_c > max) {
index = c;
max = num_c;
}
}
output.at<cv::Vec3b>(h, w) = cityspallete[index];
}
}
return APP_ERR_OK;
}
APP_ERROR Enet::Process(const std::string &inferPath, const std::string &fileName) {
std::vector<MxBase::TensorBase> inputs = {};
std::string inputIdsFile = inferPath + fileName;
APP_ERROR ret = ReadInputTensor(inputIdsFile, &inputs);
if (ret != APP_ERR_OK) {
LogError << "Read input ids failed, ret=" << ret << ".";
return ret;
}
std::vector<MxBase::TensorBase> outputs = {};
ret = Inference(inputs, &outputs);
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
ret = outputs[0].ToHost();
if (ret != APP_ERR_OK) {
LogError << "ToHost failed, ret=" << ret << ".";
return ret;
}
int outputModelHeight = outputs[0].GetShape()[MxBase::VECTOR_THIRD_INDEX];
int outputModelWidth = outputs[0].GetShape()[MxBase::VECTOR_FOURTH_INDEX];
cv::Mat output(outputModelHeight, outputModelWidth, CV_8UC3);
ret = PostProcess(outputs, output);
if (ret != APP_ERR_OK) {
LogError << "PostProcess failed, ret=" << ret << ".";
return ret;
}
std::string outFileName = this->outputDataPath_ + "/" + fileName;
size_t pos = outFileName.find_last_of(".");
outFileName.replace(outFileName.begin() + pos, outFileName.end(), "_infer.png");
cv::imwrite(outFileName, output);
return APP_ERR_OK;
}
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_Enet_H
#define MXBASE_Enet_H
#include <memory>
#include <utility>
#include <vector>
#include <string>
#include <map>
#include <opencv2/opencv.hpp>
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/PostProcessBases/PostProcessDataType.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
extern std::vector<double> g_inferCost;
struct InitParam {
uint32_t deviceId;
std::string modelPath;
std::string outputDataPath;
};
class Enet {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> *outputs);
APP_ERROR Process(const std::string &inferPath, const std::string &fileName);
protected:
APP_ERROR ReadTensorFromFile(const std::string &file, float *data);
APP_ERROR ReadInputTensor(const std::string &fileName, std::vector<MxBase::TensorBase> *inputs);
APP_ERROR PostProcess(std::vector<MxBase::TensorBase> &inputs, cv::Mat &output);
private:
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
MxBase::ModelDesc modelDesc_ = {};
uint32_t deviceId_ = 0;
std::string outputDataPath_ = "./result";
std::vector<uint32_t> inputDataShape_ = {1, 3, 512, 1024};
uint32_t inputDataSize_ = 1572864;
};
#endif
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unistd.h>
#include <dirent.h>
#include <iostream>
#include <fstream>
#include <vector>
#include "Enet.h"
#include "MxBase/Log/Log.h"
std::vector<double> g_inferCost;
void InitProtonetParam(InitParam *initParam, const std::string &model_path, const std::string &output_data_path) {
initParam->deviceId = 0;
initParam->modelPath = model_path;
initParam->outputDataPath = output_data_path;
}
APP_ERROR ReadFilesFromPath(const std::string &path, std::vector<std::string> *files) {
DIR *dir = NULL;
struct dirent *ptr = NULL;
if ((dir = opendir(path.c_str())) == NULL) {
LogError << "Open dir error: " << path;
return APP_ERR_COMM_OPEN_FAIL;
}
while ((ptr = readdir(dir)) != NULL) {
if (ptr->d_type == 8) {
files->push_back(ptr->d_name);
}
}
closedir(dir);
return APP_ERR_OK;
}
int main(int argc, char *argv[]) {
LogInfo << "======================================= !!!Parameters setting!!! "
<< "========================================";
std::string model_path = argv[1];
LogInfo << "========== loading model weights from: " << model_path;
std::string input_data_path = argv[2];
LogInfo << "========== input data path = " << input_data_path;
std::string output_data_path = argv[3];
LogInfo << "========== output data path = " << output_data_path;
LogInfo << "======================================== !!!Parameters setting!!! "
<< "========================================";
InitParam initParam;
InitProtonetParam(&initParam, model_path, output_data_path);
auto enet = std::make_shared<Enet>();
APP_ERROR ret = enet->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "Enet init failed, ret=" << ret << ".";
return ret;
}
std::vector<std::string> files;
ret = ReadFilesFromPath(input_data_path, &files);
if (ret != APP_ERR_OK) {
LogError << "Read files from path failed, ret=" << ret << ".";
return ret;
}
// do infer
for (uint32_t i = 0; i < files.size(); i++) {
LogInfo << "Processing: " + std::to_string(i + 1) + "/" + std::to_string(files.size()) + " ---> " + files[i];
ret = enet->Process(input_data_path, files[i]);
if (ret != APP_ERR_OK) {
LogError << "Enet process failed, ret=" << ret << ".";
enet->DeInit();
return ret;
}
}
LogInfo << "infer succeed and write the result data with binary file !";
enet->DeInit();
double costSum = 0;
for (uint32_t i = 0; i < g_inferCost.size(); i++) {
costSum += g_inferCost[i];
}
LogInfo << "Infer images sum " << g_inferCost.size() << ", cost total time: " << costSum << " ms.";
LogInfo << "The throughput: " << g_inferCost.size() * 1000 / costSum << " bin/sec.";
LogInfo << "========== The infer result has been saved in ---> " << output_data_path;
return APP_ERR_OK;
}
'''
The scripts to execute sdk infer
'''
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
import time
import numpy as np
import PIL.Image as Image
import MxpiDataType_pb2 as MxpiDataType
from StreamManagerApi import StreamManagerApi, InProtobufVector, \
MxProtobufIn, StringVector
def parse_args():
"""set and check parameters."""
parser = argparse.ArgumentParser(description="ENET process")
parser.add_argument("--pipeline", type=str,
default=None, help="SDK infer pipeline")
parser.add_argument("--image_path", type=str,
default=None, help="root path of image")
parser.add_argument('--image_width', default=1024,
type=int, help='image width')
parser.add_argument('--image_height', default=512,
type=int, help='image height')
parser.add_argument('--save_mask', default=1, type=int,
help='0 for False, 1 for True')
parser.add_argument('--mask_result_path', default='./mask_result', type=str,
help='the folder to save the semantic mask images')
args_opt = parser.parse_args()
return args_opt
def send_source_data(appsrc_id, tensor, stream_name, stream_manager):
"""
Construct the input of the stream,
send inputs data to a specified stream based on streamName.
Returns:
bool: send data success or not
"""
tensor_package_list = MxpiDataType.MxpiTensorPackageList()
tensor_package = tensor_package_list.tensorPackageVec.add()
array_bytes = tensor.tobytes()
tensor_vec = tensor_package.tensorVec.add()
tensor_vec.deviceId = 0
tensor_vec.memType = 0
for i in tensor.shape:
tensor_vec.tensorShape.append(i)
tensor_vec.dataStr = array_bytes
tensor_vec.tensorDataSize = len(array_bytes)
key = "appsrc{}".format(appsrc_id).encode('utf-8')
protobuf_vec = InProtobufVector()
protobuf = MxProtobufIn()
protobuf.key = key
protobuf.type = b'MxTools.MxpiTensorPackageList'
protobuf.protobuf = tensor_package_list.SerializeToString()
protobuf_vec.push_back(protobuf)
ret = stream_manager.SendProtobuf(stream_name, appsrc_id, protobuf_vec)
if ret < 0:
print("Failed to send data to stream.")
return False
# print("Success to send data to stream.")
return True
cityspallete = [
128, 64, 128,
244, 35, 232,
70, 70, 70,
102, 102, 156,
190, 153, 153,
153, 153, 153,
250, 170, 30,
220, 220, 0,
107, 142, 35,
152, 251, 152,
0, 130, 180,
220, 20, 60,
255, 0, 0,
0, 0, 142,
0, 0, 70,
0, 60, 100,
0, 80, 100,
0, 0, 230,
119, 11, 32,
]
classes = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'train', 'motorcycle', 'bicycle')
valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
23, 24, 25, 26, 27, 28, 31, 32, 33]
_key = np.array([-1, -1, -1, -1, -1, -1,
-1, -1, 0, 1, -1, -1,
2, 3, 4, -1, -1, -1,
5, -1, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15,
-1, -1, 16, 17, 18])
_mapping = np.array(range(-1, len(_key) - 1)
).astype('int32') # [-1, 0, 1, ..., 33]
def _get_city_pairs(folder, split='train'):
"""_get_city_pairs"""
def get_path_pairs(img_folder, mask_folder):
img_paths = []
mask_paths = []
for root, _, files in os.walk(img_folder):
for filename in files:
if filename.startswith('._'):
continue
if filename.endswith('.png'):
imgpath = os.path.join(root, filename)
foldername = os.path.basename(os.path.dirname(imgpath))
maskname = filename.replace(
'leftImg8bit', 'gtFine_labelIds')
maskpath = os.path.join(mask_folder, foldername, maskname)
if os.path.isfile(imgpath) and os.path.isfile(maskpath):
img_paths.append(imgpath)
mask_paths.append(maskpath)
else:
print('cannot find the mask or image:',
imgpath, maskpath)
print('Found {} images in the folder {}'.format(
len(img_paths), img_folder))
return img_paths, mask_paths
if split in ('train', 'val'):
img_folder = os.path.join(
folder, 'leftImg8bit' + os.sep + split) # os.sep:/
mask_folder = os.path.join(folder, 'gtFine' + os.sep + split)
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
return img_paths, mask_paths
assert split == 'trainval'
print('trainval set')
train_img_folder = os.path.join(folder, 'leftImg8bit' + os.sep + 'train')
train_mask_folder = os.path.join(folder, 'gtFine' + os.sep + 'train')
val_img_folder = os.path.join(folder, 'leftImg8bit' + os.sep + 'val')
val_mask_folder = os.path.join(folder, 'gtFine' + os.sep + 'val')
train_img_paths, train_mask_paths = get_path_pairs(
train_img_folder, train_mask_folder)
val_img_paths, val_mask_paths = get_path_pairs(
val_img_folder, val_mask_folder)
img_paths = train_img_paths + val_img_paths
mask_paths = train_mask_paths + val_mask_paths
return img_paths, mask_paths
def _val_sync_transform(outsize, img, mask):
"""_val_sync_transform"""
short_size = min(outsize)
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
w, h = img.size
x1 = int(round((w - outsize[1]) / 2.))
y1 = int(round((h - outsize[0]) / 2.))
img = img.crop((x1, y1, x1 + outsize[1], y1 + outsize[0]))
mask = mask.crop((x1, y1, x1 + outsize[1], y1 + outsize[0]))
# final transform
img, mask = np.array(img), _mask_transform(mask)
return img, mask
def _class_to_index(mask):
# assert the value
values = np.unique(mask)
for value in values:
assert value in _mapping
index = np.digitize(mask.ravel(), _mapping,
right=True)
return _key[index].reshape(mask.shape)
def _mask_transform(mask):
target = _class_to_index(np.array(mask).astype('int32'))
return np.array(target).astype('int32')
class SegmentationMetric:
"""Computes pixAcc and mIoU metric scores
"""
def __init__(self, nclass):
super(SegmentationMetric, self).__init__()
self.nclass = nclass
self.reset()
def update(self, preds, labels):
"""Updates the internal evaluation result.
Parameters
----------
labels : 'NumpyArray' or list of `NumpyArray`
The labels of the data.
preds : 'NumpyArray' or list of `NumpyArray`
Predicted values.
"""
def evaluate_worker(self, pred, label):
correct, labeled = batch_pix_accuracy(pred, label)
inter, union = batch_intersection_union(pred, label, self.nclass)
self.total_correct += correct
self.total_label += labeled
self.total_inter += inter
self.total_union += union
evaluate_worker(self, preds, labels)
def get(self, return_category_iou=False):
"""Gets the current evaluation result.
Returns
-------
metrics : tuple of float
pixAcc and mIoU
"""
# remove np.spacing(1)
pixAcc = 1.0 * self.total_correct / \
(2.220446049250313e-16 + self.total_label)
IoU = 1.0 * self.total_inter / \
(2.220446049250313e-16 + self.total_union)
mIoU = IoU.mean().item()
if return_category_iou:
return pixAcc, mIoU, IoU
return pixAcc, mIoU
def reset(self):
"""Resets the internal evaluation result to initial state."""
self.total_inter = np.zeros(self.nclass)
self.total_union = np.zeros(self.nclass)
self.total_correct = 0
self.total_label = 0
def batch_pix_accuracy(output, target):
"""PixAcc"""
# inputs are numpy array, output 4D NCHW where 'C' means label classes, target 3D NHW
predict = np.argmax(output.astype(np.int64), 1) + 1
target = target.astype(np.int64) + 1
pixel_labeled = (target > 0).sum()
pixel_correct = ((predict == target) * (target > 0)).sum()
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(output, target, nclass):
"""mIoU"""
# inputs are numpy array, output 4D, target 3D
mini = 1
maxi = nclass
nbins = nclass
predict = np.argmax(output.astype(np.float32), 1) + 1
target = target.astype(np.float32) + 1
predict = predict.astype(np.float32) * (target > 0).astype(np.float32)
intersection = predict * (predict == target).astype(np.float32)
# areas of intersection and union
# element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (area_inter > area_union).sum(
) == 0, "Intersection area should be smaller than Union area"
return area_inter.astype(np.float32), area_union.astype(np.float32)
def main():
"""
read pipeline and do infer
"""
args = parse_args()
# init stream manager
stream_manager_api = StreamManagerApi()
ret = stream_manager_api.InitManager()
if ret != 0:
print("Failed to init Stream manager, ret=%s" % str(ret))
return
# create streams by pipeline config file
with open(os.path.realpath(args.pipeline), 'rb') as f:
pipeline_str = f.read()
ret = stream_manager_api.CreateMultipleStreams(pipeline_str)
if ret != 0:
print("Failed to create Stream, ret=%s" % str(ret))
return
stream_name = b'enet'
infer_total_time = 0
assert os.path.exists(
args.image_path), "Please put dataset in " + str(args.image_path)
images, mask_paths = _get_city_pairs(args.image_path, 'val')
assert len(images) == len(mask_paths)
if not images:
raise RuntimeError(
"Found 0 images in subfolders of:" + args.image_path + "\n")
if args.save_mask and not os.path.exists(args.mask_result_path):
os.makedirs(args.mask_result_path)
metric = SegmentationMetric(19)
metric.reset()
for index in range(len(images)):
image_name = images[index].split(
os.sep)[-1].split(".")[0] # get the name of image file
print("Processing ---> ", image_name)
img = Image.open(images[index]).convert('RGB')
mask = Image.open(mask_paths[index])
img, mask = _val_sync_transform(
(args.image_height, args.image_width), img, mask)
img = img.astype(np.float32)
mask = mask.astype(np.int32)
img = img.transpose((2, 0, 1)) # HWC->CHW(H:height W:width C:channel)
for channel, _ in enumerate(img):
img[channel] /= 255
img = np.expand_dims(img, 0) # NCHW
mask = np.expand_dims(mask, 0) # NHW
if not send_source_data(0, img, stream_name, stream_manager_api):
return
# Obtain the inference result by specifying streamName and uniqueId.
key_vec = StringVector()
key_vec.push_back(b'modelInfer')
start_time = time.time()
infer_result = stream_manager_api.GetProtobuf(stream_name, 0, key_vec)
infer_total_time += time.time() - start_time
if infer_result.size() == 0:
print("inferResult is null")
return
if infer_result[0].errorCode != 0:
print("GetProtobuf error. errorCode=%d" %
(infer_result[0].errorCode))
return
result = MxpiDataType.MxpiTensorPackageList()
result.ParseFromString(infer_result[0].messageBuf)
res = np.frombuffer(
result.tensorPackageVec[0].tensorVec[0].dataStr, dtype='<f4')
mask_image = res.reshape(1, 20, args.image_height, args.image_width)
metric.update(mask_image, mask)
pixAcc, mIoU = metric.get()
print("[EVAL] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(
index + 1, pixAcc * 100, mIoU * 100))
if args.save_mask:
output = np.argmax(mask_image[0], axis=0)
out_img = Image.fromarray(output.astype('uint8'))
out_img.putpalette(cityspallete)
outname = str(image_name) + '.png'
out_img.save(os.path.join(args.mask_result_path, outname))
pixAcc, mIoU, category_iou = metric.get(return_category_iou=True)
print('category_iou: ', category_iou)
print('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(
pixAcc * 100, mIoU * 100))
print("Testing finished....")
print("=======================================")
print("The total time of inference is {} s".format(infer_total_time))
print("=======================================")
# destroy streams
stream_manager_api.DestroyAllStreams()
if __name__ == '__main__':
main()
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
image_width=1024
image_height=512
save_mask=1
mask_result_path=./mask_result
# help message
if [[ $1 == --help || $1 == -h ]];then
echo "usage:bash ./run.sh <args>"
echo "parameter explain:
--pipeline set SDK infer pipeline, e.g. --pipeline=../data/config/enet.pipeline
--image_path root path of processed images, e.g. --image_path=../data/
--image_width set the image width, default: --image_width=1024
--image_height set the image height, default: --image_height=512
--save_mask whether to save the semantic mask images, 0 for False, 1 for True, default: --save_mask=1
--mask_result_path the folder to save the semantic mask images, default: --mask_result_path=./mask_result
-h/--help show help message
"
exit 1
fi
for para in "$@"
do
if [[ $para == --pipeline* ]];then
pipeline=`echo ${para#*=}`
elif [[ $para == --image_path* ]];then
image_path=`echo ${para#*=}`
elif [[ $para == --image_width* ]];then
image_width=`echo ${para#*=}`
elif [[ $para == --image_height* ]];then
image_height=`echo ${para#*=}`
elif [[ $para == --save_mask* ]];then
save_mask=`echo ${para#*=}`
elif [[ $para == --mask_result_path* ]];then
mask_result_path=`echo ${para#*=}`
fi
done
if [[ $pipeline == "" ]];then
echo "[Error] para \"pipeline \" must be config"
exit 1
fi
if [[ $image_path == "" ]];then
echo "[Error] para \"image_path \" must be config"
exit 1
fi
python3 main.py --pipeline=$pipeline \
--image_path=$image_path \
--image_width=$image_width \
--image_height=$image_height \
--save_mask=$save_mask \
--mask_result_path=$mask_result_path
exit 0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorMove op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import TBERegOp
from mindspore.ops.op_info_register import DataType
tensor_move_info = TBERegOp("TensorMove") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("tensor_move.so") \
.compute_cost(10) \
.kernel_name("tensor_move") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("formatAgnostic") \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.U8_None, DataType.U8_None) \
.dtype_format(DataType.U32_None, DataType.U32_None) \
.dtype_format(DataType.BOOL_None, DataType.BOOL_None) \
.get_op_info()
@op_info_register(tensor_move_info)
def _tensor_move_tbe():
"""TensorMove TBE register"""
return
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train enet"""
import os
import moxing as mox
import numpy as np
from mindspore import Model, context, nn, load_param_into_net, export
from mindspore.common.tensor import Tensor
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore.context import ParallelMode
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.serialization import _update_param, load_checkpoint
from src.config import (TrainConfig_1, TrainConfig_2, TrainConfig_3,
ms_train_data, num_class, repeat, run_distribute, save_path, weight_init)
from src.criterion import SoftmaxCrossEntropyLoss
from src.dataset import getCityScapesDataLoader_mindrecordDataset
from src.model import Encoder_pred, Enet
from src.util import getCityLossWeight
def attach(enet, encoder_pretrain):
"""move the params in encoder to enet"""
print("attach decoder.")
encoder_trained_par = encoder_pretrain.parameters_dict()
enet_par = enet.parameters_dict()
for name, param_old in encoder_trained_par.items():
if name.startswith("encoder"):
_update_param(enet_par[name], param_old, False)
def train(ckpt_path_, trainConfig_, rank_id, rank_size, stage_):
"""train enet"""
print("stage:", stage_)
save_prefix = "Encoder" if trainConfig_.encode else "ENet"
if trainConfig_.epoch == 0:
raise RuntimeError("epoch num cannot be zero")
if trainConfig_.encode:
network = Encoder_pred(num_class, weight_init)
else:
network = Enet(num_class, weight_init)
if not os.path.exists(ckpt_path_):
print("load no ckpt file.")
else:
load_checkpoint(ckpt_file_name=ckpt_path_, net=network)
print("load ckpt file:", ckpt_path_)
# attach decoder
if trainConfig_.attach_decoder:
network_enet = Enet(num_class, weight_init)
attach(network_enet, network)
network = network_enet
dataloader = getCityScapesDataLoader_mindrecordDataset(stage_, ms_train_data, 6,
trainConfig_.encode, trainConfig_.train_img_size,
shuffle=True, aug=True,
rank_id=rank_id, global_size=rank_size, repeat=repeat)
opt = nn.Adam(network.trainable_params(), trainConfig_.lr,
weight_decay=1e-4, eps=1e-08)
loss = SoftmaxCrossEntropyLoss(num_class, getCityLossWeight(trainConfig_.encode))
loss_scale_manager = DynamicLossScaleManager()
wrapper = Model(network, loss, opt, loss_scale_manager=loss_scale_manager,
keep_batchnorm_fp32=True)
time_cb = TimeMonitor()
loss_cb = LossMonitor()
if rank_id == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=
trainConfig_.epoch_num_save * dataloader.get_dataset_size(), \
keep_checkpoint_max=9999)
saveModel_cb = ModelCheckpoint(prefix=save_prefix, directory= \
"./", config=config_ck)
call_backs = [saveModel_cb, time_cb, loss_cb]
else:
call_backs = [time_cb, loss_cb]
print("============== Starting {} Training ==============".format(save_prefix))
wrapper.train(trainConfig_.epoch, dataloader, callbacks=call_backs, dataset_sink_mode=True)
return network
def export_models(ckptfile):
print("exporting model....")
net = Enet(20, "XavierUniform", train=False)
param_dict = load_checkpoint(ckptfile)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([1, 3, 512, 1024]).astype(np.float32))
export(net, input_arr, file_name="ENet.air", file_format="AIR")
print("export model finished....")
if __name__ == "__main__":
rank_id_ = 0
rank_size_ = 1
if run_distribute:
context.set_auto_parallel_context(parameter_broadcast=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=False)
init()
rank_id_ = get_rank()
rank_size_ = get_group_size()
trainConfig = {
1: TrainConfig_1(),
2: TrainConfig_2(),
3: TrainConfig_3()
}
for i in [1, 2, 3]:
data_loader = getCityScapesDataLoader_mindrecordDataset(i, ms_train_data, 6,
trainConfig[i].encode, trainConfig[i].train_img_size,
shuffle=True, aug=True,
rank_id=rank_id_, global_size=rank_size_, repeat=repeat)
steps = int(TrainConfig_1().epoch_num_save * data_loader.get_dataset_size() / 5)
if i == 1:
ckpt_path = ""
elif i == 2:
ckpt_path = "./Encoder-{}_{}.ckpt".format(TrainConfig_1().epoch, steps)
else:
ckpt_path = "./Encoder_1-{}_{}.ckpt".format(TrainConfig_2().epoch, steps)
network_ = train(ckpt_path, trainConfig[i], rank_id=rank_id_,
rank_size=rank_size_, stage_=i)
ckpt = "./ENet-{}_{}.ckpt".format(TrainConfig_3().epoch, steps)
export_models(ckpt)
mox.file.copy_parallel("./", save_path)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment