Skip to content
Snippets Groups Projects
Commit 8f8241d9 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!797 [模型王者挑战赛]_[MindSpore赛题]_[Neighbor2Neighbor][Ascend+GPU]

Merge pull request !797 from 谭华林/Neighbor2Neighbor
parents e32ce331 336882b9
No related branches found
No related tags found
No related merge requests found
Showing
with 1763 additions and 0 deletions
This diff is collapsed.
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
use_modelarts: 0
# url for modelarts
data_url: ""
train_url: ""
outer_path: 's3://output/'
# mainly hyperparameters for training
noisetype: "gauss25"
n_feature: 48
n_channel: 3
lr: 3e-4
gamma: 0.5
epoch: 100
batch_size: 4
patchsize: 256
increase_ratio: 2.0
# dataset options, we recommend the absolute path
train_data: "/data/dataset"
test_dir: "/data/test_dataset"
# eval settings while training
eval_while_train: 1
eval_steps: 1
eval_start_epoch: 0
# checkpoint config while training
is_save_on_master: 1
output_path: './output/'
resume_path: ""
resume_name: ""
# eval settings stand alone, other hyperparameters are shared with training
pretrain_path: "./"
ckpt_name: "best_map.ckpt"
save_denoised_images: 1
# export settings stand alone, other hyperparameters are shared with training
export_batch_size: 1
image_height: 512
image_width: 512
ckpt_file: "./best_map.ckpt"
file_name: "neighbor2neighbor"
file_format: "AIR"
# ======================================================================================
# common options
device_target: 'Ascend'
is_distributed: 0
rank: 0
group_size: 1
log_interval: 10
---
# Help description for each configuration
use_modelarts: "Whether training on modelarts, 1 for True, 0 for False; default: 0"
data_url: "needed by modelarts, but we donot use it because the name is ambiguous"
train_url: "needed by modelarts, but we donot use it because the name is ambiguous"
outer_path: "obs path,to store e.g ckpt files"
noisetype: "noise type"
n_feature: "n_feature"
n_channel: "n_channel"
lr: "lr"
gamma: "gamma"
epoch: "epoch"
batch_size: "batch_size"
patchsize: "patchsize"
increase_ratio: "increase_ratio"
train_data: "train_data"
test_dir: "test_dir, it should be the root path of image folders instead of root path of images"
eval_while_train: "Whether eval while training, 1 for True, 0 for False; default: 1"
eval_steps: "each N epochs we eval"
eval_start_epoch: "eval_start_epoch"
is_save_on_master: "save ckpt on master or all rank"
output_path: "output_path,when use_modelarts is set 1, it would better be cache/output/"
resume_path: "put the path to resuming file if needed"
resume_name: "resuming file name"
pretrain_path: "path of the ckpt to eval"
ckpt_name: "name of the ckpt to eval"
save_denoised_images: "Whether to save the denoised images when eval, 1 for True, 0 for False"
export_batch_size: "batch size for export ckpt"
image_height: "image height for export ckpt"
image_width: "image width for export ckpt"
ckpt_file: "the ckpt to export"
file_name: "name of exported ckpt"
file_format: "file format, choose from ['MINDIR','AIR','ONNX']"
device_target: "device where the code will be implemented. (Default: Ascend)"
is_distributed: "Whether multi device"
rank: "local rank of distributed"
group_size: "world size of distributed"
log_interval: "Logging interval steps"
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''eval'''
import datetime
import os
import time
import glob
import pandas as pd
import numpy as np
import PIL.Image as Image
import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore.common import set_seed
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore import load_checkpoint, load_param_into_net
from src.logger import get_logger
from src.models import UNet
from src.dataset import AugmentNoise
from src.config import config as cfg
def copy_data_from_obs():
'''copy_data_from_obs'''
if cfg.use_modelarts:
import moxing as mox
cfg.logger.info("copying test weights from obs to cache....")
mox.file.copy_parallel(cfg.pretrain_path, 'cache/weight')
cfg.logger.info("copying test weights finished....")
cfg.pretrain_path = 'cache/weight/'
cfg.logger.info("copying test dataset from obs to cache....")
mox.file.copy_parallel(cfg.test_dir, 'cache/test')
cfg.logger.info("copying test dataset finished....")
cfg.test_dir = 'cache/test/'
def copy_data_to_obs():
if cfg.use_modelarts:
import moxing as mox
cfg.logger.info("copying files from cache to obs....")
mox.file.copy_parallel(cfg.save_dir, cfg.outer_path)
cfg.logger.info("copying finished....")
def test(model_path):
'''test'''
model = UNet(in_nc=cfg.n_channel, out_nc=cfg.n_channel, n_feature=cfg.n_feature)
cfg.logger.info("load test weights from %s", str(model_path))
load_param_into_net(model, load_checkpoint(model_path))
cfg.logger.info("loaded test weights from %s", str(model_path))
noise_generator = AugmentNoise(cfg.noisetype)
model.set_train(False)
cast = P.Cast()
transpose = P.Transpose()
expand_dims = P.ExpandDims()
compare_psnr = nn.PSNR()
compare_ssim = nn.SSIM()
for filename in os.listdir(cfg.test_dir):
tem_path = os.path.join(cfg.test_dir, filename)
out_dir = os.path.join(cfg.save_dir, filename)
if not cfg.use_modelarts and not os.path.exists(out_dir):
os.makedirs(out_dir)
name = []
psnr = [] #after denoise
ssim = [] #after denoise
psnr_b = [] #before denoise
ssim_b = [] #before denoise
file_list = glob.glob(os.path.join(tem_path, '*'))
cfg.logger.info('Start to test on %s', str(tem_path))
start_time = time.time()
for file in file_list:
suffix = file.split('.')[-1]
# read image
img_clean = np.array(Image.open(file), dtype='float32') / 255.0
img_test = noise_generator.add_noise(img_clean)
H = img_test.shape[0]
W = img_test.shape[1]
val_size = (max(H, W) + 31) // 32 * 32
img_test = np.pad(img_test,
[[0, val_size - H], [0, val_size - W], [0, 0]],
'reflect')
img_clean = Tensor(img_clean, mindspore.float32) #HWC
img_test = Tensor(img_test, mindspore.float32) #HWC
# predict
img_clean = expand_dims(transpose(img_clean, (2, 0, 1)), 0)#NCHW
img_test = expand_dims(transpose(img_test, (2, 0, 1)), 0)#NCHW
prediction = model(img_test)
y_predict = prediction[:, :, :H, :W]
# calculate numeric metrics
img_out = C.clip_by_value(y_predict, 0, 1)
psnr_noise, psnr_denoised = compare_psnr(img_clean, img_test[:, :, :H, :W]), \
compare_psnr(img_clean, img_out)
ssim_noise, ssim_denoised = compare_ssim(img_clean, img_test[:, :, :H, :W]), \
compare_ssim(img_clean, img_out)
psnr.append(psnr_denoised.asnumpy()[0])
ssim.append(ssim_denoised.asnumpy()[0])
psnr_b.append(psnr_noise.asnumpy()[0])
ssim_b.append(ssim_noise.asnumpy()[0])
# save images
filename = file.split('/')[-1].split('.')[0] # get the name of image file
name.append(filename)
if not cfg.use_modelarts and cfg.save_denoised_images:
# inner the operation 'Image.save', it will first check the file \
# existence of same name, which is not allowed on modelarts
img_test = cast(img_test*255, mindspore.uint8).asnumpy()
img_test = img_test.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image
img_test = Image.fromarray(img_test)
img_test.save(os.path.join(out_dir, filename+'_noisetype'+'{}_psnr{:.2f}.'\
.format(cfg.noisetype, psnr_noise.asnumpy()[0])+str(suffix)))
img_out = cast(img_out*255, mindspore.uint8).asnumpy()
img_out = img_out.squeeze(0).transpose((1, 2, 0)) #turn into HWC to save as an image
img_out = Image.fromarray(img_out)
img_out.save(os.path.join(out_dir, filename+'_psnr{:.2f}.'.format(\
psnr_denoised.asnumpy()[0])+str(suffix)))
psnr_avg = sum(psnr)/len(psnr)
ssim_avg = sum(ssim)/len(ssim)
psnr_avg_b = sum(psnr_b)/len(psnr_b)
ssim_avg_b = sum(ssim_b)/len(ssim_b)
name.append('Average')
psnr.append(psnr_avg)
ssim.append(ssim_avg)
psnr_b.append(psnr_avg_b)
ssim_b.append(ssim_avg_b)
cfg.logger.info("Result in:%s", str(tem_path))
cfg.logger.info('Before denoise: Average PSNR_b = {0:.4f}, SSIM_b = {1:.4f};'\
.format(psnr_avg_b, ssim_avg_b))
cfg.logger.info('After denoise: Average PSNR = {0:.4f}, SSIM = {1:.4f}'\
.format(psnr_avg, ssim_avg))
cfg.logger.info("testing finished....")
time_used = time.time() - start_time
cfg.logger.info("time cost:%s seconds!", str(time_used))
if not cfg.use_modelarts:
pd.DataFrame({'name': np.array(name), 'psnr_b': np.array(psnr_b), \
'psnr': np.array(psnr), 'ssim_b': np.array(ssim_b), \
'ssim': np.array(ssim)}).to_csv(out_dir+'/metrics.csv', index=True)
if __name__ == '__main__':
set_seed(1)
cfg.save_dir = os.path.join(cfg.output_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if not cfg.use_modelarts and not os.path.exists(cfg.save_dir):
os.makedirs(cfg.save_dir)
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE,
device_target=cfg.device_target, device_id=device_id, save_graphs=False)
cfg.logger = get_logger(cfg.save_dir, "Neighbor2Neighbor", 0)
cfg.logger.save_args(cfg)
copy_data_from_obs()
test(os.path.join(cfg.pretrain_path, cfg.ckpt_name))
copy_data_to_obs()
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import os
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.models import UNet
from src.config import config as cfg
if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target,
device_id=device_id)
net = UNet(in_nc=cfg.n_channel, out_nc=cfg.n_channel, n_feature=cfg.n_feature)
param_dict = load_checkpoint(cfg.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.zeros([cfg.export_batch_size, 3, \
cfg.image_height, cfg.image_width]), ms.float32)
export(net, input_arr, file_name=cfg.file_name, file_format=cfg.file_format)
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DIR="$( cd "$( dirname "$0" )" && pwd )"
# help message
if [ $# != 3 ]; then
echo "Usage: bash run_distribute_train_ascend.sh [rank_size] [rank_start_id] [rank_table_file]"
exit 1
fi
ulimit -c unlimited
ulimit -n 65530
export SLOG_PRINT_TO_STDOUT=0
export RANK_TABLE_FILE=$3
export RANK_SIZE=$1
export RANK_START_ID=$2
rm -rf $DIR/../ascend_work_space
mkdir $DIR/../ascend_work_space
for((i=0;i<=$RANK_SIZE-1;i++));
do
export RANK_ID=${i}
export DEVICE_ID=$((i + RANK_START_ID))
echo 'start rank='${i}', device id='${DEVICE_ID}'...'
if [ -d $DIR/../ascend_work_space/device${DEVICE_ID} ]; then
rm -rf $DIR/../ascend_work_space/device${DEVICE_ID}
fi
mkdir $DIR/../ascend_work_space/device${DEVICE_ID}
cp -r $DIR/../src $DIR/../ascend_work_space/device${DEVICE_ID}
cp $DIR/../train.py $DIR/../default_config.yaml $DIR/../ascend_work_space/device${DEVICE_ID}
cd $DIR/../ascend_work_space/device${DEVICE_ID} || exit
nohup python ./train.py --device_target=Ascend --is_distributed=1 > log.txt 2>&1 &
done
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# help message
# "Usage: bash run_distribute_train_gpu.sh"
DIR="$( cd "$( dirname "$0" )" && pwd )"
ulimit -c unlimited
ulimit -n 65530
rm -rf $DIR/../gpu_work_space
mkdir $DIR/../gpu_work_space
cp -r $DIR/../src $DIR/../gpu_work_space
cp $DIR/../train.py $DIR/../default_config.yaml $DIR/../gpu_work_space
cd $DIR/../gpu_work_space
echo "start training"
mpirun --allow-run-as-root -n 8 python ./train.py \
--device_target=GPU \
--is_distributed=1 > log.txt 2>&1 &
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: bash run_eval_ascend.sh [device_id]"
exit 1
fi
export DEVICE_ID=$1
DIR="$( cd "$( dirname "$0" )" && pwd )"
ulimit -n 65530
cd $DIR/../ || exit
nohup python eval.py --device_target=Ascend > eval_ascend_log.txt 2>&1 &
echo 'Validation task has been started successfully!'
echo 'Please check the log at eval_ascend_log.txt'
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: bash run_eval_gpu.sh [device_id]"
exit 1
fi
export DEVICE_ID=$1
DIR="$( cd "$( dirname "$0" )" && pwd )"
ulimit -n 65530
cd $DIR/../ || exit
nohup python eval.py --device_target=GPU > eval_gpu_log.txt 2>&1 &
echo 'Validation task has been started successfully!'
echo 'Please check the log at eval_gpu_log.txt'
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: bash run_train_ascend.sh [device_id]"
exit 1
fi
export DEVICE_ID=$1
DIR="$( cd "$( dirname "$0" )" && pwd )"
ulimit -n 65530
cd $DIR/../ || exit
nohup python train.py --is_distributed=0 --device_target=Ascend > train_ascend_log.txt 2>&1 &
echo 'Train task has been started successfully!'
echo 'Please check the log at train_ascend_log.txt'
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: bash run_train_gpu.sh [device_id]"
exit 1
fi
export DEVICE_ID=$1
DIR="$( cd "$( dirname "$0" )" && pwd )"
ulimit -n 65530
cd $DIR/../ || exit
nohup python train.py --is_distributed=0 --device_target=GPU > train_gpu_log.txt 2>&1 &
echo 'Train task has been started successfully!'
echo 'Please check the log at train_gpu_log.txt'
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
pprint(final_config)
print("Please check the above information for the configurations", flush=True)
return Config(final_config)
config = get_config()
if __name__ == '__main__':
print(config)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''dataloader'''
import os
import glob
import numpy as np
import PIL.Image as Image
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
class DataLoader_Imagenet_val:
'''DataLoader_Imagenet_val'''
def __init__(self, data_dir, patch=256, noise_style="gauss25", batch_size=4):
super(DataLoader_Imagenet_val, self).__init__()
self.data_dir = data_dir
self.patch = patch
self.train_fns = glob.glob(os.path.join(self.data_dir, "*"))
self.train_fns.sort()
print('fetch {} samples for training'.format(len(self.train_fns)))
self.noise_generator = AugmentNoise(noise_style)
self.batch_size = batch_size
self.test = 1
def __getitem__(self, index):
# fetch image
fn = self.train_fns[index]
im = Image.open(fn)
im = np.array(im, dtype=np.float32)
# random crop
H = im.shape[0]
W = im.shape[1]
if H - self.patch > 0:
xx = np.random.randint(0, H - self.patch)
im = im[xx:xx + self.patch, :, :]
if W - self.patch > 0:
yy = np.random.randint(0, W - self.patch)
im = im[:, yy:yy + self.patch, :]
im /= 255.0 #clean image
noisy = self.noise_generator.add_noise(im)
return im, noisy
def __len__(self):
return len(self.train_fns)
class AugmentNoise():
'''AugmentNoise'''
def __init__(self, style):
if style.startswith('gauss'):
self.params = [
float(p) / 255.0 for p in style.replace('gauss', '').split('_')
]
if len(self.params) == 1:
self.style = "gauss_fix"
elif len(self.params) == 2:
self.style = "gauss_range"
elif style.startswith('poisson'):
self.params = [
float(p) for p in style.replace('poisson', '').split('_')
]
if len(self.params) == 1:
self.style = "poisson_fix"
elif len(self.params) == 2:
self.style = "poisson_range"
def add_noise(self, x):
'''add_noise'''
shape = x.shape
if self.style == "gauss_fix":
std = self.params[0]
return np.array(x + np.random.normal(size=shape) * std,
dtype=np.float32)
if self.style == "gauss_range":
min_std, max_std = self.params
std = np.random.uniform(low=min_std, high=max_std, size=(1, 1, 1))
return np.array(x + np.random.normal(size=shape) * std,
dtype=np.float32)
if self.style == "poisson_fix":
lam = self.params[0]
return np.array(np.random.poisson(lam * x) / lam, dtype=np.float32)
assert self.style == "poisson_range"
min_lam, max_lam = self.params
lam = np.random.uniform(low=min_lam, high=max_lam, size=(1, 1, 1))
return np.array(np.random.poisson(lam * x) / lam, dtype=np.float32)
def create_Dataset(data_dir, patch, noise_style, batch_size, device_num, rank, shuffle):
dataset = DataLoader_Imagenet_val(data_dir, patch, noise_style, batch_size)
hwc_to_chw = CV.HWC2CHW()
data_set = ds.GeneratorDataset(dataset, column_names=["image", "noisy"], \
num_parallel_workers=8, shuffle=shuffle, num_shards=device_num, shard_id=rank)
data_set = data_set.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
data_set = data_set.map(input_columns=["noisy"], operations=hwc_to_chw, num_parallel_workers=8)
data_set = data_set.batch(batch_size, drop_remainder=True)
return data_set, data_set.get_dataset_size()
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom Logger."""
import os
import sys
import logging
from datetime import datetime
class LOGGER(logging.Logger):
"""
Logger.
Args:
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
self.rank = rank
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""Setup logging file."""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, logger_name, rank):
"""Get Logger."""
logger = LOGGER(logger_name, rank)
logger.setup_logging_file(path, rank)
return logger
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''model'''
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as np
class UpsampleCat(nn.Cell):
'''UpsampleCat'''
def __init__(self, in_nc, out_nc):
super(UpsampleCat, self).__init__()
self.in_nc = in_nc
self.out_nc = out_nc
self.deconv = nn.Conv2dTranspose(in_nc, out_nc, 2, 2, \
padding=0, has_bias=False, weight_init="HeNormal")#weight*=0.1
self.concat = ops.Concat(axis=1)#NCHW
def construct(self, x1, x2):
'''construct'''
x1 = self.deconv(x1)
return self.concat((x1, x2))
def rotate(x, angle):
if angle == 0:
return x
if angle == 90:
return np.rot90(x, 1, (3, 2))
if angle == 180:
return np.rot90(x, 2, (3, 2))
return np.rot90(x, 3, (3, 2))
def conv_func(x, conv, blindspot, pad):
ofs = 0 if (not blindspot) else 1
if ofs > 0:
x = pad(x)
x = conv(x)
if ofs > 0:
x = x[:, :, :-ofs, :]
return x
def pool_func(x, pool, blindspot, pad):
if blindspot:
x = pad(x[:, :, :-1, :])
x = pool(x)
return x
class UNet(nn.Cell):
"""
args:
in_nc=3,
out_nc=3,
n_feature=48,
blindspot=False,
zero_last=False
"""
def __init__(self,
in_nc=3,
out_nc=3,
n_feature=48,
blindspot=False,
zero_last=False):
super(UNet, self).__init__()
self.in_nc = in_nc
self.out_nc = out_nc
self.n_feature = n_feature
self.blindspot = blindspot
self.zero_last = zero_last
self.act = nn.LeakyReLU(alpha=0.2)
# Encoder part
self.enc_conv0 = nn.Conv2d(self.in_nc, self.n_feature, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.enc_conv1 = nn.Conv2d(self.n_feature, self.n_feature, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.pool1 = nn.MaxPool2d(2, 2)
self.enc_conv2 = nn.Conv2d(self.n_feature, self.n_feature, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.pool2 = nn.MaxPool2d(2, 2)
self.enc_conv3 = nn.Conv2d(self.n_feature, self.n_feature, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.pool3 = nn.MaxPool2d(2, 2)
self.enc_conv4 = nn.Conv2d(self.n_feature, self.n_feature, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.pool4 = nn.MaxPool2d(2, 2)
self.enc_conv5 = nn.Conv2d(self.n_feature, self.n_feature, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.pool5 = nn.MaxPool2d(2, 2)
self.enc_conv6 = nn.Conv2d(self.n_feature, self.n_feature, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
# Decoder part
self.up5 = UpsampleCat(self.n_feature, self.n_feature)
self.dec_conv5a = nn.Conv2d(self.n_feature * 2, self.n_feature * 2, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.dec_conv5b = nn.Conv2d(self.n_feature * 2, self.n_feature * 2, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.up4 = UpsampleCat(self.n_feature * 2, self.n_feature * 2)
self.dec_conv4a = nn.Conv2d(self.n_feature * 3, self.n_feature * 2, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.dec_conv4b = nn.Conv2d(self.n_feature * 2, self.n_feature * 2, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.up3 = UpsampleCat(self.n_feature * 2, self.n_feature * 2)
self.dec_conv3a = nn.Conv2d(self.n_feature * 3, self.n_feature * 2, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.dec_conv3b = nn.Conv2d(self.n_feature * 2, self.n_feature * 2, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.up2 = UpsampleCat(self.n_feature * 2, self.n_feature * 2)
self.dec_conv2a = nn.Conv2d(self.n_feature * 3, self.n_feature * 2, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.dec_conv2b = nn.Conv2d(self.n_feature * 2, self.n_feature * 2, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.up1 = UpsampleCat(self.n_feature * 2, self.n_feature * 2)
# Output stages
self.dec_conv1a = nn.Conv2d(self.n_feature * 2 + self.in_nc, 96, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.dec_conv1b = nn.Conv2d(96, 96, 3, 1,
pad_mode="pad", padding=1, has_bias=True, weight_init="HeNormal")#weight*=0.1
if blindspot:
self.nin_a = nn.Conv2d(96 * 4, 96 * 4, 1, 1,
pad_mode="pad", padding=0, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.nin_b = nn.Conv2d(96 * 4, 96, 1, 1,
pad_mode="pad", padding=0, has_bias=True, weight_init="HeNormal")#weight*=0.1
else:
self.nin_a = nn.Conv2d(96, 96, 1, 1,
pad_mode="pad", padding=0, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.nin_b = nn.Conv2d(96, 96, 1, 1,
pad_mode="pad", padding=0, has_bias=True, weight_init="HeNormal")#weight*=0.1
if self.zero_last:
self.nin_c = nn.Conv2d(96, self.out_nc, 1, 1,
pad_mode="pad", padding=0, has_bias=True)
else:
self.nin_c = nn.Conv2d(96, self.out_nc, 1, 1,
pad_mode="pad", padding=0, has_bias=True, weight_init="HeNormal")#weight*=0.1
self.concat0 = ops.Concat(axis=0)
self.concat1 = ops.Concat(axis=1)
# (padding_left, padding_right, padding_top, padding_bottom)
#self.pad = nn.ConstantPad2d(padding=(0, 0, 1, 0), value=0)
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 0), (0, 0)))
self.split = ops.Split(axis=0, output_num=4)
def construct(self, x):
'''construct'''
# Input stage
blindspot = self.blindspot
if blindspot:
x = self.concat0((rotate(x, 0), rotate(x, 90), rotate(x, 180), rotate(x, 270)))
# Encoder part
pool0 = x
x = self.act(conv_func(x, self.enc_conv0, blindspot, self.pad))
x = self.act(conv_func(x, self.enc_conv1, blindspot, self.pad))
x = pool_func(x, self.pool1, blindspot, self.pad)
pool1 = x
x = self.act(conv_func(x, self.enc_conv2, blindspot, self.pad))
x = pool_func(x, self.pool2, blindspot, self.pad)
pool2 = x
x = self.act(conv_func(x, self.enc_conv3, blindspot, self.pad))
x = pool_func(x, self.pool3, blindspot, self.pad)
pool3 = x
x = self.act(conv_func(x, self.enc_conv4, blindspot, self.pad))
x = pool_func(x, self.pool4, blindspot, self.pad)
pool4 = x
x = self.act(conv_func(x, self.enc_conv5, blindspot, self.pad))
x = pool_func(x, self.pool5, blindspot, self.pad)
x = self.act(conv_func(x, self.enc_conv6, blindspot, self.pad))
# Decoder part
x = self.up5(x, pool4)
x = self.act(conv_func(x, self.dec_conv5a, blindspot, self.pad))
x = self.act(conv_func(x, self.dec_conv5b, blindspot, self.pad))
x = self.up4(x, pool3)
x = self.act(conv_func(x, self.dec_conv4a, blindspot, self.pad))
x = self.act(conv_func(x, self.dec_conv4b, blindspot, self.pad))
x = self.up3(x, pool2)
x = self.act(conv_func(x, self.dec_conv3a, blindspot, self.pad))
x = self.act(conv_func(x, self.dec_conv3b, blindspot, self.pad))
x = self.up2(x, pool1)
x = self.act(conv_func(x, self.dec_conv2a, blindspot, self.pad))
x = self.act(conv_func(x, self.dec_conv2b, blindspot, self.pad))
x = self.up1(x, pool0)
# Output stage
if blindspot:
x = self.act(conv_func(x, self.dec_conv1a, blindspot, self.pad))
x = self.act(conv_func(x, self.dec_conv1b, blindspot, self.pad))
x = self.pad(x[:, :, :-1, :])
x = self.split(x)
x = self.concat1((rotate(x, 0), rotate(x, 270), rotate(x, 180), rotate(x, 90)))
x = self.act(conv_func(x, self.nin_a, blindspot, self.pad))
x = self.act(conv_func(x, self.nin_b, blindspot, self.pad))
x = conv_func(x, self.nin_c, blindspot, self.pad)
else:
x = self.act(conv_func(x, self.dec_conv1a, blindspot, self.pad))
x = self.act(conv_func(x, self.dec_conv1b, blindspot, self.pad))
x = self.act(conv_func(x, self.nin_a, blindspot, self.pad))
x = self.act(conv_func(x, self.nin_b, blindspot, self.pad))
x = conv_func(x, self.nin_c, blindspot, self.pad)
return x
class UNetWithLossCell(nn.Cell):
'''UNetWithLossCell'''
def __init__(self, network):
super(UNetWithLossCell, self).__init__()
self.network = network
self.reduceSum = ops.ReduceSum(keep_dims=False)
self.power = ops.Pow()
def construct(self, noisy_sub1, noisy_sub2, exp_diff, Lambda):
noisy_output = self.network(noisy_sub1)
diff = noisy_output - noisy_sub2
loss1 = self.power(diff, 2.0)
loss1 = self.reduceSum(loss1)
loss2 = Lambda * self.reduceSum(self.power((diff - exp_diff), 2.0))
return loss1 + loss2
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''util'''
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''train'''
import os
import datetime
import time
import glob
import numpy as np
import PIL.Image as Image
import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore.common import set_seed
from mindspore.context import ParallelMode
from mindspore.common.tensor import Tensor
from mindspore import save_checkpoint
from mindspore.dataset import config
from mindspore import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from src.logger import get_logger
from src.dataset import create_Dataset
from src.models import UNet, UNetWithLossCell
from src.util import AverageMeter
from src.dataset import AugmentNoise
from src.config import config as cfg
def get_lr(steps_per_epoch, max_epoch, init_lr, gamma):
lr_each_step = []
while max_epoch > 0:
tem = min(20, max_epoch)
for _ in range(steps_per_epoch*tem):
lr_each_step.append(init_lr)
max_epoch -= tem
init_lr *= gamma
return lr_each_step
def space_to_depth(x, block_size):
'''space_to_depth'''
n, c, h, w = x.shape #([4, 1, 256, 256])
unfolded_x = np.zeros((n, block_size*block_size, w*h // (block_size*block_size)))#([4, 4, 16384])
for j in range(n):
tx1 = x[j, 0, :, :][::2].reshape(-1, block_size)
tx2 = x[j, 0, :, :][1::2].reshape(-1, block_size)
for i in range(block_size):
unfolded_x[j, i] = tx1[:, i]
for i in range(block_size):
unfolded_x[j, i+block_size] = tx2[:, i]
return unfolded_x.reshape((n, c * block_size**2, h // block_size,
w // block_size))
def generate_subimages(img, mask):
'''generate_subimages'''
n, c, h, w = img.shape
subimage = np.zeros((n, c, h // 2, w // 2), dtype=img.dtype)
# per channel
for i in range(c): #NCHW
img_per_channel = space_to_depth(img[:, i:i + 1, :, :], block_size=2)
img_per_channel = np.transpose(img_per_channel, (0, 2, 3, 1)).reshape(-1)
subimage[:, i:i + 1, :, :] = np.transpose(img_per_channel[mask].reshape(
n, h // 2, w // 2, 1), (0, 3, 1, 2))
return subimage
def generate_mask_pair(img):
'''generate_mask_pair'''
# prepare masks (N x C x H/2 x W/2)
n, _, h, w = img.shape
mask1 = np.zeros(shape=(n * h // 2 * w // 2 * 4,),
dtype=bool)
mask2 = np.zeros(shape=(n * h // 2 * w // 2 * 4,),
dtype=bool)
# prepare random mask pairs
idx_pair = np.array(
[[0, 1], [0, 2], [1, 3], [2, 3], [1, 0], [2, 0], [3, 1], [3, 2]],
dtype=np.int64)
rd_idx = np.random.randint(low=0,
high=8,
size=(n * h // 2 * w // 2,),
dtype=np.int64)
rd_pair_idx = idx_pair[rd_idx]
rd_pair_idx += np.arange(0,
n * h // 2 * w // 2 * 4,
4,
dtype=np.int64).reshape(-1, 1)
# get masks
mask1[rd_pair_idx[:, 0]] = 1
mask2[rd_pair_idx[:, 1]] = 1
return mask1, mask2
def copy_data_from_obs():
'''copy_data_from_obs'''
if cfg.use_modelarts:
import moxing as mox
cfg.logger.info("copying train data from obs to cache....")
mox.file.copy_parallel(cfg.train_data, 'cache/dataset')
cfg.logger.info("copying traindata finished....")
cfg.train_data = 'cache/dataset/'
if cfg.eval_while_train:
cfg.logger.info("copying test data from obs to cache....")
mox.file.copy_parallel(cfg.test_dir, 'cache/test')
cfg.logger.info("copying test data finished....")
cfg.test_dir = 'cache/test/'
if cfg.resume_path:
cfg.logger.info("copying resume checkpoint from obs to cache....")
mox.file.copy_parallel(cfg.resume_path, 'cache/resume_path')
cfg.logger.info("copying resume checkpoint finished....")
cfg.resume_path = 'cache/resume_path/'
def copy_data_to_obs():
if cfg.use_modelarts:
import moxing as mox
cfg.logger.info("copying files from cache to obs....")
mox.file.copy_parallel(cfg.save_dir, cfg.outer_path)
cfg.logger.info("copying finished....")
def main():
'''main'''
dataset, cfg.steps_per_epoch = create_Dataset(cfg.train_data, cfg.patchsize, \
cfg.noisetype, cfg.batch_size, cfg.group_size, cfg.rank, shuffle=True)
f_model = UNet(in_nc=cfg.n_channel, out_nc=cfg.n_channel, n_feature=cfg.n_feature)
if cfg.resume_path:
cfg.resume_path = os.path.join(cfg.resume_path, cfg.resume_name)
cfg.logger.info('loading resume checkpoint %s into network', str(cfg.resume_path))
load_param_into_net(f_model, load_checkpoint(cfg.resume_path))
cfg.logger.info('loaded resume checkpoint %s into network', str(cfg.resume_path))
model = UNetWithLossCell(f_model)
model.set_train()
lr_list = get_lr(cfg.steps_per_epoch, cfg.epoch, float(cfg.lr), cfg.gamma)
optimizer = nn.Adam(params=model.trainable_params(), learning_rate=Tensor(lr_list, mindspore.float32))
model = nn.TrainOneStepCell(model, optimizer)
data_loader = dataset.create_dict_iterator()
loss_meter = AverageMeter('loss')
for k in range(cfg.epoch):
model.set_train(True)
t_end = time.time()
old_progress = -1
for i, data in enumerate(data_loader):
noisy = data["noisy"].asnumpy()
mask1, mask2 = generate_mask_pair(noisy)
noisy_sub1 = generate_subimages(noisy, mask1)
noisy_sub2 = generate_subimages(noisy, mask2)
noisy_denoised = f_model(data["noisy"]).asnumpy()
noisy_sub1_denoised = generate_subimages(noisy_denoised, mask1)
noisy_sub2_denoised = generate_subimages(noisy_denoised, mask2)
Lambda = k / cfg.epoch * cfg.increase_ratio
exp_diff = noisy_sub1_denoised - noisy_sub2_denoised
loss = model(Tensor(noisy_sub1, mindspore.float32), \
Tensor(noisy_sub2, mindspore.float32), \
Tensor(exp_diff, mindspore.float32), \
Tensor(Lambda, mindspore.float32))
loss_meter.update(loss.asnumpy())
if i % cfg.log_interval == 0:
if cfg.rank == 0:
time_used = time.time()- t_end
fps = cfg.batch_size * (i - old_progress) * cfg.group_size / time_used
cfg.logger.info(
'epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr:{}'.format(\
k, i, loss_meter, fps, lr_list[k*cfg.steps_per_epoch + i]))
t_end = time.time()
loss_meter.reset()
old_progress = i
if cfg.rank_save_ckpt_flag:
# checkpoint save
save_checkpoint(model, os.path.join(cfg.save_dir, str(cfg.rank)+"_last_map.ckpt"))
if cfg.eval_while_train and (k+1) > cfg.eval_start_epoch and (k+1)%cfg.eval_steps == 0:
test(f_model)
cfg.logger.info("training finished....")
def test(model):
'''test'''
noise_generator = AugmentNoise(cfg.noisetype)
model.set_train(False)
transpose = P.Transpose()
expand_dims = P.ExpandDims()
compare_psnr = nn.PSNR()
compare_ssim = nn.SSIM()
best_value = 0.
for filename in os.listdir(cfg.test_dir):
tem_path = os.path.join(cfg.test_dir, filename)
psnr = [] #after denoise
ssim = [] #after denoise
psnr_b = [] #before denoise
ssim_b = [] #before denoise
file_list = glob.glob(os.path.join(tem_path, '*'))
cfg.logger.info('Start to test on %s', str(tem_path))
start_time = time.time()
for file in file_list:
# read image
img_clean = np.array(Image.open(file), dtype='float32') / 255.0
img_test = noise_generator.add_noise(img_clean)
H = img_test.shape[0]
W = img_test.shape[1]
val_size = (max(H, W) + 31) // 32 * 32
img_test = np.pad(img_test,
[[0, val_size - H], [0, val_size - W], [0, 0]],
'reflect')
img_clean = Tensor(img_clean, mindspore.float32) #HWC
img_test = Tensor(img_test, mindspore.float32) #HWC
# predict
img_clean = expand_dims(transpose(img_clean, (2, 0, 1)), 0)#NCHW
img_test = expand_dims(transpose(img_test, (2, 0, 1)), 0)#NCHW
prediction = model(img_test)
y_predict = prediction[:, :, :H, :W]
# calculate numeric metrics
img_out = C.clip_by_value(y_predict, 0, 1)
psnr_noise, psnr_denoised = compare_psnr(img_clean, img_test[:, :, :H, :W]), \
compare_psnr(img_clean, img_out)
ssim_noise, ssim_denoised = compare_ssim(img_clean, img_test[:, :, :H, :W]), \
compare_ssim(img_clean, img_out)
psnr.append(psnr_denoised.asnumpy()[0])
ssim.append(ssim_denoised.asnumpy()[0])
psnr_b.append(psnr_noise.asnumpy()[0])
ssim_b.append(ssim_noise.asnumpy()[0])
psnr_avg = sum(psnr)/len(psnr)
ssim_avg = sum(ssim)/len(ssim)
psnr_avg_b = sum(psnr_b)/len(psnr_b)
ssim_avg_b = sum(ssim_b)/len(ssim_b)
best_value += (psnr_avg * 0.1 + ssim_avg * 10.)
cfg.logger.info("Result in:%s", str(tem_path))
cfg.logger.info('Before denoise: Average PSNR_b = {0:.4f}, SSIM_b = {1:.4f};'\
.format(psnr_avg_b, ssim_avg_b))
cfg.logger.info('After denoise: Average PSNR = {0:.4f}, SSIM = {1:.4f}'\
.format(psnr_avg, ssim_avg))
cfg.logger.info("testing finished....")
time_used = time.time() - start_time
cfg.logger.info("time cost:%s seconds!", str(time_used))
if cfg.best_value < best_value:
cfg.best_value = best_value
save_checkpoint(model, os.path.join(cfg.save_dir, str(cfg.rank)+"_best_map.ckpt"))
cfg.logger.info("Update newly best ckpt! best_value: %s", str(best_value))
if __name__ == '__main__':
set_seed(1)
cfg.save_dir = os.path.join(cfg.output_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
if not cfg.use_modelarts and not os.path.exists(cfg.save_dir):
os.makedirs(cfg.save_dir)
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE,
device_target=cfg.device_target, save_graphs=False)
if cfg.is_distributed:
if cfg.device_target == "Ascend":
context.set_context(device_id=device_id)
init("hccl")
else:
assert cfg.device_target == "GPU"
init("nccl")
cfg.rank = get_rank()
cfg.group_size = get_group_size()
device_num = cfg.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
else:
if cfg.device_target in ["Ascend", "GPU"]:
context.set_context(device_id=device_id)
config.set_enable_shared_mem(False) # we may get OOM when it set to 'True'
cfg.logger = get_logger(cfg.save_dir, "Neighbor2Neighbor", cfg.rank)
cfg.logger.save_args(cfg)
cfg.rank_save_ckpt_flag = not (cfg.is_save_on_master and cfg.rank)
cfg.best_value = 0.
copy_data_from_obs()
main()
copy_data_to_obs()
cfg.logger.info('All task finished!')
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment