Skip to content
Snippets Groups Projects
Commit d89c7482 authored by 王超's avatar 王超
Browse files

rcan_onnx

parent 85b0931b
No related branches found
No related tags found
No related merge requests found
......@@ -42,7 +42,7 @@
- 数据集大小:约7.12GB,共900张图像
- 训练集:800张图像
- 测试集:100张图像
- 基准数据集可下载如下:[Set5](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)[Set14](https://deepai.org/dataset/set14-super-resolution)[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)[Urban100](http://vllab.ucmerced.edu/wlai24/LapSRN/)
- 基准数据集可下载如下:[Set5](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)[Set14](https://deepai.org/dataset/set14-super-resolution)[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)[Urban100](https://deepai.org/dataset/urban100-4x-upscaling/)
- 数据格式:png文件
- 注:数据将在src/data/DIV2K.py中处理。
- 注:dir_data中需要指定数据集所在位置的上一层目录。
......@@ -128,6 +128,7 @@ DIV2K
├── script
│ ├── run_distribute_train.sh // Ascend分布式训练shell脚本
│ ├── run_eval.sh // eval验证shell脚本
│ ├── run_eval_onnx.sh // eval_onnx验证shell脚本
│ ├── run_ascend_standalone.sh // Ascend训练shell脚本
├── src
│ ├── data
......@@ -139,6 +140,7 @@ DIV2K
│ ├── args.py //超参数
├── train.py //训练脚本
├── eval.py //评估脚本
├── eval_onnx.py //评估ONNX脚本
├── export.py //模型导出
├── README.md // 自述文件
```
......@@ -257,6 +259,32 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DATASET_TYPE] [SCALE] [DEVICE_I
- DEVICE_ID 设备ID, 默认为:0
- 上述python命令在后台运行,可通过`run_infer.log`文件查看结果。
## ONNX评估
### 导出ONNX模型
```bash
python export.py [--dir_data] [--file_format] [--ckpt_path]
```
选项:
--dir_data 数据集目录
--file_format 需为 [ONNX]
--ckpt_path 检查点路径
### ONNX评估
- 评估过程如下,需要指定数据集类型为“DIV2K”
```bash
bash script/run_eval_onnx.sh [TEST_DATA_DIR] [ONNX_PATH] [DATASET_TYPE]
```
- TEST_DATA_DIR 测试数据文件路径
- ONNX_PATH ONNX模型路径
- DATASET_TYPE 数据集名称(DIV2K)
- 上述python命令在后台运行,可通过`eval_onnx.log`文件查看结果。
## 模型导出
```bash
......
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""eval script"""
import os
import time
import numpy as np
import mindspore.dataset as ds
......@@ -26,8 +25,11 @@ from src.data.srdata import SRData
from src.metrics import calc_psnr, quantize, calc_ssim
from src.data.div2k import DIV2K
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(mode=context.GRAPH_MODE,
device_target=args.device_target,
device_id=args.device_id,
save_graphs=False)
context.set_context(max_call_depth=10000)
def eval_net():
"""eval"""
......@@ -62,16 +64,16 @@ def eval_net():
pred = net_m(lr)
pred_np = pred.asnumpy()
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
psnr = calc_psnr(pred_np, hr, args.scale, 255.0)
pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
ssim = calc_ssim(pred_np, hr, args.scale[0])
ssim = calc_ssim(pred_np, hr, args.scale)
print("current psnr: ", psnr)
print("current ssim: ", ssim)
psnrs[batch_idx, 0] = psnr
ssims[batch_idx, 0] = ssim
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale, psnrs.mean(axis=0)[0]))
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale, ssims.mean(axis=0)[0]))
if __name__ == '__main__':
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval script"""
import time
import numpy as np
import mindspore.dataset as ds
import onnxruntime as ort
from src.args import args
from src.data.srdata import SRData
from src.metrics import calc_psnr, quantize, calc_ssim
from src.data.div2k import DIV2K
def create_session(checkpoint_path, target_device):
if target_device == 'GPU':
providers = ['CUDAExecutionProvider']
elif target_device == 'CPU':
providers = ['CPUExecutionProvider']
else:
raise ValueError(
f'Unsupported target device {target_device}, '
f'Expected one of: "CPU", "GPU"'
)
sess = ort.InferenceSession(checkpoint_path, providers=providers)
name = sess.get_inputs()[0].name
return sess, name
def eval_net():
"""eval"""
if args.epochs == 0:
args.epochs = 100
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
if args.data_test[0] == 'DIV2K':
train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
else:
train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
print('load mindspore net successfully.')
num_imgs = train_de_dataset.get_dataset_size()
psnrs = np.zeros((num_imgs, 1))
ssims = np.zeros((num_imgs, 1))
for batch_idx, imgs in enumerate(train_loader):
lr = imgs['LR']
hr = imgs['HR']
img_shape = lr.shape
onnx_file = args.onnx_path + '//' + str(img_shape[2]) + '_' + str(img_shape[3]) + '.onnx'
session, input_name = create_session(onnx_file, 'GPU')
pred = session.run(None, {input_name: lr})[0]
pred_np = pred
pred_np = quantize(pred_np, 255)
psnr = calc_psnr(pred_np, hr, args.scale, 255.0)
pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
ssim = calc_ssim(pred_np, hr, args.scale)
print("current psnr: ", psnr)
print("current ssim: ", ssim)
psnrs[batch_idx, 0] = psnr
ssims[batch_idx, 0] = ssim
print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale, psnrs.mean(axis=0)[0]))
print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale, ssims.mean(axis=0)[0]))
if __name__ == '__main__':
time_start = time.time()
print("Start eval function!")
eval_net()
time_end = time.time()
print('eval_time: %f' % (time_end - time_start))
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,8 +16,12 @@
import os
import argparse
import numpy as np
from src.args import args as args_1
from src.data.srdata import SRData
from src.data.div2k import DIV2K
from src.rcan_model import RCAN
import mindspore as ms
import mindspore.dataset as ds
from mindspore import Tensor, context, load_checkpoint, export
......@@ -32,29 +36,48 @@ parser.add_argument('--n_colors', type=int, default=3, help='number of color cha
parser.add_argument('--n_resblocks', type=int, default=20, help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=1, help='residual scaling')
parser.add_argument('--task_id', type=int, default=0)
parser.add_argument('--n_resgroups', type=int, default=10,
help='number of residual groups')
parser.add_argument('--reduction', type=int, default=16,
help='number of feature maps reduction')
args_1 = parser.parse_args()
parser.add_argument('--n_resgroups', type=int, default=10, help='number of residual groups')
parser.add_argument('--reduction', type=int, default=16, help='number of feature maps reduction')
parser.add_argument('--data_range', type=str, default='1-800/801-810', help='train/test data range')
parser.add_argument('--test_only', action='store_true', help='set this option to test the model')
parser.add_argument('--model', default='RCAN', help='model name')
parser.add_argument('--dir_data', type=str, default='', help='dataset directory')
parser.add_argument('--ext', type=str, default='sep', help='dataset file extension')
args = parser.parse_args()
MAX_HR_SIZE = 2040
def run_export(args):
def run_export():
""" export """
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=device_id)
net = RCAN(args)
max_lr_size = MAX_HR_SIZE // args.scale #max_lr_size = MAX_HR_SIZE / scale
max_lr_size = MAX_HR_SIZE // args.scale # max_lr_size = MAX_HR_SIZE / scale
param_dict = load_checkpoint(args.ckpt_path)
net.load_pre_trained_param_dict(param_dict, strict=False)
net.set_train(False)
print('load mindspore net and checkpoint successfully.')
inputs = Tensor(np.ones([args.batch_size, 3, max_lr_size, max_lr_size]), ms.float32)
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
if args.file_format == 'ONNX':
if args_1.data_test[0] == 'DIV2K':
train_dataset = DIV2K(args_1, name=args_1.data_test, train=False, benchmark=False)
else:
train_dataset = SRData(args_1, name=args_1.data_test, train=False, benchmark=False)
train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
for _, imgs in enumerate(train_loader):
img_shape = imgs['LR'].shape
export_path = str(img_shape[2]) + '_' + str(img_shape[3])
inputs = Tensor(np.ones([args.batch_size, 3, img_shape[2], img_shape[3]]), ms.float32)
export(net, inputs, file_name=export_path, file_format=args.file_format)
else:
inputs = Tensor(np.ones([args.batch_size, 3, 678, max_lr_size]), ms.float32)
export(net, inputs, file_name=args.file_name, file_format=args.file_format)
print('export successfully!')
if __name__ == "__main__":
run_export(args_1)
run_export()
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh run_eval_onnx.sh [TEST_DATA_DIR] [ONNX_PATH] [DATASET_TYPE]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
DATASET_TYPE=$3
if [ ! -d $PATH1 ]; then
echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
exit 1
fi
if [ ! -d $PATH2 ]; then
echo "error: ONNX_PATH=$PATH2 is not a directory"
exit 1
fi
if [ -d "eval_onnx" ]; then
rm -rf ./eval_onnx
fi
mkdir ./eval_onnx
cp ../*.py ./eval_onnx
cp -r ../src ./eval_onnx
cd ./eval_onnx || exit
env >env.log
echo "start evaluation ..."
python eval_onnx.py \
--dir_data=${PATH1} \
--batch_size 1 \
--test_only \
--ext "img" \
--data_test=${DATASET_TYPE} \
--onnx_path=${PATH2} \
--task_id 0 \
--scale 2 > eval_onnx.log 2>&1 &
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -33,7 +33,7 @@ parser.add_argument('--data_range', type=str, default='1-800/801-810',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
parser.add_argument('--scale', type=int, default=2,
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=48,
help='output patch size')
......@@ -64,9 +64,14 @@ parser.add_argument('--reduction', type=int, default=16,
help='number of feature maps reduction')
# Training specifications
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
help="Run distribute, default is false.")
parser.add_argument('--device_target', type=str, default='Ascend',
help='device target, Ascend or GPU (Default: Ascend)')
parser.add_argument('--device_id', type=int, default=0, help='device id')
parser.add_argument('--test_every', type=int, default=4000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=1000,
parser.add_argument('--epochs', type=int, default=500,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=16,
help='input batch size for training')
......@@ -75,7 +80,7 @@ parser.add_argument('--test_only', action='store_true',
# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-5,
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--loss_scale', type=float, default=1024.0,
help='scaling factor for optim')
......@@ -97,13 +102,12 @@ parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
help='path to save ckpt')
parser.add_argument('--ckpt_save_interval', type=int, default=10,
help='save ckpt frequency, unit is epoch')
parser.add_argument('--ckpt_save_max', type=int, default=100,
parser.add_argument('--ckpt_save_max', type=int, default=10,
help='max number of saved ckpt')
parser.add_argument('--ckpt_path', type=str, default='',
help='path of saved ckpt')
# Task
parser.add_argument('--task_id', type=int, default=0)
parser.add_argument('--onnx_path', type=str, default='',
help='path of exported onnx model')
# ModelArts
parser.add_argument('--modelArts_mode', type=ast.literal_eval, default=False,
......@@ -113,12 +117,11 @@ parser.add_argument('--data_url', type=str, default='', help='the directory path
args, unparsed = parser.parse_known_args()
args.scale = [int(x) for x in args.scale.split("+")]
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e8
args.epochs = 100
for arg in vars(args):
if vars(args)[arg] == 'True':
......
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -40,7 +40,8 @@ class SRData:
self.benchmark = benchmark
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
self.scales = [2, 3, 4]
self.set_scale()
self._set_filesystem(args.dir_data)
self._set_img(args)
if train:
......@@ -56,13 +57,13 @@ class SRData:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
for s in self.scale:
for s in self.scales:
if s == 1:
os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
else:
os.makedirs(
os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
self.images_hr, self.images_lr = [], [[] for _ in self.scales]
for h in list_hr:
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
......@@ -88,15 +89,15 @@ class SRData:
"""_scan"""
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
names_lr = [[] for _ in self.scale]
names_lr = [[] for _ in self.scales]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
for si, s in enumerate(self.scales):
if s != 1:
scale = s
names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
.format(s, filename, scale, self.ext[1])))
for si, s in enumerate(self.scale):
for si, s in enumerate(self.scales):
if s == 1:
names_lr[si] = names_hr
return names_hr, names_lr
......@@ -182,7 +183,7 @@ class SRData:
def get_patch(self, lr, hr):
"""get_patch"""
scale = self.scale[self.idx_scale]
scale = self.scales[self.idx_scale]
if self.train:
lr, hr = common.get_patch(
lr, hr,
......@@ -195,9 +196,14 @@ class SRData:
hr = hr[0:ih * scale, 0:iw * scale]
return lr, hr
def set_scale(self, idx_scale):
def set_scale(self):
"""set_scale"""
if not self.input_large:
self.idx_scale = idx_scale
if self.scale == 2:
self.idx_scale = 0
elif self.scale == 3:
self.idx_scale = 1
elif self.scale == 4:
self.idx_scale = 2
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)
self.idx_scale = random.randint(0, len(self.scales) - 1)
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -48,38 +48,28 @@ class MeanShift(nn.Conv2d):
self.has_bias = True
def _pixelsf_(x, scale):
"""rcan"""
n, c, ih, iw = x.shape
oh = ih * scale
ow = iw * scale
oc = c // (scale ** 2)
output = P.Transpose()(x, (0, 2, 1, 3))
output = P.Reshape()(output, (n, ih, oc * scale, scale, iw))
output = P.Transpose()(output, (0, 1, 2, 4, 3))
output = P.Reshape()(output, (n, ih, oc, scale, ow))
output = P.Transpose()(output, (0, 2, 1, 3, 4))
output = P.Reshape()(output, (n, oc, oh, ow))
return output
class SmallUpSampler(nn.Cell):
"""rcan"""
def __init__(self, conv, upsize, n_feats, has_bias=True):
"""rcan"""
super(SmallUpSampler, self).__init__()
self.conv = conv(n_feats, upsize * upsize * n_feats, 3, has_bias)
self.reshape = P.Reshape()
self.upsize = upsize
self.pixelsf = _pixelsf_
class PixelShuffle(nn.Cell):
"""PixelShuffle"""
def __init__(self, scale):
super(PixelShuffle, self).__init__()
self.scale = scale
def construct(self, x):
"""rcan"""
x = self.conv(x)
output = self.pixelsf(x, self.upsize)
n, c, ih, iw = x.shape
oh = ih * self.scale
ow = iw * self.scale
oc = c // (self.scale ** 2)
output = P.Transpose()(x, (0, 2, 1, 3))
output = P.Reshape()(output, (n, ih, oc * self.scale, self.scale, iw))
output = P.Transpose()(output, (0, 1, 2, 4, 3))
output = P.Reshape()(output, (n, ih, oc, self.scale, ow))
output = P.Transpose()(output, (0, 2, 1, 3, 4))
output = P.Reshape()(output, (n, oc, oh, ow))
return output
class Upsampler(nn.Cell):
"""rcan"""
def __init__(self, conv, scale, n_feats, has_bias=True):
......@@ -88,16 +78,19 @@ class Upsampler(nn.Cell):
m = []
if (scale & (scale - 1)) == 0:
for _ in range(int(math.log(scale, 2))):
m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias))
m.append(conv(n_feats, 4 * n_feats, 3, has_bias))
m.append(PixelShuffle(2))
elif scale == 3:
m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias))
m.append(conv(n_feats, 9 * n_feats, 3, has_bias))
else:
raise NotImplementedError
self.net = nn.SequentialCell(m)
def construct(self, x):
"""rcan"""
return self.net(x)
class AdaptiveAvgPool2d(nn.Cell):
"""rcan"""
def __init__(self):
......@@ -107,8 +100,7 @@ class AdaptiveAvgPool2d(nn.Cell):
def construct(self, x):
"""rcan"""
return self.ReduceMean(x, 0)
return self.ReduceMean(x, (2, 3))
class CALayer(nn.Cell):
"""rcan"""
......@@ -131,7 +123,6 @@ class CALayer(nn.Cell):
y = self.conv_du(y)
return x * y
class RCAB(nn.Cell):
"""rcan"""
def __init__(self, conv, n_feat, kernel_size, reduction, has_bias=True
......@@ -159,7 +150,6 @@ class ResidualGroup(nn.Cell):
def __init__(self, conv, n_feat, kernel_size, reduction, n_resblocks):
"""rcan"""
super(ResidualGroup, self).__init__()
modules_body = []
modules_body = [
RCAB(
conv, n_feat, kernel_size, reduction, has_bias=True, bn=False, act=nn.ReLU(), res_scale=1) \
......@@ -185,34 +175,33 @@ class RCAN(nn.Cell):
n_feats = args.n_feats
kernel_size = 3
reduction = args.reduction
scale = args.scale[0]
scale = args.scale
self.dytpe = mstype.float16
# RGB mean for DIV2K
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std).to_float(self.dytpe)
self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head module
modules_head = conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe)
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
modules_body = [
ResidualGroup(
conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks).to_float(self.dytpe) \
conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks)\
for _ in range(n_resgroups)]
modules_body.append(conv(n_feats, n_feats, kernel_size).to_float(self.dytpe))
modules_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
modules_tail = [
Upsampler(conv, scale, n_feats).to_float(self.dytpe),
conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)]
self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1).to_float(self.dytpe)
Upsampler(conv, scale, n_feats),
conv(n_feats, args.n_colors, kernel_size)]
self.head = modules_head
self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.SequentialCell(modules_head)
self.body = nn.SequentialCell(modules_body)
self.tail = nn.SequentialCell(modules_tail)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment