Skip to content
Snippets Groups Projects
Unverified Commit c23c5c5d authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3752 [电子科技大学][高校贡献][Mindspore][PSPNet]-高性能预训练模型提交+sdk+mxase+modelart提交

Merge pull request !3752 from 孙文健/master
parents 9d307f28 6e147eae
No related branches found
No related tags found
No related merge requests found
Showing
with 1730 additions and 0 deletions
ARG FROM_IMAGE_NAME
FROM $FROM_IMAGE_NAME
COPY requirements.txt .
RUN pip install -r requirements.txt
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if [ $# -ne 2 ]
then
echo "Wrong parameter format."
echo "Usage:"
echo " bash $0 INPUT_AIR_PATH OUTPUT_OM_PATH_NAME"
echo "Example: "
echo " bash convert_om.sh models/0-150_1251.air models/0-150_1251.om"
exit 255
fi
input_air_path=$1
output_om_path=$2
echo "Input AIR file path: ${input_air_path}"
echo "Output OM file path: ${output_om_path}"
atc --input_format=NCHW \
--framework=1 \
--model=${input_air_path} \
--output=${output_om_path} \
--soc_version=Ascend310 \
--disable_reuse_memory=0 \
--output_type=FP32 \
--precision_mode=allow_fp32_to_fp16 \
--op_select_implmode=high_precision
\ No newline at end of file
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
model_dir=$2
data_dir=$3
function show_help() {
echo "Usage: docker_start.sh docker_image data_dir"
}
function param_check() {
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
show_help
exit 1
fi
if [ -z "${model_dir}" ]; then
echo "please input model_dir"
show_help
exit 1
fi
if [ -z "${data_dir}" ]; then
echo "please input data_dir"
show_help
exit 1
fi
}
param_check
docker run -it -u root \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${data_dir}:${data_dir} \
-v ${model_dir}:${model_dir} \
${docker_image} \
/bin/bash
cmake_minimum_required(VERSION 3.5.2)
project(PSPNet_mindspore)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
set(TARGET_MAIN PSPNet_mindspore)
set(ACL_LIB_PATH $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories($ENV{MX_SDK_HOME}/include)
include_directories($ENV{MX_SDK_HOME}/include/MxBase/postprocess/include)
include_directories($ENV{MX_SDK_HOME}/opensource/include)
include_directories($ENV{MX_SDK_HOME}/opensource/include/opencv4)
include_directories($ENV{MX_SDK_HOME}/opensource/include/gstreamer-1.0)
include_directories($ENV{MX_SDK_HOME}/opensource/include/glib-2.0)
include_directories($ENV{MX_SDK_HOME}/opensource/lib/glib-2.0/include)
link_directories($ENV{MX_SDK_HOME}/lib)
link_directories($ENV{MX_SDK_HOME}/opensource/lib)
link_directories($ENV{MX_SDK_HOME}/lib/modelpostprocessors)
add_compile_options(-std=c++11 -fPIC -fstack-protector-all -pie -Wno-deprecated-declarations)
add_compile_options("-DPLUGIN_NAME=${PLUGIN_NAME}")
add_compile_options("-Dgoogle=mindxsdk_private")
add_definitions(-DENABLE_DVPP_INTERFACE)
include_directories(${ACL_LIB_PATH}/include)
link_directories(${ACL_LIB_PATH}/lib64/)
add_executable(${TARGET_MAIN} src/main.cpp src/PSPNet.cpp)
target_link_libraries(${TARGET_MAIN} glog cpprest mxbase deeplabv3post opencv_world)
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# env
rm -r build
mkdir -p build
cd build || exit
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export ASCEND_HOME=/usr/local/Ascend/latest/
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function make_plugin() {
if ! cmake ..;
then
echo "cmake failed."
return 1
fi
if ! (make);
then
echo "make failed."
return 1
fi
return 0
}
if make_plugin;
then
echo "INFO: Build successfully."
else
echo "ERROR: Build failed."
fi
cd - || exit
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 3.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""the main sdk infer file"""
import argparse
import os
import cv2
import numpy as np
def _parse_args():
parser = argparse.ArgumentParser('mindspore PSPNet eval')
parser.add_argument('--data_lst', type=str, default='', help='list of val data')
parser.add_argument('--num_classes', type=int, default=21,
help='number of classes')
parser.add_argument('--result_path', type=str, default='./result',
help='the result path')
parser.add_argument('--name_txt', type=str,
default='',
help='the name_txt path')
args, _ = parser.parse_known_args()
return args
def intersectionAndUnion(output, target, K, ignore_index=255):
"""
'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
"""
assert (output.ndim in [1, 2, 3])
assert output.shape == target.shape
print("output.shape=", output.shape)
print("output.size=", output.size)
output = output.reshape(output.size).copy() # output= [0 0 0 ... 0 0 0]
target = target.reshape(target.size)
output[np.where(target == ignore_index)[0]] = ignore_index
intersection = output[np.where(output == target)[0]]
area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
area_output, _ = np.histogram(output, bins=np.arange(K + 1))
area_target, _ = np.histogram(target, bins=np.arange(K + 1))
# IoU = A+B -AnB
area_union = area_output + area_target - area_intersection # [107407 0 0 0 0 153165 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
return area_intersection, area_union, area_target
def cal_acc(data_list, pred_folder, classes, names):
""" Calculation evaluating indicator """
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()
with open(data_list) as f:
img_lst = f.readlines()
for i, line in enumerate(img_lst):
image_path, target_path = line.strip().split(' ')
image_name = image_path.split('/')[-1].split('.')[0]
pred = cv2.imread(os.path.join(pred_folder, image_name + '.png'), cv2.IMREAD_GRAYSCALE)
target = cv2.imread(target_path, cv2.IMREAD_GRAYSCALE)
intersection, union, target = intersectionAndUnion(pred, target, classes)
intersection_meter.update(intersection)
union_meter.update(union)
target_meter.update(target)
accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
print('Evaluating {0}/{1} on image {2}, accuracy {3:.4f}.'.format(
i + 1, len(data_list), image_name + '.png', accuracy))
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
mIoU = np.mean(iou_class) # 计算所有类别交集和并集之比的平均值
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
print('Eval result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))
for i in range(classes):
print('Class_{} result: iou/accuracy {:.4f}/{:.4f}, name: {}.'.format(
i, iou_class[i], accuracy_class[i], names[i]))
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.count = 0
self.sum = 0
self.avg = 0
self.val = 0
def update(self, val, n=1):
""" calculate the result """
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def main():
args = _parse_args()
gray_folder = os.path.join(args.result_path, "gray")
names = [line.rstrip('\n') for line in open(args.name_txt)]
cal_acc(args.data_lst, gray_folder, args.num_classes, names)
if __name__ == '__main__':
main()
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <string>
#include "PSPNet.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
APP_ERROR PSPNet::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
model_ = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
void PSPNet::DeInit() {
model_->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
}
APP_ERROR PSPNet::CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase) {
const uint32_t dataSize = imageMat.total() * imageMat.elemSize();
MxBase::MemoryData memoryDataDst(dataSize, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(imageMat.data, dataSize, MxBase::MemoryData::MEMORY_HOST);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape = {static_cast<uint32_t>(imageMat.size[0]),
static_cast<uint32_t>(imageMat.size[1]),
static_cast<uint32_t>(imageMat.size[2]),
static_cast<uint32_t>(imageMat.size[3])};
tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32);
return APP_ERR_OK;
}
APP_ERROR PSPNet::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> &outputs) {
auto dtypes = model_->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs.push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
APP_ERROR ret = model_->ModelInference(inputs, outputs, dynamicInfo);
if (ret != APP_ERR_OK) {
LogError << "ModelInference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR PSPNet::Process(const cv::Mat &imageMat, std::vector<MxBase::TensorBase>& outputs) {
MxBase::TensorBase tensorBase;
APP_ERROR ret = CVMatToTensorBase(imageMat, tensorBase);
if (ret != APP_ERR_OK) {
LogError << "CVMatToTensorBase failed, ret=" << ret << ".";
return ret;
}
std::vector<MxBase::TensorBase> inputs = {};
inputs.push_back(tensorBase);
ret = Inference(inputs, outputs);
if (ret != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef PSPNet_H
#define PSPNet_H
#include <vector>
#include <string>
#include <memory>
#include <fstream>
#include <opencv2/opencv.hpp>
#include <opencv2/core/mat.hpp>
#include <opencv2/imgproc.hpp>
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/PostProcessBases/PostProcessDataType.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
struct InitParam {
uint32_t deviceId;
std::string labelPath;
std::string modelPath;
uint32_t classNum;
uint32_t modelType;
std::string checkTensor;
uint32_t frameworkType;
};
class PSPNet {
public:
APP_ERROR Init(const InitParam &initParam);
void DeInit();
APP_ERROR CVMatToTensorBase(const cv::Mat &imageMat, MxBase::TensorBase &tensorBase);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs);
APP_ERROR Process(const cv::Mat &imageMat, std::vector<MxBase::TensorBase>& outputs);
private:
std::shared_ptr<MxBase::ModelInferenceProcessor> model_;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
};
#endif
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <opencv2/dnn/dnn.hpp>
#include "PSPNet.h"
#include "MxBase/Log/Log.h"
using std::vector;
using std::string;
using std::cout;
using std::endl;
using std::min;
std::vector<std::string> SplitLine(const std::string & str, const char* delim) {
std::vector<std::string> res;
if ("" == str) {
res.push_back(str);
return res;
}
char* p_strs = new char[str.length() + 1];
char* p_save = NULL;
strncpy(p_strs, str.c_str(), str.length());
char* part = strtok_r(p_strs, delim, &p_save);
while (part) {
std::string s = part;
res.push_back(s);
part = strtok_r(NULL, delim, &p_save);
}
return res;
}
std::vector<std::string> GetAllFiles(const std::string & root_path, const std::string & data_path) {
std::ifstream ifs;
std::vector<std::string> files;
ifs.open(data_path, std::ios::in);
if (!ifs.is_open()) {
std::cout << "File: " << data_path << " is not exist" << std::endl;
return files;
}
std::string buf;
while (getline(ifs, buf)) {
std::vector<std::string> line = SplitLine(buf, " ");
std::string img_path = line[0];
std::string msk_path = line[1];
files.emplace_back(img_path);
files.emplace_back(img_path);
}
ifs.close();
return files;
}
void SaveResult(const cv::Mat& binImg, const std::string& res_dir, const std::string & file_name) {
cv::Mat imageGrayC3 = cv::Mat::zeros(binImg.rows, binImg.cols, CV_8UC3);
std::vector<cv::Mat> planes;
for (int i = 0; i < 3; i++) {
planes.push_back(binImg);
}
cv::merge(planes, imageGrayC3);
uchar rgbColorMap[256*3] = {
0, 0, 0,
128, 0, 0,
0, 128, 0,
128, 128, 0,
0, 0, 128,
128, 0, 128,
0, 128, 128,
128, 128, 128,
64, 0, 0,
192, 0, 0,
64, 128, 0,
192, 128, 0,
64, 0, 128,
192, 0, 128,
64, 128, 128,
192, 128, 128,
0, 64, 0,
128, 64, 0,
0, 192, 0,
128, 192, 0,
0, 64, 128,
};
cv::Mat lut(1, 256, CV_8UC3, rgbColorMap);
cv::Mat imageColor;
cv::LUT(imageGrayC3, lut, imageColor);
cv::cvtColor(imageColor, imageColor, cv::COLOR_RGB2BGR);
std::string gray_path = res_dir + "gray";
std::string color_path = res_dir + "color";
std::string command = "mkdir -p " + gray_path;
system(command.c_str());
command = "mkdir -p " + color_path;
system(command.c_str());
std::cout << "save to " << gray_path << std::endl;
std::cout << "save to " << color_path << std::endl;
cv::imwrite(color_path + file_name, imageColor);
cv::imwrite(gray_path + file_name, binImg);
}
void ArgMax(const cv::Mat& Tensor, const cv::Mat& Res) {
uchar* pDst = Res.data;
float* tensordata = reinterpret_cast<float * >(Tensor.data);
int high = Tensor.rows;
int width = Tensor.cols;
int classes = Tensor.channels();
for (int i = 0; i < width; i++) {
for (int j = 0; j < high; j++) {
float max = 0;
uint8_t index = 0;
for (int k = 0; k < classes; k++) {
float res = *(tensordata + i*high*classes + j * 21 + k);
if (res > max) {
max = res;
index = k;
}
}
uint8_t gray = index;
*(pDst + i*high + j) = gray;
}
}
}
cv::Mat ScaleProcess(cv::Mat image,
int classes,
int crop_h,
int crop_w,
int ori_h,
int ori_w,
float stride_rate,
bool flip,
PSPNet& pspnet) {
int ori_h1 = image.rows;
int ori_w1 = image.cols;
int pad_h = (crop_h - ori_h1) > 0 ? (crop_h - ori_h1) : 0;
int pad_w = (crop_w - ori_w1) > 0 ? (crop_w - ori_w1) : 0;
int pad_h_half = static_cast<int>(pad_h / 2);
int pad_w_half = static_cast<int>(pad_w / 2);
cv::Scalar mean_value(0.485 * 255, 0.456 * 255, 0.406 * 255);
vector<double> std_value = {0.229 * 255, 0.224 * 255, 0.225 * 255};
cv::Mat padded_img;
padded_img.convertTo(padded_img, CV_32FC3);
if (pad_h > 0 || pad_w > 0) {
cv::copyMakeBorder(image,
padded_img,
pad_h_half,
pad_h - pad_h_half,
pad_w_half,
pad_w - pad_w_half,
cv::BORDER_CONSTANT,
mean_value);
} else {
padded_img = image;
}
int new_h = padded_img.rows;
int new_w = padded_img.cols;
int stride_h = ceil(static_cast<float>(crop_h * stride_rate));
int stride_w = ceil(static_cast<float>(crop_w * stride_rate));
int grid_h = static_cast<int>(ceil(static_cast<float>(new_h - crop_h) / stride_h) + 1);
int grid_w = static_cast<int>(ceil(static_cast<float>(new_w - crop_w) / stride_w) + 1);
cv::Mat count_crop = cv::Mat::zeros(new_h, new_w, CV_32FC1);
cv::Mat prediction = cv::Mat::zeros(new_h, new_w, CV_32FC(classes));
for (int index_h = 0; index_h < grid_h; index_h++) {
for (int index_w = 0; index_w < grid_w; index_w++) {
int start_x = min(index_w * stride_w + crop_w, new_w) - crop_w;
int start_y = min(index_h * stride_h + crop_h, new_h) - crop_h;
cv::Mat crop_roi(count_crop, cv::Rect(start_x, start_y, crop_w, crop_h));
crop_roi += 1; // area infer count
cv::Mat prediction_roi(prediction, cv::Rect(start_x, start_y, crop_w, crop_h));
cv::Mat image_roi = padded_img(cv::Rect(start_x, start_y, crop_w, crop_h)).clone();
image_roi = image_roi - mean_value;
std::vector<cv::Mat> rgb_channels(3);
cv::split(image_roi, rgb_channels);
for (int i = 0; i < 3; i++) {
rgb_channels[i].convertTo(rgb_channels[i], CV_32FC1, 1.0 / std_value[i]);
}
cv::merge(rgb_channels, image_roi);
cv::Mat blob = cv::dnn::blobFromImage(image_roi); // 473 473 3 ---> 3 473 473
std::vector<MxBase::TensorBase> outputs;
pspnet.Process(blob, outputs);
MxBase::TensorBase pred = outputs[0];
pred.ToHost();
float* data = reinterpret_cast<float* >(pred.GetBuffer());
if (flip) {
cv::Mat flipped_img;
std::vector<MxBase::TensorBase> flipped_outputs;
cv::flip(image_roi, flipped_img, 1);
cv::Mat blob_flip = cv::dnn::blobFromImage(flipped_img);
pspnet.Process(blob_flip, flipped_outputs);
MxBase::TensorBase flipped_pred = flipped_outputs[0];
flipped_pred.ToHost();
float* flipped_data = reinterpret_cast<float* >(flipped_pred.GetBuffer());
for (int i = 0; i < crop_h; i++) {
for (int j = 0; j < crop_w; j++) {
for (int k = 0; k < classes; k ++) {
float res = (*(data+k*crop_h*crop_w + i*crop_w + j) + // data[k][i][j]
*(flipped_data+k*crop_h*crop_w + i*crop_w + 472-j)) / 2;
*(data+k*crop_h*crop_w + i*crop_w + j) = res;
}
}
}
}
for (int i = 0; i < crop_h; i++) {
for (int j = 0; j < crop_w; j++) {
for (int k = 0; k < classes; k ++) {
float res = *(data+k*crop_h*crop_w + i*crop_w + j);
prediction_roi.ptr<float>(i)[j * classes + k] += res; // 21 473 473
}
}
}
}
}
std::vector<cv::Mat> cls_channels(classes);
cv::split(prediction, cls_channels);
for (int i = 0; i < classes; i++) {
cls_channels[i] = cls_channels[i] / count_crop;
}
cv::merge(cls_channels, prediction);
cv::Mat prediction_crop(prediction, cv::Rect(pad_w_half, pad_h_half, ori_w1, ori_h1));
cv::Mat final_pre;
cv::resize(prediction_crop, final_pre, cv::Size(ori_w, ori_h), cv::INTER_LINEAR);
return final_pre;
}
int main(int argc, char *argv[]) {
if (argc < 3) {
LogError << "Please input the om file path and dataset path";
}
std::string om_path = argv[1];
std::string dataset_path = argv[2];
InitParam initParam = {};
initParam.deviceId = 0;
initParam.modelPath = om_path;
PSPNet pspnet;
APP_ERROR ret = pspnet.Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "PSPNet init failed, ret=" << ret << ".";
return ret;
}
cout << "PSPNet Init Done." << endl;
std::string voc_val_list = dataset_path + "/voc_val_lst.txt";
int crop_h = 473;
int crop_w = 473; // crop image to 473
int classes = 21; // number of classes
float stride_rate = 2.0/3.0;
cout << "Start to get image" << endl;
auto all_files = GetAllFiles(dataset_path, voc_val_list);
if (all_files.empty()) {
std::cout << "ERROR: no input data." << std::endl;
return APP_ERR_INVALID_FILE;
}
for (int i = 0; i < all_files.size(); i = i + 2) {
std::string imgPath = all_files[i];
cout << "Process image : " << imgPath << endl;
cv::Mat image = cv::imread(imgPath, cv::IMREAD_COLOR);
cv::Mat image_RGB;
cv::cvtColor(image, image_RGB, cv::COLOR_BGR2RGB);
float ori_h = image.rows;
float ori_w = image.cols;
float long_size = 512; // The longer edge should align to 512
int new_h = long_size;
int new_w = long_size;
if (ori_h > ori_w) {
new_w = round(long_size / ori_h * ori_w);
} else {
new_h = round(long_size / ori_w * ori_h);
}
cv::Mat resized_img;
image_RGB.convertTo(image_RGB, CV_32FC3);
resized_img.convertTo(resized_img, CV_32FC3);
cv::resize(image_RGB, resized_img, cv::Size(new_w, new_h), cv::INTER_LINEAR);
cv::Mat pre = ScaleProcess(resized_img,
classes,
crop_h,
crop_w,
image.rows,
image.cols,
stride_rate, true, pspnet);
cv::Mat pre_max(pre.rows, pre.cols, CV_8UC1, cv::Scalar(0));
ArgMax(pre, pre_max);
size_t pos = imgPath.find_last_of("/");
std::string file_name(imgPath.begin() + pos, imgPath.end());
pos = file_name.find_last_of(".");
file_name.replace(file_name.begin() + pos, imgPath.end(), ".png");
SaveResult(pre_max, "cpp_res/", file_name);
}
pspnet.DeInit();
return APP_ERR_OK;
}
Pillow==9.1.0
onnx
opencv-python
\ No newline at end of file
{
"segmentation": {
"stream_config": {
"deviceId": "0"
},
"appsrc0": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "appsrc0",
"modelPath": "../data/models/PSPNet.om"
},
"factory": "mxpi_tensorinfer",
"next": "appsink0"
},
"appsink0": {
"props": {
"blocksize": "4096000"
},
"factory": "appsink"
}
}
}
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 3.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import os
def _parse_args():
parser = argparse.ArgumentParser('dataset list generator')
parser.add_argument("--data_dir", type=str, default='',
help='where dataset stored.')
return parser.parse_args()
def _get_data_list(data_list_file):
with open(data_list_file, mode='r') as f:
return f.readlines()
def main():
args = _parse_args()
data_dir = args.data_dir
voc_img_dir = os.path.join(data_dir, 'img')
voc_anno_gray_dir = os.path.join(data_dir, 'gray')
voc_train_txt = os.path.join(data_dir, 'train.txt')
voc_val_txt = os.path.join(data_dir, 'val.txt')
voc_train_lst_txt = os.path.join(data_dir, 'voc_train_lst.txt')
voc_val_lst_txt = os.path.join(data_dir, 'voc_val_lst.txt')
voc_train_data_lst = _get_data_list(voc_train_txt)
with open(voc_train_lst_txt, mode='w') as f:
for id_ in voc_train_data_lst:
id_ = id_.strip()
img_ = os.path.join(voc_img_dir, id_ + '.jpg')
anno_ = os.path.join(voc_anno_gray_dir, id_ + '.png')
f.write(img_ + ' ' + anno_ + '\n')
print('generating voc train list success.')
voc_val_data_lst = _get_data_list(voc_val_txt)
with open(voc_val_lst_txt, mode='w') as f:
for id_ in voc_val_data_lst:
id_ = id_.strip()
img_ = os.path.join(voc_img_dir, id_ + '.jpg')
anno_ = os.path.join(voc_anno_gray_dir, id_ + '.png')
f.write(img_ + ' ' + anno_ + '\n')
print('generating voc val list success.')
if __name__ == '__main__':
main()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 3.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""the main sdk infer file"""
import argparse
import base64
import json
import os
import cv2
import numpy as np
from PIL import Image
import MxpiDataType_pb2 as MxpiDataType
from StreamManagerApi import StreamManagerApi, InProtobufVector, MxProtobufIn, StringVector
def _parse_args():
parser = argparse.ArgumentParser('mindspore PSPNet eval')
parser.add_argument('--data_root', type=str, default='',
help='root path of val data')
parser.add_argument('--data_lst', type=str, default='',
help='list of val data')
parser.add_argument('--num_classes', type=int, default=21,
help='number of classes')
parser.add_argument('--result_path', type=str, default='./result',
help='the result path')
parser.add_argument('--color_txt', type=str,
default='',
help='the color path')
parser.add_argument('--name_txt', type=str,
default='',
help='the name_txt path')
parser.add_argument('--pipeline_path', type=str,
default='',
help='root path of pipeline file')
args, _ = parser.parse_known_args()
return args
def _cal_hist(a, b, n):
k = (a >= 0) & (a < n)
return np.bincount(
n * a[k].astype(np.int32) + b[k].astype(np.int32), minlength=n ** 2).reshape(n, n)
def _init_stream(pipeline_path):
"""
initial sdk stream before inference
Returns:
stream manager api
"""
stream_manager_api = StreamManagerApi()
ret = stream_manager_api.InitManager()
if ret != 0:
raise RuntimeError(f"Failed to init stream manager, ret={ret}")
with open(pipeline_path, 'rb') as f:
pipeline_str = f.read()
ret = stream_manager_api.CreateMultipleStreams(pipeline_str)
if ret != 0:
raise RuntimeError(f"Failed to create stream, ret={ret}")
return stream_manager_api
def _do_infer(stream_manager_api, data_input):
"""
send images into stream to do infer
Returns:
infer result, numpy array
"""
stream_name = b'segmentation'
unique_id = stream_manager_api.SendDataWithUniqueId(
stream_name, 0, data_input)
if unique_id < 0:
raise RuntimeError("Failed to send data to stream.")
print("success to send data to stream.")
timeout = 3000
infer_result = stream_manager_api.GetResultWithUniqueId(
stream_name, unique_id, timeout)
if infer_result.errorCode != 0:
raise RuntimeError(
"GetResultWithUniqueId error, errorCode=%d, errorMsg=%s" % (
infer_result.errorCode, infer_result.data.decode()))
load_dict = json.loads(infer_result.data.decode())
image_mask = load_dict["MxpiImageMask"][0]
data_str = base64.b64decode(image_mask['dataStr'])
shape = image_mask['shape']
return np.frombuffer(data_str, dtype=np.uint8).reshape(shape)
def send_source_data(appsrc_id, tensor, stream_name, stream_manager):
"""
Construct the input of the stream,
send inputs data to a specified stream based on streamName.
Returns:
bool: send data success or not
"""
tensor_package_list = MxpiDataType.MxpiTensorPackageList()
tensor_package = tensor_package_list.tensorPackageVec.add()
array_bytes = tensor.tobytes()
tensor_vec = tensor_package.tensorVec.add()
tensor_vec.deviceId = 0
tensor_vec.memType = 0
for i in tensor.shape:
tensor_vec.tensorShape.append(i)
tensor_vec.dataStr = array_bytes
tensor_vec.tensorDataSize = len(array_bytes)
key = "appsrc{}".format(appsrc_id).encode('utf-8')
protobuf_vec = InProtobufVector()
protobuf = MxProtobufIn()
protobuf.key = key
protobuf.type = b'MxTools.MxpiTensorPackageList'
protobuf.protobuf = tensor_package_list.SerializeToString()
protobuf_vec.push_back(protobuf)
ret = stream_manager.SendProtobuf(stream_name, appsrc_id, protobuf_vec)
if ret < 0:
print("Failed to send data to stream.")
return False
print("Success to send data to stream.")
return True
def get_result(stream_name, stream_manager_api):
"""
# Obtain the inference result by specifying streamName and uniqueId.
"""
key_vec = StringVector()
key_vec.push_back(b'mxpi_tensorinfer0')
infer_result = stream_manager_api.GetProtobuf(stream_name, 0, key_vec)
if infer_result.size() == 0:
print("inferResult is null")
return 0
if infer_result[0].errorCode != 0:
print("GetProtobuf error. errorCode=%d" %
(infer_result[0].errorCode))
return 0
result = MxpiDataType.MxpiTensorPackageList()
result.ParseFromString(infer_result[0].messageBuf)
vision_data_ = result.tensorPackageVec[0].tensorVec[0].dataStr
vision_data_ = np.frombuffer(vision_data_, dtype=np.float32)
shape = result.tensorPackageVec[0].tensorVec[0].tensorShape
mask_image = vision_data_.reshape(shape)
return mask_image[0]
def scale_process(stream_manager_api, image, classes, crop_h, crop_w, ori_h, ori_w,
mean, std, stride_rate, flip):
stream_name = b'segmentation'
print("image=", image.shape)
ori_h1, ori_w1, _ = image.shape
pad_h = max(crop_h - ori_h1, 0)
pad_w = max(crop_w - ori_w1, 0)
pad_h_half = int(pad_h / 2)
pad_w_half = int(pad_w / 2)
if pad_h > 0 or pad_w > 0:
image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half,
cv2.BORDER_CONSTANT, value=mean)
new_h, new_w, _ = image.shape
stride_h = int(np.ceil(crop_h * stride_rate))
stride_w = int(np.ceil(crop_w * stride_rate))
grid_h = int(np.ceil(float(new_h - crop_h) / stride_h) + 1)
grid_w = int(np.ceil(float(new_w - crop_w) / stride_w) + 1)
prediction_crop = np.zeros((new_h, new_w, classes), dtype=float)
count_crop = np.zeros((new_h, new_w), dtype=float)
print("grid_h, grid_w=", grid_h, grid_w)
for index_h in range(0, grid_h):
for index_w in range(0, grid_w):
s_h = index_h * stride_h
e_h = min(s_h + crop_h, new_h)
s_h = e_h - crop_h
s_w = index_w * stride_w
e_w = min(s_w + crop_w, new_w)
s_w = e_w - crop_w
print("s_h:e_h, s_w:e_w==", s_h, e_h, s_w, e_w)
image_crop = image[s_h:e_h, s_w:e_w].copy()
count_crop[s_h:e_h, s_w:e_w] += 1
mean = np.array(mean).astype(np.float32)
std = np.array(std).astype(np.float32)
image_crop = image_crop.transpose(2, 0, 1)
image_crop = (image_crop - mean[:, None, None]) / std[:, None, None]
image_crop = np.expand_dims(image_crop, 0)
if not send_source_data(0, image_crop, stream_name, stream_manager_api):
return 0
mask_image = get_result(stream_name, stream_manager_api)
mask_image = mask_image.transpose(1, 2, 0)
if flip:
image_crop = np.flip(image_crop, axis=3)
if not send_source_data(0, image_crop, stream_name, stream_manager_api):
return 0
mask_image_flip = get_result(stream_name, stream_manager_api).transpose(1, 2, 0)
mask_image_flip = np.flip(mask_image_flip, axis=1)
mask_image = (mask_image + mask_image_flip) / 2
prediction_crop[s_h:e_h, s_w:e_w, :] += mask_image
prediction_crop /= np.expand_dims(count_crop, 2) # (473, 512, 21)
print(f"prediction_crop = {pad_h_half}:{pad_h_half + ori_h1},{pad_w_half}:{pad_w_half + ori_w1}")
prediction_crop = prediction_crop[pad_h_half:pad_h_half + ori_h1, pad_w_half:pad_w_half + ori_w1]
prediction = cv2.resize(prediction_crop, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR)
return prediction
def check_makedirs(dir_name):
""" check file dir """
if not os.path.exists(dir_name):
os.makedirs(dir_name)
def colorize(gray, palette):
""" gray: numpy array of the label and 1*3N size list palette 列表调色板 """
color = Image.fromarray(gray.astype(np.uint8)).convert('P')
color.putpalette(palette)
return color
def intersectionAndUnion(output, target, K, ignore_index=255):
"""
'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
"""
assert (output.ndim in [1, 2, 3])
assert output.shape == target.shape
print("output.shape=", output.shape)
print("output.size=", output.size)
output = output.reshape(output.size).copy()
target = target.reshape(target.size)
output[np.where(target == ignore_index)[0]] = ignore_index
intersection = output[np.where(output == target)[0]]
area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
area_output, _ = np.histogram(output, bins=np.arange(K + 1))
area_target, _ = np.histogram(target, bins=np.arange(K + 1))
area_union = area_output + area_target - area_intersection
return area_intersection, area_union, area_target
def cal_acc(data_root, data_list, pred_folder, classes, names):
""" Calculation evaluating indicator """
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()
with open(data_list) as f:
img_lst = f.readlines()
for i, line in enumerate(img_lst):
image_path, target_path = line.strip().split(' ')
image_name = image_path.split('/')[-1].split('.')[0]
pred = cv2.imread(os.path.join(pred_folder, image_name + '.png'), cv2.IMREAD_GRAYSCALE)
target = cv2.imread(os.path.join(data_root, target_path), cv2.IMREAD_GRAYSCALE)
intersection, union, target = intersectionAndUnion(pred, target, classes)
intersection_meter.update(intersection)
union_meter.update(union)
target_meter.update(target)
accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
print('Evaluating {0}/{1} on image {2}, accuracy {3:.4f}.'.format(
i + 1, len(data_list), image_name + '.png', accuracy))
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
mIoU = np.mean(iou_class)
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
print('Eval result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))
for i in range(classes):
print('Class_{} result: iou/accuracy {:.4f}/{:.4f}, name: {}.'.format(
i, iou_class[i], accuracy_class[i], names[i]))
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.count = 0
self.sum = 0
self.avg = 0
self.val = 0
def update(self, val, n=1):
""" calculate the result """
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def main():
args = _parse_args()
value_scale = 255
mean = [0.485, 0.456, 0.406]
mean = [item * value_scale for item in mean]
std = [0.229, 0.224, 0.225]
std = [item * value_scale for item in std]
crop_h = 473
crop_w = 473
classes = 21
long_size = 512
gray_folder = os.path.join(args.result_path, 'gray')
color_folder = os.path.join(args.result_path, 'color')
colors = np.loadtxt(args.color_txt).astype('uint8')
names = [line.rstrip('\n') for line in open(args.name_txt)]
stream_manager_api = _init_stream(args.pipeline_path)
if not stream_manager_api:
exit(1)
with open(args.data_lst) as f:
img_lst = f.readlines()
os.makedirs(args.result_path, exist_ok=True)
for _, line in enumerate(img_lst):
img_path, msk_path = line.strip().split(' ')
print("--------------------------------------------")
img_path = os.path.join(args.data_root, img_path)
print("img_path:", img_path)
print("msk_paty:", msk_path)
ori_image = cv2.imread(img_path, cv2.IMREAD_COLOR)
ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB)
ori_image = np.float32(ori_image)
ori_h, ori_w, _ = ori_image.shape
print("ori_h=", ori_h)
print("ori_w=", ori_w)
new_h = long_size
new_w = long_size
if ori_h > ori_w:
new_w = round(long_size / float(ori_h) * ori_w)
else:
new_h = round(long_size / float(ori_w) * ori_h)
print(f"new_w, new_h = ({new_w}, {new_h})")
image = cv2.resize(ori_image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
prediction = scale_process(stream_manager_api, image, classes, crop_h, crop_w, ori_h, ori_w, mean, std=std,
stride_rate=2 / 3, flip=True)
print("prediction0.shape=", prediction.shape)
prediction = np.argmax(prediction, axis=2)
check_makedirs(gray_folder)
check_makedirs(color_folder)
gray = np.uint8(prediction)
color = colorize(gray, colors)
image_name = img_path.split('/')[-1].split('.')[0]
gray_path = os.path.join(gray_folder, image_name + '.png')
color_path = os.path.join(color_folder, image_name + '.png')
cv2.imwrite(gray_path, gray)
color.save(color_path)
stream_manager_api.DestroyAllStreams()
cal_acc(args.data_root, args.data_lst, gray_folder, classes, names)
if __name__ == '__main__':
main()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 3.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from __future__ import print_function
import os
import sys
import glob
import scipy.io
from PIL import Image as PILImage
def main():
input_path, output_path = process_arguments(sys.argv)
if os.path.isdir(input_path) and os.path.isdir(output_path):
# glob.blob 返回所有匹配的文件路径列表
mat_files = glob.glob(os.path.join(input_path, '*.mat'))
convert_mat2png(mat_files, output_path)
else:
helps('Input or output path does not exist!\n')
def mat2png_hariharan(mat_file, key='GTcls'):
mat = scipy.io.loadmat(mat_file, mat_dtype=True, squeeze_me=True, struct_as_record=False)
return mat[key].Segmentation
def process_arguments(argv):
num_args = len(argv)
input_path = None
output_path = None
if num_args == 3:
input_path = argv[1]
output_path = argv[2]
else:
help()
return input_path, output_path
def convert_mat2png(mat_files, output_path):
if not mat_files:
help('Input directory does not contain any Matlab files!\n')
for mat in mat_files:
numpy_img = mat2png_hariharan(mat)
pil_img = PILImage.fromarray(numpy_img)
pil_img.save(os.path.join(output_path, modify_image_name(mat, 'png')))
# Extract name of image from given path, replace its extension with specified one
# and return new name only, not path.
def modify_image_name(path, ext):
return os.path.basename(path).split('.')[0] + '.' + ext
def helps(msg=''):
print(msg +
'Usage: python mat2png.py INPUT_PATH OUTPUT_PATH\n'
'INPUT_PATH denotes path containing Matlab files for conversion.\n'
'OUTPUT_PATH denotes path where converted Png files ar going to be saved.'
, file=sys.stderr)
exit()
if __name__ == '__main__':
main()
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
data_root=$1
data_lst=$2
num_classes=$3
result_path=$4
color_txt=$5
name_txt=$6
pipeline_path=$7
set -e
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins
#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
python3 main.py \
--data_root=$data_root \
--data_lst=$data_lst \
--num_classes=$num_classes \
--result_path=$result_path \
--color_txt=$color_txt \
--name_txt=$name_txt \
--pipeline_path=$pipeline_path
exit 0
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" train PSPNet and get checkpoint files """
import os
import ast
import argparse
from src.utils import functions_args as fa
from src.model import pspnet
from src.model.cell import Aux_CELoss_Cell
from src.dataset import pt_dataset
from src.dataset import pt_transform as transform
from src.utils.lr import poly_lr
from src.utils.metric_and_evalcallback import pspnet_metric
import mindspore
from mindspore import nn
from mindspore import context
from mindspore import Tensor
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.communication import init
from mindspore.context import ParallelMode
from mindspore.train.callback import Callback
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.loss_scale_manager import FixedLossScaleManager
import mindspore.dataset as ds
import moxing as mox
set_seed(1234)
rank_id = int(os.getenv('RANK_ID'))
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
Model_Art = True
def get_parser():
"""
Read parameter file
-> for ADE20k: ./src/config/voc2012_pspnet50.yaml
-> for voc2012: ./src/config/voc2012_pspnet50.yaml
"""
global Model_Art
parser = argparse.ArgumentParser(description='MindSpore Semantic Segmentation')
parser.add_argument('--config', type=str, required=True,
help='config file')
parser.add_argument('--model_art', type=ast.literal_eval, default=True,
help='train on modelArts or not, default: True')
parser.add_argument('--obs_data_path', type=str, default='',
help='dataset path in obs')
parser.add_argument('--epochs', type=int, default='',
help='epochs for training')
parser.add_argument('--obs_save', type=str, default='',
help='.ckpt file save path in obs')
parser.add_argument('opts', help='see ./src/config/voc2012_pspnet50.yaml for all options', default=None,
nargs=argparse.REMAINDER)
#export AIR model
parser.add_argument('--device_target', type=str, default='Ascend',
choices=['Ascend', 'GPU'],
help='device id of GPU or Ascend. (Default: Ascend)')
parser.add_argument('--file_name', type=str, default='PSPNet', help='export file name')
parser.add_argument('--file_format', type=str, default="AIR",
choices=['AIR', 'MINDIR'],
help='export model type')
parser.add_argument('--num_classes', type=int, default=21, help='number of classes')
args_ = parser.parse_args()
if args_.model_art:
mox.file.shift('os', 'mox')
Model_Art = True
root = "/cache/"
local_data_path = os.path.join(root, 'data')
print("local_data_path=", local_data_path)
print("########### Downloading data from OBS #############")
mox.file.copy_parallel(src_url=args_.obs_data_path, dst_url=local_data_path)
print('########### data downloading is completed ############')
assert args_.config is not None
cfg = fa.load_cfg_from_cfg_file(args_.config)
if args_.opts is not None:
cfg = fa.merge_cfg_from_list(cfg, args_.opts)
cfg.epochs = args_.epochs #使用modelarts传参代替yaml文件中的参数
cfg.obs_save = args_.obs_save
cfg.config = args_.config
return cfg
def _get_last_ckpt(ckpt_dir):
ckpt_files = [ckpt_file for ckpt_file in os.listdir(ckpt_dir)
if ckpt_file.endswith('.ckpt')]
if not ckpt_files:
print("No ckpt file found.")
return None
return os.path.join(ckpt_dir, sorted(ckpt_files)[-1])
def _export_air(ckpt_dir):
ckpt_file = _get_last_ckpt(ckpt_dir)
if not ckpt_file:
return
print(os.path.abspath("export.py"))
print(os.path.realpath("export.py"))
export_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "export.py")
print("export_file=", export_file)
file_name = os.path.join(ckpt_dir, "PSPNet")
print("file_name=", file_name)
yamlpath = args.config
print("args.config=", args.config)
cmd = ["python", export_file,
f"--yaml_path={yamlpath}",
f"--ckpt_file={ckpt_file}",
f"--file_name={file_name}",
f"--file_format={'AIR'}",
f"--device_target={'Ascend'}"]
print(f"Start exporting AIR, cmd = {' '.join(cmd)}.")
process = subprocess.Popen(cmd, shell=False)
process.wait()
class EvalCallBack(Callback):
"""Precision verification using callback function."""
def __init__(self, models, eval_dataset, eval_per_epochs, epochs_per_eval):
super(EvalCallBack, self).__init__()
self.models = models
self.eval_dataset = eval_dataset
self.eval_per_epochs = eval_per_epochs
self.epochs_per_eval = epochs_per_eval
def epoch_end(self, run_context):
""" evaluate during training """
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
if cur_epoch % self.eval_per_epochs == 0:
val_loss = self.models.eval(self.eval_dataset, dataset_sink_mode=False)
self.epochs_per_eval["epoch"].append(cur_epoch)
self.epochs_per_eval["val_loss"].append(val_loss)
print(val_loss)
def get_dict(self):
""" get eval dict"""
return self.epochs_per_eval
def create_dataset(purpose, data_root, data_list, batch_size=8):
""" get dataset """
value_scale = 255
mean = [0.485, 0.456, 0.406]
mean = [item * value_scale for item in mean]
std = [0.229, 0.224, 0.225]
std = [item * value_scale for item in std]
if purpose == 'train':
cur_transform = transform.Compose([
transform.RandScale([0.5, 2.0]),
transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
transform.RandomGaussianBlur(),
transform.RandomHorizontalFlip(),
transform.Crop([473, 473], crop_type='rand', padding=mean, ignore_label=255),
transform.Normalize(mean=mean, std=std, is_train=True)])
data = pt_dataset.SemData(
split=purpose, data_root=data_root,
data_list=data_list,
transform=cur_transform,
data_name=args.data_name
)
dataset = ds.GeneratorDataset(data, column_names=["data", "label"],
shuffle=True, num_shards=device_num, shard_id=rank_id)
dataset = dataset.batch(batch_size, drop_remainder=False)
else:
cur_transform = transform.Compose([
transform.Crop([473, 473], crop_type='center', padding=mean, ignore_label=255),
transform.Normalize(mean=mean, std=std, is_train=True)])
data = pt_dataset.SemData(
split=purpose, data_root=data_root,
data_list=data_list,
transform=cur_transform,
data_name=args.data_name
)
dataset = ds.GeneratorDataset(data, column_names=["data", "label"],
shuffle=False, num_shards=device_num, shard_id=rank_id)
dataset = dataset.batch(batch_size, drop_remainder=False)
return dataset
def psp_train():
""" Train process """
if Model_Art:
pre_path = args.art_pretrain_path
data_path = args.art_data_root
train_list_path = args.art_train_list
val_list_path = args.art_val_list
print("val_list_path=", val_list_path)
else:
pre_path = args.pretrain_path
data_path = args.data_root
train_list_path = args.train_list
val_list_path = args.val_list
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, gradients_mean=True)
init()
PSPNet = pspnet.PSPNet(
feature_size=args.feature_size, num_classes=args.classes, backbone=args.backbone, pretrained=True,
pretrained_path=pre_path, aux_branch=True, deep_base=True,
BatchNorm_layer=nn.SyncBatchNorm
)
train_dataset = create_dataset('train', data_path, train_list_path)
validation_dataset = create_dataset('val', data_path, val_list_path)
else:
PSPNet = pspnet.PSPNet(
feature_size=args.feature_size, num_classes=args.classes, backbone=args.backbone, pretrained=True,
pretrained_path=pre_path, aux_branch=True, deep_base=True
)
train_dataset = create_dataset('train', data_path, train_list_path)
validation_dataset = create_dataset('val', data_path, val_list_path)
# loss
train_net_loss = Aux_CELoss_Cell(args.classes, ignore_label=255)
steps_per_epoch = train_dataset.get_dataset_size() # Return the number of batches in an epoch.
total_train_steps = steps_per_epoch * args.epochs
if device_num > 1:
lr_iter = poly_lr(args.art_base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
lr_iter_ten = poly_lr(args.art_base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
else:
lr_iter = poly_lr(args.base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
lr_iter_ten = poly_lr(args.base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9)
pretrain_params = list(filter(lambda x: 'backbone' in x.name, PSPNet.trainable_params()))
cls_params = list(filter(lambda x: 'backbone' not in x.name, PSPNet.trainable_params()))
group_params = [{'params': pretrain_params, 'lr': Tensor(lr_iter, mindspore.float32)},
{'params': cls_params, 'lr': Tensor(lr_iter_ten, mindspore.float32)}]
opt = nn.SGD(
params=group_params,
momentum=0.9,
weight_decay=0.0001,
loss_scale=1024,
)
# loss scale
manager_loss_scale = FixedLossScaleManager(1024, False)
m_metric = {'val_loss': pspnet_metric(args.classes, 255)}
model = Model(
PSPNet, train_net_loss, optimizer=opt, loss_scale_manager=manager_loss_scale, metrics=m_metric
)
time_cb = TimeMonitor(data_size=steps_per_epoch)
loss_cb = LossMonitor()
epoch_per_eval = {"epoch": [], "val_loss": []}
eval_cb = EvalCallBack(model, validation_dataset, 10, epoch_per_eval)
config_ck = CheckpointConfig(
save_checkpoint_steps=10 * steps_per_epoch,
keep_checkpoint_max=12,
)
if Model_Art:
os.path.join('/cache/', 'save')
ckpoint_cb = ModelCheckpoint(
prefix=args.prefix, directory='/cache/save/', config=config_ck #+ str(device_id)
)
else:
ckpoint_cb = ModelCheckpoint(
prefix=args.prefix, directory=args.save_dir, config=config_ck
)
model.train(
args.epochs, train_dataset, callbacks=[loss_cb, time_cb, ckpoint_cb, eval_cb], dataset_sink_mode=True,
)
dict_eval = eval_cb.get_dict()
val_num_list = dict_eval["epoch"]
val_value = dict_eval["val_loss"]
for i in range(len(val_num_list)):
print(val_num_list[i], " : ", val_value[i])
if Model_Art:
print("######### upload to OBS #########")
mox.file.shift('os', 'mox')
mox.file.copy_parallel(src_url="/cache/save", dst_url=args.obs_save)
if __name__ == "__main__":
args = get_parser()
print(args.obs_save)
psp_train()
_export_air(args.obs_save)
mox.file.shift('os', 'mox')
mox.file.copy_parallel(src_url="/cache/save", dst_url=args.obs_save)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment