From 29bd65f89ed22eb4bcfb711a44623bd759e6de16 Mon Sep 17 00:00:00 2001
From: kman066 <894855066@qq.com>
Date: Thu, 28 Apr 2022 11:54:35 +0800
Subject: [PATCH] add SiamRPN-GPU

---
 research/cv/siamRPN/README_CN.md              | 156 +++++++++++++-----
 research/cv/siamRPN/eval.py                   |   6 +-
 .../scripts/run_distribute_train_gpu.sh       |  33 ++++
 research/cv/siamRPN/scripts/run_eval.sh       |   2 +-
 research/cv/siamRPN/scripts/run_eval_gpu.sh   |  28 ++++
 research/cv/siamRPN/scripts/run_gpu.sh        |  31 ++++
 research/cv/siamRPN/src/net.py                |  11 +-
 research/cv/siamRPN/train.py                  |  45 +++--
 8 files changed, 248 insertions(+), 64 deletions(-)
 create mode 100644 research/cv/siamRPN/scripts/run_distribute_train_gpu.sh
 create mode 100644 research/cv/siamRPN/scripts/run_eval_gpu.sh
 create mode 100644 research/cv/siamRPN/scripts/run_gpu.sh

diff --git a/research/cv/siamRPN/README_CN.md b/research/cv/siamRPN/README_CN.md
index d7937fa68..bd7dd2533 100644
--- a/research/cv/siamRPN/README_CN.md
+++ b/research/cv/siamRPN/README_CN.md
@@ -2,7 +2,7 @@
 
 - [鐩綍](#鐩綍)
 - [SiamRPN鎻忚堪](#姒傝堪)
-- [妯″瀷鏋舵瀯](#s妯″瀷鏋舵瀯)
+- [妯″瀷鏋舵瀯](#妯″瀷鏋舵瀯)
 - [鏁版嵁闆哴(#鏁版嵁闆�)
 - [鐗规€(#鐗规€�)
     - [娣峰悎绮惧害](#娣峰悎绮惧害)
@@ -17,6 +17,7 @@
     - [璇勪及杩囩▼](#璇勪及杩囩▼)
         - [璇勪及](#璇勪及)
             - [910璇勪及](#910璇勪及)
+            - [GPU璇勪及](#gpu璇勪及)
             - [310璇勪及路](#310璇勪及)
 - [妯″瀷鎻忚堪](#妯″瀷鎻忚堪)
     - [鎬ц兘](#鎬ц兘)
@@ -82,6 +83,20 @@ Siam-RPN鎻愬嚭浜嗕竴绉嶅熀浜嶳PN鐨勫鐢熺綉缁滅粨鏋勩€傜敱瀛敓瀛愮綉缁滃拰RPN
 
   ```
 
+- GPU澶勭悊鍣ㄧ幆澧冭繍琛�
+
+  ```python
+  # 杩愯璁粌绀轰緥
+  bash scripts/run_gpu.sh 0
+
+  # 杩愯鍒嗗竷寮忚缁冪ず渚�
+  bash scripts/run_distribute_train_gpu.sh  device_num device_list
+
+  # 杩愯璇勪及绀轰緥
+  bash scripts/run_eval_gpu.sh 0 /path/dataset /path/ckpt/siamRPN-50_1417.ckpt eval.json
+
+  ```
+
 # 鑴氭湰璇存槑
 
 ## 鑴氭湰鍙婃牱渚嬩唬鐮�
@@ -92,23 +107,44 @@ Siam-RPN鎻愬嚭浜嗕竴绉嶅熀浜嶳PN鐨勫鐢熺綉缁滅粨鏋勩€傜敱瀛敓瀛愮綉缁滃拰RPN
     鈹溾攢鈹€ research
         鈹溾攢鈹€ cv
             鈹溾攢鈹€ siamRPN
-                鈹溾攢鈹€ README_CN.md           // googlenet鐩稿叧璇存槑
+                鈹溾攢鈹€ README_CN.md            // SiamRPN鐩稿叧璇存槑
                 鈹溾攢鈹€ ascend_310_infer        // 瀹炵幇310鎺ㄧ悊婧愪唬鐮�
                 鈹溾攢鈹€ scripts
                 鈹�    鈹溾攢鈹€run.sh              // 璁粌鑴氭湰
-                |    |鈹€鈹€run_distribute_train.sh //鏈湴澶氬崱璁粌鑴氭湰
+                |    |鈹€鈹€run_distribute_train.sh //鏈湴Ascend澶氬崱璁粌鑴氭湰
                 |    |鈹€鈹€run_eval.sh         //910璇勪及鑴氭湰
+                |    |鈹€鈹€run_eval_gpu.sh     // GPU璇勪及鑴氭湰
+                |    |鈹€鈹€run_distribute_train_gpu.sh      // 鏈湴GPU澶氬崱璁粌鑴氭湰
                 |    |鈹€鈹€run_infer_310.sh    //310鎺ㄧ悊璇勪及鑴氭湰
+                |    |鈹€鈹€run_gpu.sh          //GPU鍗曞崱璁粌鑴氭湰
                 鈹溾攢鈹€ src
-                鈹�    鈹溾攢鈹€data_loader.py      // 鏁版嵁闆嗗姞杞藉鐞嗚剼鏈�
-                鈹�    鈹溾攢鈹€net.py              //  siamRPN鏋舵瀯
-                鈹�    鈹溾攢鈹€loss.py             //  鎹熷け鍑芥暟
-                鈹�    鈹溾攢鈹€util.py             //  宸ュ叿鑴氭湰
-                鈹�    鈹溾攢鈹€tracker.py
-                鈹�    鈹溾攢鈹€generate_anchors.py
-                鈹�    鈹溾攢鈹€tracker.py
-                鈹�    鈹溾攢鈹€evaluation.py
-                鈹�    鈹溾攢鈹€config.py          // 鍙傛暟閰嶇疆
+                鈹�    鈹溾攢鈹€ data_loader.py      // 鏁版嵁闆嗗姞杞藉鐞嗚剼鏈�
+                鈹�    鈹溾攢鈹€ net.py              //  siamRPN鏋舵瀯
+                鈹�    鈹溾攢鈹€ loss.py             //  鎹熷け鍑芥暟
+                鈹�    鈹溾攢鈹€ util.py             //  宸ュ叿鑴氭湰
+                鈹�    鈹溾攢鈹€ tracker.py
+                鈹�    鈹溾攢鈹€ generate_anchors.py
+                鈹�    鈹溾攢鈹€ tracker.py
+                鈹�    鈹溾攢鈹€ evaluation.py
+                鈹�    鈹溾攢鈹€ config.py          // 鍙傛暟閰嶇疆
+                鈹溾攢鈹€ ytb_vid_filter         //璁粌闆�(闇€瑕佽嚜宸变笅杞�)
+                鈹�    鈹溾攢鈹€ --0bLFuriZ4
+                鈹�    鈹溾攢鈹€ --4VWx_0Sc4
+                鈹�    鈹溾攢鈹€ 路路路路路路
+                鈹�    鈹溾攢鈹€ 路路路路路路
+                鈹�    鈹斺攢鈹€ meta_data.pkl
+                鈹溾攢鈹€ vot2015                //娴嬭瘯闆�(闇€瑕佽嚜宸变笅杞�)
+                鈹�    鈹溾攢鈹€ bag
+                鈹�    鈹溾攢鈹€ ball1
+                鈹�    鈹溾攢鈹€ 路路路路路路
+                鈹�    鈹溾攢鈹€ 路路路路路路
+                鈹�    鈹斺攢鈹€ list.txt
+                鈹溾攢鈹€ vot2016                //娴嬭瘯闆�(闇€瑕佽嚜宸变笅杞�)
+                鈹�    鈹溾攢鈹€ bag
+                鈹�    鈹溾攢鈹€ ball1
+                鈹�    鈹溾攢鈹€ 路路路路路路
+                鈹�    鈹溾攢鈹€ 路路路路路路
+                鈹�    鈹斺攢鈹€ list.txt
                 鈹溾攢鈹€ train.py               // 璁粌鑴氭湰
                 鈹溾攢鈹€ eval.py                // 璇勪及鑴氭湰
                 鈹溾攢鈹€ export_mindir.py       // 灏哻heckpoint鏂囦欢瀵煎嚭鍒癮ir/mindir
@@ -144,7 +180,7 @@ Siam-RPN鎻愬嚭浜嗕竴绉嶅熀浜嶳PN鐨勫鐢熺綉缁滅粨鏋勩€傜敱瀛敓瀛愮綉缁滃拰RPN
 - Ascend澶勭悊鍣ㄧ幆澧冭繍琛�
 
   ```bash
-  python train.py --device_id=0 > train.log 2>&1 &
+  python train.py --device_id=0 --device_target="Ascend"> train.log 2>&1 &
   ```
 
   涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃train.log鏂囦欢鏌ョ湅缁撴灉銆�
@@ -160,7 +196,17 @@ Siam-RPN鎻愬嚭浜嗕竴绉嶅熀浜嶳PN鐨勫鐢熺綉缁滅粨鏋勩€傜敱瀛敓瀛愮綉缁滃拰RPN
 
   妯″瀷妫€鏌ョ偣淇濆瓨鍦ㄥ綋鍓嶇洰褰曚笅銆�
 
-### 鍒嗗竷寮忚缁�
+- GPU澶勭悊鍣ㄧ幆澧冭繍琛�
+
+  鍦ㄨ繍琛宼rain.py鏂囦欢鍓嶏紝闇€瑕佹墜鍔ㄩ厤缃畇rc/config.py鏂囦欢涓殑pretrain_model鍙傛暟銆乼rain_path鍙傛暟鍜宑heckpoint_path鍙傛暟锛宲retrain_model鍙傛暟浠h〃棰勮缁冩潈閲嶆ā鍨嬭矾寰勶紝train_path鍙傛暟浠h〃璁粌闆嗗瓨鏀剧殑浣嶇疆锛宑heckpoint_path鍙傛暟浠h〃瀛樻斁鐢熸垚寰楀埌鐨勮缁冩ā鍨嬬殑浣嶇疆銆�
+
+  ```bash
+  python train.py --device_id=0 --device_target="GPU"> train.log 2>&1 &
+  ```
+
+  涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃train.log鏂囦欢鏌ョ湅缁撴灉銆�
+
+### Ascend鍒嗗竷寮忚缁�
 
   瀵逛簬鍒嗗竷寮忚缁冿紝闇€瑕佹彁鍓嶅垱寤篔SON鏍煎紡鐨刪ccl閰嶇疆鏂囦欢銆�
 
@@ -185,6 +231,17 @@ Siam-RPN鎻愬嚭浜嗕竴绉嶅熀浜嶳PN鐨勫鐢熺綉缁滅粨鏋勩€傜敱瀛敓瀛愮綉缁滃拰RPN
       # (6) 鍒涘缓璁粌浣滀笟
       ```
 
+#### GPU鍒嗗竷寮忚缁�
+
+  ```bash
+  cd  SiamRPN      //杩涘叆鍒癝iamRPN鏂囦欢鏍圭洰褰�
+
+  bash scripts/run_distribute_train_gpu.sh DEVICE_NUM DEVICE_ LIST //杩愯鑴氭湰
+
+  # DEVICE_NUM琛ㄧず鏄惧崱鏁伴噺
+  # DEVICE_LIST: GPU澶勭悊鍣ㄧ殑id锛岄渶鐢ㄦ埛鎸囧畾锛屼緥濡傗€�0,1,2,3鈥�
+  ```
+
 ## 璇勪及杩囩▼
 
 ### 璇勪及
@@ -194,8 +251,8 @@ Siam-RPN鎻愬嚭浜嗕竴绉嶅熀浜嶳PN鐨勫鐢熺綉缁滅粨鏋勩€傜敱瀛敓瀛愮綉缁滃拰RPN
 - 璇勪及杩囩▼濡備笅锛岄渶瑕乿ot鏁版嵁闆嗗搴攙ideo鐨勫浘鐗囨斁浜庡搴旀枃浠跺す鐨刢olor鏂囦欢澶逛笅锛屾爣绛緂roundtruth.txt鏀句簬璇ョ洰褰曚笅銆�
 
 ```bash
-# 浣跨敤鏁版嵁闆�
-  python eval.py --device_id=0 --dataset_path=/path/dataset --checkpoint_path=/path/ckpt/siamRPN-50_1417.ckpt --filename=eval.json &> evallog &
+# 浣跨敤Ascend
+  python eval.py --device_id=0 --dataset_path=/path/dataset --checkpoint_path=/path/ckpt/siamRPN-xx_xxxx.ckpt --filename=eval.json --device_target="Ascend"&> evallog &
 ```
 
 - 涓婅堪python鍛戒护鍦ㄥ悗鍙拌繍琛岋紝鍙€氳繃`evallog`鏂囦欢鏌ョ湅璇勪及杩涚▼锛岀粨鏉熷悗鍙€氳繃`eval.json`鏂囦欢鏌ョ湅璇勪及缁撴灉銆傝瘎浼扮粨鏋滃涓嬶細
@@ -204,6 +261,21 @@ Siam-RPN鎻愬嚭浜嗕竴绉嶅熀浜嶳PN鐨勫鐢熺綉缁滅粨鏋勩€傜敱瀛敓瀛愮綉缁滃拰RPN
 {... "all_videos": {"accuracy": 0.5809545709441025, "robustness": 0.33422978326730364, "eao": 0.3102655908013835}}
 ```
 
+#### GPU璇勪及
+
+- 璇勪及杩囩▼濡備笅锛岄渶瑕乿ot鏁版嵁闆嗗搴攙ideo鐨勫浘鐗囨斁浜庡搴旀枃浠跺す鐨刢olor鏂囦欢澶逛笅锛屾爣绛緂roundtruth.txt鏀句簬璇ョ洰褰曚笅銆�
+
+```bash
+# 浣跨敤gpu
+  python eval.py --device_id=0 --dataset_path=/path/dataset --checkpoint_path=/path/ckpt/siamRPN-xx_xxxx.ckpt --filename=eval.json --device_target="GPU"&> evallog &
+```
+
+- 涓婅堪python鍛戒护鍦ㄥ悗鍙拌繍琛岋紝鍙€氳繃`evallog`鏂囦欢鏌ョ湅璇勪及杩涚▼锛岀粨鏉熷悗鍙€氳繃`eval.json`鏂囦欢鏌ョ湅璇勪及缁撴灉銆傝瘎浼扮粨鏋滃涓嬶細
+
+```bash
+{... "all_videos": {"accuracy": 0.5826686315079969, "robustness": 0.2982987648566767, "eao": 0.3289693903290864}}
+```
+
 #### 310璇勪及
 
 - 璇勪及杩囩▼濡備笅锛岄渶瑕乿ot鏁版嵁闆嗗搴攙ideo鐨勫浘鐗囨斁浜庡搴旀枃浠跺す鐨刢olor鏂囦欢澶逛笅锛屾爣绛緂roundtruth.txt鏀句簬璇ョ洰褰曚笅锛屽苟鍒皊cript鐩綍銆�
@@ -225,35 +297,35 @@ cat acc.log
 
 ### 璁粌鎬ц兘
 
-| 鍙傛暟           | siamRPN(Ascend)                                  |
-| -------------------------- | ---------------------------------------------- |
-| 妯″瀷鐗堟湰                | siamRPN                                          |
-| 璧勬簮                   | Ascend 910锛汣PU锛�2.60GHz锛�192鏍革紱鍐呭瓨锛�755 GB    |
-| 涓婁紶鏃ユ湡              | 2021-07-22                                           |
-| MindSpore鐗堟湰        | 1.2.0-alpha                                     |
-| 鏁版嵁闆�                |VID-youtube-bb                                     |
-| 璁粌鍙傛暟  |epoch=50, steps=1471, batch_size = 32 |
-| 浼樺寲鍣�                  | SGD                                                        |
-| 鎹熷け鍑芥暟 | 鑷畾涔夋崯澶卞嚱鏁� |
-| 杈撳嚭              | 鐩爣妗�                                                |
-| 鎹熷け             |100~0.05                                          |
-| 閫熷害 | 8鍗★細120姣/姝� |
-| 鎬绘椂闀� | 8鍗★細12.3灏忔椂 |
-| 璋冧紭妫€鏌ョ偣 |    247.58MB锛�.ckpt 鏂囦欢锛�               |
-| 鑴氭湰                | [siamRPN鑴氭湰](https://gitee.com/mindspore/models/tree/master/research/cv/siamRPN) |
+| 鍙傛暟           | siamRPN(Ascend)                                  | siamRPN(GPU) |
+| -------------------------- | ---------------------------------------------- | --------- |
+| 妯″瀷鐗堟湰                | siamRPN                                          | siamRPN |
+| 璧勬簮                   | Ascend 910锛汣PU锛�2.60GHz锛�192鏍革紱鍐呭瓨锛�755 GB    | RTX3090 |
+| 涓婁紶鏃ユ湡              | 2021-07-22                                           |   |
+| MindSpore鐗堟湰        | 1.2.0-alpha                                     |   |
+| 鏁版嵁闆�                |VID-youtube-bb                                     | VID-youtube-bb|
+| 璁粌鍙傛暟  |epoch=50, steps=1417, batch_size = 32                      | epoch=50, steps=1417, batch_size = 32  |
+| 浼樺寲鍣�                  | SGD                                               | SGD  |
+| 鎹熷け鍑芥暟 | 鑷畾涔夋崯澶卞嚱鏁� | 鑷畾涔夋崯澶卞嚱鏁� |
+| 杈撳嚭              | 鐩爣妗�                                                |鐩爣妗�  |
+| 鎹熷け             |100~0.05                                          | 100~0.05     |
+| 閫熷害 | 8鍗★細625姣/姝� | 8鍗★細296姣/姝�  |
+| 鎬绘椂闀� | 8鍗★細12.3灏忔椂 | 8鍗★細 5.8灏忔椂|
+| 璋冧紭妫€鏌ョ偣 |    247.58MB锛�.ckpt 鏂囦欢锛�               | 247.44MB锛�.ckpt 鏂囦欢锛墊
+| 鑴氭湰                | [siamRPN鑴氭湰](https://gitee.com/mindspore/models/tree/master/research/cv/siamRPN) | [siamRPN鑴氭湰](https://gitee.com/mindspore/models/tree/master/research/cv/siamRPN) |
 
 ### 璇勪及鎬ц兘
 
-| 鍙傛暟  | siamRPN(Ascend)                         | siamRPN(Ascend)                         |
-| ------------------- | --------------------------- | --------------------------- |
-| 妯″瀷鐗堟湰      | simaRPN                       | simaRPN                       |
-| 璧勬簮        | Ascend 910                  | Ascend 910                  |
-| 涓婁紶鏃ユ湡              | 2021-07-22                    | 2021-07-22                    |
-| MindSpore鐗堟湰   | 1.2.0-alpha                 | 1.2.0-alpha                 |
-| 鏁版嵁闆� | vot2015锛�60涓獀ideo | vot2016锛�60涓獀ideo |
-| batch_size          |   1                        |   1                        |
-| 杈撳嚭 | 鐩爣妗� | 鐩爣妗� |
-| 鍑嗙‘鐜� | 鍗曞崱锛歛ccuracy锛�0.58,robustness锛�0.33,eao:0.31; | 鍗曞崱锛歛ccuracy锛�0.56,robustness锛�0.39,eao:0.28;|
+| 鍙傛暟  | siamRPN(Ascend)        | siamRPN(Ascend)     | siamRPN(GPU)         | siamRPN(GPU)                   |
+| ------------------- | --------------------------- | --------------------------- |--------------------------- | --------------------------- |
+| 妯″瀷鐗堟湰      | simaRPN               | simaRPN          |simaRPN                       | simaRPN                       |
+| 璧勬簮        | Ascend 910           | Ascend 910       |GPU         | GPU                       |
+| 涓婁紶鏃ユ湡              | 2021-07-22         | 2021-07-22         |     2021-12-7      |         2021-12-7             |
+| MindSpore鐗堟湰   | 1.2.0-alpha                 | 1.2.0-alpha       |      1.5.0   |   1.5.0   |
+| 鏁版嵁闆� | vot2015锛�60涓獀ideo | vot2016锛�60涓獀ideo |vot2015锛�60涓獀ideo          | vot2016锛�60涓獀ideo            |
+| batch_size          |   1                |   1               |1           | 1                |
+| 杈撳嚭 | 鐩爣妗� | 鐩爣妗� |鐩爣妗�             | 鐩爣妗�        |
+| 鍑嗙‘鐜� | 鍗曞崱锛歛ccuracy锛�0.58,robustness锛�0.33,eao:0.31; | 鍗曞崱锛歛ccuracy锛�0.56,robustness锛�0.39,eao:0.28;|鍗曞崱锛歛ccuracy锛�0.5826,robustness锛�0.298,eao:0.329;       | 鍗曞崱锛歛ccuracy锛�0.5538,robustness锛�0.345,eao:0.295;                  |
 
 # 闅忔満鎯呭喌璇存槑
 
diff --git a/research/cv/siamRPN/eval.py b/research/cv/siamRPN/eval.py
index d5cdafd46..7ae023f79 100644
--- a/research/cv/siamRPN/eval.py
+++ b/research/cv/siamRPN/eval.py
@@ -144,7 +144,7 @@ def test(model_path, data_path, save_name):
 def parse_args():
     '''parse_args'''
     parser = argparse.ArgumentParser(description='Mindspore SiameseRPN Infering')
-    parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend'), help='run platform')
+    parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
     parser.add_argument('--device_id', type=int, default=0, help='DEVICE_ID')
     parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
     parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of siamRPN')
@@ -154,10 +154,10 @@ def parse_args():
 
 if __name__ == '__main__':
     args = parse_args()
-    if args.platform == 'Ascend':
+    if args.device_target == 'Ascend':
         device_id = args.device_id
         context.set_context(device_id=device_id)
-    context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
+    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
     model_file_path = args.checkpoint_path
     data_file_path = args.dataset_path
     save_file_name = args.filename
diff --git a/research/cv/siamRPN/scripts/run_distribute_train_gpu.sh b/research/cv/siamRPN/scripts/run_distribute_train_gpu.sh
new file mode 100644
index 000000000..b5a276b04
--- /dev/null
+++ b/research/cv/siamRPN/scripts/run_distribute_train_gpu.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 2 ]
+then 
+    echo "Usage: bash run_distribute_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]"
+exit 1
+fi
+
+
+DEVICE_NUM=$1
+echo $DEVICE_NUM
+
+export DEVICE_NUM=$1
+export RANK_SIZE=$DEVICE_NUM
+export CUDA_VISIBLE_DEVICES="$2"
+
+
+nohup mpirun -n $DEVICE_NUM --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
+python -u train.py  --device_target="GPU" --is_parallel=True > train_gpu.log 2>&1 &
diff --git a/research/cv/siamRPN/scripts/run_eval.sh b/research/cv/siamRPN/scripts/run_eval.sh
index 57b3014c2..a8d217638 100644
--- a/research/cv/siamRPN/scripts/run_eval.sh
+++ b/research/cv/siamRPN/scripts/run_eval.sh
@@ -17,5 +17,5 @@ export DEVICE_ID=$1
 export DATA_NAME=$2
 export MODEL_PATH=$3
 export FILENAME=$4
-python  eval.py  --device_id=$DEVICE_ID --dataset_path=$DATA_NAME --checkpoint_path=$MODEL_PATH --filename=$FILENAME &> evallog &
+python  eval.py  --device_id=$DEVICE_ID --dataset_path=$DATA_NAME --checkpoint_path=$MODEL_PATH --filename=$FILENAME &> eval.log &
 
diff --git a/research/cv/siamRPN/scripts/run_eval_gpu.sh b/research/cv/siamRPN/scripts/run_eval_gpu.sh
new file mode 100644
index 000000000..1864dcce9
--- /dev/null
+++ b/research/cv/siamRPN/scripts/run_eval_gpu.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 4 ]
+then 
+    echo "Usage: bash run_eval_gpu.sh [DEVICE_id] [DATA_NAME] [MODEL_PATH] [FILENAME]"
+exit 1
+fi
+
+export DEVICE_ID=$1
+export DATA_NAME=$2
+export MODEL_PATH=$3
+export FILENAME=$4
+python  eval.py  --device_id=$DEVICE_ID --dataset_path=$DATA_NAME --checkpoint_path=$MODEL_PATH --filename=$FILENAME --device_target="GPU" &> eval_gpu.log &
+
diff --git a/research/cv/siamRPN/scripts/run_gpu.sh b/research/cv/siamRPN/scripts/run_gpu.sh
new file mode 100644
index 000000000..785a5b4eb
--- /dev/null
+++ b/research/cv/siamRPN/scripts/run_gpu.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 1 ]
+then 
+    echo "=============================================================================================================="
+    echo "Please run the script as: "
+    echo "bash run.sh DEVICE_ID"
+    echo "For example: bash run_gpu.sh 0"
+    echo "=============================================================================================================="
+exit 1
+fi
+
+
+DEVICE_ID=$1
+
+export DEVICE_ID=$DEVICE_ID
+python3 train.py --device_id=$DEVICE_ID --device_target="GPU"> train.log 2>&1 &
diff --git a/research/cv/siamRPN/src/net.py b/research/cv/siamRPN/src/net.py
index fc3a1a447..3310101d6 100644
--- a/research/cv/siamRPN/src/net.py
+++ b/research/cv/siamRPN/src/net.py
@@ -82,6 +82,9 @@ class SiameseRPN(nn.Cell):
         self.softmax = ops.Softmax(axis=2)
         self.print = ops.Print()
 
+        self.anchor_num = config.anchor_num
+        self.score_size = config.score_size
+
     def construct(self, template=None, detection=None, ckernal=None, rkernal=None):
         """ forward function """
         if self.is_train is True and template is not None and detection is not None:
@@ -172,13 +175,13 @@ class SiameseRPN(nn.Cell):
                 routputs = routputs + (self.conv2d_rout(r_features[i], r_weights[i]),)
             coutputs = self.op_concat(coutputs)
             routputs = self.op_concat(routputs)
-            coutputs = self.reshape(coutputs, (self.groups, 2*config.anchor_num, config.score_size, config.score_size))
-            routputs = self.reshape(routputs, (self.groups, 4*config.anchor_num, config.score_size, config.score_size))
+            coutputs = self.reshape(coutputs, (self.groups, 2*self.anchor_num, self.score_size, self.score_size))
+            routputs = self.reshape(routputs, (self.groups, 4*self.anchor_num, self.score_size, self.score_size))
             routputs = self.regress_adjust(routputs)
             coutputs = self.transpose(
-                self.reshape(coutputs, (-1, 2, config.anchor_num * config.score_size* config.score_size)), (0, 2, 1))
+                self.reshape(coutputs, (-1, 2, self.anchor_num * self.score_size* self.score_size)), (0, 2, 1))
             routputs = self.transpose(
-                self.reshape(routputs, (-1, 4, config.anchor_num * config.score_size* config.score_size)),
+                self.reshape(routputs, (-1, 4, self.anchor_num * self.score_size* self.score_size)),
                 (0, 2, 1))
             out1, out2 = coutputs, routputs
         else:
diff --git a/research/cv/siamRPN/train.py b/research/cv/siamRPN/train.py
index 7267d7fd3..504492d43 100644
--- a/research/cv/siamRPN/train.py
+++ b/research/cv/siamRPN/train.py
@@ -27,6 +27,7 @@ import mindspore.dataset as ds
 from mindspore.context import ParallelMode
 from mindspore.communication.management import init, get_rank
 from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.communication.management import get_group_size
 import numpy as np
 from src.data_loader import TrainDataLoader
 from src.net import SiameseRPN, BuildTrainNet, MyTrainOneStepCell
@@ -47,7 +48,10 @@ parser.add_argument('--data_url', default=None, help='Location of data.')
 
 parser.add_argument('--unzip_mode', default=0, type=int, metavar='N', help='unzip mode:0:no unzip,1:tar,2:unzip')
 
-parser.add_argument('--device_id', default=2, type=int, metavar='N', help='number of total epochs to run')
+parser.add_argument('--device_id', default=0, type=int, metavar='N', help='number of total epochs to run')
+
+parser.add_argument('--device_target', default="Ascend", type=str, choices=["Ascend", "GPU"],
+                    help='type of platform:Ascend or GPU')
 
 
 #add random seed
@@ -79,7 +83,7 @@ def main(args):
         # create dataset
         dataset = ds.GeneratorDataset(data_loader, ["template", "detection", "label"], shuffle=True,
                                       num_parallel_workers=rank_size, num_shards=rank_size, shard_id=rank_id)
-    else:
+    if not args.is_parallel:
         dataset = ds.GeneratorDataset(data_loader, ["template", "detection", "label"], shuffle=True)
     dataset = dataset.batch(config.batch_size, drop_remainder=True)
 
@@ -129,17 +133,19 @@ def main(args):
                   "avg_loss is %s, step time is %s" % (cb_params.cur_epoch_num, cb_params.cur_step_num, loss,
                                                        self.tlosses.avg, step_mseconds), flush=True)
     print_cb = Print_info()
+    cb = [loss_cb, print_cb]
     #save checkpoint
     ckpt_cfg = CheckpointConfig(save_checkpoint_steps=dataset.get_dataset_size(), keep_checkpoint_max=51)
     if args.is_cloudtrain:
         ckpt_cb = ModelCheckpoint(prefix='siamRPN', directory=config.train_path+'/ckpt', config=ckpt_cfg)
     else:
         ckpt_cb = ModelCheckpoint(prefix='siamRPN', directory='./ckpt', config=ckpt_cfg)
-
+    if rank == 0:
+        cb += [ckpt_cb]
     if config.checkpoint_path is not None and os.path.exists(config.checkpoint_path):
-        model.train(total_epoch, dataset, callbacks=[loss_cb, ckpt_cb, print_cb], dataset_sink_mode=False)
+        model.train(total_epoch, dataset, callbacks=cb, dataset_sink_mode=False)
     else:
-        model.train(epoch=total_epoch, train_dataset=dataset, callbacks=[loss_cb, ckpt_cb, print_cb],
+        model.train(epoch=total_epoch, train_dataset=dataset, callbacks=cb,
                     dataset_sink_mode=False)
 
 
@@ -178,7 +184,7 @@ def adjust_learning_rate(start_lr, end_lr, total_epochs, steps_pre_epoch):
 
 if __name__ == '__main__':
     Args = parser.parse_args()
-    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
+    context.set_context(mode=context.GRAPH_MODE, device_target=Args.device_target)
     if Args.is_cloudtrain:
         import moxing as mox
         device_id = int(os.getenv('DEVICE_ID') if os.getenv('DEVICE_ID') is not None else 0)
@@ -199,18 +205,29 @@ if __name__ == '__main__':
             local_data_path = local_data_path + '/train/ytb_vid_filter'
         config.train_path = local_data_path
     else:
-        config.train_path = Args.train_url
+        rank = 0
         if Args.is_parallel:
-            device_id = int(os.getenv('DEVICE_ID'))
-            device_num = int(os.getenv('RANK_SIZE'))
-            if device_num > 1:
-                context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend")
-                context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
-                                                  parameter_broadcast=True, gradients_mean=True)
+            if Args.device_target == "Ascend":
+                config.train_path = Args.train_url
+                device_id = int(os.getenv('DEVICE_ID'))
+                device_num = int(os.getenv('RANK_SIZE'))
+                if device_num > 1:
+                    context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend")
+                    context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
+                                                      parameter_broadcast=True, gradients_mean=True)
+                    init()
+            elif Args.device_target == "GPU":
                 init()
+                context.set_context(device_id=Args.device_id)
+                device_num = get_group_size()
+                context.reset_auto_parallel_context()
+                rank = get_rank()
+                context.set_auto_parallel_context(device_num=device_num,
+                                                  parallel_mode=ParallelMode.DATA_PARALLEL,
+                                                  gradients_mean=True)
         else:
             device_id = Args.device_id
-            context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend")
+            context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target=Args.device_target)
     main(Args)
     if Args.is_cloudtrain:
         mox.file.copy_parallel(src_url=local_data_path + '/ckpt', dst_url=Args.train_url + '/ckpt')
-- 
GitLab