diff --git a/research/cv/siamRPN/README_CN.md b/research/cv/siamRPN/README_CN.md
index d7937fa68290760d3de87e1c7daa63dbfe4ac91b..bd7dd253341f5feb805ae516ee6dd99a583e6d1f 100644
--- a/research/cv/siamRPN/README_CN.md
+++ b/research/cv/siamRPN/README_CN.md
@@ -2,7 +2,7 @@
 
 - [目录](#目录)
 - [SiamRPN描述](#概述)
-- [模型架构](#s模型架构)
+- [模型架构](#模型架构)
 - [数据集](#数据集)
 - [特性](#特性)
     - [混合精度](#混合精度)
@@ -17,6 +17,7 @@
     - [评估过程](#评估过程)
         - [评估](#评估)
             - [910评估](#910评估)
+            - [GPU评估](#gpu评估)
             - [310评估·](#310评估)
 - [模型描述](#模型描述)
     - [性能](#性能)
@@ -82,6 +83,20 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
 
   ```
 
+- GPU处理器环境运行
+
+  ```python
+  # 运行训练示例
+  bash scripts/run_gpu.sh 0
+
+  # 运行分布式训练示例
+  bash scripts/run_distribute_train_gpu.sh  device_num device_list
+
+  # 运行评估示例
+  bash scripts/run_eval_gpu.sh 0 /path/dataset /path/ckpt/siamRPN-50_1417.ckpt eval.json
+
+  ```
+
 # 脚本说明
 
 ## 脚本及样例代码
@@ -92,23 +107,44 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
     ├── research
         ├── cv
             ├── siamRPN
-                ├── README_CN.md           // googlenet相关说明
+                ├── README_CN.md            // SiamRPN相关说明
                 ├── ascend_310_infer        // 实现310推理源代码
                 ├── scripts
                 │    ├──run.sh              // 训练脚本
-                |    |──run_distribute_train.sh //本地多卡训练脚本
+                |    |──run_distribute_train.sh //本地Ascend多卡训练脚本
                 |    |──run_eval.sh         //910评估脚本
+                |    |──run_eval_gpu.sh     // GPU评估脚本
+                |    |──run_distribute_train_gpu.sh      // 本地GPU多卡训练脚本
                 |    |──run_infer_310.sh    //310推理评估脚本
+                |    |──run_gpu.sh          //GPU单卡训练脚本
                 ├── src
-                │    ├──data_loader.py      // 数据集加载处理脚本
-                │    ├──net.py              //  siamRPN架构
-                │    ├──loss.py             //  损失函数
-                │    ├──util.py             //  工具脚本
-                │    ├──tracker.py
-                │    ├──generate_anchors.py
-                │    ├──tracker.py
-                │    ├──evaluation.py
-                │    ├──config.py          // 参数配置
+                │    ├── data_loader.py      // 数据集加载处理脚本
+                │    ├── net.py              //  siamRPN架构
+                │    ├── loss.py             //  损失函数
+                │    ├── util.py             //  工具脚本
+                │    ├── tracker.py
+                │    ├── generate_anchors.py
+                │    ├── tracker.py
+                │    ├── evaluation.py
+                │    ├── config.py          // 参数配置
+                ├── ytb_vid_filter         //训练集(需要自己下载)
+                │    ├── --0bLFuriZ4
+                │    ├── --4VWx_0Sc4
+                │    ├── ······
+                │    ├── ······
+                │    └── meta_data.pkl
+                ├── vot2015                //测试集(需要自己下载)
+                │    ├── bag
+                │    ├── ball1
+                │    ├── ······
+                │    ├── ······
+                │    └── list.txt
+                ├── vot2016                //测试集(需要自己下载)
+                │    ├── bag
+                │    ├── ball1
+                │    ├── ······
+                │    ├── ······
+                │    └── list.txt
                 ├── train.py               // 训练脚本
                 ├── eval.py                // 评估脚本
                 ├── export_mindir.py       // 将checkpoint文件导出到air/mindir
@@ -144,7 +180,7 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
 - Ascend处理器环境运行
 
   ```bash
-  python train.py --device_id=0 > train.log 2>&1 &
+  python train.py --device_id=0 --device_target="Ascend"> train.log 2>&1 &
   ```
 
   上述python命令将在后台运行,您可以通过train.log文件查看结果。
@@ -160,7 +196,17 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
 
   模型检查点保存在当前目录下。
 
-### 分布式训练
+- GPU处理器环境运行
+
+  在运行train.py文件前,需要手动配置src/config.py文件中的pretrain_model参数、train_path参数和checkpoint_path参数,pretrain_model参数代表预训练权重模型路径,train_path参数代表训练集存放的位置,checkpoint_path参数代表存放生成得到的训练模型的位置。
+
+  ```bash
+  python train.py --device_id=0 --device_target="GPU"> train.log 2>&1 &
+  ```
+
+  上述python命令将在后台运行,您可以通过train.log文件查看结果。
+
+### Ascend分布式训练
 
   对于分布式训练,需要提前创建JSON格式的hccl配置文件。
 
@@ -185,6 +231,17 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
       # (6) 创建训练作业
       ```
 
+#### GPU分布式训练
+
+  ```bash
+  cd  SiamRPN      //进入到SiamRPN文件根目录
+
+  bash scripts/run_distribute_train_gpu.sh DEVICE_NUM DEVICE_ LIST //运行脚本
+
+  # DEVICE_NUM表示显卡数量
+  # DEVICE_LIST: GPU处理器的id,需用户指定,例如“0,1,2,3”
+  ```
+
 ## 评估过程
 
 ### 评估
@@ -194,8 +251,8 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
 - 评估过程如下,需要vot数据集对应video的图片放于对应文件夹的color文件夹下,标签groundtruth.txt放于该目录下。
 
 ```bash
-# 使用数据集
-  python eval.py --device_id=0 --dataset_path=/path/dataset --checkpoint_path=/path/ckpt/siamRPN-50_1417.ckpt --filename=eval.json &> evallog &
+# 使用Ascend
+  python eval.py --device_id=0 --dataset_path=/path/dataset --checkpoint_path=/path/ckpt/siamRPN-xx_xxxx.ckpt --filename=eval.json --device_target="Ascend"&> evallog &
 ```
 
 - 上述python命令在后台运行,可通过`evallog`文件查看评估进程,结束后可通过`eval.json`文件查看评估结果。评估结果如下:
@@ -204,6 +261,21 @@ Siam-RPN提出了一种基于RPN的孪生网络结构。由孪生子网络和RPN
 {... "all_videos": {"accuracy": 0.5809545709441025, "robustness": 0.33422978326730364, "eao": 0.3102655908013835}}
 ```
 
+#### GPU评估
+
+- 评估过程如下,需要vot数据集对应video的图片放于对应文件夹的color文件夹下,标签groundtruth.txt放于该目录下。
+
+```bash
+# 使用gpu
+  python eval.py --device_id=0 --dataset_path=/path/dataset --checkpoint_path=/path/ckpt/siamRPN-xx_xxxx.ckpt --filename=eval.json --device_target="GPU"&> evallog &
+```
+
+- 上述python命令在后台运行,可通过`evallog`文件查看评估进程,结束后可通过`eval.json`文件查看评估结果。评估结果如下:
+
+```bash
+{... "all_videos": {"accuracy": 0.5826686315079969, "robustness": 0.2982987648566767, "eao": 0.3289693903290864}}
+```
+
 #### 310评估
 
 - 评估过程如下,需要vot数据集对应video的图片放于对应文件夹的color文件夹下,标签groundtruth.txt放于该目录下,并到script目录。
@@ -225,35 +297,35 @@ cat acc.log
 
 ### 训练性能
 
-| 参数           | siamRPN(Ascend)                                  |
-| -------------------------- | ---------------------------------------------- |
-| 模型版本                | siamRPN                                          |
-| 资源                   | Ascend 910;CPU:2.60GHz,192核;内存:755 GB    |
-| 上传日期              | 2021-07-22                                           |
-| MindSpore版本        | 1.2.0-alpha                                     |
-| 数据集                |VID-youtube-bb                                     |
-| 训练参数  |epoch=50, steps=1471, batch_size = 32 |
-| 优化器                  | SGD                                                        |
-| 损失函数 | 自定义损失函数 |
-| 输出              | 目标框                                                |
-| 损失             |100~0.05                                          |
-| 速度 | 8卡:120毫秒/步 |
-| 总时长 | 8卡:12.3小时 |
-| 调优检查点 |    247.58MB(.ckpt 文件)               |
-| 脚本                | [siamRPN脚本](https://gitee.com/mindspore/models/tree/master/research/cv/siamRPN) |
+| 参数           | siamRPN(Ascend)                                  | siamRPN(GPU) |
+| -------------------------- | ---------------------------------------------- | --------- |
+| 模型版本                | siamRPN                                          | siamRPN |
+| 资源                   | Ascend 910;CPU:2.60GHz,192核;内存:755 GB    | RTX3090 |
+| 上传日期              | 2021-07-22                                           |   |
+| MindSpore版本        | 1.2.0-alpha                                     |   |
+| 数据集                |VID-youtube-bb                                     | VID-youtube-bb|
+| 训练参数  |epoch=50, steps=1417, batch_size = 32                      | epoch=50, steps=1417, batch_size = 32  |
+| 优化器                  | SGD                                               | SGD  |
+| 损失函数 | 自定义损失函数 | 自定义损失函数 |
+| 输出              | 目标框                                                |目标框  |
+| 损失             |100~0.05                                          | 100~0.05     |
+| 速度 | 8卡:625毫秒/步 | 8卡:296毫秒/步  |
+| 总时长 | 8卡:12.3小时 | 8卡: 5.8小时|
+| 调优检查点 |    247.58MB(.ckpt 文件)               | 247.44MB(.ckpt 文件)|
+| 脚本                | [siamRPN脚本](https://gitee.com/mindspore/models/tree/master/research/cv/siamRPN) | [siamRPN脚本](https://gitee.com/mindspore/models/tree/master/research/cv/siamRPN) |
 
 ### 评估性能
 
-| 参数  | siamRPN(Ascend)                         | siamRPN(Ascend)                         |
-| ------------------- | --------------------------- | --------------------------- |
-| 模型版本      | simaRPN                       | simaRPN                       |
-| 资源        | Ascend 910                  | Ascend 910                  |
-| 上传日期              | 2021-07-22                    | 2021-07-22                    |
-| MindSpore版本   | 1.2.0-alpha                 | 1.2.0-alpha                 |
-| 数据集 | vot2015,60个video | vot2016,60个video |
-| batch_size          |   1                        |   1                        |
-| 输出 | 目标框 | 目标框 |
-| 准确率 | 单卡:accuracy:0.58,robustness:0.33,eao:0.31; | 单卡:accuracy:0.56,robustness:0.39,eao:0.28;|
+| 参数  | siamRPN(Ascend)        | siamRPN(Ascend)     | siamRPN(GPU)         | siamRPN(GPU)                   |
+| ------------------- | --------------------------- | --------------------------- |--------------------------- | --------------------------- |
+| 模型版本      | simaRPN               | simaRPN          |simaRPN                       | simaRPN                       |
+| 资源        | Ascend 910           | Ascend 910       |GPU         | GPU                       |
+| 上传日期              | 2021-07-22         | 2021-07-22         |     2021-12-7      |         2021-12-7             |
+| MindSpore版本   | 1.2.0-alpha                 | 1.2.0-alpha       |      1.5.0   |   1.5.0   |
+| 数据集 | vot2015,60个video | vot2016,60个video |vot2015,60个video          | vot2016,60个video            |
+| batch_size          |   1                |   1               |1           | 1                |
+| 输出 | 目标框 | 目标框 |目标框             | 目标框        |
+| 准确率 | 单卡:accuracy:0.58,robustness:0.33,eao:0.31; | 单卡:accuracy:0.56,robustness:0.39,eao:0.28;|单卡:accuracy:0.5826,robustness:0.298,eao:0.329;       | 单卡:accuracy:0.5538,robustness:0.345,eao:0.295;                  |
 
 # 随机情况说明
 
diff --git a/research/cv/siamRPN/eval.py b/research/cv/siamRPN/eval.py
index d5cdafd4667a39b4c7af2a815ca35766d8472470..7ae023f79c16d11be60e72f6a9311661b658ed79 100644
--- a/research/cv/siamRPN/eval.py
+++ b/research/cv/siamRPN/eval.py
@@ -144,7 +144,7 @@ def test(model_path, data_path, save_name):
 def parse_args():
     '''parse_args'''
     parser = argparse.ArgumentParser(description='Mindspore SiameseRPN Infering')
-    parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend'), help='run platform')
+    parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
     parser.add_argument('--device_id', type=int, default=0, help='DEVICE_ID')
     parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
     parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of siamRPN')
@@ -154,10 +154,10 @@ def parse_args():
 
 if __name__ == '__main__':
     args = parse_args()
-    if args.platform == 'Ascend':
+    if args.device_target == 'Ascend':
         device_id = args.device_id
         context.set_context(device_id=device_id)
-    context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
+    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
     model_file_path = args.checkpoint_path
     data_file_path = args.dataset_path
     save_file_name = args.filename
diff --git a/research/cv/siamRPN/scripts/run_distribute_train_gpu.sh b/research/cv/siamRPN/scripts/run_distribute_train_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b5a276b04cc0197c5e634a66f5e2521a4953d736
--- /dev/null
+++ b/research/cv/siamRPN/scripts/run_distribute_train_gpu.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 2 ]
+then 
+    echo "Usage: bash run_distribute_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]"
+exit 1
+fi
+
+
+DEVICE_NUM=$1
+echo $DEVICE_NUM
+
+export DEVICE_NUM=$1
+export RANK_SIZE=$DEVICE_NUM
+export CUDA_VISIBLE_DEVICES="$2"
+
+
+nohup mpirun -n $DEVICE_NUM --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
+python -u train.py  --device_target="GPU" --is_parallel=True > train_gpu.log 2>&1 &
diff --git a/research/cv/siamRPN/scripts/run_eval.sh b/research/cv/siamRPN/scripts/run_eval.sh
index 57b3014c22889a541369c5e232ff1cd94e205640..a8d21763872acbbe1d7959085ab3e71d228d51a5 100644
--- a/research/cv/siamRPN/scripts/run_eval.sh
+++ b/research/cv/siamRPN/scripts/run_eval.sh
@@ -17,5 +17,5 @@ export DEVICE_ID=$1
 export DATA_NAME=$2
 export MODEL_PATH=$3
 export FILENAME=$4
-python  eval.py  --device_id=$DEVICE_ID --dataset_path=$DATA_NAME --checkpoint_path=$MODEL_PATH --filename=$FILENAME &> evallog &
+python  eval.py  --device_id=$DEVICE_ID --dataset_path=$DATA_NAME --checkpoint_path=$MODEL_PATH --filename=$FILENAME &> eval.log &
 
diff --git a/research/cv/siamRPN/scripts/run_eval_gpu.sh b/research/cv/siamRPN/scripts/run_eval_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1864dcce9bcecc37af3fbec74ee1a72f077850b4
--- /dev/null
+++ b/research/cv/siamRPN/scripts/run_eval_gpu.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 4 ]
+then 
+    echo "Usage: bash run_eval_gpu.sh [DEVICE_id] [DATA_NAME] [MODEL_PATH] [FILENAME]"
+exit 1
+fi
+
+export DEVICE_ID=$1
+export DATA_NAME=$2
+export MODEL_PATH=$3
+export FILENAME=$4
+python  eval.py  --device_id=$DEVICE_ID --dataset_path=$DATA_NAME --checkpoint_path=$MODEL_PATH --filename=$FILENAME --device_target="GPU" &> eval_gpu.log &
+
diff --git a/research/cv/siamRPN/scripts/run_gpu.sh b/research/cv/siamRPN/scripts/run_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..785a5b4ebab34b72c901f8e0cbd23bb100ac28df
--- /dev/null
+++ b/research/cv/siamRPN/scripts/run_gpu.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 1 ]
+then 
+    echo "=============================================================================================================="
+    echo "Please run the script as: "
+    echo "bash run.sh DEVICE_ID"
+    echo "For example: bash run_gpu.sh 0"
+    echo "=============================================================================================================="
+exit 1
+fi
+
+
+DEVICE_ID=$1
+
+export DEVICE_ID=$DEVICE_ID
+python3 train.py --device_id=$DEVICE_ID --device_target="GPU"> train.log 2>&1 &
diff --git a/research/cv/siamRPN/src/net.py b/research/cv/siamRPN/src/net.py
index fc3a1a4470d7be81f5663c79a251ab5b9864ec5e..3310101d6234662a11a5931a808744bc3d9c0cf6 100644
--- a/research/cv/siamRPN/src/net.py
+++ b/research/cv/siamRPN/src/net.py
@@ -82,6 +82,9 @@ class SiameseRPN(nn.Cell):
         self.softmax = ops.Softmax(axis=2)
         self.print = ops.Print()
 
+        self.anchor_num = config.anchor_num
+        self.score_size = config.score_size
+
     def construct(self, template=None, detection=None, ckernal=None, rkernal=None):
         """ forward function """
         if self.is_train is True and template is not None and detection is not None:
@@ -172,13 +175,13 @@ class SiameseRPN(nn.Cell):
                 routputs = routputs + (self.conv2d_rout(r_features[i], r_weights[i]),)
             coutputs = self.op_concat(coutputs)
             routputs = self.op_concat(routputs)
-            coutputs = self.reshape(coutputs, (self.groups, 2*config.anchor_num, config.score_size, config.score_size))
-            routputs = self.reshape(routputs, (self.groups, 4*config.anchor_num, config.score_size, config.score_size))
+            coutputs = self.reshape(coutputs, (self.groups, 2*self.anchor_num, self.score_size, self.score_size))
+            routputs = self.reshape(routputs, (self.groups, 4*self.anchor_num, self.score_size, self.score_size))
             routputs = self.regress_adjust(routputs)
             coutputs = self.transpose(
-                self.reshape(coutputs, (-1, 2, config.anchor_num * config.score_size* config.score_size)), (0, 2, 1))
+                self.reshape(coutputs, (-1, 2, self.anchor_num * self.score_size* self.score_size)), (0, 2, 1))
             routputs = self.transpose(
-                self.reshape(routputs, (-1, 4, config.anchor_num * config.score_size* config.score_size)),
+                self.reshape(routputs, (-1, 4, self.anchor_num * self.score_size* self.score_size)),
                 (0, 2, 1))
             out1, out2 = coutputs, routputs
         else:
diff --git a/research/cv/siamRPN/train.py b/research/cv/siamRPN/train.py
index 7267d7fd3610b95594cab332221e4431b9d88809..504492d431ddca860803cafd9e6605a36afe86f9 100644
--- a/research/cv/siamRPN/train.py
+++ b/research/cv/siamRPN/train.py
@@ -27,6 +27,7 @@ import mindspore.dataset as ds
 from mindspore.context import ParallelMode
 from mindspore.communication.management import init, get_rank
 from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.communication.management import get_group_size
 import numpy as np
 from src.data_loader import TrainDataLoader
 from src.net import SiameseRPN, BuildTrainNet, MyTrainOneStepCell
@@ -47,7 +48,10 @@ parser.add_argument('--data_url', default=None, help='Location of data.')
 
 parser.add_argument('--unzip_mode', default=0, type=int, metavar='N', help='unzip mode:0:no unzip,1:tar,2:unzip')
 
-parser.add_argument('--device_id', default=2, type=int, metavar='N', help='number of total epochs to run')
+parser.add_argument('--device_id', default=0, type=int, metavar='N', help='number of total epochs to run')
+
+parser.add_argument('--device_target', default="Ascend", type=str, choices=["Ascend", "GPU"],
+                    help='type of platform:Ascend or GPU')
 
 
 #add random seed
@@ -79,7 +83,7 @@ def main(args):
         # create dataset
         dataset = ds.GeneratorDataset(data_loader, ["template", "detection", "label"], shuffle=True,
                                       num_parallel_workers=rank_size, num_shards=rank_size, shard_id=rank_id)
-    else:
+    if not args.is_parallel:
         dataset = ds.GeneratorDataset(data_loader, ["template", "detection", "label"], shuffle=True)
     dataset = dataset.batch(config.batch_size, drop_remainder=True)
 
@@ -129,17 +133,19 @@ def main(args):
                   "avg_loss is %s, step time is %s" % (cb_params.cur_epoch_num, cb_params.cur_step_num, loss,
                                                        self.tlosses.avg, step_mseconds), flush=True)
     print_cb = Print_info()
+    cb = [loss_cb, print_cb]
     #save checkpoint
     ckpt_cfg = CheckpointConfig(save_checkpoint_steps=dataset.get_dataset_size(), keep_checkpoint_max=51)
     if args.is_cloudtrain:
         ckpt_cb = ModelCheckpoint(prefix='siamRPN', directory=config.train_path+'/ckpt', config=ckpt_cfg)
     else:
         ckpt_cb = ModelCheckpoint(prefix='siamRPN', directory='./ckpt', config=ckpt_cfg)
-
+    if rank == 0:
+        cb += [ckpt_cb]
     if config.checkpoint_path is not None and os.path.exists(config.checkpoint_path):
-        model.train(total_epoch, dataset, callbacks=[loss_cb, ckpt_cb, print_cb], dataset_sink_mode=False)
+        model.train(total_epoch, dataset, callbacks=cb, dataset_sink_mode=False)
     else:
-        model.train(epoch=total_epoch, train_dataset=dataset, callbacks=[loss_cb, ckpt_cb, print_cb],
+        model.train(epoch=total_epoch, train_dataset=dataset, callbacks=cb,
                     dataset_sink_mode=False)
 
 
@@ -178,7 +184,7 @@ def adjust_learning_rate(start_lr, end_lr, total_epochs, steps_pre_epoch):
 
 if __name__ == '__main__':
     Args = parser.parse_args()
-    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+    context.set_context(mode=context.GRAPH_MODE, device_target=Args.device_target)
     if Args.is_cloudtrain:
         import moxing as mox
         device_id = int(os.getenv('DEVICE_ID') if os.getenv('DEVICE_ID') is not None else 0)
@@ -199,18 +205,29 @@ if __name__ == '__main__':
             local_data_path = local_data_path + '/train/ytb_vid_filter'
         config.train_path = local_data_path
     else:
-        config.train_path = Args.train_url
+        rank = 0
         if Args.is_parallel:
-            device_id = int(os.getenv('DEVICE_ID'))
-            device_num = int(os.getenv('RANK_SIZE'))
-            if device_num > 1:
-                context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend")
-                context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
-                                                  parameter_broadcast=True, gradients_mean=True)
+            if Args.device_target == "Ascend":
+                config.train_path = Args.train_url
+                device_id = int(os.getenv('DEVICE_ID'))
+                device_num = int(os.getenv('RANK_SIZE'))
+                if device_num > 1:
+                    context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend")
+                    context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
+                                                      parameter_broadcast=True, gradients_mean=True)
+                    init()
+            elif Args.device_target == "GPU":
                 init()
+                context.set_context(device_id=Args.device_id)
+                device_num = get_group_size()
+                context.reset_auto_parallel_context()
+                rank = get_rank()
+                context.set_auto_parallel_context(device_num=device_num,
+                                                  parallel_mode=ParallelMode.DATA_PARALLEL,
+                                                  gradients_mean=True)
         else:
             device_id = Args.device_id
-            context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend")
+            context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target=Args.device_target)
     main(Args)
     if Args.is_cloudtrain:
         mox.file.copy_parallel(src_url=local_data_path + '/ckpt', dst_url=Args.train_url + '/ckpt')