diff --git a/research/cv/RCAN/README.md b/research/cv/RCAN/README.md index 687e4d47a47958793af45aa7679358b487bd2b76..5e710052f38770bb1a3cb3ef14b466c1e6eda253 100644 --- a/research/cv/RCAN/README.md +++ b/research/cv/RCAN/README.md @@ -42,7 +42,7 @@ - 数据集大小:约7.12GB,共900张图像 - 训练集:800张图像 - 测试集:100张图像 -- 基准数据集可下载如下:[Set5](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)、[Set14](https://deepai.org/dataset/set14-super-resolution)、[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)、[Urban100](http://vllab.ucmerced.edu/wlai24/LapSRN/)。 +- 基准数据集可下载如下:[Set5](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)、[Set14](https://deepai.org/dataset/set14-super-resolution)、[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)、[Urban100](https://deepai.org/dataset/urban100-4x-upscaling/)。 - 数据格式:png文件 - 注:数据将在src/data/DIV2K.py中处理。 - 注:dir_data中需要指定数据集所在位置的上一层目录。 @@ -128,6 +128,7 @@ DIV2K ├── script │ ├── run_distribute_train.sh // Ascend分布式训练shell脚本 │ ├── run_eval.sh // eval验证shell脚本 + │ ├── run_eval_onnx.sh // eval_onnx验证shell脚本 │ ├── run_ascend_standalone.sh // Ascend训练shell脚本 ├── src │ ├── data @@ -139,6 +140,7 @@ DIV2K │ ├── args.py //超参数 ├── train.py //训练脚本 ├── eval.py //评估脚本 + ├── eval_onnx.py //评估ONNX脚本 ├── export.py //模型导出 ├── README.md // 自述文件 ``` @@ -257,6 +259,32 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DATASET_TYPE] [SCALE] [DEVICE_I - DEVICE_ID 设备ID, 默认为:0 - 上述python命令在后台运行,可通过`run_infer.log`文件查看结果。 +## ONNX评估 + +### 导出ONNX模型 + +```bash +python export.py [--dir_data] [--file_format] [--ckpt_path] +``` + +选项: + --dir_data 数据集目录 + --file_format 需为 [ONNX] + --ckpt_path 检查点路径 + +### ONNX评估 + +- 评估过程如下,需要指定数据集类型为“DIV2K” + +```bash +bash script/run_eval_onnx.sh [TEST_DATA_DIR] [ONNX_PATH] [DATASET_TYPE] +``` + +- TEST_DATA_DIR 测试数据文件路径 +- ONNX_PATH ONNX模型路径 +- DATASET_TYPE 数据集名称(DIV2K) +- 上述python命令在后台运行,可通过`eval_onnx.log`文件查看结果。 + ## 模型导出 ```bash diff --git a/research/cv/RCAN/eval.py b/research/cv/RCAN/eval.py index 61a0573d5f740564baaf2cff67b9be5404b0d3f6..ee406362d4b909e827f70aad14bf118ab58082db 100644 --- a/research/cv/RCAN/eval.py +++ b/research/cv/RCAN/eval.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ """eval script""" -import os import time import numpy as np import mindspore.dataset as ds @@ -26,8 +25,11 @@ from src.data.srdata import SRData from src.metrics import calc_psnr, quantize, calc_ssim from src.data.div2k import DIV2K -device_id = int(os.getenv('DEVICE_ID', '0')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) + +context.set_context(mode=context.GRAPH_MODE, + device_target=args.device_target, + device_id=args.device_id, + save_graphs=False) context.set_context(max_call_depth=10000) def eval_net(): """eval""" @@ -62,16 +64,16 @@ def eval_net(): pred = net_m(lr) pred_np = pred.asnumpy() pred_np = quantize(pred_np, 255) - psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0) + psnr = calc_psnr(pred_np, hr, args.scale, 255.0) pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0) hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0) - ssim = calc_ssim(pred_np, hr, args.scale[0]) + ssim = calc_ssim(pred_np, hr, args.scale) print("current psnr: ", psnr) print("current ssim: ", ssim) psnrs[batch_idx, 0] = psnr ssims[batch_idx, 0] = ssim - print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0])) - print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0])) + print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale, psnrs.mean(axis=0)[0])) + print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale, ssims.mean(axis=0)[0])) if __name__ == '__main__': diff --git a/research/cv/RCAN/eval_onnx.py b/research/cv/RCAN/eval_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..1a103c35b9a6d1dbeebe2e4b8b4c88c8918d0bf0 --- /dev/null +++ b/research/cv/RCAN/eval_onnx.py @@ -0,0 +1,87 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""eval script""" +import time +import numpy as np +import mindspore.dataset as ds +import onnxruntime as ort +from src.args import args +from src.data.srdata import SRData +from src.metrics import calc_psnr, quantize, calc_ssim +from src.data.div2k import DIV2K + + +def create_session(checkpoint_path, target_device): + if target_device == 'GPU': + providers = ['CUDAExecutionProvider'] + elif target_device == 'CPU': + providers = ['CPUExecutionProvider'] + else: + raise ValueError( + f'Unsupported target device {target_device}, ' + f'Expected one of: "CPU", "GPU"' + ) + sess = ort.InferenceSession(checkpoint_path, providers=providers) + name = sess.get_inputs()[0].name + return sess, name + + +def eval_net(): + """eval""" + if args.epochs == 0: + args.epochs = 100 + for arg in vars(args): + if vars(args)[arg] == 'True': + vars(args)[arg] = True + elif vars(args)[arg] == 'False': + vars(args)[arg] = False + if args.data_test[0] == 'DIV2K': + train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False) + else: + train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False) + train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False) + train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) + train_loader = train_de_dataset.create_dict_iterator(output_numpy=True) + print('load mindspore net successfully.') + num_imgs = train_de_dataset.get_dataset_size() + psnrs = np.zeros((num_imgs, 1)) + ssims = np.zeros((num_imgs, 1)) + for batch_idx, imgs in enumerate(train_loader): + lr = imgs['LR'] + hr = imgs['HR'] + img_shape = lr.shape + onnx_file = args.onnx_path + '//' + str(img_shape[2]) + '_' + str(img_shape[3]) + '.onnx' + session, input_name = create_session(onnx_file, 'GPU') + pred = session.run(None, {input_name: lr})[0] + pred_np = pred + pred_np = quantize(pred_np, 255) + psnr = calc_psnr(pred_np, hr, args.scale, 255.0) + pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0) + hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0) + ssim = calc_ssim(pred_np, hr, args.scale) + print("current psnr: ", psnr) + print("current ssim: ", ssim) + psnrs[batch_idx, 0] = psnr + ssims[batch_idx, 0] = ssim + print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale, psnrs.mean(axis=0)[0])) + print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale, ssims.mean(axis=0)[0])) + + +if __name__ == '__main__': + time_start = time.time() + print("Start eval function!") + eval_net() + time_end = time.time() + print('eval_time: %f' % (time_end - time_start)) diff --git a/research/cv/RCAN/export.py b/research/cv/RCAN/export.py index 692ec39e6cac224b3b20731114922da3f5c6717d..3d765b21d28866a04d75a1d2a9dd3a6748c35a94 100644 --- a/research/cv/RCAN/export.py +++ b/research/cv/RCAN/export.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +16,12 @@ import os import argparse import numpy as np +from src.args import args as args_1 +from src.data.srdata import SRData +from src.data.div2k import DIV2K from src.rcan_model import RCAN import mindspore as ms +import mindspore.dataset as ds from mindspore import Tensor, context, load_checkpoint, export @@ -32,29 +36,48 @@ parser.add_argument('--n_colors', type=int, default=3, help='number of color cha parser.add_argument('--n_resblocks', type=int, default=20, help='number of residual blocks') parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps') parser.add_argument('--res_scale', type=float, default=1, help='residual scaling') -parser.add_argument('--task_id', type=int, default=0) -parser.add_argument('--n_resgroups', type=int, default=10, - help='number of residual groups') -parser.add_argument('--reduction', type=int, default=16, - help='number of feature maps reduction') -args_1 = parser.parse_args() +parser.add_argument('--n_resgroups', type=int, default=10, help='number of residual groups') +parser.add_argument('--reduction', type=int, default=16, help='number of feature maps reduction') +parser.add_argument('--data_range', type=str, default='1-800/801-810', help='train/test data range') +parser.add_argument('--test_only', action='store_true', help='set this option to test the model') +parser.add_argument('--model', default='RCAN', help='model name') +parser.add_argument('--dir_data', type=str, default='', help='dataset directory') +parser.add_argument('--ext', type=str, default='sep', help='dataset file extension') + +args = parser.parse_args() MAX_HR_SIZE = 2040 -def run_export(args): +def run_export(): """ export """ device_id = int(os.getenv('DEVICE_ID', '0')) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=device_id) net = RCAN(args) - max_lr_size = MAX_HR_SIZE // args.scale #max_lr_size = MAX_HR_SIZE / scale + max_lr_size = MAX_HR_SIZE // args.scale # max_lr_size = MAX_HR_SIZE / scale param_dict = load_checkpoint(args.ckpt_path) net.load_pre_trained_param_dict(param_dict, strict=False) net.set_train(False) print('load mindspore net and checkpoint successfully.') - inputs = Tensor(np.ones([args.batch_size, 3, max_lr_size, max_lr_size]), ms.float32) - export(net, inputs, file_name=args.file_name, file_format=args.file_format) + + if args.file_format == 'ONNX': + if args_1.data_test[0] == 'DIV2K': + train_dataset = DIV2K(args_1, name=args_1.data_test, train=False, benchmark=False) + else: + train_dataset = SRData(args_1, name=args_1.data_test, train=False, benchmark=False) + train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False) + train_de_dataset = train_de_dataset.batch(1, drop_remainder=True) + train_loader = train_de_dataset.create_dict_iterator(output_numpy=True) + + for _, imgs in enumerate(train_loader): + img_shape = imgs['LR'].shape + export_path = str(img_shape[2]) + '_' + str(img_shape[3]) + inputs = Tensor(np.ones([args.batch_size, 3, img_shape[2], img_shape[3]]), ms.float32) + export(net, inputs, file_name=export_path, file_format=args.file_format) + else: + inputs = Tensor(np.ones([args.batch_size, 3, 678, max_lr_size]), ms.float32) + export(net, inputs, file_name=args.file_name, file_format=args.file_format) print('export successfully!') if __name__ == "__main__": - run_export(args_1) + run_export() diff --git a/research/cv/RCAN/script/run_eval_onnx.sh b/research/cv/RCAN/script/run_eval_onnx.sh new file mode 100644 index 0000000000000000000000000000000000000000..fcf8bed66e509206fc13fa15f462b571325f14b0 --- /dev/null +++ b/research/cv/RCAN/script/run_eval_onnx.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ]; then + echo "Usage: sh run_eval_onnx.sh [TEST_DATA_DIR] [ONNX_PATH] [DATASET_TYPE]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) +DATASET_TYPE=$3 + +if [ ! -d $PATH1 ]; then + echo "error: TEST_DATA_DIR=$PATH1 is not a directory" + exit 1 +fi + +if [ ! -d $PATH2 ]; then + echo "error: ONNX_PATH=$PATH2 is not a directory" + exit 1 +fi + +if [ -d "eval_onnx" ]; then + rm -rf ./eval_onnx +fi +mkdir ./eval_onnx +cp ../*.py ./eval_onnx +cp -r ../src ./eval_onnx +cd ./eval_onnx || exit +env >env.log +echo "start evaluation ..." + +python eval_onnx.py \ + --dir_data=${PATH1} \ + --batch_size 1 \ + --test_only \ + --ext "img" \ + --data_test=${DATASET_TYPE} \ + --onnx_path=${PATH2} \ + --task_id 0 \ + --scale 2 > eval_onnx.log 2>&1 & diff --git a/research/cv/RCAN/src/args.py b/research/cv/RCAN/src/args.py index 12f983989b2effb2ed587a4b1f2af88a44cf7afc..bd5ae796a4b77722bced1c4f8f40a85b686be558 100644 --- a/research/cv/RCAN/src/args.py +++ b/research/cv/RCAN/src/args.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2021-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ parser.add_argument('--data_range', type=str, default='1-800/801-810', help='train/test data range') parser.add_argument('--ext', type=str, default='sep', help='dataset file extension') -parser.add_argument('--scale', type=str, default='4', +parser.add_argument('--scale', type=int, default=2, help='super resolution scale') parser.add_argument('--patch_size', type=int, default=48, help='output patch size') @@ -64,9 +64,14 @@ parser.add_argument('--reduction', type=int, default=16, help='number of feature maps reduction') # Training specifications +parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, + help="Run distribute, default is false.") +parser.add_argument('--device_target', type=str, default='Ascend', + help='device target, Ascend or GPU (Default: Ascend)') +parser.add_argument('--device_id', type=int, default=0, help='device id') parser.add_argument('--test_every', type=int, default=4000, help='do test per every N batches') -parser.add_argument('--epochs', type=int, default=1000, +parser.add_argument('--epochs', type=int, default=500, help='number of epochs to train') parser.add_argument('--batch_size', type=int, default=16, help='input batch size for training') @@ -75,7 +80,7 @@ parser.add_argument('--test_only', action='store_true', # Optimization specifications -parser.add_argument('--lr', type=float, default=1e-5, +parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') parser.add_argument('--loss_scale', type=float, default=1024.0, help='scaling factor for optim') @@ -97,13 +102,12 @@ parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/', help='path to save ckpt') parser.add_argument('--ckpt_save_interval', type=int, default=10, help='save ckpt frequency, unit is epoch') -parser.add_argument('--ckpt_save_max', type=int, default=100, +parser.add_argument('--ckpt_save_max', type=int, default=10, help='max number of saved ckpt') parser.add_argument('--ckpt_path', type=str, default='', help='path of saved ckpt') - -# Task -parser.add_argument('--task_id', type=int, default=0) +parser.add_argument('--onnx_path', type=str, default='', + help='path of exported onnx model') # ModelArts parser.add_argument('--modelArts_mode', type=ast.literal_eval, default=False, @@ -113,12 +117,11 @@ parser.add_argument('--data_url', type=str, default='', help='the directory path args, unparsed = parser.parse_known_args() -args.scale = [int(x) for x in args.scale.split("+")] args.data_train = args.data_train.split('+') args.data_test = args.data_test.split('+') if args.epochs == 0: - args.epochs = 1e8 + args.epochs = 100 for arg in vars(args): if vars(args)[arg] == 'True': diff --git a/research/cv/RCAN/src/data/srdata.py b/research/cv/RCAN/src/data/srdata.py index b94ee4a3b26ef411d641c027faff173cd3862e32..d526303c23c62a964ec413943c11d01dc9093ad6 100644 --- a/research/cv/RCAN/src/data/srdata.py +++ b/research/cv/RCAN/src/data/srdata.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2021-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,7 +40,8 @@ class SRData: self.benchmark = benchmark self.input_large = (args.model == 'VDSR') self.scale = args.scale - self.idx_scale = 0 + self.scales = [2, 3, 4] + self.set_scale() self._set_filesystem(args.dir_data) self._set_img(args) if train: @@ -56,13 +57,13 @@ class SRData: self.images_hr, self.images_lr = list_hr, list_lr elif args.ext.find('sep') >= 0: os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True) - for s in self.scale: + for s in self.scales: if s == 1: os.makedirs(os.path.join(self.dir_hr), exist_ok=True) else: os.makedirs( os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True) - self.images_hr, self.images_lr = [], [[] for _ in self.scale] + self.images_hr, self.images_lr = [], [[] for _ in self.scales] for h in list_hr: b = h.replace(self.apath, path_bin) b = b.replace(self.ext[0], '.pt') @@ -88,15 +89,15 @@ class SRData: """_scan""" names_hr = sorted( glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))) - names_lr = [[] for _ in self.scale] + names_lr = [[] for _ in self.scales] for f in names_hr: filename, _ = os.path.splitext(os.path.basename(f)) - for si, s in enumerate(self.scale): + for si, s in enumerate(self.scales): if s != 1: scale = s names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \ .format(s, filename, scale, self.ext[1]))) - for si, s in enumerate(self.scale): + for si, s in enumerate(self.scales): if s == 1: names_lr[si] = names_hr return names_hr, names_lr @@ -182,7 +183,7 @@ class SRData: def get_patch(self, lr, hr): """get_patch""" - scale = self.scale[self.idx_scale] + scale = self.scales[self.idx_scale] if self.train: lr, hr = common.get_patch( lr, hr, @@ -195,9 +196,14 @@ class SRData: hr = hr[0:ih * scale, 0:iw * scale] return lr, hr - def set_scale(self, idx_scale): + def set_scale(self): """set_scale""" if not self.input_large: - self.idx_scale = idx_scale + if self.scale == 2: + self.idx_scale = 0 + elif self.scale == 3: + self.idx_scale = 1 + elif self.scale == 4: + self.idx_scale = 2 else: - self.idx_scale = random.randint(0, len(self.scale) - 1) + self.idx_scale = random.randint(0, len(self.scales) - 1) diff --git a/research/cv/RCAN/src/rcan_model.py b/research/cv/RCAN/src/rcan_model.py index 0fd5020768cfc18ef43945bcb4557f2fb31270f7..b04559161f2671f58dc215a6b16a57e1fd6e9d55 100644 --- a/research/cv/RCAN/src/rcan_model.py +++ b/research/cv/RCAN/src/rcan_model.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2021-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,38 +48,28 @@ class MeanShift(nn.Conv2d): self.has_bias = True -def _pixelsf_(x, scale): - """rcan""" - n, c, ih, iw = x.shape - oh = ih * scale - ow = iw * scale - oc = c // (scale ** 2) - output = P.Transpose()(x, (0, 2, 1, 3)) - output = P.Reshape()(output, (n, ih, oc * scale, scale, iw)) - output = P.Transpose()(output, (0, 1, 2, 4, 3)) - output = P.Reshape()(output, (n, ih, oc, scale, ow)) - output = P.Transpose()(output, (0, 2, 1, 3, 4)) - output = P.Reshape()(output, (n, oc, oh, ow)) - return output - - -class SmallUpSampler(nn.Cell): - """rcan""" - def __init__(self, conv, upsize, n_feats, has_bias=True): - """rcan""" - super(SmallUpSampler, self).__init__() - self.conv = conv(n_feats, upsize * upsize * n_feats, 3, has_bias) - self.reshape = P.Reshape() - self.upsize = upsize - self.pixelsf = _pixelsf_ + +class PixelShuffle(nn.Cell): + """PixelShuffle""" + def __init__(self, scale): + super(PixelShuffle, self).__init__() + self.scale = scale def construct(self, x): - """rcan""" - x = self.conv(x) - output = self.pixelsf(x, self.upsize) + n, c, ih, iw = x.shape + oh = ih * self.scale + ow = iw * self.scale + oc = c // (self.scale ** 2) + output = P.Transpose()(x, (0, 2, 1, 3)) + output = P.Reshape()(output, (n, ih, oc * self.scale, self.scale, iw)) + output = P.Transpose()(output, (0, 1, 2, 4, 3)) + output = P.Reshape()(output, (n, ih, oc, self.scale, ow)) + output = P.Transpose()(output, (0, 2, 1, 3, 4)) + output = P.Reshape()(output, (n, oc, oh, ow)) return output + class Upsampler(nn.Cell): """rcan""" def __init__(self, conv, scale, n_feats, has_bias=True): @@ -88,16 +78,19 @@ class Upsampler(nn.Cell): m = [] if (scale & (scale - 1)) == 0: for _ in range(int(math.log(scale, 2))): - m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias)) + m.append(conv(n_feats, 4 * n_feats, 3, has_bias)) + m.append(PixelShuffle(2)) elif scale == 3: - m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias)) + m.append(conv(n_feats, 9 * n_feats, 3, has_bias)) + else: + raise NotImplementedError + self.net = nn.SequentialCell(m) def construct(self, x): """rcan""" return self.net(x) - class AdaptiveAvgPool2d(nn.Cell): """rcan""" def __init__(self): @@ -107,8 +100,7 @@ class AdaptiveAvgPool2d(nn.Cell): def construct(self, x): """rcan""" - return self.ReduceMean(x, 0) - + return self.ReduceMean(x, (2, 3)) class CALayer(nn.Cell): """rcan""" @@ -131,7 +123,6 @@ class CALayer(nn.Cell): y = self.conv_du(y) return x * y - class RCAB(nn.Cell): """rcan""" def __init__(self, conv, n_feat, kernel_size, reduction, has_bias=True @@ -159,7 +150,6 @@ class ResidualGroup(nn.Cell): def __init__(self, conv, n_feat, kernel_size, reduction, n_resblocks): """rcan""" super(ResidualGroup, self).__init__() - modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, has_bias=True, bn=False, act=nn.ReLU(), res_scale=1) \ @@ -192,27 +182,26 @@ class RCAN(nn.Cell): rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) - self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std).to_float(self.dytpe) + self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std) # define head module - modules_head = conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe) + modules_head = [conv(args.n_colors, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( - conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks).to_float(self.dytpe) \ + conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks)\ for _ in range(n_resgroups)] - modules_body.append(conv(n_feats, n_feats, kernel_size).to_float(self.dytpe)) + modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ - Upsampler(conv, scale, n_feats).to_float(self.dytpe), - conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)] - - self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1).to_float(self.dytpe) + Upsampler(conv, scale, n_feats), + conv(n_feats, args.n_colors, kernel_size)] - self.head = modules_head + self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) + self.head = nn.SequentialCell(modules_head) self.body = nn.SequentialCell(modules_body) self.tail = nn.SequentialCell(modules_tail)