Skip to content
Snippets Groups Projects
Unverified Commit 3a81c03d authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2822 [华为大学][自研贡献][MindSpore]DNCNN-移动official下C类代码至reserch

Merge pull request !2822 from 迎接光辉岁月/master
parents 77b2ca94 02dfdcee
No related branches found
No related tags found
No related merge requests found
Showing
with 1054 additions and 0 deletions
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
model_path=$1
output_model_name=$2
atc --model=$model_path \
--framework=1 \
--output=$output_model_name \
--input_format=NCHW \
--soc_version=Ascend310 \
--output_type=FP32
\ No newline at end of file
{
"DnCNN": {
"stream_config": {
"deviceId": "0"
},
"appsrc0": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "appsrc0",
"modelPath": "../data/model/DnCNN.om"
},
"factory": "mxpi_tensorinfer",
"next": "mxpi_dataserialize0"
},
"mxpi_dataserialize0": {
"props": {
"outputDataKeys": "mxpi_tensorinfer0"
},
"factory": "mxpi_dataserialize",
"next": "appsink0"
},
"appsink0": {
"props": {
"blocksize": "4096000"
},
"factory": "appsink"
}
}
}
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
export mnist dataset to bin.
"""
import os
import glob
import argparse
import PIL
import numpy as np
import cv2
import mindspore
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
def ResziePadding(img, fixed_side=256):
h, w = img.shape[0], img.shape[1]
scale = max(w, h) / float(fixed_side)
new_w, new_h = int(w / scale), int(h / scale)
resize_img = cv2.resize(img, (new_w, new_h))
if new_w % 2 != 0 and new_h % 2 == 0:
top, bottom, left, right = (fixed_side - new_h) // 2, (fixed_side - new_h) // 2, (
fixed_side - new_w) // 2 + 1, (fixed_side - new_w) // 2
elif new_w % 2 == 0 and new_h % 2 != 0:
top, bottom, left, right = (fixed_side - new_h) // 2 + 1, (fixed_side - new_h) // 2, (
fixed_side - new_w) // 2, (fixed_side - new_w) // 2
elif new_w % 2 == 0 and new_h % 2 == 0:
top, bottom, left, right = (fixed_side - new_h) // 2, (fixed_side - new_h) // 2, (fixed_side - new_w) // 2, (
fixed_side - new_w) // 2
else:
top, bottom, left, right = (fixed_side - new_h) // 2 + 1, (fixed_side - new_h) // 2, (
fixed_side - new_w) // 2 + 1, (fixed_side - new_w) // 2
pad_img = cv2.copyMakeBorder(resize_img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0])
return pad_img
class DnCNN_eval_Dataset():
def __init__(self, dataset_path, task_type, noise_level):
self.im_list = []
self.im_list.extend(glob.glob(os.path.join(dataset_path, "*.png")))
self.im_list.extend(glob.glob(os.path.join(dataset_path, "*.bmp")))
self.im_list.extend(glob.glob(os.path.join(dataset_path, "*.jpg")))
self.task_type = task_type
self.noise_level = noise_level
def __getitem__(self, i):
img = cv2.imread(self.im_list[i], 0)
if self.task_type == "denoise":
noisy = self.add_noise(img, self.noise_level)
elif self.task_type == "super-resolution":
h, w = img.shape
noisy = cv2.resize(img, (int(w/self.noise_level), int(h/self.noise_level)))
noisy = cv2.resize(noisy, (w, h))
elif self.task_type == "jpeg-deblock":
noisy = self.jpeg_compression(img, self.noise_level)
#add channel dimension
noisy = noisy[np.newaxis, :, :]
noisy = noisy / 255.0
noisy = ResziePadding(noisy[0])
img = ResziePadding(img)
return noisy, img
def __len__(self):
return len(self.im_list)
def add_noise(self, im, sigma):
gauss = np.random.normal(0, sigma, im.shape)
noisy = im + gauss
noisy = np.clip(noisy, 0, 255)
noisy = noisy.astype('float32')
return noisy
def jpeg_compression(self, img, severity):
im_pil = PIL.Image.fromarray(img)
output = io.BytesIO()
im_pil.save(output, 'JPEG', quality=severity)
im_pil = PIL.Image.open(output)
img_np = np.asarray(im_pil)
return img_np
def create_eval_dataset(data_path, task_type, noise_level, batch_size=1):
# define dataset
dataset = DnCNN_eval_Dataset(data_path, task_type, noise_level)
dataloader = ds.GeneratorDataset(dataset, ["noisy", "clear"])
# apply map operations on images
dataloader = dataloader.map(input_columns="noisy", operations=C.TypeCast(mindspore.float32))
dataloader = dataloader.map(input_columns="clear", operations=C.TypeCast(mindspore.uint8))
dataloader = dataloader.batch(batch_size, drop_remainder=False)
return dataloader
def parse_args():
parser = argparse.ArgumentParser(description='MNIST to bin')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--dataset_dir', type=str, default='', help='dataset path')
parser.add_argument('--save_dir', type=str, default='', help='path to save bin file')
parser.add_argument('--batch_size', type=int, default=1, help='batch size for bin')
parser.add_argument('--model_type', type=str, default='DnCNN-S', \
choices=['DnCNN-S', 'DnCNN-B', 'DnCNN-3'], help='type of DnCNN')
parser.add_argument('--noise_type', type=str, default="denoise", \
choices=["denoise", "super-resolution", "jpeg-deblock"], help='trained ckpt')
parser.add_argument('--noise_level', type=int, default=25, help='trained ckpt')
args_, _ = parser.parse_known_args()
return args_
if __name__ == "__main__":
args = parse_args()
os.environ["RANK_SIZE"] = '1'
os.environ["RANK_ID"] = '0'
device_id = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0
mnist_path = os.path.join(args.dataset_dir, 'test')
batchsize = args.batch_size
save_dir = os.path.join(args.save_dir, 'dncnn_infer_data')
folder_noisy = os.path.join(save_dir, 'dncnn_bs_' + str(batchsize) + '_noisy_bin')
folder_clear = os.path.join(save_dir, 'dncnn_bs_' + str(batchsize) + '_clear_bin')
if not os.path.exists(folder_clear):
os.makedirs(folder_clear)
if not os.path.exists(folder_noisy):
os.makedirs(folder_noisy)
ds = create_eval_dataset(args.dataset_dir, args.noise_type, args.noise_level, batch_size=args.batch_size)
iter_num = 0
label_file = os.path.join(save_dir, './dncnn_bs_' + str(batchsize) + '_label.txt')
with open(label_file, 'w') as f:
for data in ds.create_dict_iterator():
noisy_img = data['noisy']
clear_img = data['clear']
noisy_file_name = "dncnn_noisy_" + str(iter_num) + ".bin"
noisy_file_path = folder_noisy + "/" + noisy_file_name
noisy_img.asnumpy().tofile(noisy_file_path)
clear_file_name = "dncnn_clear_" + str(iter_num) + ".bin"
clear_file_path = folder_clear + "/" + clear_file_name
clear_img.asnumpy().tofile(clear_file_path)
f.write(noisy_file_name + ',' + clear_file_name + '\n')
iter_num += 1
print("=====iter_num:{}=====".format(iter_num))
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
model_dir=$2
function show_help() {
echo "Usage: docker_start.sh docker_image model_dir data_dir"
}
function param_check() {
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
show_help
exit 1
fi
if [ -z "${model_dir}" ]; then
echo "please input model_dir"
show_help
exit 1
fi
}
param_check
docker run -it -u root \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${model_dir}:${model_dir} \
${docker_image} \
/bin/bash
cmake_minimum_required(VERSION 3.5.2)
project(DnCNN)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
set(TARGET_MAIN DnCNN)
set(ACL_LIB_PATH $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories($ENV{MX_SDK_HOME}/include)
include_directories($ENV{MX_SDK_HOME}/opensource/include)
include_directories($ENV{MX_SDK_HOME}/opensource/include/opencv4)
include_directories($ENV{MX_SDK_HOME}/opensource/include/gstreamer-1.0)
include_directories($ENV{MX_SDK_HOME}/opensource/include/glib-2.0)
include_directories($ENV{MX_SDK_HOME}/opensource/lib/glib-2.0/include)
link_directories($ENV{MX_SDK_HOME}/lib)
link_directories($ENV{MX_SDK_HOME}/opensource/lib/)
add_compile_options(-std=c++11 -fPIC -fstack-protector-all -pie -Wno-deprecated-declarations)
add_compile_options("-DPLUGIN_NAME=${PLUGIN_NAME}")
add_compile_options("-Dgoogle=mindxsdk_private")
add_definitions(-DENABLE_DVPP_INTERFACE)
include_directories(${ACL_LIB_PATH}/include)
link_directories(${ACL_LIB_PATH}/lib64/)
add_executable(${TARGET_MAIN} src/main.cpp src/DnCNN.cpp)
target_link_libraries(${TARGET_MAIN} ${TARGET_LIBRARY} glog cpprest mxbase libascendcl.so)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export ASCEND_VERSION=ascend-toolkit/latest
export ARCH_PATTERN=.
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib/modelpostprocessors:${LD_LIBRARY_PATH}
mkdir -p build
cd build || exit
function make_plugin() {
if ! cmake ..;
then
echo "cmake failed."
return 1
fi
if ! (make);
then
echo "make failed."
return 1
fi
return 0
}
if make_plugin;
then
echo "INFO: Build successfully."
else
echo "ERROR: Build failed."
fi
cd - || exit
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "DnCNN.h"
#include <cstdlib>
#include <memory>
#include <string>
#include <cmath>
#include <vector>
#include <algorithm>
#include <queue>
#include <utility>
#include <fstream>
#include <map>
#include <iostream>
#include "acl/acl.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
namespace {
const int FLOAT_SIZE = 4;
}
void WriteResult(const int index, const std::vector<MxBase::TensorBase> &outputs) {
std::string homePath = "./result";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
outputSize = outputs[i].GetSize();
std::string outFileName = homePath + "/output_" + std::to_string(index) + ".bin";
float *boxes = reinterpret_cast<float *>(outputs[i].GetBuffer());
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(boxes, sizeof(float), 256 * 256, outputFile);
fclose(outputFile);
outputFile = nullptr;
}
}
APP_ERROR DnCNN::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
model_DnCNN = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_DnCNN->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR DnCNN::DeInit() {
model_DnCNN->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR DnCNN::ReadBin(const std::string &path, std::vector<std::vector<float>> &dataset) {
std::ifstream inFile(path, std::ios::binary);
float data[256*256];
inFile.read(reinterpret_cast<char *>(&data), sizeof(data));
std::vector<float> temp(data, data+sizeof(data) / sizeof(data[0]));
dataset.push_back(temp);
return APP_ERR_OK;
}
APP_ERROR DnCNN::VectorToTensorBase(const std::vector<std::vector<float>> &input,
MxBase::TensorBase &tensorBase) {
uint32_t dataSize = 4*1*256*256;
float *metaFeatureData = new float[dataSize];
uint32_t idx = 0;
for (size_t bs = 0; bs < input.size(); bs++) {
for (size_t c = 0; c < input[bs].size(); c++) {
metaFeatureData[idx++] = input[bs][c];
}
}
MxBase::MemoryData memoryDataDst(dataSize * FLOAT_SIZE, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void *>(metaFeatureData), dataSize * FLOAT_SIZE,
MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
std::vector<uint32_t> shape = {1, 1, 256, 256};
tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32);
return APP_ERR_OK;
}
APP_ERROR DnCNN::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> &outputs) {
auto dtypes = model_DnCNN->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs.push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret = model_DnCNN->ModelInference(inputs, outputs, dynamicInfo);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "ModelInference DnCNN failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR DnCNN::Process(const int index, const std::string &image_path,
const InitParam &initParam, std::vector<int> &outputs) {
std::vector<MxBase::TensorBase> inputs = {};
std::vector<MxBase::TensorBase> outputs_tb = {};
std::vector<std::vector<float>> image_data;
APP_ERROR ret = ReadBin(image_path, image_data);
if (ret != APP_ERR_OK) {
LogError << "ToTensorBase failed, ret=" << ret << ".";
return ret;
}
MxBase::TensorBase tensorBase;
APP_ERROR ret1 = VectorToTensorBase(image_data, tensorBase);
if (ret1 != APP_ERR_OK) {
LogError << "ToTensorBase failed, ret=" << ret1 << ".";
return ret1;
}
inputs.push_back(tensorBase);
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret3 = Inference(inputs, outputs_tb);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
inferCostTimeMilliSec += costMs;
if (ret3 != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret3 << ".";
return ret3;
}
if (!outputs_tb[0].IsHost()) {
outputs_tb[0].ToHost();
}
WriteResult(index, outputs_tb);
}
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_DnCNN_H
#define MXBASE_DnCNN_H
#include <memory>
#include <string>
#include <vector>
#include "acl/acl.h"
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
#include "MxBase/CV/Core/DataType.h"
struct InitParam {
uint32_t deviceId;
bool checkTensor;
std::string modelPath;
};
class DnCNN {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR VectorToTensorBase(const std::vector<std::vector<float>> &input, MxBase::TensorBase &tensorBase);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs);
APP_ERROR Process(const int i, const std::string &image_path, const InitParam &initParam, std::vector<int> &outputs);
APP_ERROR ReadBin(const std::string &path, std::vector<std::vector<float>> &dataset);
// get infer time
double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
private:
std::shared_ptr<MxBase::ModelInferenceProcessor> model_DnCNN;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
// infer time
double inferCostTimeMilliSec = 0.0;
};
#endif
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <dirent.h>
#include <fstream>
#include "MxBase/Log/Log.h"
#include "DnCNN.h"
void SplitString(const std::string &s, std::vector<std::string> *v, const std::string &c) {
std::string::size_type pos1, pos2;
pos2 = s.find(c);
pos1 = 0;
while (std::string::npos != pos2) {
v->push_back(s.substr(pos1, pos2 - pos1));
pos1 = pos2 + c.size();
pos2 = s.find(c, pos1);
}
if (pos1 != s.length()) {
v->push_back(s.substr(pos1));
}
}
APP_ERROR ReadImagesPath(const std::string &path, std::vector<std::string> *imagesPath) {
std::ifstream inFile;
inFile.open(path, std::ios_base::in);
std::string line;
// Check images path file validity
if (inFile.fail()) {
LogError << "Failed to open annotation file: " << path;
return APP_ERR_COMM_OPEN_FAIL;
}
std::vector<std::string> vectorStr_path;
std::string splitStr_path = ",";
// construct label map
while (std::getline(inFile, line)) {
vectorStr_path.clear();
SplitString(line, &vectorStr_path, splitStr_path);
std::string str_path = vectorStr_path[0];
imagesPath->push_back(str_path);
}
inFile.close();
return APP_ERR_OK;
}
int main(int argc, char* argv[]) {
InitParam initParam = {};
initParam.deviceId = 0;
initParam.checkTensor = true;
initParam.modelPath = "../data/model/DnCNN.om";
std::string dataPath = "../data/dncnn_infer_data/dncnn_bs_1_noisy_bin/";
std::string annoPath = "../data/dncnn_infer_data/dncnn_bs_1_label.txt";
auto model_DnCNN = std::make_shared<DnCNN>();
APP_ERROR ret = model_DnCNN->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "Tagging init failed, ret=" << ret << ".";
return ret;
}
std::vector<std::string> imagesPath;
ret = ReadImagesPath(annoPath, &imagesPath);
if (ret != APP_ERR_OK) {
model_DnCNN->DeInit();
return ret;
}
int img_size = imagesPath.size();
std::vector<int> outputs;
for (int i=0; i < img_size; i++) {
ret = model_DnCNN->Process(i, dataPath + imagesPath[i], initParam, outputs);
if (ret !=APP_ERR_OK) {
LogError << "DnCNN process failed, ret=" << ret << ".";
model_DnCNN->DeInit();
return ret;
}
}
model_DnCNN->DeInit();
double total_time = model_DnCNN->GetInferCostMilliSec() / 1000;
LogInfo<< "inferance total cost time: "<< total_time<< ", FPS: "<< img_size/total_time;
return APP_ERR_OK;
}
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" main.py """
import argparse
import os
from StreamManagerApi import StreamManagerApi, StringVector
from StreamManagerApi import MxDataInput, InProtobufVector, MxProtobufIn
import MxpiDataType_pb2 as MxpiDataType
import numpy as np
shape = [1, 1, 256, 256]
def parse_args(parsers):
"""
Parse commandline arguments.
"""
parsers.add_argument('--images_txt_path', type=str,
default="../data/dncnn_infer_data/dncnn_bs_1_label.txt",
help='image text')
return parsers
def read_file_list(input_file):
"""
:param infer file content:
1.bin 0
2.bin 2
...
:return image path list, label list
"""
noisy_image_file = []
clear_image_file = []
if not os.path.exists(input_file):
print('input file does not exists.')
with open(input_file, "r") as fs:
for line in fs.readlines():
line = line.strip('\n').split(',')
noisy_file_name = line[0]
clear_file_name = line[1]
noisy_image_file.append(noisy_file_name)
clear_image_file.append(clear_file_name)
return noisy_image_file, clear_image_file
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Om DnCNN Inference')
parser = parse_args(parser)
args, _ = parser.parse_known_args()
# init stream manager
stream_manager = StreamManagerApi()
ret = stream_manager.InitManager()
if ret != 0:
print("Failed to init Stream manager, ret=%s" % str(ret))
exit()
# create streams by pipeline config file
with open("../data/config/DnCNN.pipeline", 'rb') as f:
pipeline = f.read()
ret = stream_manager.CreateMultipleStreams(pipeline)
if ret != 0:
print("Failed to create Stream, ret=%s" % str(ret))
exit()
# Construct the input of the stream
res_dir_name = 'result'
if not os.path.exists(res_dir_name):
os.makedirs(res_dir_name)
noisy_image_files, clear_image_files = read_file_list(args.images_txt_path)
img_size = len(noisy_image_files)
results = []
for idx, file in enumerate(noisy_image_files):
image_path = os.path.join(args.images_txt_path.replace('label.txt', 'noisy_bin'), file)
# Construct the input of the stream
data_input = MxDataInput()
with open(image_path, 'rb') as f:
data = f.read()
data_input.data = data
tensorPackageList1 = MxpiDataType.MxpiTensorPackageList()
tensorPackage1 = tensorPackageList1.tensorPackageVec.add()
tensorVec1 = tensorPackage1.tensorVec.add()
tensorVec1.deviceId = 0
tensorVec1.memType = 0
for t in shape:
tensorVec1.tensorShape.append(t)
tensorVec1.dataStr = data_input.data
tensorVec1.tensorDataSize = len(data)
protobufVec1 = InProtobufVector()
protobuf1 = MxProtobufIn()
protobuf1.key = b'appsrc0'
protobuf1.type = b'MxTools.MxpiTensorPackageList'
protobuf1.protobuf = tensorPackageList1.SerializeToString()
protobufVec1.push_back(protobuf1)
unique_id = stream_manager.SendProtobuf(b'DnCNN', b'appsrc0', protobufVec1)
keyVec = StringVector()
keyVec.push_back(b'mxpi_tensorinfer0')
infer_result = stream_manager.GetProtobuf(b'DnCNN', 0, keyVec)
if infer_result.size() == 0:
print("inferResult is null")
exit()
if infer_result[0].errorCode != 0:
print("GetProtobuf error. errorCode=%d" % (
infer_result[0].errorCode))
exit()
# get infer result
result = MxpiDataType.MxpiTensorPackageList()
result.ParseFromString(infer_result[0].messageBuf)
# convert the inference result to Numpy array
res = np.frombuffer(result.tensorPackageVec[0].tensorVec[0].dataStr, dtype=np.float32).reshape(1, 1, 256, 256)
res.tofile(os.path.join(res_dir_name, 'output_{}.bin'.format(idx)))
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" eval_sdk.py """
import os
import numpy as np
import skimage.metrics
def read_file_list(input_file):
"""
:param infer file content:
1.bin 0
2.bin 2
...
:return image path list, label list
"""
noisy_image_file = []
clear_image_file = []
if not os.path.exists(input_file):
print('input file does not exists.')
with open(input_file, "r") as fs:
for line in fs.readlines():
line = line.strip('\n').split(',')
noisy_file_name = line[0]
clear_file_name = line[1]
noisy_image_file.append(noisy_file_name)
clear_image_file.append(clear_file_name)
return noisy_image_file, clear_image_file
images_txt_path = "../data/dncnn_infer_data/dncnn_bs_1_label.txt"
noisy_image_files, clear_image_files = read_file_list(images_txt_path)
mean_psnr = 0
mean_ssim = 0
count = 0
for index, out_file in enumerate(clear_image_files):
out_path = '../data/dncnn_infer_data/dncnn_bs_1_clear_bin/' + out_file
clear = np.fromfile(out_path, dtype=np.uint8).reshape(1, 256, 256)
noisy = np.fromfile(out_path.replace('clear', 'noisy'), dtype=np.float32).reshape(1, 256, 256)
# get denoised image
residual = np.fromfile('./result/output_{}.bin'.format(index), dtype=np.float32).reshape(1, 256, 256)
denoised = np.clip(noisy - residual, 0, 255).astype("uint8")
denoised = np.squeeze(denoised)
clear = np.squeeze(clear)
noisy = np.squeeze(noisy)
# calculate psnr
mse = np.mean((clear - denoised) ** 2)
psnr = 10 * np.log10(255 * 255 / mse)
# calculate ssim
ssim = skimage.metrics.structural_similarity(clear, denoised, data_range=255) # skimage 0.18
mean_psnr += psnr
mean_ssim += ssim
count += 1
mean_psnr = mean_psnr / count
mean_ssim = mean_ssim / count
print("mean psnr", mean_psnr)
print("mean_ssim", mean_ssim)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }
#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python
python3 main.py
exit 0
\ No newline at end of file
#!/usr/bin/env python3
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import datetime
import numpy as np
import mindspore as ms
from mindspore import Tensor, export
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, LearningRateScheduler
from mindspore.train import Model
from mindspore.train.callback import Callback
from src.dataset import create_train_dataset
from src.model import DnCNN
class BatchAverageMSELoss(nn.Cell):
def __init__(self, batch_size):
super(BatchAverageMSELoss, self).__init__()
self.batch_size = batch_size
self.sumMSELoss = nn.MSELoss(reduction='sum')
def construct(self, logits, labels):
#equation 1 on the paper
loss = self.sumMSELoss(logits, labels) / self.batch_size / 2
return loss
class Print_info(Callback):
def epoch_end(self, run_context):
cb_params = run_context.original_args()
print(datetime.datetime.now(), "end epoch", cb_params.cur_epoch_num)
def learning_rate_function(lr, cur_step_num):
if cur_step_num % 40000 == 0:
lr = lr*0.8
print("current lr: ", str(lr))
return lr
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DnCNN")
parser.add_argument("--data_url", type=str, default="/code/BSR_bsds500/BSR/BSDS500/data/images/", \
help='training image path')
parser.add_argument("--train_url", type=str, default="./", \
help='training image path')
parser.add_argument("--batch_size", type=int, default=128, help='training batch size')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight_decay')
parser.add_argument('--model_type', type=str, default='DnCNN-S', \
choices=['DnCNN-S', 'DnCNN-B', 'DnCNN-3'], help='type of DnCNN')
parser.add_argument('--noise_level', type=int, default=25, help="noise level only for DnCNN-S")
parser.add_argument('--ckpt_prefix', type=str, default="dncnn_mindspore", help='ckpt name prefix')
parser.add_argument('--epoch_num', type=int, default=50, help='epoch number')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
if args.model_type == 'DnCNN-S':
network = DnCNN(1, num_of_layers=17)
elif args.model_type == 'DnCNN-3' or args.model_type == 'DnCNN-B':
network = DnCNN(1, num_of_layers=20)
else:
print("wrong model type")
exit()
ds_train = create_train_dataset(args.data_url, args.model_type, noise_level=args.noise_level, \
batch_size=args.batch_size)
opt = nn.AdamWeightDecay(network.trainable_params(), args.lr, weight_decay=args.weight_decay)
loss_fun = BatchAverageMSELoss(args.batch_size)
model = Model(network, loss_fun, opt)
#training callbacks
checkpoint_config = CheckpointConfig(save_checkpoint_steps=1000, keep_checkpoint_max=3)
ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_prefix, directory=args.train_url, config=checkpoint_config)
print_cb = Print_info()
lr_cb = LearningRateScheduler(learning_rate_function)
loss_monitor_cb = LossMonitor(per_print_times=100)
print(datetime.datetime.now(), " training starts")
model.train(args.epoch_num, ds_train, callbacks=[lr_cb, ckpoint_cb, print_cb, loss_monitor_cb], \
dataset_sink_mode=False)
input_arr = Tensor(np.ones([1, 1, 256, 256]), ms.float32)
export(network, input_arr, file_name=args.train_url+'/DnCNN', file_format='AIR')
#!/bin/bash
# Copyright (c) 2022. Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
docker_image=$1
data_dir=$2
model_dir=$3
docker run -it -u root --ipc=host \
--device=/dev/davinci0 \
--device=/dev/davinci1 \
--device=/dev/davinci2 \
--device=/dev/davinci3 \
--device=/dev/davinci4 \
--device=/dev/davinci5 \
--device=/dev/davinci6 \
--device=/dev/davinci7 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
--privileged \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons \
-v ${data_dir}:${data_dir} \
-v ${model_dir}:${model_dir} \
-v /root/ascend/log:/root/ascend/log ${docker_image} /bin/bash
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment