diff --git a/research/cv/IRN/README.md b/research/cv/IRN/README.md
index 3566a8da165bae088fac43524a0ad124795a313b..aa80778ee202a29de7a991787c6d98124ef6a840 100644
--- a/research/cv/IRN/README.md
+++ b/research/cv/IRN/README.md
@@ -36,22 +36,23 @@ Mingqing Xiao, Shuxin Zheng, Chang Liu, Yaolong Wang, Di He, Guolin Ke, Jiang Bi
 
 ## 模型架构
 
-![1](./figures/architecture.jpg)
+![1](./figures/architecture.png)
 
 ## 数据集
 
 本示例使用[DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/),其目录结构如下:
 
 ```bash
-DIV2K_data
-├── DIV2K_train_HR/                 # 训练集数据
-└── DIV2K_valid_HR/                 # 测试集数据
+data/
+    ├── DIV2K_train_HR/                 # 训练集高分辨率数据
+    └── DIV2K_valid_HR/                 # 测试集高分辨率数据
 ```
 
 ## 环境要求
 
 - 硬件
     - Ascend处理器
+    - GPU
 - 框架
     - [MindSpore](https://www.mindspore.cn/install/)
 - 如需查看详情,请参见如下资源:
@@ -62,17 +63,39 @@ DIV2K_data
 
 完成计算设备和框架环境的准备后,开发者可以运行如下指令对本示例进行训练和评估。
 
-- Ascend处理器环境运行
+- GPU环境运行
 
 ```bash
-# 8卡分布式训练
-用法:bash run_distribute_train.sh [RANK_TABLE_FILE] [SCALE] [DATASET_PATH]
+# 单卡训练
+# 用法:bash run_standalone_train_gpu.sh [SCALE] [DATASET_GT_PATH]
+bash run_standalone_train_gpu.sh 4 /home/nonroot/IRN/data/DIV2K_train_HR
+
+# 分布式训练
+# 用法:bash run_distribute_train_gpu.sh [DEVICE_NUM] [SCALE] [DATASET_PATH]
+# 样例:DEVICE_NUM等于2、4、8,分别对应2、4、8卡分布式
+bash run_distribute_train_gpu.sh 8 4 /home/nonroot/IRN/data/DIV2K_train_HR
+
+# 单卡评估
+# 用法:bash run_eval_gpu.sh [SCALE] [DATASET_PATH] [CHECKPOINT_PATH]
+bash run_eval_gpu.sh 4 /home/nonroot/IRN/data/DIV2K_valid_HR /home/nonroot/IRN/ckpt/latest.ckpt
+
+```
+
+- Ascend处理器环境运行
 
+```bash
 # 单卡训练
-用法:bash run_standalone_train.sh [SCALE] [DATASET_GT_PATH]
+# 用法:bash run_standalone_train_ascend.sh [SCALE] [DATASET_GT_PATH]
+bash run_standalone_train_ascend.sh 4 /home/nonroot/IRN/data/DIV2K_train_HR
+
+# 8卡分布式训练
+# 用法:bash run_distribute_train_ascend.sh [RANK_TABLE_FILE] [SCALE] [DATASET_PATH]
+bash run_distribute_train_ascend.sh rank_table_file.json 4 /home/nonroot/IRN/data/DIV2K_train_HR
 
 # 单卡评估
-用法:bash run_eval.sh [SCALE] [DATASET_PATH] [CHECKPOINT_PATH]
+# 用法:bash run_eval_ascend.sh [SCALE] [DATASET_PATH] [CHECKPOINT_PATH]
+bash run_eval_ascend.sh 4 /home/nonroot/IRN/data/DIV2K_valid_HR /home/nonroot/IRN/ckpt/latest.ckpt
+
 ```
 
 分布式训练需要提前创建JSON格式的HCCL配置文件。
@@ -85,9 +108,12 @@ DIV2K_data
 .
 ├── README.md                               # 说明文档
 ├── scripts
-│   ├── run_distribute_train.sh             # Ascend处理器环境多卡训练脚本
-│   ├── run_eval.sh                         # Ascend处理器环境评估脚本
-│   └── run_standalone_train.sh             # Ascend处理器环境单卡训练脚本
+│   ├── run_distribute_train_ascend.sh      # Ascend处理器环境多卡训练脚本
+│   ├── run_distribute_train_gpu.sh         # GPU处理器环境多卡训练脚本
+│   ├── run_eval_ascend.sh                  # Ascend处理器环境评估脚本
+│   ├── run_eval_gpu.sh                     # GPU处理器环境评估脚本
+│   ├── run_standalone_train_ascend.sh      # Ascend处理器环境单卡训练脚本
+│   └── run_standalone_train_gpu.sh         # GPU处理器环境单卡训练脚本
 ├── src
 │   ├── data
 │   │   ├── dataset.py                      # 数据集处理
@@ -111,7 +137,9 @@ DIV2K_data
 │   └── utils
 │       └── util.py                         # 评价指标计算
 ├── train.py                                # 训练网络
-└── eval.py                                 # 测试网络
+├── export.py                               # 导出网络
+├── requirements.txt                        # 环境需求文件
+└── val.py                                  # 测试网络
 ```
 
 ## 脚本参数
@@ -201,8 +229,7 @@ optional arguments:
 # python test.py -h
 usage: eval.py  [--scale {2,4}] [--dataset_GT_path {path of intended GT dataset}] [--dataset_LQ_path {path of intended LQ dataset}] [--resume_state {path of the checkpoint}] [--device_target {Ascend,GPU,CPU}]
 
-
-AutoAugment for image classification.
+IRN for image rescaling.
 
 optional arguments:
   -h, --help            Show this help message and exit
@@ -225,18 +252,17 @@ optional arguments:
 
 ```bash
 # python export.py -h
-
 usage: export.py [-h] [--scale {2,4}] [--device_id DEVICE_ID] --checkpoint_path
                  CHECKPOINT_PATH [--file_name FILE_NAME]
                  [--file_format {AIR,ONNX,MINDIR}]
                  [--device_target {Ascend,GPU,CPU}]
 
-WRN with AutoAugment export.
+IRN with AutoAugment export.
 
 optional arguments:
   -h, --help            Show this help message and exit
   --scale {2,4}
-                        Rescaling parameter
+                        ResIcaling parameter
   --device_id DEVICE_ID
                         Device id.
   --checkpoint_path CHECKPOINT_PATH
@@ -252,20 +278,20 @@ optional arguments:
 
 ## 模型描述
 
-| 参数 | 单卡GPU | 单卡Ascend 910 | 8卡Ascend 910 |
-|:---|:---|:---|:--|
-| 资源 | GTX 1080ti | Ascend 910 | Ascend 910|
-| 上传日期 | 2021.09.25 | 2021.09.25 | 2021.11.01 |
-| MindSpore版本 | 1.2.0 | 1.3.0 | 1.3.0 |
-| 训练数据集 | DIV2K | DIV2K | DIV2K |
-| 优化器 | Adam | Adam | Adam |
-| 输出 | Reconstructed HR image | Reconstructed HR image | Reconstructed HR image |
-| PSNR | 34.83 | 34.11 | 33.88 |
-| SSIM | 0.9287  | 0.9206 | 0.9167 |
-| 速度 | 1534 ms/step | 271 ms/step | 409 ms/step |
-| 总时长 | 3162 mins | 2258 mins | 409 mins
-| 微调检查点 | 50.1M(.ckpt文件) | 50.1M(.ckpt文件) | 50.1M(.ckpt文件) |
-| 脚本 | [IRN](./) | [IRN](./) | [IRN](./) |
+| 参数          | 单卡GPU                 | 4卡GPU                  | 单卡Ascend 910         | 8卡Ascend 910          |
+| :------------ | :---------------------- | ----------------------- | :--------------------- | ---------------------- |
+| 资源          | NVIDIA V100 | NVIDIA GeForce RTX 3090 | Ascend 910             | Ascend 910             | V |
+| 上传日期      | 2022.6.13               | 2022.6.13               | 2021.09.25             | 2021.11.01             |
+| MindSpore版本 | 1.6.1                   | 1.6.1                   | 1.3.0                  | 1.3.0                  |
+| 训练数据集    | DIV2K                   | DIV2K                   | DIV2K                  | DIV2K                  |
+| 优化器        | Adam                    | Adam                    | Adam                   | Adam                   |
+| 输出          | Reconstructed HR image  | Reconstructed HR image  | Reconstructed HR image | Reconstructed HR image |
+| PSNR          | NaN                     | 34.53                   | 34.11                  | 33.88                  |
+| SSIM          | NaN                     | 0.9246                  | 0.9206                 | 0.9167                 |
+| 速度          | 836ms/step              | 1417 ms/step            | 271 ms/step            | 409 ms/step            |
+| 总时长        | NaN                     | 2952mins                | 2258 mins              | 409 mins               |
+| 微调检查点    | 50.1M(.ckpt文件)      | 50.1M(.ckpt文件)       | 50.1M(.ckpt文件)      | 50.1M(.ckpt文件)      |
+| 脚本          | [IRN](./)               | [IRN](./)               | [IRN](./)              | [IRN](./)              |
 
 ## 随机情况说明
 
@@ -273,4 +299,4 @@ optional arguments:
 
 ## 官方主页
 
-请浏览官网[主页](https://gitee.com/mindspore/models)。
+请浏览官网[主页](https://gitee.com/mindspore/models)。
\ No newline at end of file
diff --git a/research/cv/IRN/eval.py b/research/cv/IRN/eval.py
index 32d5984668d05046fb1cd292d33ddcb4eedd5bbb..3d247e783091c6d451561329cd587328b68845b8 100644
--- a/research/cv/IRN/eval.py
+++ b/research/cv/IRN/eval.py
@@ -46,7 +46,7 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser(description="irn testing")
     parser.add_argument('--scale', type=int, default=4, choices=(2, 4),
                         help='Rescaling Parameter.')
-    parser.add_argument('--dataset_GT_path', type=str, default='/home/nonroot/DIV2K/DIV2K_train_HR',
+    parser.add_argument('--dataset_GT_path', type=str, default='/home/nonroot/IRN/data/DIV2K_train_HR',
                         help='Path to the folder where the intended GT dataset is stored.')
     parser.add_argument('--dataset_LQ_path', type=str, default=None,
                         help='Path to the folder where the intended LQ dataset is stored.')
@@ -86,6 +86,7 @@ if __name__ == '__main__':
     val_dataset = create_dataset(
         args.dataset_GT_path,
         args.scale,
+        target=args.device_target,
         do_train=False,
         batch_size=1)
 
@@ -109,7 +110,12 @@ if __name__ == '__main__':
                             learning_rate=0.1, momentum=0.9)
 
     # Model
-    model = Model(network=loss, optimizer=optimizer, amp_level="O3")
+    if args.device_target == "Ascend":
+        model = Model(network=loss, optimizer=optimizer, amp_level="O3")
+    elif args.device_target == "GPU":
+        model = Model(network=loss, optimizer=optimizer, amp_level="O0")
+    else:
+        raise ValueError("Unsupported device target.")
 
     val_iter = val_dataset.create_dict_iterator()
 
diff --git a/research/cv/IRN/figures/architecture.jpg b/research/cv/IRN/figures/architecture.jpg
deleted file mode 100644
index 6f66a1a2754cc8fc2169b1455caf94c54aa3950b..0000000000000000000000000000000000000000
Binary files a/research/cv/IRN/figures/architecture.jpg and /dev/null differ
diff --git a/research/cv/IRN/figures/architecture.png b/research/cv/IRN/figures/architecture.png
new file mode 100644
index 0000000000000000000000000000000000000000..06a6bbeac7e79278e3a59f183a58a2f235b5af8c
Binary files /dev/null and b/research/cv/IRN/figures/architecture.png differ
diff --git a/research/cv/IRN/requirements.txt b/research/cv/IRN/requirements.txt
index b3ac74b1b9f9067b37237241683d585912f32818..5724b18e615f316f0b1ae3de50961793015d1a0e 100644
--- a/research/cv/IRN/requirements.txt
+++ b/research/cv/IRN/requirements.txt
@@ -1,3 +1,4 @@
-yaml
+PyYAML
 numpy
-cv2
\ No newline at end of file
+opencv-python
+mindspore-gpu==1.6.1
\ No newline at end of file
diff --git a/research/cv/IRN/scripts/run_distribute_train.sh b/research/cv/IRN/script/run_distribute_train_ascend.sh
similarity index 89%
rename from research/cv/IRN/scripts/run_distribute_train.sh
rename to research/cv/IRN/script/run_distribute_train_ascend.sh
index cbcbbc7835a4996d9f3b97a5b05e8c8fd45035cd..487fdca17430ad59cdedf074dd00da6e2d75b29d 100644
--- a/research/cv/IRN/scripts/run_distribute_train.sh
+++ b/research/cv/IRN/script/run_distribute_train_ascend.sh
@@ -16,7 +16,7 @@
 
 if [ $# != 3 ]; then
     echo "Usage: 
-bash run_distribute_train.sh [RANK_TABLE_FILE] [SCALE] [DATASET_PATH]"
+bash run_distribute_train_ascend.sh [RANK_TABLE_FILE] [SCALE] [DATASET_PATH]"
     exit 1
 fi
 
@@ -39,11 +39,12 @@ for ((i=0; i<${RANK_SIZE}; i++)); do
     export DEVICE_ID=${i}
     export RANK_ID=$((i))
     echo "Start distributed training for rank ${RANK_ID}, device ${DEVICE_ID}"
-    python ../train.py \
+    nohup python ../train.py \
         --scale $2 \
         --dataset_GT_path $3 \
+        --device_target Ascend \
         --run_distribute=True \
-        > train-${i}.log 2>&1 & 
+        > ../train-${i}.log 2>&1 & 
     pid=$!
     PID_LIST+=("${pid}")
 done
diff --git a/research/cv/IRN/script/run_distribute_train_gpu.sh b/research/cv/IRN/script/run_distribute_train_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b36344e8a94d930b86a8e9831ff037b83bf6d051
--- /dev/null
+++ b/research/cv/IRN/script/run_distribute_train_gpu.sh
@@ -0,0 +1,67 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if  [ $# != 3 ]
+then
+    echo "Usage:\
+          bash run_distribute_train_gpu.sh [DEVICE_NUM] [SCALE] [DATASET_PATH]
+          "
+exit 1
+fi
+
+echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
+
+export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+export DEVICE_NUM=$1
+export SCALE=$2
+export HCCL_CONNECT_TIMEOUT=200
+
+if [ $DEVICE_NUM -lt 2 ] && [ $DEVICE_NUM -gt 8 ]
+then
+    echo "error: DEVICE_NUM=$DEVICE_NUM is not in (2-8)"
+exit 1
+fi
+
+if [ $SCALE -ne 2 ] && [ $SCALE -ne 4 ]
+then
+    echo "error: SCALE=$SCALE is not 2 or 4."
+exit 1
+fi
+
+if [ ! -d "$3" ]; then
+    echo "error: DATASET_PATH:$3 does not exist"
+    exit 1
+fi
+
+rm -rf ./train_distribute
+mkdir ./train_distribute
+cp -r ../src ./train_distribute
+cp -r ../*.py ./train_distribute
+
+echo "start distribute training"
+env > env.log
+
+if [ $# == 3 ]
+then
+    nohup mpirun -n $DEVICE_NUM --allow-run-as-root --output-filename ./train_distribute/log_output --merge-stderr-to-stdout \
+    python ../train.py --run_distribute True \
+        --device_num $DEVICE_NUM \
+        --scale $SCALE \
+        --dataset_GT_path $3 \
+        --device_target GPU \
+        > ../train_dis.log 2>&1 &
+fi
+cd ..
diff --git a/research/cv/IRN/scripts/run_eval.sh b/research/cv/IRN/script/run_eval_ascend.sh
similarity index 87%
rename from research/cv/IRN/scripts/run_eval.sh
rename to research/cv/IRN/script/run_eval_ascend.sh
index 8375f27fdd78ef50cabc838d64e311ee3dbc7374..efebeb6c08879888861d8f3d313a5cadda7dedc3 100644
--- a/research/cv/IRN/scripts/run_eval.sh
+++ b/research/cv/IRN/script/run_eval_ascend.sh
@@ -14,14 +14,14 @@
 # limitations under the License.
 # ============================================================================
 
-export DEVICE_ID=1
+export DEVICE_ID=0
 export DEVICE_NUM=1
 export RANK_ID=0
 export RANK_SIZE=1
 
 if [ $# != 3 ]; then
     echo "Usage: \
-bash run_eval.sh [SCALE] [DATASET_PATH] [CHECKPOINT_PATH]"
+bash run_eval_ascend.sh [SCALE] [DATASET_PATH] [CHECKPOINT_PATH]"
     exit 1
 fi
 
@@ -35,10 +35,11 @@ if [ ! -d $2 ]; then
     exit 1
 fi
 
-python ../eval.py \
+nohup python ../eval.py \
     --scale $1 \
     --dataset_GT_path $2 \
+    --device_target Ascend \
     --resume_state $3 \
-    > eval.log 2>&1 &
+    > ../eval.log 2>&1 &
 pid=$!
 echo "Start evaluating with rank ${RANK_ID} on device ${DEVICE_ID}: ${pid}"
diff --git a/research/cv/IRN/script/run_eval_gpu.sh b/research/cv/IRN/script/run_eval_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..38d78282bc69d61e52a35d98fc6638ea5b75f525
--- /dev/null
+++ b/research/cv/IRN/script/run_eval_gpu.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+export DEVICE_ID=0
+export DEVICE_NUM=1
+export RANK_ID=0
+export RANK_SIZE=1
+
+if [ $# != 3 ]; then
+    echo "Usage: \
+bash run_eval_gpu.sh [SCALE] [DATASET_PATH] [CHECKPOINT_PATH]"
+    exit 1
+fi
+
+if [ ! -f "$3" ]; then
+    echo "error: CHECKPOINT_PATH:$3 does not exist"
+    exit 1
+fi
+
+if [ ! -d "$2" ]; then
+    echo "error: DATASET_PATH:$2 does not exist"
+    exit 1
+fi
+
+if [ $1 -ne 2 ] && [ $1 -ne 4 ]
+then
+    echo "error: SCALE=$1 is not 2 or 4."
+exit 1
+fi
+
+nohup python ../eval.py \
+    --scale $1 \
+    --dataset_GT_path $2 \
+    --device_target GPU \
+    --resume_state $3 \
+    > ../eval.log 2>&1 &
+pid=$!
+echo "Start evaluating with rank ${RANK_ID} on device ${DEVICE_ID}: ${pid}"
+
+
diff --git a/research/cv/IRN/script/run_standalone_train_ascend.sh b/research/cv/IRN/script/run_standalone_train_ascend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8ab2fe2bb9ac6a792a65839750d42ec739e45e69
--- /dev/null
+++ b/research/cv/IRN/script/run_standalone_train_ascend.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+export DEVICE_ID=0
+export DEVICE_NUM=1
+export RANK_ID=0
+export RANK_SIZE=1
+if [ $# != 2 ]; then
+    echo "Usage: \
+bash run_standalone_train_ascend.sh [SCALE] [DATASET_GT_PATH]"
+    exit 1
+fi
+if [ ! -d $2 ]; then
+    echo "error: DATASET_PATH:$2 does not exist"
+    exit 1
+fi
+nohup python ../train.py \
+    --scale $1 \
+    --device_target Ascend \
+    --dataset_GT_path $2 \
+    > ../rain.log 2>&1 &
+pid=$!
+echo "Start training with rank ${RANK_ID} on device ${DEVICE_ID}: ${pid}"
\ No newline at end of file
diff --git a/research/cv/IRN/scripts/run_standalone_train.sh b/research/cv/IRN/script/run_standalone_train_gpu.sh
similarity index 84%
rename from research/cv/IRN/scripts/run_standalone_train.sh
rename to research/cv/IRN/script/run_standalone_train_gpu.sh
index 675b067f9f06958fe93bdeb12e6f831dbcbaa1e6..f16043f7d8526db1df5cbb6a3f79539e8d08b1fe 100644
--- a/research/cv/IRN/scripts/run_standalone_train.sh
+++ b/research/cv/IRN/script/run_standalone_train_gpu.sh
@@ -14,26 +14,26 @@
 # limitations under the License.
 # ============================================================================
 
-export DEVICE_ID=3
+export DEVICE_ID=0
 export DEVICE_NUM=1
 export RANK_ID=0
 export RANK_SIZE=1
 
 if [ $# != 2 ]; then
     echo "Usage: \
-bash run_standalone_train.sh [SCALE] [DATASET_GT_PATH]"
+bash run_standalone_train_gpu.sh [SCALE] [DATASET_GT_PATH]"
     exit 1
 fi
 
-if [ ! -d $2 ]; then
+if [ ! -d "$2" ]; then
     echo "error: DATASET_PATH:$2 does not exist"
     exit 1
 fi
 
-
-python ../train.py \
+nohup python ../train.py \
     --scale $1 \
+    --device_target GPU \
     --dataset_GT_path $2 \
-    > train.log 2>&1 &
+    > ../train.log 2>&1 &
 pid=$!
 echo "Start training with rank ${RANK_ID} on device ${DEVICE_ID}: ${pid}"
\ No newline at end of file
diff --git a/research/cv/IRN/src/data/__init__.py b/research/cv/IRN/src/data/__init__.py
index ff0afaed16cdcf2cd85623684eed931b4579c3d6..8c7c952db021f2471a76d970d176bed655dec088 100644
--- a/research/cv/IRN/src/data/__init__.py
+++ b/research/cv/IRN/src/data/__init__.py
@@ -87,19 +87,21 @@ def create_dataset(dataset_path, scale, do_train=True, repeat_num=1,
         rank_size, rank_id = _get_rank_info()
     else:
         if distribute:
-            init()
             rank_id = get_rank()
             rank_size = get_group_size()
+            sampler = Sampler(len(sr_ds), rank_id, rank_size)
         else:
+            sampler = None
             rank_size = 1
+            rank_id = 0
 
     num_shards = None if rank_size == 1 else rank_size
     shard_id = None if rank_size == 1 else rank_id
     if do_train:
         dataset = ds.GeneratorDataset(
             sr_ds, ["downscaled", "original"],
-            num_parallel_workers=1, shuffle=True,
-            sampler=Sampler(len(sr_ds), rank_id, rank_size),
+            num_parallel_workers=4 * rank_size, shuffle=True,
+            sampler=sampler,
             num_shards=num_shards, shard_id=shard_id,
         )
     else:
diff --git a/research/cv/IRN/src/data/dataset.py b/research/cv/IRN/src/data/dataset.py
index 5a2025adc87356f092ef539893a458ab69715734..8b7bede00c524fe9596bd01ea74b90102c1f4543 100644
--- a/research/cv/IRN/src/data/dataset.py
+++ b/research/cv/IRN/src/data/dataset.py
@@ -109,7 +109,8 @@ class SRDataset():
 
         # pil -> numpy, HWC -> CHW
         orig = np.transpose(
-            np.asarray(orig), (2, 0, 1)).astype(np.float16) / 255.
+            np.asarray(orig), (2, 0, 1)).astype(np.float32) / 255.
         downscaled = np.transpose(
-            np.asarray(downscaled), (2, 0, 1)).astype(np.float16) / 255.
+            np.asarray(downscaled), (2, 0, 1)).astype(np.float32) / 255.
+
         return downscaled, orig
diff --git a/research/cv/IRN/src/network/Invnet.py b/research/cv/IRN/src/network/Invnet.py
index 17d40e0dd70fbf793809d011baf156a68e00eb87..58e37979a3db5ce6b17322997b4b8c1fc42d7666 100644
--- a/research/cv/IRN/src/network/Invnet.py
+++ b/research/cv/IRN/src/network/Invnet.py
@@ -20,7 +20,7 @@ import numpy as np
 import mindspore as ms
 import mindspore.nn as nn
 from mindspore.ops import operations as ops
-
+from mindspore import dtype as mstype, context
 import src.network.util as mutil
 
 
@@ -110,6 +110,12 @@ class HaarDownsampling(nn.Cell):
 
     def __init__(self, channel_in):
         super(HaarDownsampling, self).__init__()
+
+        if context.get_context("device_target") == "Ascend":
+            self.cast_type = mstype.float16
+        else:
+            self.cast_type = mstype.float32
+
         self.channel_in = channel_in
 
         self.haar_weights = np.ones((4, 1, 2, 2))
@@ -124,8 +130,8 @@ class HaarDownsampling(nn.Cell):
         self.haar_weights[3, 0, 0, 1] = -1
 
         self.haar_weights = np.concatenate(
-            [self.haar_weights] * self.channel_in, 0).astype(np.float16)
-        self.haar_weights = ms.Tensor(self.haar_weights)
+            [self.haar_weights] * self.channel_in, 0)
+        self.haar_weights = ms.Tensor(self.haar_weights).astype(self.cast_type)
         self.haar_weights.requires_grad = False
 
         self.conv2d = mutil.GroupConv(
diff --git a/research/cv/IRN/src/network/__init__.py b/research/cv/IRN/src/network/__init__.py
index 55dbc8770bd6f799f70069e4823d923849e11c0c..f0a7d533da0af46fe474f2b586bb9d651b121905 100644
--- a/research/cv/IRN/src/network/__init__.py
+++ b/research/cv/IRN/src/network/__init__.py
@@ -14,19 +14,15 @@
 # ============================================================================
 """Initializing for network"""
 
-
 import mindspore as ms
 import mindspore.nn as nn
 from mindspore.ops import functional as F
 from mindspore.ops import composite as C
 from mindspore.ops import operations as ops
 from mindspore import dtype as mstype
-
-
 import src.network.Invnet as Invnet
 from .net_with_loss import IRN_loss
 
-
 def create_model(opt):
     """
         create invertible rescaling network
@@ -37,7 +33,6 @@ def create_model(opt):
     print('Model [{:s}] is created.'.format(m.__class__.__name__))
     return m
 
-
 class TrainOneStepCell_IRN(nn.TrainOneStepCell):
     """
         Encapsulation class of IRN network training
@@ -87,7 +82,7 @@ class TrainOneStepCell_IRN(nn.TrainOneStepCell):
         if clip_coef < 1:
             new_grads = ()
             for grad in grads:
-                new_grads += (self.mul(grad, clip_coef),)  # 更新梯度
+                new_grads += (self.mul(grad, clip_coef),)
             grads = new_grads
         self.optimizer(grads)
         return loss
diff --git a/research/cv/IRN/src/network/net_with_loss.py b/research/cv/IRN/src/network/net_with_loss.py
index e32f7d97c557dcc7d213bfb42b6ba6ae9287a7d8..19c725ae56ed58aa079fd52d1ccbe386d2d97532 100644
--- a/research/cv/IRN/src/network/net_with_loss.py
+++ b/research/cv/IRN/src/network/net_with_loss.py
@@ -14,13 +14,11 @@
 # ============================================================================
 """define network with loss function"""
 
-
-import mindspore as ms
 from mindspore.ops import operations as ops
 from mindspore.ops import functional as F
 from mindspore.ops import composite as C
 import mindspore.nn as nn
-
+from mindspore import dtype as mstype, context
 from src.network.util import ReconstructionLoss
 
 class Rounding(nn.Cell):
@@ -52,6 +50,13 @@ class IRN_loss(nn.Cell):
 
     def __init__(self, net_G, opt):
         super(IRN_loss, self).__init__()
+
+        if context.get_context("device_target") == "Ascend":
+            self.cast_type = mstype.float16
+        else:
+            self.cast_type = mstype.float32
+
+        self.cast = ops.Cast()
         self.netG = net_G
 
         train_opt = opt['train']
@@ -71,25 +76,23 @@ class IRN_loss(nn.Cell):
         self.cat = ops.Concat(1)
         self.reshape = ops.Reshape()
         self.cast = ops.Cast()
-
         self.Quantization = Quantization()
 
     def loss_forward(self, out, y, z):
         l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y)
 
-        z = self.reshape(z, (out.shape[0], -1))
+        z = self.cast(self.reshape(z, (out.shape[0], -1)), self.cast_type)
         l_forw_ce = self.train_opt['lambda_ce_forw'] * self.ms_sum(z**2) / z.shape[0]
 
         return l_forw_fit, l_forw_ce
 
     def gaussian_batch(self, dims):
-        return self.cast(self.stdnormal(dims), ms.float16)
+        return self.cast(self.stdnormal(dims), self.cast_type)
 
     def loss_backward(self, x, y):
         x_samples = self.netG(x=y, rev=True)
         x_samples_image = x_samples[:, :3, :, :]
         l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(x, x_samples_image)
-
         return l_back_rec
 
     def backward(self, x, y):
@@ -125,17 +128,17 @@ class IRN_loss(nn.Cell):
         ##     l_forw_ce   --distribution matching loss
         l_forw_fit, l_forw_ce = self.loss_forward(output[:, :3, :, :], LR_ref, output[:, 3:, :, :])
 
-        ### backward upscaling
+        ## backward upscaling
         zshape = output[:, 3:, :, :].shape
         LR = self.Quantization(output[:, :3, :, :])
 
         gaussian_scale = 1
         T = gaussian_scale * self.gaussian_batch(zshape)
+
         y_ = self.cat((LR, T))
 
         l_back_rec = self.loss_backward(real_H, y_)
 
         ## total loss
         loss = l_forw_fit + l_forw_ce + l_back_rec
-
         return loss
diff --git a/research/cv/IRN/src/network/util.py b/research/cv/IRN/src/network/util.py
index f073549ee19fa7a10949305d9e2a1bf248366333..f06e65e647c505c2d139a4e6f974a2f29e260186 100644
--- a/research/cv/IRN/src/network/util.py
+++ b/research/cv/IRN/src/network/util.py
@@ -17,34 +17,40 @@
 import mindspore
 import mindspore.nn as nn
 import mindspore.common.initializer as init
-from mindspore.ops import operations as ops
-from mindspore.common import dtype as mstype
-
+import mindspore.ops as ops
+from mindspore import Tensor
+from mindspore import dtype as mstype, context
 
 class ReconstructionLoss(nn.Cell):
     """L1 and L2 Loss """
-
     def __init__(self, losstype='l2', eps=1e-6):
         super(ReconstructionLoss, self).__init__()
         self.losstype = losstype
-        self.eps = eps
+
+        if context.get_context("device_target") == "Ascend":
+            self.cast_type = mstype.float16
+        else:
+            self.cast_type = mstype.float32
 
         self.mean = ops.ReduceMean()
-        self.sum = ops.ReduceSum()
+        self.sum = ops.ReduceSum(keep_dims=False)
+        self.abs = ops.Abs()
+        self.cast = ops.Cast()
         self.sqrt = ops.Sqrt()
+        self.eps = Tensor(eps, self.cast_type)
 
     def construct(self, x, target):
         '''construct method for loss'''
+        x = self.cast(x, self.cast_type)
+        target = self.cast(target, self.cast_type)
         if self.losstype == 'l2':
             return self.mean(self.sum((x - target)**2, (1, 2, 3)))
         if self.losstype == 'l1':
             diff = x - target
             return self.mean(self.sum(self.sqrt(diff * diff + self.eps), (1, 2, 3)))
-
         print("reconstruction loss type error!")
         return 0
 
-
 def initialize_weights(net_l, scale=1):
     """weights initialization"""
     if not isinstance(net_l, list):
@@ -163,7 +169,10 @@ class GroupConv(nn.Cell):
         self.op_split_w = ops.Split(axis=0, output_num=self.groups)
         self.op_concat = ops.Concat(axis=1)
         self.cast = ops.Cast()
-
+        if context.get_context("device_target") == "Ascend":
+            self.cast_type = mstype.float16
+        else:
+            self.cast_type = mstype.float32
         for _ in range(groups):
             self.convs.append(mindspore.ops.Conv2D(out_channels//groups,
                                                    kernel_size=kernel_size, stride=stride,
@@ -176,11 +185,12 @@ class GroupConv(nn.Cell):
         for i in range(self.groups):
             outputs = outputs + \
                 (self.convs[i](
-                    self.cast(features[i], mstype.float16), weights[i]),)
+                    self.cast(features[i], self.cast_type), weights[i]),)
         out = self.op_concat(outputs)
         return out
 
 
+
 class GroupTransConv(nn.Cell):
     """
     group transposed convolution operation.
@@ -208,7 +218,10 @@ class GroupTransConv(nn.Cell):
 
         self.op_concat = ops.Concat(axis=1)
         self.cast = ops.Cast()
-
+        if context.get_context("device_target") == "Ascend":
+            self.cast_type = mstype.float16
+        else:
+            self.cast_type = mstype.float32
         weights = self.op_split_w(weight_init)
 
         for i in range(groups):
@@ -222,6 +235,6 @@ class GroupTransConv(nn.Cell):
         outputs = ()
         for i in range(self.groups):
             outputs = outputs + \
-                (self.convsTrans[i](self.cast(features[i], mstype.float16)),)
+                (self.convsTrans[i](self.cast(features[i], self.cast_type)),)
         out = self.op_concat(outputs)
         return out
diff --git a/research/cv/IRN/train.py b/research/cv/IRN/train.py
index f89d1970ecc0b2bd66481d1e8deb0c2e102b1418..6b67e19d5e8ddacd3a380c408f6c3df882d225da 100644
--- a/research/cv/IRN/train.py
+++ b/research/cv/IRN/train.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,23 +19,20 @@ Model training entrypoint.
 import os
 import ast
 import argparse
-
 import src.options.options as option
 import src.utils.util as util
 from src.data import create_dataset
 from src.optim import warmup_step_lr, warmup_cosine_annealing_lr
 from src.optim.adam_clip import AdamClipped
 from src.network import create_model, IRN_loss
-
 from mindspore import context, Tensor
-
 from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
-from mindspore.communication.management import init, get_group_size
+from mindspore.communication.management import init, get_rank
 from mindspore.train.model import ParallelMode
-from mindspore import Model, load_checkpoint, load_param_into_net
+from mindspore.train.model import Model
+from mindspore import load_checkpoint, load_param_into_net
 from mindspore.common import set_seed
-set_seed(0)
-
+set_seed(10)
 
 current_path = os.path.abspath(__file__)
 root_path = os.path.dirname(current_path)
@@ -44,12 +41,11 @@ X2_TRAIN_YAML_FILE = os.path.join(
 X4_TRAIN_YAML_FILE = os.path.join(
     root_path, "src", "options", "train", "train_IRN_x4.yml")
 
-
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description="irn training")
     parser.add_argument('--scale', type=int, default=4, choices=(2, 4),
                         help='Rescaling Parameter.')
-    parser.add_argument('--dataset_GT_path', type=str, default='/home/nonroot/DIV2K/DIV2K_train_HR',
+    parser.add_argument('--dataset_GT_path', type=str, default='/home/nonroot/IRN/data/DIV2K_train_HR',
                         help='Path to the folder where the intended GT dataset is stored.')
     parser.add_argument('--dataset_LQ_path', type=str, default=None,
                         help='Path to the folder where the intended LQ dataset is stored.')
@@ -88,18 +84,21 @@ if __name__ == '__main__':
             context.set_auto_parallel_context(device_num=args.device_num,
                                               parallel_mode=ParallelMode.DATA_PARALLEL,
                                               gradients_mean=True)
+            init()
         elif args.device_target == "GPU":
-            context.set_context(device_num=get_group_size(),
-                                parallel_mode=ParallelMode.DATA_PARALLEL,
-                                gradients_mean=True)
+            init("nccl")
+            context.reset_auto_parallel_context()
+            context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True,
+                                              device_num=args.device_num)
         else:
             raise ValueError("Unsupported device target.")
-        init()
+        rank_id = get_rank()
     else:
         if args.device_target == "Ascend":
             context.set_context(device_id=int(os.getenv('DEVICE_ID')))
         opt['dist'] = False
-        rank = -1
+        rank_id = 0
         print('Disabled distributed training.')
 
     context.set_context(max_call_depth=4030)
@@ -115,6 +114,7 @@ if __name__ == '__main__':
     train_dataset = create_dataset(
         dataset_opt["dataroot_GT"],
         dataset_opt["scale"],
+        target=args.device_target,
         batch_size=dataset_opt["batch_size"],
         distribute=args.run_distribute,
     )
@@ -124,10 +124,10 @@ if __name__ == '__main__':
     # learning rate
     wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
     if train_opt['lr_scheme'] == 'MultiStepLR':
-        lr = warmup_step_lr(train_opt['lr_G']*2,
+        lr = warmup_step_lr(train_opt['lr_G'],
                             train_opt['lr_steps'],
                             step_size,
-                            200,
+                            0,
                             total_epochs,
                             train_opt['lr_gamma'],
                             )
@@ -160,19 +160,24 @@ if __name__ == '__main__':
         beta1=train_opt['beta1'], beta2=train_opt['beta2'], weight_decay=wd_G)
 
     # Model
-    model = Model(network=loss, optimizer=optimizer, amp_level="O3")
+    if args.device_target == "Ascend":
+        model = Model(network=loss, optimizer=optimizer, amp_level="O3")
+    elif args.device_target == "GPU":
+        model = Model(network=loss, optimizer=optimizer, amp_level="O0")
+    else:
+        raise ValueError("Unsupported device target.")
 
     # define callbacks
     ckpt_save_steps = step_size*100
     callbacks = [LossMonitor(), TimeMonitor(data_size=ckpt_save_steps)]
     config_ck = CheckpointConfig(
-        save_checkpoint_steps=ckpt_save_steps, keep_checkpoint_max=50)
+        save_checkpoint_steps=ckpt_save_steps, keep_checkpoint_max=10)
     save_ckpt_path = os.path.join(
-        'ckpt/', 'ckpt_one_step_x4/', util.get_timestamp() + '/')
+        'ckpt/', 'ckpt_one_step_x{}/'.format(args.scale), util.get_timestamp() + '/')
     ckpt_cb = ModelCheckpoint(
         prefix="irn_onestep", directory=save_ckpt_path, config=config_ck)
-    callbacks.append(ckpt_cb)
 
     # training
-    model.train(total_epochs, train_dataset, callbacks=callbacks,
-                dataset_sink_mode=True)
+    if rank_id == 0:
+        callbacks.append(ckpt_cb)
+    model.train(total_epochs, train_dataset, callbacks=callbacks, dataset_sink_mode=True)