Skip to content
Snippets Groups Projects
Commit f25842b6 authored by Gogery's avatar Gogery
Browse files

update

parent d78808ae
No related branches found
No related tags found
No related merge requests found
Showing
with 1506 additions and 0 deletions
# 目录
<!-- TOC -->
- [目录](#目录)
- [RedNet30描述](#rednet30描述)
- [模型结构](#模型结构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [训练过程](#训练过程)
- [训练参数](#训练参数)
- [训练启动](#训练启动)
- [评估过程](#评估过程)
- [评估参数](#评估参数)
- [评估启动](#评估启动)
- [转换过程](#转换过程)
- [推理过程](#推理过程)
- [在昇腾310上推理](#在昇腾310上推理)
- [模型描述](#模型描述)
- [训练性能结果](#训练性能结果)
<!-- /TOC -->
## RedNet30描述
RedNet30是一个使用encoder-decoder处理图像降噪任务的模型, 本项目是图像去躁模型RedNet30在mindspore上的复现。
论文: Mao X J , Shen C , Yang Y B . [Image Restoration Using Very Deep Convolutional Encoder-Decoder Networks with Symmetric Skip Connections[J]](https://arxiv.org/pdf/1603.09056v2.pdf). 2016.
## 模型结构
网络由15层的conv block和15层的deconv block组成,其中下采样中的每一层是conv加relU,上采样过程中的每一层是deconv加relu。
## 数据集
训练集:[BSD300](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html)
测试集:[BSD200](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html)
其中,训练集是由[BSD300](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html) 训练集和验证集合成得到的300张彩色图像,验证集是BSD300的训练集的200张彩色图像。
## 环境要求
- 硬件(Ascend)
- 准备Ascend处理器搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
- 安装requirements.txt中的python包。
- 生成config json文件用于8卡训练。
## 脚本说明
```shell
.
└── rednet30
├─ README_CN.md
├── ascend310_infer
├──src # 实现Ascend-310推理源代码
├──inc # 实现Ascend-310推理源代码
├──build.sh # 构建Ascend-310推理程序的shell脚本
└─CMakeLists.txt # 构建Ascend-310推理程序的CMakeLists
├─ scripts
├─run_standalone_train.sh # Ascend环境下的单卡训练脚本
├─run_standalone_train_gpu.sh # GPU环境下的单卡训练脚本
├─run_distribute_train.sh # Ascend环境下的八卡并行训练脚本
├─run_distribute_train_gpu.sh # GPU环境下的八卡并行训练脚本
├─run_eval.sh # Ascend环境下的评估脚本
├─run_infer_310.sh # Ascend-310推理shell脚本
└─run_eval_gpu.sh # GPU环境下的评估脚本
├── src
├── dataset.py # 数据读取
├── get_input_data_310.py # 获取310推理噪声图片
├── get_input_data.py # 获取噪声图片
└── model.py # 模型定义
├── export.py # 导出MINDIR文件
├── preprocess.py # Ascend-310推理的数据准备脚本
├── postprocess.py # Ascend-310推理的数据后处理脚本
├── eval.py # 评估脚本
└── train.py # 训练脚本
```
## 训练过程
可通过`train.py`脚本中的参数修改训练行为。`train.py`脚本中的参数如下:
### 训练参数
```bash
--dataset_path ./data/BSD300 # 训练数据路径
--platform 'GPU' # 训练设备
--is_distributed False # 分布式训练
--patch_size 50 # 输入数据大小
--batch_size 16 # 批次大小
--num_epochs 1000 # 训练轮次
--lr 0.0001 # 学习率
--seed 1 # 随机种子
--ckpt_save_max 5 # ckpt最大保存数量
--init_loss_scale 65536. # 初始loss scale
```
### 启动
您可以使用python或shell脚本进行训练。
```shell
# 训练示例
- running on Ascend with default parameters
python:
Ascend单卡训练示例:python train.py --dataset_path [DATA_DIR] --platform Ascend
# example: python train.py --dataset_path ./data/BSD300 --platform Ascend
shell:
Ascend八卡并行训练: bash scripts/run_distribute_train.sh [DATA_DIR] [RANK_TABLE_FILE]
# example: bash scripts/run_distribute_train.sh ./data/BSD300 ./rank_table_8p.json
Ascend单卡训练示例: bash scripts/run_standalone_train.sh [DATA_DIR]
# example: bash scripts/run_standalone_train.sh ./data/BSD300
- running on GPU with gpu default parameters
python:
Ascend单卡训练示例:python train.py --dataset_path [DATA_DIR] --platform GPU
# example: python train.py --dataset_path ./data/BSD300 --platform GPU
shell:
GPU八卡并行训练: bash scripts/run_distribute_train_gpu.sh [DATA_DIR]
# example: bash scripts/run_distribute_train_gpu.sh ./data/BSD300
GPU单卡训练示例: bash scripts/run_standalone_train_gpu.sh [DATA_DIR]
# example: bash scripts/run_standalone_train_gpu.sh ./data/BSD300
```
分布式训练需要提前创建JSON格式的HCCL配置文件。
运行分布式任务时需要用到RANK_TABLE_FILE指定的rank_table.json。您可以使用hccl_tools生成该文件,详见[链接](https://gitee.com/mindspore/models/blob/master/utils/hccl_tools/hccl_tools.py)
## 评估过程
### 评估参数
```bash
--dataset_path ./data/BSD300 # 测试数据路径
--ckpt_path ./ckpt/RedNet30-1000_18.ckpt # 测试ckpt文件路径
--platform 'GPU' # 训练设备
```
### 评估启动
您可以使用python或shell脚本进行评估。
```shell
# 评估前需要生成噪声图像,生成方法如下
python ./src/get_input_data.py --dataset_path [DATA_DIR] --output_path [NOISE_IMAGE_DIR]
# Ascend评估示例
python:
python eval.py --dataset_path [DATA_DIR] --noise_path [NOISE_IMAGE_DIR] --ckpt_path [PATH_CHECKPOINT] --platform [PLATFORM]
# example: python eval.py --dataset_path ./data/BSD200 --noise_path ./data/BSD200_jpeg_quality10 --ckpt_path ./train/ckpt/ckpt_0/RedNet30_0-1000_18.ckpt --platform 'Ascend'
shell:
bash scripts/run_eval.sh [DATA_DIR] [NOISE_IMAGE_DIR] [PATH_CHECKPOINT] [PLATFORM]
# example: bash scripts/run_eval.sh ./data/BSD200 ./data/BSD200_jpeg_quality10 ./train/ckpt/ckpt_0/RedNet30_0-1000_18.ckpt Ascend
# GPU评估示例
python:
python eval.py --dataset_path [DATA_DIR] --noise_path [NOISE_IMAGE_DIR] --ckpt_path [PATH_CHECKPOINT] --platform [PLATFORM]
# example: python eval.py --dataset_path ./data/BSD200 --noise_path ./data/BSD200_jpeg_quality10 --ckpt_path ./train/ckpt/ckpt_0/RedNet30_0-1000_18.ckpt --platform 'GPU'
shell:
bash scripts/run_eval_gpu.sh [PLATFORM] [DATA_DIR] [NOISE_IMAGE_DIR] [PATH_CHECKPOINT] [PLATFORM]
# example: bash scripts/run_eval_gpu.sh ./data/BSD200 ./data/BSD200_jpeg_quality10 ./train/ckpt/ckpt_0/RedNet30_0-1000_18.ckpt GPU
```
## 转换过程
### 转换
如果您想推断Ascend 310上的网络,则应将模型转换为MINDIR:
```python
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
必须设置ckpt_file参数。
`FILE_FORMAT`取值为["AIR", "MINDIR"]。
## 推理过程
### 在昇腾310上推理
```python
#使用脚本./script/run_infer_310.sh进行推理,最后在run_infer.log文件中查看结果;
bash ./script/run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [SAVE_BIN_PATH] [SAVE_OUTPUT_PATH] [DEVICE_ID]
vim run_infer.log
```
## 模型描述
### 训练性能结果
| 参数 | Ascend | GPU | Ascend |
| -------------------------- | --------------------------------- | -------------------------- |-------------------------- |
| 模型名称 | RedNet30 | RedNet30 | RedNet30 |
| 运行环境 | Ascend 910 | RTX 3090 | Ascend 310 |
| 上传时间 | 2022-03-06 | 2022-03-06 | 2022-03-06 |
| MindSpore 版本 | 1.5.2 | 1.5.2 | 1.5.2 |
| 数据集 | BSD | BSD | BSD |
| 优化器 | Adam | Adam | Adam |
| 损失函数 | MSELoss | MSELoss | MSELoss |
| 精确度 (1p) | PSNR[27.51], SSIM[0.7946] | PSNR[27.35], SSIM[0.7886] | PSNR[28.67], SSIM[0.8614] |
| 训练总时间 (1p) | 15m11s | 19m23s | - |
| 评估总时间 | 38s | 17s | - |
| 参数量 (M) | 11.8M | 11.8M | 11.8M |
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
find_package(gflags REQUIRED)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
cmake .. \
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs,
const std::string &homePath);
#endif
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/vision_ascend.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::dataset::MapTargetDevice;
using mindspore::dataset::TensorTransform;
using mindspore::dataset::vision::Rescale;
using mindspore::dataset::vision::HWC2CHW;
using mindspore::dataset::vision::Normalize;
using mindspore::dataset::vision::Decode;
using mindspore::dataset::vision::RGB2GRAY;
using mindspore::dataset::vision::CenterCrop;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(dataset_path, ".", "dataset path");
DEFINE_string(save_dir, "", "save dir");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
DIR *dir = OpenDir(FLAGS_save_dir);
if (dir == nullptr) {
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
ascend310->SetBufferOptimizeMode("off_optimize");
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
auto decode = Decode();
auto normalize = Normalize({0.0, 0.0, 0.0}, {1.0, 1.0, 1.0});
auto hwc2chw = HWC2CHW();
Execute transform({decode, normalize, hwc2chw});
auto all_files = GetAllFiles(FLAGS_dataset_path);
std::map<double, double> costTime_map;
size_t size = all_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << all_files[i] << std::endl;
auto img = MSTensor();
auto image = ReadFileToTensor(all_files[i]);
transform(image, &img);
std::vector<MSTensor> model_inputs = model.GetInputs();
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
img.Data().get(), img.DataSize());
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << all_files[i] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(all_files[i], outputs, FLAGS_save_dir);
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = FLAGS_save_dir + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inc/utils.h"
#include <fstream>
#include <algorithm>
#include <iostream>
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> dirs;
std::vector<std::string> files;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == "..") {
continue;
} else if (filename->d_type == DT_DIR) {
dirs.emplace_back(std::string(dirName) + "/" + filename->d_name);
} else if (filename->d_type == DT_REG) {
files.emplace_back(std::string(dirName) + "/" + filename->d_name);
} else {
continue;
}
}
for (auto d : dirs) {
dir = OpenDir(d);
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
files.emplace_back(std::string(d) + "/" + filename->d_name);
}
}
std::sort(files.begin(), files.end());
for (auto &f : files) {
std::cout << "image file: " << f << std::endl;
}
return files;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs, const std::string &homePath) {
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval rednet30."""
import argparse
import os
import time
import glob
from tqdm import tqdm
import numpy as np
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
from PIL import Image
import mindspore
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.model import REDNet30
def PSNR(img1, img2):
"""metrics"""
psnr = peak_signal_noise_ratio(img1, img2)
return psnr
def SSIM(img1, img2):
"""metrics"""
ssim = structural_similarity(img1, img2, data_range=255, multichannel=True)
return ssim
def get_metric(ori_path, res_path):
"""metrics"""
files = glob.glob(os.path.join(ori_path, "*"))
names = []
for i in files:
names.append(i.split("/")[-1])
# PSNR
print("PSNR...")
res = 0
for i in tqdm(names):
ori = Image.open(os.path.join(ori_path, i))
gen = Image.open(os.path.join(res_path, i))
res += PSNR(np.array(ori), np.array(gen))
psnr_res = res / len(names)
# SSIM
print("SSIM...")
res = 0
for i in tqdm(names):
ori = Image.open(os.path.join(ori_path, i))
gen = Image.open(os.path.join(res_path, i))
res += SSIM(np.array(ori), np.array(gen))
ssim_res = res / len(names)
print("PSNR: ", psnr_res)
print("SSIM: ", ssim_res)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default='./data/BSD200', help='evaling image path')
parser.add_argument('--noise_path', type=str, default='./data/BSD200_jpeg_quality10', help='evaling image path')
parser.add_argument('--ckpt_path', type=str, default="./ckpt/RedNet30-1000_18.ckpt", help='ckpt path')
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID', '0'))
if opt.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
else:
context.set_context(mode=context.GRAPH_MODE, save_graphs=False,
device_target="Ascend", device_id=device_id)
save_path = "./output"
if not os.path.exists(save_path):
os.makedirs(save_path)
model = REDNet30()
if opt.ckpt_path:
param_dict = load_checkpoint(opt.ckpt_path)
load_param_into_net(model, param_dict)
model.set_train(False)
# data
img_files = glob.glob(opt.noise_path + "/*")
time_start = time.time()
for file in tqdm(img_files):
name = file.split("/")[-1]
img = np.array(Image.open(file))
img = np.expand_dims(img, axis=0).transpose(0, 3, 1, 2)
input_img = Tensor(img, dtype=mindspore.float32)
result = model(input_img)
out_img = result[0].asnumpy().transpose(1, 2, 0)
out_img = np.clip(out_img, 0, 255)
out_img = np.uint8(out_img)
out_img = Image.fromarray(out_img)
out_img.save(os.path.join(save_path, name), quality=95)
print("finished!")
time_end = time.time()
print('--------------------')
print('test time: %f' % (time_end - time_start))
print('--------------------')
get_metric(opt.dataset_path, save_path)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export file of MINDIR format"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.model import REDNet30
parse = argparse.ArgumentParser(description='REDNet30 export')
parse.add_argument("--batch_size", type=int, default=1, help="batch size")
parse.add_argument("--image_height", type=int, default=480, help="height of each input image")
parse.add_argument("--image_width", type=int, default=480, help="width of each input image")
parse.add_argument("--ckpt_path", type=str, required=True, help="Checkpoint file path.")
parse.add_argument("--file_name", type=str, default="REDNet30", help="output file name.")
parse.add_argument("--file_format", type=str, default="MINDIR", help="output file format")
args = parse.parse_args()
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
model = REDNet30()
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(model, param_dict)
input_arr = Tensor(np.ones([args.batch_size, 3, args.image_height, args.image_width]), ms.float32)
export(model, input_arr, file_name=args.file_name, file_format=args.file_format)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""postprocess"""
import os
import math
import glob
import argparse
import numpy as np
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default="/cache/data", help="dataset path.")
parser.add_argument("--save_path", type=str, default="/cache/output", help="output path.")
parser.add_argument("--bin_path", type=str, default="/cache/data", help="lr bin path.")
args = parser.parse_args()
def read_bin(bin_path):
img = np.fromfile(bin_path, dtype=np.float32)
num_pix = img.size
img_shape = int(math.sqrt(num_pix / 3))
if 1 * 3 * img_shape * img_shape != num_pix:
raise RuntimeError(f'bin file error, it not output from dncnn network, {bin_path}')
img = img.reshape(1, 3, img_shape, img_shape)
return img
def read_bin_as_hwc(bin_path):
nchw_img = read_bin(bin_path)
chw_img = nchw_img[0]
hwc_img = chw_img.transpose(1, 2, 0)
return hwc_img
def PSNR(img1, img2):
"""metrics"""
psnr = peak_signal_noise_ratio(img1, img2)
return psnr
def SSIM(img1, img2):
"""metrics"""
ssim = structural_similarity(img1, img2, data_range=255, multichannel=True)
return ssim
def get_metric(ori_path, res_path):
"""metrics"""
files = glob.glob(os.path.join(ori_path, "*"))
names = []
for i in files:
names.append(i.split("/")[-1])
# PSNR
print("PSNR...")
res = 0
for i in names:
ori = Image.open(os.path.join(ori_path, i))
gen = Image.open(os.path.join(res_path, i))
res += PSNR(np.array(ori), np.array(gen))
psnr_res = res / len(names)
# SSIM
print("SSIM...")
res = 0
for i in names:
ori = Image.open(os.path.join(ori_path, i))
gen = Image.open(os.path.join(res_path, i))
res += SSIM(np.array(ori), np.array(gen))
ssim_res = res / len(names)
print("PSNR: ", psnr_res)
print("SSIM: ", ssim_res)
def run_post_process(dataset_path, save_path, bin_path):
"""run post process """
files = os.listdir(dataset_path)
files.sort()
for file in files:
file_name = file.split('.')[0]
bin_file = os.path.join(bin_path, file_name + "_0.bin")
sr = read_bin_as_hwc(bin_file)
out_img = sr
out_img = np.clip(out_img, 0, 255)
out_img = np.uint8(out_img)
out_img = Image.fromarray(out_img)
out_img.save(os.path.join(save_path, file), quality=95)
get_metric(dataset_path, save_path)
if __name__ == "__main__":
run_post_process(args.dataset_path, args.save_path, args.bin_path)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get input data."""
import os
import glob
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
def padding(im, target_shape):
"padding for 310 infer"
h, w = target_shape[0], target_shape[1]
img_h, img_w = im.shape[0], im.shape[1]
dh, dw = h - img_h, w - img_w
if dh < 0 or dw < 0:
raise RuntimeError(f"target_shape is bigger than img.shape, {target_shape} > {im.shape}")
if dh != 0 or dw != 0:
im = np.pad(im, ((0, int(dh)), (0, int(dw)), (0, 0)), "constant")
return im
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default='./data/BSD200', help='evaling image path')
parser.add_argument('--noise_path', type=str, default='./data/BSD200_310', help='output noise image path')
parser.add_argument('--output_path', type=str, default='./data/BSD200_jpeg_quality10_310', help='output image path')
opt = parser.parse_args()
if not os.path.exists(opt.noise_path):
os.makedirs(opt.noise_path)
if not os.path.exists(opt.output_path):
os.makedirs(opt.output_path)
# data
files = glob.glob(opt.dataset_path + '/*')
for file in tqdm(files):
name = file.split("/")[-1]
img = Image.open(file)
img = np.array(img).astype(np.uint8)
img = padding(img, target_shape=[480, 480])
img = Image.fromarray(img.astype('uint8')).convert('RGB')
img.save(os.path.join(opt.output_path, name), format='jpeg', quality=95)
img.save(os.path.join(opt.noise_path, name), format='jpeg', quality=10)
print("finished!")
numpy==1.21.5
scikit-image==0.18.3
Pillow==8.3.2
tqdm==4.63.0
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: bash run_distribute_train.sh [DATASET_PATH] [RANK_TABLE_FILE]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
RANK_TABLE_FILE=$(get_real_path $2)
echo $DATASET_PATH
echo $RANK_TABLE_FILE
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $RANK_TABLE_FILE ]
then
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$RANK_TABLE_FILE
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
avg=`expr $cpus \/ $RANK_SIZE`
gap=`expr $avg \- 1`
for((i=0; i<${DEVICE_NUM}; i++))
do
start=`expr $i \* $avg`
end=`expr $start \+ $gap`
cmdopt=$start"-"$end
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ./*.py ./train_parallel$i
cp -r ./src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python train.py \
--platform="Ascend" \
--dataset_path=$DATASET_PATH \
--is_distributed=True > log.txt 2>&1 &
cd ..
done
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: bash run_distribute_train_gpu.sh [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
export DEVICE_NUM=8
export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'
rm -rf ./train_parallel
mkdir ./train_parallel
cp ./*.py ./train_parallel
cp -r ./src ./train_parallel
cd ./train_parallel || exit
env > env.log
mpirun --allow-run-as-root -n ${DEVICE_NUM} --output-filename log_output --merge-stderr-to-stdout \
python train.py \
--platform="GPU" \
--dataset_path=$DATASET_PATH \
--is_distributed=True > log.txt 2>&1 &
cd ..
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DATA_DIR=$1
NOISE_DIR=$2
PATH_CHECKPOINT=$3
PLATFORM=$4
if [ $PLATFORM = "GPU" ]; then
export CUDA_VISIBLE_DEVICES='0'
fi
python eval.py \
--platform=$PLATFORM \
--dataset_path=$DATA_DIR \
--noise_path=$NOISE_DIR \
--ckpt_path=$PATH_CHECKPOINT > log.txt 2>&1 &
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 4 || $# -gt 5 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [SAVE_BIN_PATH] [SAVE_OUTPUT_PATH] [DEVICE_ID](optional)
DEVICE_ID can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
data_path=$(get_real_path $2)
save_bin_path=$(get_real_path $3)
save_output_path=$(get_real_path $4)
device_id=0
if [ $# == 5 ]; then
device_id=$5
fi
log_file="./run_infer.log"
log_file=$log_file
echo "***************** param *****************"
echo "mindir name: "$model
echo "dataset path: "$data_path
echo "log file: "$log_file
echo "***************** param *****************"
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export ASCEND_HOME=/usr/local/Ascend/latest
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
export PYTHONPATH=$PWD:$PYTHONPATH
function preprocess()
{
echo "waitting for preprocess finish..."
python ./preprocess.py --dataset_path=$data_path --noise_path='./preprocess_Result/noise_data_310' --output_path='./preprocess_Result/data_310' >> $log_file 2>&1
echo "preprocess finished!"
}
function compile_app()
{
echo "begin to compile app..."
cd ./ascend310_infer || exit
bash build.sh >> $log_file 2>&1
cd -
echo "finish compile app"
}
function infer()
{
echo "begin to infer..."
if [ -d $save_bin_path ]; then
rm -rf $save_bin_path
fi
mkdir -p $save_bin_path
./ascend310_infer/out/main --mindir_path=$model --dataset_path='./preprocess_Result/noise_data_310' --device_id=$device_id --save_dir=$save_bin_path >> $log_file 2>&1
echo "finish infer"
}
function postprocess()
{
echo "begin to postprocess..."
export DEVICE_ID=$device_id
export RANK_SIZE=1
if [ -d $save_output_path ]; then
rm -rf $save_output_path
fi
mkdir -p $save_output_path
python ./postprocess.py --dataset_path='./preprocess_Result/data_310' --save_path=$save_output_path --bin_path=$save_bin_path >> $log_file 2>&1
echo "finish postprocess"
}
preprocess
if [ $? -ne 0 ]; then
echo "execute preprocess failed"
exit 1
fi
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed, check $log_file"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed, check $log_file"
exit 1
fi
postprocess
if [ $? -ne 0 ]; then
echo "postprocess failed, check $log_file"
exit 1
fi
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: bash run_standalone_train.sh [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
export CPU_BIND_NUM=24
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ./*.py ./train
cp -r ./src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
cmdopt=`lscpu | grep NUMA | tail -1 | awk '{print $4}'`
if test -z $cmdopt
then
cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
if [ $cpus -ge $CPU_BIND_NUM ]
then
start=`expr $cpus - $CPU_BIND_NUM`
end=`expr $cpus - 1`
else
start=0
end=`expr $cpus - 1`
fi
cmdopt=$start"-"$end
fi
taskset -c $cmdopt python train.py \
--platform="Ascend" \
--dataset_path=$DATASET_PATH > log.txt 2>&1 &
cd ..
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: bash run_standalone_train_gpu.sh [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
export CUDA_VISIBLE_DEVICES='0'
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ./*.py ./train
cp -r ./src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py \
--platform="GPU" \
--dataset_path=$DATASET_PATH > log.txt 2>&1 &
cd ..
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataset."""
import io
import random
import glob
import numpy as np
from PIL import Image
class Dataset():
"""Dataset."""
def __init__(self, dataset_path, patch_size):
self.image_files = sorted(glob.glob(dataset_path + '/*'))
self.patch_size = patch_size
def __getitem__(self, idx):
label = Image.open(self.image_files[idx]).convert('RGB')
# randomly crop patch from training set
crop_x = random.randint(0, label.width - self.patch_size)
crop_y = random.randint(0, label.height - self.patch_size)
label = label.crop((crop_x, crop_y, crop_x + self.patch_size, crop_y + self.patch_size))
# additive jpeg noise
buffer = io.BytesIO()
label.save(buffer, format='jpeg', quality=10)
input_img = Image.open(buffer)
input_img = np.array(input_img).astype(np.float32)
label = np.array(label).astype(np.float32)
input_img = np.transpose(input_img, axes=[2, 0, 1])
label = np.transpose(label, axes=[2, 0, 1])
# normalization
input_img /= 255.0
label /= 255.0
return input_img, label
def __len__(self):
return len(self.image_files)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""get input data."""
import os
import glob
import argparse
from tqdm import tqdm
from PIL import Image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default='./data/BSD200', help='evaling image path')
parser.add_argument('--output_path', type=str, default='./data/BSD200_jpeg_quality10', help='output image path')
opt = parser.parse_args()
images_dir = "./data/BSD200"
path = "./data/BSD200_jpeg_quality10"
if not os.path.exists(opt.output_path):
os.makedirs(opt.output_path)
# data
files = glob.glob(opt.dataset_path + '/*')
for file in tqdm(files):
name = file.split("/")[-1]
img = Image.open(file)
img.save(os.path.join(opt.output_path, name), format='jpeg', quality=10)
print("finished!")
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""model."""
import mindspore.nn as nn
class REDNet30(nn.Cell):
"""model"""
def __init__(self, num_layers=15, num_features=64):
super(REDNet30, self).__init__()
self.num_layers = num_layers
conv_layers = []
deconv_layers = []
conv_layers.append(nn.Conv2d(3, num_features, kernel_size=3, stride=1,
pad_mode="pad", padding=1, has_bias=False,
weight_init="XavierUniform"))
for _ in range(num_layers - 1):
conv_layers.append(nn.Conv2d(num_features, num_features, kernel_size=3,
pad_mode="pad", padding=1, has_bias=False,
weight_init="XavierUniform"))
for _ in range(num_layers - 1):
deconv_layers.append(nn.Conv2dTranspose(num_features, num_features, kernel_size=3,
pad_mode="pad", padding=1, has_bias=False,
weight_init="XavierUniform"))
deconv_layers.append(nn.Conv2dTranspose(num_features, 3, kernel_size=3, stride=1,
pad_mode="same", has_bias=False,
weight_init="XavierUniform"))
self.conv_layers = nn.CellList(conv_layers)
self.deconv_layers = nn.CellList(deconv_layers)
self.relu = nn.ReLU()
def construct(self, x):
"""model"""
residual = x
conv_feats = []
for i in range(self.num_layers):
x = self.conv_layers[i](x)
x = self.relu(x)
if (i + 1) % 2 == 0 and len(conv_feats) < self.num_layers//2:
conv_feats.append(x)
conv_feats_idx = 0
for i in range(self.num_layers):
x = self.deconv_layers[i](x)
if i != 14:
x = self.relu(x)
if (i + 1 + self.num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats):
conv_feat = conv_feats[-(conv_feats_idx + 1)]
conv_feats_idx += 1
x = x + conv_feat
x = self.relu(x)
x += residual
x = self.relu(x)
return x
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment