diff --git a/research/cv/RCAN/README.md b/research/cv/RCAN/README.md
index 687e4d47a47958793af45aa7679358b487bd2b76..5e710052f38770bb1a3cb3ef14b466c1e6eda253 100644
--- a/research/cv/RCAN/README.md
+++ b/research/cv/RCAN/README.md
@@ -42,7 +42,7 @@
 - 数据集大小:约7.12GB,共900张图像
  - 训练集:800张图像
  - 测试集:100张图像
-- 基准数据集可下载如下:[Set5](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)、[Set14](https://deepai.org/dataset/set14-super-resolution)、[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)、[Urban100](http://vllab.ucmerced.edu/wlai24/LapSRN/)。
+- 基准数据集可下载如下:[Set5](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)、[Set14](https://deepai.org/dataset/set14-super-resolution)、[B100](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)、[Urban100](https://deepai.org/dataset/urban100-4x-upscaling/)。
 - 数据格式:png文件
  - 注:数据将在src/data/DIV2K.py中处理。
  - 注:dir_data中需要指定数据集所在位置的上一层目录。
@@ -128,6 +128,7 @@ DIV2K
         ├── script
         │   ├── run_distribute_train.sh           // Ascend分布式训练shell脚本
         │   ├── run_eval.sh                       // eval验证shell脚本
+        │   ├── run_eval_onnx.sh                  // eval_onnx验证shell脚本
         │   ├── run_ascend_standalone.sh          // Ascend训练shell脚本
         ├── src
         │   ├── data
@@ -139,6 +140,7 @@ DIV2K
         │   ├── args.py                           //超参数
         ├── train.py                              //训练脚本
         ├── eval.py                               //评估脚本
+        ├── eval_onnx.py                          //评估ONNX脚本
         ├── export.py                             //模型导出
         ├── README.md                             // 自述文件
 ```
@@ -257,6 +259,32 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DATASET_TYPE] [SCALE] [DEVICE_I
 - DEVICE_ID 设备ID, 默认为:0
 - 上述python命令在后台运行,可通过`run_infer.log`文件查看结果。
 
+## ONNX评估
+
+### 导出ONNX模型
+
+```bash
+python export.py [--dir_data] [--file_format] [--ckpt_path]
+```
+
+选项:
+  --dir_data        数据集目录
+  --file_format     需为 [ONNX]
+  --ckpt_path       检查点路径
+
+### ONNX评估
+
+- 评估过程如下,需要指定数据集类型为“DIV2K”
+
+```bash
+bash script/run_eval_onnx.sh [TEST_DATA_DIR] [ONNX_PATH] [DATASET_TYPE]
+```
+
+- TEST_DATA_DIR 测试数据文件路径
+- ONNX_PATH ONNX模型路径
+- DATASET_TYPE 数据集名称(DIV2K)
+- 上述python命令在后台运行,可通过`eval_onnx.log`文件查看结果。
+
 ## 模型导出
 
 ```bash
diff --git a/research/cv/RCAN/eval.py b/research/cv/RCAN/eval.py
index 61a0573d5f740564baaf2cff67b9be5404b0d3f6..ee406362d4b909e827f70aad14bf118ab58082db 100644
--- a/research/cv/RCAN/eval.py
+++ b/research/cv/RCAN/eval.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
 # limitations under the License.
 # ============================================================================
 """eval script"""
-import os
 import time
 import numpy as np
 import mindspore.dataset as ds
@@ -26,8 +25,11 @@ from src.data.srdata import SRData
 from src.metrics import calc_psnr, quantize, calc_ssim
 from src.data.div2k import DIV2K
 
-device_id = int(os.getenv('DEVICE_ID', '0'))
-context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
+
+context.set_context(mode=context.GRAPH_MODE,
+                    device_target=args.device_target,
+                    device_id=args.device_id,
+                    save_graphs=False)
 context.set_context(max_call_depth=10000)
 def eval_net():
     """eval"""
@@ -62,16 +64,16 @@ def eval_net():
         pred = net_m(lr)
         pred_np = pred.asnumpy()
         pred_np = quantize(pred_np, 255)
-        psnr = calc_psnr(pred_np, hr, args.scale[0], 255.0)
+        psnr = calc_psnr(pred_np, hr, args.scale, 255.0)
         pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
         hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
-        ssim = calc_ssim(pred_np, hr, args.scale[0])
+        ssim = calc_ssim(pred_np, hr, args.scale)
         print("current psnr: ", psnr)
         print("current ssim: ", ssim)
         psnrs[batch_idx, 0] = psnr
         ssims[batch_idx, 0] = ssim
-    print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale[0], psnrs.mean(axis=0)[0]))
-    print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale[0], ssims.mean(axis=0)[0]))
+    print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale, psnrs.mean(axis=0)[0]))
+    print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale, ssims.mean(axis=0)[0]))
 
 
 if __name__ == '__main__':
diff --git a/research/cv/RCAN/eval_onnx.py b/research/cv/RCAN/eval_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a103c35b9a6d1dbeebe2e4b8b4c88c8918d0bf0
--- /dev/null
+++ b/research/cv/RCAN/eval_onnx.py
@@ -0,0 +1,87 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""eval script"""
+import time
+import numpy as np
+import mindspore.dataset as ds
+import onnxruntime as ort
+from src.args import args
+from src.data.srdata import SRData
+from src.metrics import calc_psnr, quantize, calc_ssim
+from src.data.div2k import DIV2K
+
+
+def create_session(checkpoint_path, target_device):
+    if target_device == 'GPU':
+        providers = ['CUDAExecutionProvider']
+    elif target_device == 'CPU':
+        providers = ['CPUExecutionProvider']
+    else:
+        raise ValueError(
+            f'Unsupported target device {target_device}, '
+            f'Expected one of: "CPU", "GPU"'
+        )
+    sess = ort.InferenceSession(checkpoint_path, providers=providers)
+    name = sess.get_inputs()[0].name
+    return sess, name
+
+
+def eval_net():
+    """eval"""
+    if args.epochs == 0:
+        args.epochs = 100
+    for arg in vars(args):
+        if vars(args)[arg] == 'True':
+            vars(args)[arg] = True
+        elif vars(args)[arg] == 'False':
+            vars(args)[arg] = False
+    if args.data_test[0] == 'DIV2K':
+        train_dataset = DIV2K(args, name=args.data_test, train=False, benchmark=False)
+    else:
+        train_dataset = SRData(args, name=args.data_test, train=False, benchmark=False)
+    train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
+    train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
+    train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
+    print('load mindspore net successfully.')
+    num_imgs = train_de_dataset.get_dataset_size()
+    psnrs = np.zeros((num_imgs, 1))
+    ssims = np.zeros((num_imgs, 1))
+    for batch_idx, imgs in enumerate(train_loader):
+        lr = imgs['LR']
+        hr = imgs['HR']
+        img_shape = lr.shape
+        onnx_file = args.onnx_path + '//' + str(img_shape[2]) + '_' + str(img_shape[3]) + '.onnx'
+        session, input_name = create_session(onnx_file, 'GPU')
+        pred = session.run(None, {input_name: lr})[0]
+        pred_np = pred
+        pred_np = quantize(pred_np, 255)
+        psnr = calc_psnr(pred_np, hr, args.scale, 255.0)
+        pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
+        hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
+        ssim = calc_ssim(pred_np, hr, args.scale)
+        print("current psnr: ", psnr)
+        print("current ssim: ", ssim)
+        psnrs[batch_idx, 0] = psnr
+        ssims[batch_idx, 0] = ssim
+    print('Mean psnr of %s x%s is %.4f' % (args.data_test[0], args.scale, psnrs.mean(axis=0)[0]))
+    print('Mean ssim of %s x%s is %.4f' % (args.data_test[0], args.scale, ssims.mean(axis=0)[0]))
+
+
+if __name__ == '__main__':
+    time_start = time.time()
+    print("Start eval function!")
+    eval_net()
+    time_end = time.time()
+    print('eval_time: %f' % (time_end - time_start))
diff --git a/research/cv/RCAN/export.py b/research/cv/RCAN/export.py
index 692ec39e6cac224b3b20731114922da3f5c6717d..3d765b21d28866a04d75a1d2a9dd3a6748c35a94 100644
--- a/research/cv/RCAN/export.py
+++ b/research/cv/RCAN/export.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,8 +16,12 @@
 import os
 import argparse
 import numpy as np
+from src.args import args as args_1
+from src.data.srdata import SRData
+from src.data.div2k import DIV2K
 from src.rcan_model import RCAN
 import mindspore as ms
+import mindspore.dataset as ds
 from mindspore import Tensor, context, load_checkpoint, export
 
 
@@ -32,29 +36,48 @@ parser.add_argument('--n_colors', type=int, default=3, help='number of color cha
 parser.add_argument('--n_resblocks', type=int, default=20, help='number of residual blocks')
 parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps')
 parser.add_argument('--res_scale', type=float, default=1, help='residual scaling')
-parser.add_argument('--task_id', type=int, default=0)
-parser.add_argument('--n_resgroups', type=int, default=10,
-                    help='number of residual groups')
-parser.add_argument('--reduction', type=int, default=16,
-                    help='number of feature maps reduction')
-args_1 = parser.parse_args()
+parser.add_argument('--n_resgroups', type=int, default=10, help='number of residual groups')
+parser.add_argument('--reduction', type=int, default=16, help='number of feature maps reduction')
+parser.add_argument('--data_range', type=str, default='1-800/801-810', help='train/test data range')
+parser.add_argument('--test_only', action='store_true', help='set this option to test the model')
+parser.add_argument('--model', default='RCAN', help='model name')
+parser.add_argument('--dir_data', type=str, default='', help='dataset directory')
+parser.add_argument('--ext', type=str, default='sep', help='dataset file extension')
+
+args = parser.parse_args()
 
 MAX_HR_SIZE = 2040
 
-def run_export(args):
+def run_export():
     """ export """
     device_id = int(os.getenv('DEVICE_ID', '0'))
-    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
+    context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=device_id)
     net = RCAN(args)
-    max_lr_size = MAX_HR_SIZE // args.scale  #max_lr_size = MAX_HR_SIZE / scale
+    max_lr_size = MAX_HR_SIZE // args.scale  #  max_lr_size = MAX_HR_SIZE / scale
     param_dict = load_checkpoint(args.ckpt_path)
     net.load_pre_trained_param_dict(param_dict, strict=False)
     net.set_train(False)
     print('load mindspore net and checkpoint successfully.')
-    inputs = Tensor(np.ones([args.batch_size, 3, max_lr_size, max_lr_size]), ms.float32)
-    export(net, inputs, file_name=args.file_name, file_format=args.file_format)
+
+    if args.file_format == 'ONNX':
+        if args_1.data_test[0] == 'DIV2K':
+            train_dataset = DIV2K(args_1, name=args_1.data_test, train=False, benchmark=False)
+        else:
+            train_dataset = SRData(args_1, name=args_1.data_test, train=False, benchmark=False)
+        train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
+        train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
+        train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
+
+        for _, imgs in enumerate(train_loader):
+            img_shape = imgs['LR'].shape
+            export_path = str(img_shape[2]) + '_' + str(img_shape[3])
+            inputs = Tensor(np.ones([args.batch_size, 3, img_shape[2], img_shape[3]]), ms.float32)
+            export(net, inputs, file_name=export_path, file_format=args.file_format)
+    else:
+        inputs = Tensor(np.ones([args.batch_size, 3, 678, max_lr_size]), ms.float32)
+        export(net, inputs, file_name=args.file_name, file_format=args.file_format)
     print('export successfully!')
 
 
 if __name__ == "__main__":
-    run_export(args_1)
+    run_export()
diff --git a/research/cv/RCAN/script/run_eval_onnx.sh b/research/cv/RCAN/script/run_eval_onnx.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fcf8bed66e509206fc13fa15f462b571325f14b0
--- /dev/null
+++ b/research/cv/RCAN/script/run_eval_onnx.sh
@@ -0,0 +1,62 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 3 ]; then
+  echo "Usage: sh run_eval_onnx.sh [TEST_DATA_DIR] [ONNX_PATH] [DATASET_TYPE]"
+  exit 1
+fi
+
+get_real_path() {
+  if [ "${1:0:1}" == "/" ]; then
+    echo "$1"
+  else
+    echo "$(realpath -m $PWD/$1)"
+  fi
+}
+
+PATH1=$(get_real_path $1)
+PATH2=$(get_real_path $2)
+DATASET_TYPE=$3
+
+if [ ! -d $PATH1 ]; then
+  echo "error: TEST_DATA_DIR=$PATH1 is not a directory"
+  exit 1
+fi
+
+if [ ! -d $PATH2 ]; then
+  echo "error: ONNX_PATH=$PATH2 is not a directory"
+  exit 1
+fi
+
+if [ -d "eval_onnx" ]; then
+  rm -rf ./eval_onnx
+fi
+mkdir ./eval_onnx
+cp ../*.py ./eval_onnx
+cp -r ../src ./eval_onnx
+cd ./eval_onnx || exit
+env >env.log
+echo "start evaluation ..."
+
+python eval_onnx.py \
+    --dir_data=${PATH1} \
+    --batch_size 1 \
+    --test_only \
+    --ext "img" \
+    --data_test=${DATASET_TYPE} \
+    --onnx_path=${PATH2} \
+    --task_id 0 \
+    --scale 2 > eval_onnx.log 2>&1 &
diff --git a/research/cv/RCAN/src/args.py b/research/cv/RCAN/src/args.py
index 12f983989b2effb2ed587a4b1f2af88a44cf7afc..bd5ae796a4b77722bced1c4f8f40a85b686be558 100644
--- a/research/cv/RCAN/src/args.py
+++ b/research/cv/RCAN/src/args.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2021-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@ parser.add_argument('--data_range', type=str, default='1-800/801-810',
                     help='train/test data range')
 parser.add_argument('--ext', type=str, default='sep',
                     help='dataset file extension')
-parser.add_argument('--scale', type=str, default='4',
+parser.add_argument('--scale', type=int, default=2,
                     help='super resolution scale')
 parser.add_argument('--patch_size', type=int, default=48,
                     help='output patch size')
@@ -64,9 +64,14 @@ parser.add_argument('--reduction', type=int, default=16,
                     help='number of feature maps reduction')
 
 # Training specifications
+parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
+                    help="Run distribute, default is false.")
+parser.add_argument('--device_target', type=str, default='Ascend',
+                    help='device target, Ascend or GPU (Default: Ascend)')
+parser.add_argument('--device_id', type=int, default=0, help='device id')
 parser.add_argument('--test_every', type=int, default=4000,
                     help='do test per every N batches')
-parser.add_argument('--epochs', type=int, default=1000,
+parser.add_argument('--epochs', type=int, default=500,
                     help='number of epochs to train')
 parser.add_argument('--batch_size', type=int, default=16,
                     help='input batch size for training')
@@ -75,7 +80,7 @@ parser.add_argument('--test_only', action='store_true',
 
 
 # Optimization specifications
-parser.add_argument('--lr', type=float, default=1e-5,
+parser.add_argument('--lr', type=float, default=1e-4,
                     help='learning rate')
 parser.add_argument('--loss_scale', type=float, default=1024.0,
                     help='scaling factor for optim')
@@ -97,13 +102,12 @@ parser.add_argument('--ckpt_save_path', type=str, default='./ckpt/',
                     help='path to save ckpt')
 parser.add_argument('--ckpt_save_interval', type=int, default=10,
                     help='save ckpt frequency, unit is epoch')
-parser.add_argument('--ckpt_save_max', type=int, default=100,
+parser.add_argument('--ckpt_save_max', type=int, default=10,
                     help='max number of saved ckpt')
 parser.add_argument('--ckpt_path', type=str, default='',
                     help='path of saved ckpt')
-
-# Task
-parser.add_argument('--task_id', type=int, default=0)
+parser.add_argument('--onnx_path', type=str, default='',
+                    help='path of exported onnx model')
 
 # ModelArts
 parser.add_argument('--modelArts_mode', type=ast.literal_eval, default=False,
@@ -113,12 +117,11 @@ parser.add_argument('--data_url', type=str, default='', help='the directory path
 
 args, unparsed = parser.parse_known_args()
 
-args.scale = [int(x) for x in args.scale.split("+")]
 args.data_train = args.data_train.split('+')
 args.data_test = args.data_test.split('+')
 
 if args.epochs == 0:
-    args.epochs = 1e8
+    args.epochs = 100
 
 for arg in vars(args):
     if vars(args)[arg] == 'True':
diff --git a/research/cv/RCAN/src/data/srdata.py b/research/cv/RCAN/src/data/srdata.py
index b94ee4a3b26ef411d641c027faff173cd3862e32..d526303c23c62a964ec413943c11d01dc9093ad6 100644
--- a/research/cv/RCAN/src/data/srdata.py
+++ b/research/cv/RCAN/src/data/srdata.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2021-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -40,7 +40,8 @@ class SRData:
         self.benchmark = benchmark
         self.input_large = (args.model == 'VDSR')
         self.scale = args.scale
-        self.idx_scale = 0
+        self.scales = [2, 3, 4]
+        self.set_scale()
         self._set_filesystem(args.dir_data)
         self._set_img(args)
         if train:
@@ -56,13 +57,13 @@ class SRData:
             self.images_hr, self.images_lr = list_hr, list_lr
         elif args.ext.find('sep') >= 0:
             os.makedirs(self.dir_hr.replace(self.apath, path_bin), exist_ok=True)
-            for s in self.scale:
+            for s in self.scales:
                 if s == 1:
                     os.makedirs(os.path.join(self.dir_hr), exist_ok=True)
                 else:
                     os.makedirs(
                         os.path.join(self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s)), exist_ok=True)
-            self.images_hr, self.images_lr = [], [[] for _ in self.scale]
+            self.images_hr, self.images_lr = [], [[] for _ in self.scales]
             for h in list_hr:
                 b = h.replace(self.apath, path_bin)
                 b = b.replace(self.ext[0], '.pt')
@@ -88,15 +89,15 @@ class SRData:
         """_scan"""
         names_hr = sorted(
             glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])))
-        names_lr = [[] for _ in self.scale]
+        names_lr = [[] for _ in self.scales]
         for f in names_hr:
             filename, _ = os.path.splitext(os.path.basename(f))
-            for si, s in enumerate(self.scale):
+            for si, s in enumerate(self.scales):
                 if s != 1:
                     scale = s
                     names_lr[si].append(os.path.join(self.dir_lr, 'X{}/{}x{}{}' \
                                                      .format(s, filename, scale, self.ext[1])))
-        for si, s in enumerate(self.scale):
+        for si, s in enumerate(self.scales):
             if s == 1:
                 names_lr[si] = names_hr
         return names_hr, names_lr
@@ -182,7 +183,7 @@ class SRData:
 
     def get_patch(self, lr, hr):
         """get_patch"""
-        scale = self.scale[self.idx_scale]
+        scale = self.scales[self.idx_scale]
         if self.train:
             lr, hr = common.get_patch(
                 lr, hr,
@@ -195,9 +196,14 @@ class SRData:
             hr = hr[0:ih * scale, 0:iw * scale]
         return lr, hr
 
-    def set_scale(self, idx_scale):
+    def set_scale(self):
         """set_scale"""
         if not self.input_large:
-            self.idx_scale = idx_scale
+            if self.scale == 2:
+                self.idx_scale = 0
+            elif self.scale == 3:
+                self.idx_scale = 1
+            elif self.scale == 4:
+                self.idx_scale = 2
         else:
-            self.idx_scale = random.randint(0, len(self.scale) - 1)
+            self.idx_scale = random.randint(0, len(self.scales) - 1)
diff --git a/research/cv/RCAN/src/rcan_model.py b/research/cv/RCAN/src/rcan_model.py
index 0fd5020768cfc18ef43945bcb4557f2fb31270f7..b04559161f2671f58dc215a6b16a57e1fd6e9d55 100644
--- a/research/cv/RCAN/src/rcan_model.py
+++ b/research/cv/RCAN/src/rcan_model.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2021-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -48,38 +48,28 @@ class MeanShift(nn.Conv2d):
         self.has_bias = True
 
 
-def _pixelsf_(x, scale):
-    """rcan"""
-    n, c, ih, iw = x.shape
-    oh = ih * scale
-    ow = iw * scale
-    oc = c // (scale ** 2)
-    output = P.Transpose()(x, (0, 2, 1, 3))
-    output = P.Reshape()(output, (n, ih, oc * scale, scale, iw))
-    output = P.Transpose()(output, (0, 1, 2, 4, 3))
-    output = P.Reshape()(output, (n, ih, oc, scale, ow))
-    output = P.Transpose()(output, (0, 2, 1, 3, 4))
-    output = P.Reshape()(output, (n, oc, oh, ow))
-    return output
-
-
-class SmallUpSampler(nn.Cell):
-    """rcan"""
-    def __init__(self, conv, upsize, n_feats, has_bias=True):
-        """rcan"""
-        super(SmallUpSampler, self).__init__()
-        self.conv = conv(n_feats, upsize * upsize * n_feats, 3, has_bias)
-        self.reshape = P.Reshape()
-        self.upsize = upsize
-        self.pixelsf = _pixelsf_
+
+class PixelShuffle(nn.Cell):
+    """PixelShuffle"""
+    def __init__(self, scale):
+        super(PixelShuffle, self).__init__()
+        self.scale = scale
 
     def construct(self, x):
-        """rcan"""
-        x = self.conv(x)
-        output = self.pixelsf(x, self.upsize)
+        n, c, ih, iw = x.shape
+        oh = ih * self.scale
+        ow = iw * self.scale
+        oc = c // (self.scale ** 2)
+        output = P.Transpose()(x, (0, 2, 1, 3))
+        output = P.Reshape()(output, (n, ih, oc * self.scale, self.scale, iw))
+        output = P.Transpose()(output, (0, 1, 2, 4, 3))
+        output = P.Reshape()(output, (n, ih, oc, self.scale, ow))
+        output = P.Transpose()(output, (0, 2, 1, 3, 4))
+        output = P.Reshape()(output, (n, oc, oh, ow))
         return output
 
 
+
 class Upsampler(nn.Cell):
     """rcan"""
     def __init__(self, conv, scale, n_feats, has_bias=True):
@@ -88,16 +78,19 @@ class Upsampler(nn.Cell):
         m = []
         if (scale & (scale - 1)) == 0:
             for _ in range(int(math.log(scale, 2))):
-                m.append(SmallUpSampler(conv, 2, n_feats, has_bias=has_bias))
+                m.append(conv(n_feats, 4 * n_feats, 3, has_bias))
+                m.append(PixelShuffle(2))
         elif scale == 3:
-            m.append(SmallUpSampler(conv, 3, n_feats, has_bias=has_bias))
+            m.append(conv(n_feats, 9 * n_feats, 3, has_bias))
+        else:
+            raise NotImplementedError
+
         self.net = nn.SequentialCell(m)
 
     def construct(self, x):
         """rcan"""
         return self.net(x)
 
-
 class AdaptiveAvgPool2d(nn.Cell):
     """rcan"""
     def __init__(self):
@@ -107,8 +100,7 @@ class AdaptiveAvgPool2d(nn.Cell):
 
     def construct(self, x):
         """rcan"""
-        return self.ReduceMean(x, 0)
-
+        return self.ReduceMean(x, (2, 3))
 
 class CALayer(nn.Cell):
     """rcan"""
@@ -131,7 +123,6 @@ class CALayer(nn.Cell):
         y = self.conv_du(y)
         return x * y
 
-
 class RCAB(nn.Cell):
     """rcan"""
     def __init__(self, conv, n_feat, kernel_size, reduction, has_bias=True
@@ -159,7 +150,6 @@ class ResidualGroup(nn.Cell):
     def __init__(self, conv, n_feat, kernel_size, reduction, n_resblocks):
         """rcan"""
         super(ResidualGroup, self).__init__()
-        modules_body = []
         modules_body = [
             RCAB(
                 conv, n_feat, kernel_size, reduction, has_bias=True, bn=False, act=nn.ReLU(), res_scale=1) \
@@ -192,27 +182,26 @@ class RCAN(nn.Cell):
         rgb_mean = (0.4488, 0.4371, 0.4040)
         rgb_std = (1.0, 1.0, 1.0)
 
-        self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std).to_float(self.dytpe)
+        self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std)
 
         # define head module
-        modules_head = conv(args.n_colors, n_feats, kernel_size).to_float(self.dytpe)
+        modules_head = [conv(args.n_colors, n_feats, kernel_size)]
 
         # define body module
         modules_body = [
             ResidualGroup(
-                conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks).to_float(self.dytpe) \
+                conv, n_feats, kernel_size, reduction, n_resblocks=n_resblocks)\
             for _ in range(n_resgroups)]
 
-        modules_body.append(conv(n_feats, n_feats, kernel_size).to_float(self.dytpe))
+        modules_body.append(conv(n_feats, n_feats, kernel_size))
 
         # define tail module
         modules_tail = [
-            Upsampler(conv, scale, n_feats).to_float(self.dytpe),
-            conv(n_feats, args.n_colors, kernel_size).to_float(self.dytpe)]
-
-        self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1).to_float(self.dytpe)
+            Upsampler(conv, scale, n_feats),
+            conv(n_feats, args.n_colors, kernel_size)]
 
-        self.head = modules_head
+        self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
+        self.head = nn.SequentialCell(modules_head)
         self.body = nn.SequentialCell(modules_body)
         self.tail = nn.SequentialCell(modules_tail)