diff --git a/research/cv/wdsr/README_CN.md b/research/cv/wdsr/README_CN.md
index 8c9cb7faa51e8fb3d0059b2133eded4031ca74b0..60ee554f8f36167d95dc207f662687f8758ab56b 100644
--- a/research/cv/wdsr/README_CN.md
+++ b/research/cv/wdsr/README_CN.md
@@ -111,7 +111,7 @@ WDSR缃戠粶涓昏鐢卞嚑涓熀鏈ā鍧楋紙鍖呮嫭鍗风Н灞傚拰姹犲寲灞傦級缁勬垚銆�
 
 # 鐜瑕佹眰
 
-- 纭欢锛圓scend锛�
+- 纭欢锛圓scend/GPU锛�
     - 浣跨敤ascend澶勭悊鍣ㄦ潵鎼缓纭欢鐜銆�
 - 妗嗘灦
     - [MindSpore](https://www.mindspore.cn/install/en)
@@ -124,18 +124,24 @@ WDSR缃戠粶涓昏鐢卞嚑涓熀鏈ā鍧楋紙鍖呮嫭鍗风Н灞傚拰姹犲寲灞傦級缁勬垚銆�
 閫氳繃瀹樻柟缃戠珯瀹夎MindSpore鍚庯紝鎮ㄥ彲浠ユ寜鐓у涓嬫楠よ繘琛岃缁冨拰璇勪及锛�
 
 ```shell
-#鍗曞崱璁粌
+# 鍗曞崱璁粌
+# Ascend
 sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
+# GPU
+bash run_gpu_standalone.sh [TRAIN_DATA_DIR]
 ```
 
 ```shell
-#鍒嗗竷寮忚缁�
+# 鍒嗗竷寮忚缁�
+# Ascend
 sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
+# GPU
+bash run_gpu_distribute.sh [TRAIN_DATA_DIR] [DEVICE_NUM]
 ```
 
 ```python
 #璇勪及
-sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
+bash run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
 ```
 
 # 鑴氭湰璇存槑
@@ -145,11 +151,11 @@ sh run_eval.sh [TEST_DATA_DIR] [CHECKPOINT_PATH] [DATASET_TYPE]
 ```bash
 WDSR
    鈹溾攢鈹€ README_CN.md                           //鑷堪鏂囦欢
-   鈹溾攢鈹€ eval.py                                //璇勪及鑴氭湰
-   鈹溾攢鈹€ export.py
    鈹溾攢鈹€ script
    鈹偮犅�      鈹溾攢鈹€ run_ascend_distribute.sh      //Ascend鍒嗗竷寮忚缁僺hell鑴氭湰
    鈹偮犅�      鈹溾攢鈹€ run_ascend_standalone.sh      //Ascend鍗曞崱璁粌shell鑴氭湰
+   鈹偮犅�      鈹溾攢鈹€ run_gpu_distribute.sh         //GPU鍒嗗竷寮忚缁僺hell鑴氭湰
+   鈹偮犅�      鈹溾攢鈹€ run_gpu_standalone.sh         //GPU鍗曞崱璁粌shell鑴氭湰
    鈹偮犅�      鈹斺攢鈹€ run_eval.sh                   //eval楠岃瘉shell鑴氭湰
    鈹溾攢鈹€ src
    鈹偮犅�   鈹溾攢鈹€ args.py                          //瓒呭弬鏁�
@@ -160,8 +166,10 @@ WDSR
    鈹偮犅�   鈹偮犅�    鈹斺攢鈹€ srdata.py                 //鎵€鏈夋暟鎹泦
    鈹偮犅�   鈹溾攢鈹€ metrics.py                       //PSNR鍜孲SIM璁$畻鍣�
    鈹偮犅�   鈹溾攢鈹€ model.py                         //WDSR缃戠粶
-   鈹偮犅�   鈹斺攢鈹€ utils.py                         //璁粌鑴氭湰
-   鈹斺攢鈹€ train.py                               //璁粌鑴氭湰
+   鈹偮犅�   鈹斺攢鈹€ utils.py                         //杈呭姪鍑芥暟
+   鈹溾攢鈹€ train.py                               //璁粌鑴氭湰
+   鈹溾攢鈹€ eval.py                                //璇勪及鑴氭湰
+   鈹斺攢鈹€ export.py
 ```
 
 ## 鑴氭湰鍙傛暟
@@ -169,43 +177,30 @@ WDSR
 涓昏鍙傛暟濡備笅:
 
 ```python
-  -h, --help            show this help message and exit
-  --dir_data DIR_DATA   dataset directory
-  --data_train DATA_TRAIN
-                        train dataset name
-  --data_test DATA_TEST
-                        test dataset name
-  --data_range DATA_RANGE
-                        train/test data range
-  --ext EXT             dataset file extension
-  --scale SCALE         super resolution scale
-  --patch_size PATCH_SIZE
-                        output patch size
-  --rgb_range RGB_RANGE
-                        maximum value of RGB
-  --n_colors N_COLORS   number of color channels to use
-  --no_augment          do not use data augmentation
-  --model MODEL         model name
-  --n_resblocks N_RESBLOCKS
-                        number of residual blocks
-  --n_feats N_FEATS     number of feature maps
-  --res_scale RES_SCALE
-                        residual scaling
-  --test_every TEST_EVERY
-                        do test per every N batches
-  --epochs EPOCHS       number of epochs to train
-  --batch_size BATCH_SIZE
-                        input batch size for training
-  --test_only           set this option to test the model
-  --lr LR               learning rate
-  --ckpt_save_path CKPT_SAVE_PATH
-                        path to save ckpt
-  --ckpt_save_interval CKPT_SAVE_INTERVAL
-                        save ckpt frequency, unit is epoch
-  --ckpt_save_max CKPT_SAVE_MAX
-                        max number of saved ckpt
-  --ckpt_path CKPT_PATH
-                        path of saved ckpt
+  -h, --help                  show this help message and exit
+  --dir_data DIR_DATA         dataset directory
+  --data_train DATA_TRAIN     train dataset name
+  --data_test DATA_TEST       test dataset name
+  --data_range DATA_RANGE     train/test data range
+  --ext EXT                   dataset file extension
+  --scale SCALE               super-resolution scale
+  --patch_size PATCH_SIZE     output patch size
+  --rgb_range RGB_RANGE       maximum value of RGB
+  --n_colors N_COLORS         number of color channels to use
+  --no_augment                do not use data augmentation
+  --model MODEL               model name
+  --n_resblocks N_RESBLOCKS   number of residual blocks
+  --n_feats N_FEATS           number of feature maps
+  --res_scale RES_SCALE       residual scaling
+  --test_every TEST_EVERY     do test per every N batches
+  --epochs EPOCHS             number of epochs to train
+  --batch_size BATCH_SIZE     input batch size for training
+  --test_only                 set this option to test the model
+  --lr LR                     learning rate
+  --ckpt_path CKPT_PATH       path of saved ckpt
+  --ckpt_save_path CKPT_SAVE_PATH              path to save ckpt
+  --ckpt_save_interval CKPT_SAVE_INTERVAL      save ckpt frequency, unit is epoch
+  --ckpt_save_max CKPT_SAVE_MAX                max number of saved ckpt
   --task_id TASK_ID
 
 ```
@@ -220,6 +215,12 @@ WDSR
   sh run_ascend_standalone.sh [TRAIN_DATA_DIR]
   ```
 
+- GPU鐜杩愯
+
+  ```bash
+  sh run_gpu_standalone.sh [TRAIN_DATA_DIR]
+  ```
+
   涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃train.log鏂囦欢鏌ョ湅缁撴灉銆�
 
 ### 鍒嗗竷寮忚缁�
@@ -230,6 +231,12 @@ WDSR
   sh run_ascend_distribute.sh [RANK_TABLE_FILE] [TRAIN_DATA_DIR]
   ```
 
+- GPU鐜杩愯
+
+  ```bash
+  sh run_gpu_distribute.sh [TRAIN_DATA_DIR] [DEVICE_NUM]
+  ```
+
 TRAIN_DATA_DIR = "~DATA/"銆�
 
 ## 璇勪及杩囩▼
@@ -260,33 +267,33 @@ FILE_FORMAT 鍙€� ['MINDIR', 'AIR', 'ONNX'], 榛樿['MINDIR']銆�
 
 ### 璁粌鎬ц兘
 
-| 鍙傛暟          | Ascend                                                       |
-| ------------- | ------------------------------------------------------------ |
-| 璧勬簮          | Ascend 910                                                   |
-| 涓婁紶鏃ユ湡      | 2021-7-4                                                     |
-| MindSpore鐗堟湰 | 1.2.0                                                        |
-| 鏁版嵁闆�        | DIV2K                                                        |
-| 璁粌鍙傛暟      | epoch=1000, steps=100, batch_size =16, lr=0.0001            |
-| 浼樺寲鍣�        | Adam                                                         |
-| 鎹熷け鍑芥暟      | L1                                                           |
-| 杈撳嚭          | 瓒呭垎杈ㄧ巼鍥剧墖                                                 |
-| 鎹熷け          | 3.5                                                          |
-| 閫熷害          | 8鍗★細绾�130姣/姝�                                            |
-| 鎬绘椂闀�        | 8鍗★細0.5灏忔椂                                                   |
-| 寰皟妫€鏌ョ偣    | 35 MB(.ckpt鏂囦欢)                                        |
-| 鑴氭湰          | [WDSR](https://gitee.com/mindspore/models/tree/master/research/cv/wdsr) |
+| 鍙傛暟          | Ascend                                                       | GPU|
+| ------------- | ------------------------------------------------------------ |----|
+| 璧勬簮          | Ascend 910                                                   |NVIDIA GeForce RTX 3090|
+| 涓婁紶鏃ユ湡      | 2021-7-4                                                     |2021-11-22|
+| MindSpore鐗堟湰 | 1.2.0                                                        |1.5.0|
+| 鏁版嵁闆�        | DIV2K                                                        |DIV2K|
+| 璁粌鍙傛暟      | epoch=1000, steps=100, batch_size =16, lr=0.0001            |epoch=300, batch_size=16, lr=0.0005|
+| 浼樺寲鍣�        | Adam                                                         |Adam|
+| 鎹熷け鍑芥暟      | L1                                                           |L1|
+| 杈撳嚭          | 瓒呭垎杈ㄧ巼鍥剧墖                                                 |瓒呭垎杈ㄧ巼鍥剧墖|
+| 鎹熷け          | 3.5                                                          |3.3|
+| 閫熷害          | 8鍗★細绾�130姣/姝�                                            |8鍗★細绾�140姣/姝
+| 鎬绘椂闀�        | 8鍗★細0.5灏忔椂                                                   |8鍗★細1.5灏忔椂|
+| 寰皟妫€鏌ョ偣    | 35 MB(.ckpt鏂囦欢)                                        |14 MB(.ckpt鏂囦欢)|
+| 鑴氭湰          | [WDSR](https://gitee.com/mindspore/models/tree/master/research/cv/wdsr) |[WDSR](https://gitee.com/mindspore/models/tree/master/research/cv/wdsr)|
 
 ### 璇勪及鎬ц兘
 
-| 鍙傛暟          | Ascend                                                      |
-| ------------- | ----------------------------------------------------------- |
-| 璧勬簮          | Ascend 910                                                  |
-| 涓婁紶鏃ユ湡      | 2021-7-4                                                    |
-| MindSpore鐗堟湰 | 1.2.0                                                       |
-| 鏁版嵁闆�        | DIV2K                                                       |
-| batch_size    | 1                                                           |
-| 杈撳嚭          | 瓒呭垎杈ㄧ巼鍥剧墖                                                |
-| PSNR          | DIV2K 34.7780                                               |
+| 鍙傛暟          | Ascend                                                      |GPU                    |
+| ------------- | ----------------------------------------------------------- |----------------------|
+| 璧勬簮          | Ascend 910                                                  |NVIDIA GeForce RTX 3090|
+| 涓婁紶鏃ユ湡      | 2021-7-4                                                    |2021-11-22              |
+| MindSpore鐗堟湰 | 1.2.0                                                       |1.5.0                  |
+| 鏁版嵁闆�        | DIV2K                                                       |DIV2K                   |
+| batch_size    | 1                                                           |1                      |
+| 杈撳嚭          | 瓒呭垎杈ㄧ巼鍥剧墖                                                  |瓒呭垎杈ㄧ巼鍥剧墖              |
+| PSNR          | DIV2K 34.7780                                               |DIV2K 35.9735          |
 
 # 闅忔満鎯呭喌璇存槑
 
diff --git a/research/cv/wdsr/eval.py b/research/cv/wdsr/eval.py
index 97a785fc6c0e98a12e5c545f8a433834cbdd34a5..2f2c0f428db3719b93b6858475158f3b1fba95b6 100644
--- a/research/cv/wdsr/eval.py
+++ b/research/cv/wdsr/eval.py
@@ -20,13 +20,21 @@ from mindspore import Tensor, context
 from mindspore.common import dtype as mstype
 from mindspore.train.serialization import load_checkpoint, load_param_into_net
 from src.args import args
-import src.model as wdsr
+from src.model import WDSR
 from src.data.srdata import SRData
 from src.data.div2k import DIV2K
 from src.metrics import calc_psnr, quantize, calc_ssim
-device_id = int(os.getenv('DEVICE_ID', '0'))
-context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
-context.set_context(max_call_depth=10000)
+
+if args.device_target == 'GPU':
+    context.set_context(mode=context.GRAPH_MODE,
+                        device_target=args.device_target,
+                        save_graphs=False)
+    context.set_context(max_call_depth=10000)
+elif args.device_target == 'Ascend':
+    device_id = int(os.getenv('DEVICE_ID', '0'))
+    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
+    context.set_context(max_call_depth=10000)
+
 def eval_net():
     """eval"""
     if args.epochs == 0:
@@ -43,7 +51,7 @@ def eval_net():
     train_de_dataset = ds.GeneratorDataset(train_dataset, ['LR', 'HR'], shuffle=False)
     train_de_dataset = train_de_dataset.batch(1, drop_remainder=True)
     train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
-    net_m = wdsr.WDSR()
+    net_m = WDSR(scale=args.scale[args.task_id], n_resblocks=args.n_resblocks, n_feats=args.n_feats)
     if args.ckpt_path:
         param_dict = load_checkpoint(args.ckpt_path)
         load_param_into_net(net_m, param_dict)
diff --git a/research/cv/wdsr/script/run_eval.sh b/research/cv/wdsr/script/run_eval.sh
index f8828810f9e9398794e3cff33fa3fd1150b0695a..2942fec300e7b9a85c983132320cce3a9eb8eb98 100644
--- a/research/cv/wdsr/script/run_eval.sh
+++ b/research/cv/wdsr/script/run_eval.sh
@@ -57,5 +57,6 @@ python eval.py \
     --ext "img" \
     --data_test=${DATASET_TYPE} \
     --ckpt_path=${PATH2} \
+    --data_range "801-900" \
     --task_id 0 \
     --scale 2 > eval.log 2>&1 &
diff --git a/research/cv/wdsr/script/run_gpu_distribute.sh b/research/cv/wdsr/script/run_gpu_distribute.sh
new file mode 100644
index 0000000000000000000000000000000000000000..836d431e1c5fa647300453f0d3c2415b61d768aa
--- /dev/null
+++ b/research/cv/wdsr/script/run_gpu_distribute.sh
@@ -0,0 +1,59 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 2 ]; then
+  echo "Usage: sh run_gpu_distribute.sh  [TRAIN_DATA_DIR] [DEVICE_NUM]"
+  exit 1
+fi
+
+get_real_path() {
+  if [ "${1:0:1}" == "/" ]; then
+    echo "$1"
+  else
+    echo "$(realpath -m $PWD/$1)"
+  fi
+}
+
+PATH1=$(get_real_path $1)
+DEVICE_NUM=$2
+
+if [ ! -d $PATH1 ]; then
+  echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
+  exit 1
+fi
+
+if [ -d "train_parallel" ]; then
+    rm -rf ./train_parallel
+fi
+mkdir ./train_parallel
+cp ../*.py ./train_parallel
+cp -r ../src ./train_parallel
+cd ./train_parallel || exit
+
+env >env.log
+
+nohup mpirun --allow-run-as-root -n $DEVICE_NUM \
+python train.py \
+      --run_distribute 1 \
+      --device_num $DEVICE_NUM \
+      --batch_size 16 \
+      --lr 5e-4 \
+      --scale 2 \
+      --task_id 0 \
+      --dir_data $PATH1 \
+      --epochs 300 \
+      --test_every 1000 \
+      --patch_size 48 > train.log 2>&1 &
diff --git a/research/cv/wdsr/script/run_gpu_standalone.sh b/research/cv/wdsr/script/run_gpu_standalone.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ac412fa1156f387b674195d254dd317b32a1f096
--- /dev/null
+++ b/research/cv/wdsr/script/run_gpu_standalone.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+# Copyright 2020-2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 1 ]; then
+  echo "Usage: sh run_gpu_standalone.sh [TRAIN_DATA_DIR]"
+  exit 1
+fi
+
+get_real_path() {
+  if [ "${1:0:1}" == "/" ]; then
+    echo "$1"
+  else
+    echo "$(realpath -m $PWD/$1)"
+  fi
+}
+
+PATH1=$(get_real_path $1)
+
+if [ ! -d $PATH1 ]; then
+  echo "error: TRAIN_DATA_DIR=$PATH1 is not a directory"
+  exit 1
+fi
+
+
+if [ -d "train" ]; then
+    rm -rf ./train
+fi
+mkdir ./train
+cp ../*.py ./train
+cp -r ../src ./train
+cd ./train || exit
+
+env >env.log
+
+nohup python train.py \
+      --batch_size 16 \
+      --lr 1e-4 \
+      --scale 2 \
+      --task_id 0 \
+      --dir_data $PATH1 \
+      --epochs 300 \
+      --test_every 1000 \
+      --patch_size 48 > train.log 2>&1 &
diff --git a/research/cv/wdsr/src/args.py b/research/cv/wdsr/src/args.py
index 70b3bd43d08cfce76a665e8d306f7fadce39d4ad..358bd3ec74c63a90b138ab64e26c2c78f9c2f243 100644
--- a/research/cv/wdsr/src/args.py
+++ b/research/cv/wdsr/src/args.py
@@ -22,7 +22,7 @@ parser.add_argument('--data_train', type=str, default='DIV2K',
                     help='train dataset name')
 parser.add_argument('--data_test', type=str, default='DIV2K',
                     help='test dataset name')
-parser.add_argument('--data_range', type=str, default='1-800/801-900',
+parser.add_argument('--data_range', type=str, default='1-800/801-810',
                     help='train/test data range')
 parser.add_argument('--ext', type=str, default='sep',
                     help='dataset file extension')
@@ -61,7 +61,7 @@ parser.add_argument('--init_loss_scale', type=float, default=65536.,
                     help='scaling factor')
 parser.add_argument('--loss_scale', type=float, default=1024.0,
                     help='loss_scale')
-parser.add_argument('--decay', type=str, default='200',
+parser.add_argument('--decay', type=int, default=200,
                     help='learning rate decay type')
 parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
                     help='ADAM beta')
@@ -80,8 +80,21 @@ parser.add_argument('--ckpt_save_max', type=int, default=5,
                     help='max number of saved ckpt')
 parser.add_argument('--ckpt_path', type=str, default='',
                     help='path of saved ckpt')
+# sr result specifications
+parser.add_argument('--save_dir', type=str, default='result',
+                    help='file name to save')
+parser.add_argument('--save_result', action='store_true',
+                    help='save output results')
 # alltask
 parser.add_argument('--task_id', type=int, default=0)
+parser.add_argument('--pre_trained', type=str, default='', help='model_path, local pretrained model to load')
+parser.add_argument('--device_target', type=str, default='GPU', choices=("GPU"),
+                    help="Device target, support GPU.")
+parser.add_argument("--run_distribute", type=int, default=False,
+                    help="Run distribute, default: false.")
+parser.add_argument('--device_num', type=int, default=1, help='Device num.')
+parser.add_argument("--device_id", type=int, default=0, help="device id, default: 0.")
+parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.")
 # rgb_mean
 parser.add_argument('--r_mean', type=float, default=0.4488,
                     help='r_mean')
@@ -89,6 +102,7 @@ parser.add_argument('--g_mean', type=float, default=0.4371,
                     help='g_mean')
 parser.add_argument('--b_mean', type=float, default=0.4040,
                     help='b_mean')
+
 args, unparsed = parser.parse_known_args()
 args.scale = [int(x) for x in args.scale.split("+")]
 args.data_train = args.data_train.split('+')
diff --git a/research/cv/wdsr/src/model.py b/research/cv/wdsr/src/model.py
index 50922dcaa50fa3a6a1d42a5e026378e55303a04e..b10a950aafd178b4910dabc1fbb1ff28bcc79b07 100644
--- a/research/cv/wdsr/src/model.py
+++ b/research/cv/wdsr/src/model.py
@@ -65,11 +65,8 @@ class PixelShuffle(nn.Cell):
 
 class WDSR(nn.Cell):
     """main structure of wdsr"""
-    def __init__(self):
+    def __init__(self, scale=2, n_resblocks=8, n_feats=64):
         super(WDSR, self).__init__()
-        scale = 2
-        n_resblocks = 8
-        n_feats = 64
         self.sub_mean = MeanShift(255)
         self.add_mean = MeanShift(255, sign=1)
         # define head module
diff --git a/research/cv/wdsr/train.py b/research/cv/wdsr/train.py
index a858bb6d6b0bc85ca19f2f8bb5569c8c16065452..345043579c21cc0adae57a8137457beb13ec4800 100644
--- a/research/cv/wdsr/train.py
+++ b/research/cv/wdsr/train.py
@@ -19,7 +19,7 @@ from mindspore import dataset as ds
 import mindspore.nn as nn
 from mindspore.context import ParallelMode
 from mindspore.train.serialization import load_checkpoint, load_param_into_net
-from mindspore.communication.management import init
+from mindspore.communication.management import init, get_rank
 from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
 from mindspore.common import set_seed
 from mindspore.train.model import Model
@@ -31,21 +31,47 @@ from src.model import WDSR
 def train_net():
     """train wdsr"""
     set_seed(1)
-    device_id = int(os.getenv('DEVICE_ID', '0'))
-    rank_id = int(os.getenv('RANK_ID', '0'))
-    device_num = int(os.getenv('RANK_SIZE', '1'))
+    if args.device_target == 'GPU':
+        context.set_context(mode=context.GRAPH_MODE,
+                            device_target=args.device_target,
+                            device_id=args.device_id,
+                            save_graphs=False)
+    elif args.device_target == 'Ascend':
+        device_id = int(os.getenv('DEVICE_ID', '0'))
+        rank_id = int(os.getenv('RANK_ID', '0'))
+        device_num = int(os.getenv('RANK_SIZE', '1'))
+        context.set_context(mode=context.GRAPH_MODE,
+                            device_target="Ascend",
+                            save_graphs=False,
+                            device_id=device_id)
+    rank = 0
     # if distribute:
-    if device_num > 1:
-        init()
-        context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
-                                          device_num=device_num, gradients_mean=True)
-    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
+    if args.run_distribute:
+        print("distribute")
+        if args.device_target == 'GPU':
+            init("nccl")
+            device_num = args.device_num
+            context.reset_auto_parallel_context()
+            rank = get_rank()
+            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True)
+        elif args.device_target == 'Ascend':
+            init()
+            context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              device_num=device_num, gradients_mean=True)
+
     train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
     train_dataset.set_scale(args.task_id)
-    train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=device_num,
-                                           shard_id=rank_id, shuffle=True)
+    if args.device_target == 'GPU':
+        train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"],
+                                               num_shards=args.device_num,
+                                               shard_id=rank, shuffle=True)
+    elif args.device_target == 'Ascend':
+        train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"],
+                                               num_shards=device_num,
+                                               shard_id=rank_id, shuffle=True)
     train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
-    net_m = WDSR()
+    net_m = WDSR(scale=args.scale[args.task_id], n_resblocks=args.n_resblocks, n_feats=args.n_feats)
     print("Init net successfully")
     if args.ckpt_path:
         param_dict = load_checkpoint(args.ckpt_path)
@@ -54,7 +80,7 @@ def train_net():
     step_size = train_de_dataset.get_dataset_size()
     lr = []
     for i in range(0, args.epochs):
-        cur_lr = args.lr / (2 ** ((i + 1)//200))
+        cur_lr = args.lr / (2 ** ((i + 1)//args.decay))
         lr.extend([cur_lr] * step_size)
     opt = nn.Adam(net_m.trainable_params(), learning_rate=lr, loss_scale=args.loss_scale)
     loss = nn.L1Loss()
@@ -67,7 +93,7 @@ def train_net():
     config_ck = CheckpointConfig(save_checkpoint_steps=args.ckpt_save_interval * step_size,
                                  keep_checkpoint_max=args.ckpt_save_max)
     ckpt_cb = ModelCheckpoint(prefix="wdsr", directory=args.ckpt_save_path, config=config_ck)
-    if device_id == 0:
+    if rank == 0:
         cb += [ckpt_cb]
     model.train(args.epochs, train_de_dataset, callbacks=cb, dataset_sink_mode=True)