From 748eba339ce5a60bdc73cdc95cb8aba09c252d28 Mon Sep 17 00:00:00 2001
From: jialingqu <qu0604@163.com>
Date: Tue, 5 Jul 2022 20:25:04 +0800
Subject: [PATCH] update LEO

---
 research/cv/LEO/README.md                     | 101 +++++++++++++-----
 .../config/LEO-N5-K1_miniImageNet_config.yaml |   6 +-
 .../LEO-N5-K1_tieredImageNet_config.yaml      |   6 +-
 .../config/LEO-N5-K5_miniImageNet_config.yaml |  15 ++-
 .../LEO-N5-K5_tieredImageNet_config.yaml      |   4 +-
 research/cv/LEO/eval.py                       |  14 ++-
 research/cv/LEO/export.py                     |  91 ++++++++++++++++
 research/cv/LEO/model_utils/config.py         |  28 +++--
 research/cv/LEO/scripts/export.sh             |  36 +++++++
 .../cv/LEO/scripts/run_distribution_ascend.sh |  71 ++++++++++++
 research/cv/LEO/scripts/run_eval_ascend.sh    |  36 +++++++
 research/cv/LEO/scripts/run_eval_gpu.sh       |  20 ++--
 research/cv/LEO/scripts/run_train_ascend.sh   |  33 ++++++
 research/cv/LEO/scripts/run_train_gpu.sh      |  32 +++---
 research/cv/LEO/train.py                      |  69 ++++++++----
 15 files changed, 475 insertions(+), 87 deletions(-)
 create mode 100644 research/cv/LEO/export.py
 create mode 100644 research/cv/LEO/scripts/export.sh
 create mode 100644 research/cv/LEO/scripts/run_distribution_ascend.sh
 create mode 100644 research/cv/LEO/scripts/run_eval_ascend.sh
 create mode 100644 research/cv/LEO/scripts/run_train_ascend.sh

diff --git a/research/cv/LEO/README.md b/research/cv/LEO/README.md
index 20cb4db10..919e97a62 100644
--- a/research/cv/LEO/README.md
+++ b/research/cv/LEO/README.md
@@ -105,8 +105,9 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
 
 # 鐜瑕佹眰
 
-- 纭欢锛圙PU锛�
+- 纭欢锛圙PU or Ascend锛�
     - 浣跨敤GPU澶勭悊鍣ㄦ潵鎼缓纭欢鐜銆�
+    - 浣跨敤Ascend澶勭悊鍣ㄦ潵鎼缓纭欢鐜銆�
 - 妗嗘灦
     - [MindSpore](https://www.mindspore.cn/install/en)
 - 濡傞渶鏌ョ湅璇︽儏锛岃鍙傝濡備笅璧勬簮锛�
@@ -130,6 +131,24 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
   bash scripts/run_eval_gpu.sh [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [CKPT_FILE]
   # 杩愯璇勪及绀轰緥
   bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpt/1P_mini_1/xxx.ckpt
+
+  ```
+
+- Ascend澶勭悊鍣ㄧ幆澧冭繍琛�
+
+  ```bash
+  # 杩愯璁粌绀轰緥
+  bash scripts/run_train_gpu.sh [DEVICE_ID] [DEVICE_TARGET] [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [SAVE_PATH]
+  # 渚嬪锛�
+  bash scripts/run_train_ascend.sh 6 Ascend /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpts/1P_mini_5
+  # 杩愯鍒嗗竷寮忚缁冪ず渚�
+  bash scripts/run_distribution_ascend.sh [RANK_TABLE_FILE] [DEVICE_TARGET] [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [SAVE_PATH]
+  # 渚嬪锛�
+  bash scripts/run_distribution_ascend.sh ./hccl_8p_01234567_127.0.0.1.json Ascend /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpts/8P_mini_5
+  # 杩愯璇勪及绀轰緥
+  bash scripts/run_eval_gpu.sh [DEVICE_ID] [DATA_PATH] [CKPT_FILE]
+  # 渚嬪
+  bash scripts/run_eval_ascend.sh 4 Ascend /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpt/1P_mini_5/xxx.ckpt
   ```
 
 浠ヤ笂涓虹涓€涓疄楠岀ず渚嬶紝鍏朵綑涓変釜瀹為獙璇峰弬鑰冭缁冮儴鍒嗐€�
@@ -144,8 +163,11 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
    鈹溾攢 train.py                    # 璁粌鑴氭湰
    鈹溾攢 eval.py                     # 璇勪及鑴氭湰
    鈹溾攢 scripts
-   鈹�  鈹溾攢 run_eval_gpu.sh          # 鍚姩璇勪及
-   鈹�  鈹斺攢 run_train_gpu.sh         # 鍚姩璁粌
+   鈹�  鈹溾攢 run_distribution_ascend.sh          # 鍚姩8鍗scend璁粌
+   鈹�  鈹溾攢 run_eval_ascend.sh           # ascend鍚姩璇勪及
+   鈹�  鈹溾攢 run_eval_gpu.sh              # gpu鍚姩璇勪及
+   鈹�  鈹溾攢 run_train_ascend.sh          # ascend鍚姩璁粌
+   鈹�  鈹斺攢 run_train_gpu.sh             # gpu鍚姩璁粌
    鈹溾攢 src
    鈹�  鈹溾攢 data.py                  # 鏁版嵁澶勭悊
    鈹�  鈹溾攢 model.py                 # LEO妯″瀷
@@ -211,7 +233,7 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
   outer_lr: 0.004      #瓒呭弬
   gradient_threshold: 0.1
   gradient_norm_threshold: 0.1
-  total_steps: 100000
+  total_steps: 200000
   ```
 
 鏇村閰嶇疆缁嗚妭璇峰弬鑰僣onfig鏂囦欢澶癸紝**鍚姩璁粌涔嬪墠璇锋牴鎹笉鍚岀殑瀹為獙璁剧疆涓婅堪瓒呭弬鏁般€�**
@@ -221,13 +243,13 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
 - 鍥涗釜瀹為獙璁剧疆涓嶅悓鐨勮秴鍙�
 
 | 瓒呭弬                           | miniImageNet 1-shot | miniImageNet 5-shot | tieredImageNet 1-shot | tieredImageNet 5-shot |
-| ------------------------------ | ------------------- | ------------------- | --------------------- | --------------------- |
+| ------------------------------ |---------------------|---------------------|-----------------------| --------------------- |
 | `dropout`                      | 0.3                 | 0.3                 | 0.2                   | 0.3                   |
 | `kl_weight`                    | 0.001               | 0.001               | 0                     | 0.001                 |
 | `encoder_penalty_weight`       | 1E-9                | 2.66E-7             | 5.7E-1                | 5.7E-6                |
 | `l2_penalty_weight`            | 0.0001              | 8.5E-6              | 5.10E-6               | 3.6E-10               |
-| `orthogonality_penalty_weight` | 303.0               | 0.00152             | 4.88E-1              | 0.188                 |
-| `outer_lr`                     | 0.004               | 0.004               | 0.004                 | 0.0025                |
+| `orthogonality_penalty_weight` | 303.0               | 0.00152             | 4.88E-1               | 0.188                 |
+| `outer_lr`                     | 0.005               | 0.005               | 0.005                 | 0.0025                |
 
 ### 璁粌
 
@@ -240,6 +262,15 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
   bash scripts/run_train_gpu.sh 1 /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/1P_tiered_5
   ```
 
+- 閰嶇疆濂戒笂杩板弬鏁板悗锛孉Scend鐜杩愯
+
+  ```bash
+  bash scripts/run_train_ascend.sh 6 Ascend /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpts/1P_mini_1
+  bash scripts/run_train_ascend.sh 6 Ascend /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpts/1P_mini_5
+  bash scripts/run_train_ascend.sh 6 Ascend /home/mindspore/dataset/embeddings/ tieredImageNet 1 ./ckpt/1P_tiered_1
+  bash scripts/run_train_ascend.sh 6 Ascend /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/1P_tiered_5
+  ```
+
   璁粌灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃`1P_miniImageNet_1_train.log`绛夋棩蹇楁枃浠舵煡鐪嬭缁冭繃绋嬨€�
   璁粌缁撴潫鍚庯紝鎮ㄥ彲鍦� ` ./ckpt/1P_mini_1` 绛塩heckpoint鏂囦欢澶逛笅鎵惧埌妫€鏌ョ偣鏂囦欢銆�
 
@@ -256,6 +287,15 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
   bash scripts/run_train_gpu.sh 8 /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/8P_tiered_5
   ```
 
+- 閰嶇疆濂戒笂杩板弬鏁板悗锛孉scend鐜杩愯
+
+  ```bash
+  bash scripts/run_distribution_ascend.sh ./hccl_8p_01234567_127.0.0.1.json Ascend /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpts/8P_mini_1
+  bash scripts/run_distribution_ascend.sh ./hccl_8p_01234567_127.0.0.1.json Ascend /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpts/8P_mini_5
+  bash scripts/run_distribution_ascend.sh ./hccl_8p_01234567_127.0.0.1.json Ascend /home/mindspore/dataset/embeddings/ tieredImageNet 1 ./ckpts/8P_tired_1
+  bash scripts/run_distribution_ascend.sh ./hccl_8p_01234567_127.0.0.1.json Ascend /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpts/8P_tired_5
+  ```
+
   涓庡崟鍗¤缁冧竴鏍凤紝鍙互鍦╜8P_miniImageNet_1_train.log`鏂囦欢鏌ョ湅璁粌杩囩▼锛屽苟鍦ㄩ粯璁./ckpt/8P_mini_1`绛塩heckpoint鏂囦欢澶逛笅鎵惧埌妫€鏌ョ偣鏂囦欢銆�
 
 ## 璇勪及杩囩▼
@@ -273,6 +313,15 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
   bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/1P_tiered_5/xxx.ckpt
   ```
 
+- Ascend鐜杩愯
+
+  ```bash
+  bash scripts/run_eval_ascend.sh 0 Ascend /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpt/1P_mini_1/xxx.ckpt
+  bash scripts/run_eval_ascend.sh 0 Ascend /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpt/1P_mini_5/xxx.ckpt
+  bash scripts/run_eval_ascend.sh 0 Ascend /home/mindspore/dataset/embeddings/ tieredImageNet 1 ./ckpt/1P_tiered_1/xxx.ckpt
+  bash scripts/run_eval_ascend.sh 0 Ascend /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/1P_tiered_5/xxx.ckpt
+  ```
+
   璇勪及灏嗗湪鍚庡彴杩愯锛屾偍鍙互閫氳繃`1P_miniImageNet_1_eval.log`绛夋棩蹇楁枃浠舵煡鐪嬭瘎浼拌繃绋嬨€�
 
 # 妯″瀷鎻忚堪
@@ -283,19 +332,19 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
 
 - 璁粌鍙傛暟
 
-| 鍙傛暟          | LEO                                                         |
-| -------------| ----------------------------------------------------------- |
-| 璧勬簮          | NVIDIA GeForce RTX 3090锛汣UDA鏍稿績 10496涓紱鏄惧瓨 24GB |
-| 涓婁紶鏃ユ湡       | 2022-03-27                                             |
-| MindSpore鐗堟湰 | 1.7.0                                                      |
-| 鏁版嵁闆�        | miniImageNet                                                 |
-| 浼樺寲鍣�        | Adam                                                         |
-| 鎹熷け鍑芥暟       | Cross Entropy Loss                                           |
-| 杈撳嚭          | 鍑嗙‘鐜�                                                        |
-| 鎹熷け          | GANLoss,L1Loss,localLoss,DTLoss                             |
-| 寰皟妫€鏌ョ偣     | 672KB (.ckpt鏂囦欢)                                     |
+| 鍙傛暟          | LEO                                                         | Ascend                                        |
+| -------------| ----------------------------------------------------------- |-----------------------------------------------|
+| 璧勬簮          | NVIDIA GeForce RTX 3090锛汣UDA鏍稿績 10496涓紱鏄惧瓨 24GB | Ascend 910; CPU 24cores; 鏄惧瓨 256G; OS Euler2.8 |
+| 涓婁紶鏃ユ湡       | 2022-03-27                                             | 2022-06-12                                    |
+| MindSpore鐗堟湰 | 1.7.0                                                      | 1.5.0                                         |
+| 鏁版嵁闆�        | miniImageNet                                                 | miniImageNet                                  |
+| 浼樺寲鍣�        | Adam                                                         | Adam                                          |
+| 鎹熷け鍑芥暟       | Cross Entropy Loss                                           | Cross Entropy Loss                            |
+| 杈撳嚭          | 鍑嗙‘鐜�                                                        | 鍑嗙‘鐜�                                           |
+| 鎹熷け          | GANLoss,L1Loss,localLoss,DTLoss                             | GANLoss,L1Loss,localLoss,DTLoss               |
+| 寰皟妫€鏌ョ偣     | 672KB (.ckpt鏂囦欢)                                     | 672KB (.ckpt鏂囦欢)                               |
 
-- 璇勪及鎬ц兘
+- GPU璇勪及鎬ц兘
 
 | 瀹為獙 | miniImageNet 1-shot | miniImageNet 5-shot | tieredImageNet 1-shot | tieredImageNet 5-shot |
 | ----- | ------------------- | ------------------- | --------------------- | --------------------- |
@@ -306,13 +355,13 @@ LEO鐢变互涓嬪嚑涓ā鍧楃粍鎴愶紝鍒嗙被鍣紝缂栫爜鍣紝鍏崇郴缃戠粶鍜岀紪鐮�
 
 - 璇勪及鍙傛暟
 
-| 鍙傛暟          | LEO                                                         |
-| ------------ | ----------------------------------------------------------- |
-| 璧勬簮          | NVIDIA GeForce RTX 3090锛汣UDA鏍稿績 10496涓紱鏄惧瓨 24GB |
-| 涓婁紶鏃ユ湡       | 2022-03-27                                              |
-| MindSpore鐗堟湰 | 1.7.0                                                      |
-| 鏁版嵁闆�        | miniImageNet                                                 |
-| 杈撳嚭          | 鍑嗙‘鐜�                                                        |
+| 鍙傛暟          | LEO                                                         | Ascend                                        |
+| ------------ | ----------------------------------------------------------- |-----------------------------------------------|
+| 璧勬簮          | NVIDIA GeForce RTX 3090锛汣UDA鏍稿績 10496涓紱鏄惧瓨 24GB | Ascend 910; CPU 24cores; 鏄惧瓨 256G; OS Euler2.8 |
+| 涓婁紶鏃ユ湡       | 2022-03-27                                              | 2022-06-12                                    |
+| MindSpore鐗堟湰 | 1.7.0                                                      |1.5.0                                         |
+| 鏁版嵁闆�        | miniImageNet                                                 | miniImageNet                                  |
+| 杈撳嚭          | 鍑嗙‘鐜�                                                        | 鍑嗙‘鐜�                                           |
 
 - 璇勪及绮惧害
 
diff --git a/research/cv/LEO/config/LEO-N5-K1_miniImageNet_config.yaml b/research/cv/LEO/config/LEO-N5-K1_miniImageNet_config.yaml
index 4675844d4..71b2aee87 100644
--- a/research/cv/LEO/config/LEO-N5-K1_miniImageNet_config.yaml
+++ b/research/cv/LEO/config/LEO-N5-K1_miniImageNet_config.yaml
@@ -2,6 +2,8 @@
 enable_modelarts: False
 data_url: ""
 train_url: ""
+ckpt_url: 'ckpt files'
+result_url: 'infer result files'
 checkpoint_url: ""
 device_target: "GPU"
 device_num: 1
@@ -36,10 +38,10 @@ metatrain_batch_size: 12
 metavalid_batch_size: 200
 metatest_batch_size: 200
 num_steps_limit: int(1e5)
-outer_lr: 0.004 # parameters
+outer_lr: 0.005 # parameters
 gradient_threshold: 0.1
 gradient_norm_threshold: 0.1
-total_steps: 100000
+total_steps: 200000
 
 # Model Description
 model_name: LEO
diff --git a/research/cv/LEO/config/LEO-N5-K1_tieredImageNet_config.yaml b/research/cv/LEO/config/LEO-N5-K1_tieredImageNet_config.yaml
index e63670346..8255b4107 100644
--- a/research/cv/LEO/config/LEO-N5-K1_tieredImageNet_config.yaml
+++ b/research/cv/LEO/config/LEO-N5-K1_tieredImageNet_config.yaml
@@ -2,6 +2,8 @@
 enable_modelarts: False
 data_url: ""
 train_url: ""
+ckpt_url: 'ckpt files'
+result_url: 'infer result files'
 checkpoint_url: ""
 device_target: "GPU"
 device_num: 1
@@ -36,10 +38,10 @@ metatrain_batch_size: 12
 metavalid_batch_size: 200
 metatest_batch_size: 200
 num_steps_limit: int(1e5)
-outer_lr: 0.004 # parameters
+outer_lr: 0.005 # parameters
 gradient_threshold: 0.1
 gradient_norm_threshold: 0.1
-total_steps: 100000
+total_steps: 200000
 
 # Model Description
 model_name: LEO
diff --git a/research/cv/LEO/config/LEO-N5-K5_miniImageNet_config.yaml b/research/cv/LEO/config/LEO-N5-K5_miniImageNet_config.yaml
index 38e74815a..d3e5838d2 100644
--- a/research/cv/LEO/config/LEO-N5-K5_miniImageNet_config.yaml
+++ b/research/cv/LEO/config/LEO-N5-K5_miniImageNet_config.yaml
@@ -2,6 +2,8 @@
 enable_modelarts: False
 data_url: ""
 train_url: ""
+ckpt_url: 'ckpt files'
+result_url: 'infer result files'
 checkpoint_url: ""
 device_target: "GPU"
 device_num: 1
@@ -20,10 +22,11 @@ inner_unroll_length: 5
 finetuning_unroll_length: 5
 num_latents: 64
 inner_lr_init: 1.0
-finetuning_lr_init: 0.001
+finetuning_lr_init: 0.0005
 dropout_rate: 0.3 # parameters
 kl_weight: 0.001 # parameters
-encoder_penalty_weight: 2.66E-7 # parameters
+encoder_penalty_weight: 2.66E-7 # parameters origin
+
 l2_penalty_weight: 8.5E-6 # parameters
 orthogonality_penalty_weight: 0.00152 # parameters
 # ==============================================================================
@@ -36,15 +39,15 @@ metatrain_batch_size: 12
 metavalid_batch_size: 200
 metatest_batch_size: 200
 num_steps_limit: int(1e5)
-outer_lr: 0.004 # parameters
+outer_lr: 0.005 # parameters origin
 gradient_threshold: 0.1
 gradient_norm_threshold: 0.1
-total_steps: 100000
+total_steps: 200000
 
 # Model Description
 model_name: LEO
 file_name: 'leo'
-file_format: 'MINDIR'  # ['AIR', 'MINDIR']
+file_format: 'AIR'
 
 
 ---
@@ -54,6 +57,8 @@ data_url: 'Dataset url for obs'
 train_url: 'Training output url for obs'
 data_path: 'Dataset path for local'
 output_path: 'Training output path for local'
+ckpt_url: 'ckpt files'
+result_url: 'infer result files'
 
 device_target: 'Target device type'
 enable_profiling: 'Whether enable profiling while training, default: False'
diff --git a/research/cv/LEO/config/LEO-N5-K5_tieredImageNet_config.yaml b/research/cv/LEO/config/LEO-N5-K5_tieredImageNet_config.yaml
index 92056d7b6..97f7cb50e 100644
--- a/research/cv/LEO/config/LEO-N5-K5_tieredImageNet_config.yaml
+++ b/research/cv/LEO/config/LEO-N5-K5_tieredImageNet_config.yaml
@@ -2,6 +2,8 @@
 enable_modelarts: False
 data_url: ""
 train_url: ""
+ckpt_url: 'ckpt files'
+result_url: 'infer result files'
 checkpoint_url: ""
 device_target: "GPU"
 device_num: 1
@@ -39,7 +41,7 @@ num_steps_limit: int(1e5)
 outer_lr: 0.0025 # parameters
 gradient_threshold: 0.1
 gradient_norm_threshold: 0.1
-total_steps: 100000
+total_steps: 200000
 
 # Model Description
 model_name: LEO
diff --git a/research/cv/LEO/eval.py b/research/cv/LEO/eval.py
index 62169c7b4..167a1f7ce 100644
--- a/research/cv/LEO/eval.py
+++ b/research/cv/LEO/eval.py
@@ -35,7 +35,7 @@ def eval_leo(init_config, inner_model_config, outer_model_config):
     total_test_steps = 100
 
     data_utils = data.Data_Utils(
-        train=False, seed=100, way=outer_model_config['num_classes'],
+        train=False, seed=1, way=outer_model_config['num_classes'],
         shot=outer_model_config['num_tr_examples_per_class'],
         data_path=init_config['data_path'], dataset_name=init_config['dataset_name'],
         embedding_crop=init_config['embedding_crop'],
@@ -75,6 +75,7 @@ if __name__ == '__main__':
     initConfig = config.get_init_config()
     inner_model_Config = config.get_inner_model_config()
     outer_model_Config = config.get_outer_model_config()
+    args = config.get_config(get_args=True)
 
     print("===============inner_model_config=================")
     for key, value in inner_model_Config.items():
@@ -84,5 +85,16 @@ if __name__ == '__main__':
         print(key+": "+str(value))
 
     context.set_context(mode=context.GRAPH_MODE, device_target=initConfig['device_target'])
+    if args.enable_modelarts:
+        import moxing as mox
+
+        mox.file.copy_parallel(
+            src_url=args.data_url, dst_url='/cache/dataset/device_' + os.getenv('DEVICE_ID'))
+        train_dataset_path = os.path.join('/cache/dataset/device_' + os.getenv('DEVICE_ID'), "embeddings")
+        ckpt_path = '/home/work/user-job-dir/checkpoint.ckpt'
+        mox.file.copy(args.ckpt_url, ckpt_path)
+        initConfig['data_path'] = train_dataset_path
+        initConfig['ckpt_file'] = ckpt_path
+
 
     eval_leo(initConfig, inner_model_Config, outer_model_Config)
diff --git a/research/cv/LEO/export.py b/research/cv/LEO/export.py
new file mode 100644
index 000000000..5c755a3c3
--- /dev/null
+++ b/research/cv/LEO/export.py
@@ -0,0 +1,91 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+import src.data as data
+import src.outerloop as outerloop
+import model_utils.config as config
+import mindspore as ms
+from mindspore import context, Tensor, export
+from mindspore import load_checkpoint, load_param_into_net
+import numpy as np
+
+os.environ['GLOG_v'] = "3"
+os.environ['GLOG_log_dir'] = '/var/log'
+
+
+def export_leo(init_config, inner_model_config, outer_model_config):
+    inner_lr_init = inner_model_config['inner_lr_init']
+    finetuning_lr_init = inner_model_config['finetuning_lr_init']
+
+    data_utils = data.Data_Utils(
+        train=False, seed=100, way=outer_model_config['num_classes'],
+        shot=outer_model_config['num_tr_examples_per_class'],
+        data_path=init_config['data_path'], dataset_name=init_config['dataset_name'],
+        embedding_crop=init_config['embedding_crop'],
+        batchsize=outer_model_config['metatrain_batch_size'],
+        val_batch_size=outer_model_config['metavalid_batch_size'],
+        test_batch_size=outer_model_config['metatest_batch_size'],
+        meta_val_steps=outer_model_config['num_val_examples_per_class'], embedding_size=640, verbose=True)
+
+    test_outer_loop = outerloop.OuterLoop(
+        batchsize=outer_model_config['metavalid_batch_size'], input_size=640,
+        latent_size=inner_model_config['num_latents'],
+        way=outer_model_config['num_classes'], shot=outer_model_config['num_tr_examples_per_class'],
+        dropout=inner_model_config['dropout_rate'], kl_weight=inner_model_config['kl_weight'],
+        encoder_penalty_weight=inner_model_config['encoder_penalty_weight'],
+        orthogonality_weight=inner_model_config['orthogonality_penalty_weight'],
+        inner_lr_init=inner_lr_init, finetuning_lr_init=finetuning_lr_init,
+        inner_step=inner_model_config['inner_unroll_length'],
+        finetune_inner_step=inner_model_config['finetuning_unroll_length'], is_meta_training=False)
+
+    parm_dict = load_checkpoint(init_config['ckpt_file'])
+    load_param_into_net(test_outer_loop, parm_dict)
+
+    batch = data_utils.get_batch('test')
+    print(batch['train']['input'].shape)  # [200,5,5,640]
+    print(batch['train']['input'].dtype)  # Float32
+    print(batch['train']['target'].shape)  # [200,5,5,1]
+    print(batch['train']['target'].dtype)  # Int64
+    print(batch['val']['input'].shape)  # [200,5,15,640]
+    print(batch['val']['input'].dtype)  # Float32
+    print(batch['val']['target'].shape)  # [200,5,15,1]
+    print(batch['val']['target'].dtype)  # Int64
+    train_input = Tensor(np.zeros(batch['train']['input'].shape), ms.float32)
+    train_target = Tensor(np.zeros(batch['train']['target'].shape), ms.int64)
+    val_input = Tensor(np.zeros(batch['val']['input'].shape), ms.float32)
+    val_target = Tensor(np.zeros(batch['val']['target'].shape), ms.int64)
+    result_name = "LEO-" + init_config['dataset_name'] + str(outer_model_config['num_classes']) +\
+                  "N" + str(outer_model_config['num_tr_examples_per_class']) + "K"
+    export(test_outer_loop, train_input, train_target, val_input, val_target,
+           file_name=result_name, file_format="MINDIR")
+
+
+if __name__ == '__main__':
+    initConfig = config.get_init_config()
+    inner_model_Config = config.get_inner_model_config()
+    outer_model_Config = config.get_outer_model_config()
+
+    print("===============inner_model_config=================")
+    for key, value in inner_model_Config.items():
+        print(key + ": " + str(value))
+    print("===============outer_model_config=================")
+    for key, value in outer_model_Config.items():
+        print(key + ": " + str(value))
+
+    context.set_context(mode=context.GRAPH_MODE, device_target=initConfig['device_target'])
+
+    export_leo(initConfig, inner_model_Config, outer_model_Config)
+    print("successfully export LEO model!")
diff --git a/research/cv/LEO/model_utils/config.py b/research/cv/LEO/model_utils/config.py
index ac01fc5da..461d5ba85 100644
--- a/research/cv/LEO/model_utils/config.py
+++ b/research/cv/LEO/model_utils/config.py
@@ -56,6 +56,8 @@ def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="LEO-N5-K
     helper = {} if helper is None else helper
     choices = {} if choices is None else choices
     for item in cfg:
+        if item in ("dataset_name", "num_tr_examples_per_class"):
+            continue
         if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
             help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
             choice = choices[item] if item in choices else None
@@ -110,20 +112,28 @@ def merge(args, cfg):
     return cfg
 
 
-def get_config():
+def get_config(get_args=False):
     """
     Get Config according to the yaml file and cli arguments.
     """
     parser = argparse.ArgumentParser(description="default name", add_help=False)
-    config_dir = os.path.join(os.path.abspath(os.getcwd()), "config")
-    config_name = "LEO-N5-K" + str(os.getenv("NUM_TR_EXAMPLES_PER_CLASS")) \
-                  + "_" + os.getenv("DATA_NAME") + "_config.yaml"
-    parser.add_argument("--config_path", type=str,
-                        default=os.path.join(config_dir, config_name),
-                        help="Config file path")
+
+    parser.add_argument("--num_tr_examples_per_class", type=int,
+                        default=5,
+                        help="num_tr_examples_per_class")
+    parser.add_argument("--dataset_name", type=str,
+                        default="miniImageNet",
+                        help="dataset_name")
     path_args, _ = parser.parse_known_args()
-    default, helper, choices = parse_yaml(path_args.config_path)
-    args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
+    config_name = "LEO-N5-K" + str(path_args.num_tr_examples_per_class) \
+                  + "_" + path_args.dataset_name + "_config.yaml"
+    config_path = os.path.join(os.path.abspath(os.path.join(__file__, "../..")), "config", config_name)
+
+
+    default, helper, choices = parse_yaml(config_path)
+    args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=config_path)
+    if get_args:
+        return args
     final_config = merge(args, default)
     return Config(final_config)
 
diff --git a/research/cv/LEO/scripts/export.sh b/research/cv/LEO/scripts/export.sh
new file mode 100644
index 000000000..16872d732
--- /dev/null
+++ b/research/cv/LEO/scripts/export.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+if [[ $# -ne 6 ]]; then
+    echo "=============================================================================================================="
+    echo "Please run the script as: "
+    echo "bash scripts/run_eval_gpu.sh [DEVICE_NUM] [DEVICE_TARGET] [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [CKPT_FILE] "
+    echo "For example: bash scripts/export.sh 1 Ascend ../leo/leo-mindspore/embeddings miniImageNet 5 ./ckpts/xxx.ckpt "
+    echo "=============================================================================================================="
+    exit 1;
+fi
+
+export  GLOG_v=3
+export  DEVICE_ID=$1
+export  DEVICE_TARGET=$2
+export  DATA_PATH=$3
+export  DATA_NAME=$4
+export  NUM_TR_EXAMPLES_PER_CLASS=$5
+export  CKPT_FILE=$6
+
+nohup python export.py --device_target $DEVICE_TARGET \
+                     --data_path $DATA_PATH \
+                     --dataset_name $DATA_NAME \
+                     --num_tr_examples_per_class $NUM_TR_EXAMPLES_PER_CLASS \
+                     --ckpt_file $CKPT_FILE  > export.log 2>&1 &
diff --git a/research/cv/LEO/scripts/run_distribution_ascend.sh b/research/cv/LEO/scripts/run_distribution_ascend.sh
new file mode 100644
index 000000000..2f15b7fb5
--- /dev/null
+++ b/research/cv/LEO/scripts/run_distribution_ascend.sh
@@ -0,0 +1,71 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# an simple tutorial as follows, more parameters can be setting
+
+if [ $# != 6 ]
+then
+    echo "Usage: bash scripts/run_distribution_ascend.sh [RANK_TABLE_FILE] [DEVICE_TARGET] [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [SAVE_PATH]"
+    echo "For example: bash scripts/run_distribution_ascend.sh ./hccl_8p_01234567_127.0.0.1.json Ascend /home/jialing/leo/leo-mindspore/embeddings miniImageNet 5 ./ckpts/8P_mini_5
+"
+exit 1
+fi
+
+if [ ! -f $1 ]
+then
+    echo "error: RANK_TABLE_FILE=$1 is not a file"
+exit 1
+fi
+
+get_real_path(){
+  if [ "${1:0:1}" == "/" ]; then
+    echo "$1"
+  else
+    echo "$(realpath -m $PWD/$1)"
+  fi
+}
+
+ulimit -u unlimited
+export DEVICE_NUM=8
+export RANK_SIZE=8
+RANK_TABLE_FILE=$(realpath $1)
+export RANK_TABLE_FILE
+echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
+
+export  DEVICE_TARGET=$2
+export  DATA_PATH=$3
+export  DATA_NAME=$4
+export  NUM_TR_EXAMPLES_PER_CLASS=$5
+export  SAVE_PATH=$6
+
+export SERVER_ID=0
+rank_start=$((DEVICE_NUM * SERVER_ID))
+for((i=0; i<${DEVICE_NUM}; i++))
+do
+    export DEVICE_ID=$i
+    export RANK_ID=$((rank_start + i))
+    rm -rf ./train_parallel$i
+    mkdir ./train_parallel$i
+    cp -r ./src ./train_parallel$i
+    cp -r ./config ./train_parallel$i
+    cp -r ./model_utils ./train_parallel$i
+    cp ./train.py ./train_parallel$i
+    echo "start training for rank $RANK_ID, device $DEVICE_ID"
+    cd ./train_parallel$i ||exit
+    env > env.log
+    nohup python -u train.py --device_target $DEVICE_TARGET --data_path $DATA_PATH --dataset_name $DATA_NAME --num_tr_examples_per_class $NUM_TR_EXAMPLES_PER_CLASS --save_path $SAVE_PATH >log_distribution_ascend 2>&1 &
+    cd ..
+done
+
diff --git a/research/cv/LEO/scripts/run_eval_ascend.sh b/research/cv/LEO/scripts/run_eval_ascend.sh
new file mode 100644
index 000000000..7781e3d2a
--- /dev/null
+++ b/research/cv/LEO/scripts/run_eval_ascend.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+if [[ $# -ne 6 ]]; then
+    echo "=============================================================================================================="
+    echo "Please run the script as: "
+    echo "bash scripts/run_eval_gpu.sh [DEVICE_ID] [DEVICE_TARGET] [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [CKPT_FILE] "
+    echo "For example: bash scripts/run_eval_ascend.sh 4 Ascend ../leo/leo-mindspore/embeddings miniImageNet 5 ./ckpt/1P_mini_5/xxx.ckpt "
+    echo "=============================================================================================================="
+    exit 1;
+fi
+
+export  GLOG_v=3
+export  DEVICE_ID=$1
+export  DEVICE_TARGET=$2
+export  DATA_PATH=$3
+export  DATA_NAME=$4
+export  NUM_TR_EXAMPLES_PER_CLASS=$5
+export  CKPT_FILE=$6
+
+nohup python eval.py --device_target $DEVICE_TARGET \
+                     --data_path $DATA_PATH \
+                     --dataset_name $DATA_NAME \
+                     --num_tr_examples_per_class $NUM_TR_EXAMPLES_PER_CLASS \
+                     --ckpt_file $CKPT_FILE  > ${DATA_NAME}_${NUM_TR_EXAMPLES_PER_CLASS}_eval.log 2>&1 &
diff --git a/research/cv/LEO/scripts/run_eval_gpu.sh b/research/cv/LEO/scripts/run_eval_gpu.sh
index c3ced3c09..bfb8d3373 100644
--- a/research/cv/LEO/scripts/run_eval_gpu.sh
+++ b/research/cv/LEO/scripts/run_eval_gpu.sh
@@ -12,14 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-echo "============================================================================================================================"
-echo "Please run the script as: "
-echo "bash scripts/run_eval_gpu.sh [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [CKPT_FILE] "
-echo "For example: bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpt/1P_mini_1/xxx.ckpt "
-echo "============ bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpt/1P_mini_5/xxx.ckpt "
-echo "============ bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ tieredImageNet 1 ./ckpt/1P_tiered_1/xxx.ckpt "
-echo "============ bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/1P_tiered_5/xxx.ckpt "
-echo "============================================================================================================================"
+if [[ $# -ne 6 ]]; then
+    echo "============================================================================================================================"
+    echo "Please run the script as: "
+    echo "bash scripts/run_eval_gpu.sh [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [CKPT_FILE] "
+    echo "For example: bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpt/1P_mini_1/xxx.ckpt "
+    echo "============ bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpt/1P_mini_5/xxx.ckpt "
+    echo "============ bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ tieredImageNet 1 ./ckpt/1P_tiered_1/xxx.ckpt "
+    echo "============ bash scripts/run_eval_gpu.sh /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/1P_tiered_5/xxx.ckpt "
+    echo "============================================================================================================================"
+    exit 1;
+fi
+
 export  GLOG_v=3
 export  DEVICE_TARGET=GPU
 export  DATA_PATH=$1
diff --git a/research/cv/LEO/scripts/run_train_ascend.sh b/research/cv/LEO/scripts/run_train_ascend.sh
new file mode 100644
index 000000000..bd42217f8
--- /dev/null
+++ b/research/cv/LEO/scripts/run_train_ascend.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+if [[ $# -ne 6 ]]; then
+    echo "=============================================================================================================="
+    echo "Please run the script as: "
+    echo "bash scripts/run_train_gpu.sh [DEVICE_ID] [DEVICE_TARGET] [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [SAVE_PATH] "
+    echo "For example: bash scripts/run_train_ascend.sh 6 Ascend /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpts/1P_mini_5"
+    echo "=============================================================================================================="
+    exit 1;
+fi
+
+export  DEVICE_ID=$1
+export  DEVICE_TARGET=$2
+export  DATA_PATH=$3
+export  DATA_NAME=$4
+export  NUM_TR_EXAMPLES_PER_CLASS=$5
+export  SAVE_PATH=$6
+
+export  GLOG_v=3
+export  DEVICE_ID=$DEVICE_ID
+nohup python -u train.py --device_target $DEVICE_TARGET --data_path $DATA_PATH --dataset_name $DATA_NAME --num_tr_examples_per_class $NUM_TR_EXAMPLES_PER_CLASS --save_path $SAVE_PATH > ${DEVICE_NUM}P_${DATA_NAME}_${NUM_TR_EXAMPLES_PER_CLASS}_train.log 2>&1 &
diff --git a/research/cv/LEO/scripts/run_train_gpu.sh b/research/cv/LEO/scripts/run_train_gpu.sh
index 4c3d099dd..2161d4a84 100644
--- a/research/cv/LEO/scripts/run_train_gpu.sh
+++ b/research/cv/LEO/scripts/run_train_gpu.sh
@@ -12,20 +12,24 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-echo "===================================================================================================================="
-echo "Please run the script as: "
-echo "bash scripts/run_train_gpu.sh [DEVICE_NUM] [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [SAVE_PATH] "
-echo "For example: bash scripts/run_train_gpu.sh 1 /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpt/1P_mini_1"
-echo " ============bash scripts/run_train_gpu.sh 1 /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpt/1P_mini_5"
-echo " ============bash scripts/run_train_gpu.sh 1 /home/mindspore/dataset/embeddings/ tieredImageNet 1 ./ckpt/1P_tiered_1"
-echo " ============bash scripts/run_train_gpu.sh 1 /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/1P_tiered_5"  
-echo "===================================================================================================================="
-echo "Please run distributed training script as: "
-echo "For example: bash scripts/run_train_gpu.sh 8 /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpt/8P_mini_1 "
-echo " ============bash scripts/run_train_gpu.sh 8 /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpt/8P_mini_5"
-echo " ============bash scripts/run_train_gpu.sh 8 /home/mindspore/dataset/embeddings/ tieredImageNet 1 ./ckpt/8P_tiered_1"
-echo " ============bash scripts/run_train_gpu.sh 8 /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/8P_tiered_5"            
-echo "===================================================================================================================="
+if [[ $# -ne 6 ]]; then
+    echo "===================================================================================================================="
+    echo "Please run the script as: "
+    echo "bash scripts/run_train_gpu.sh [DEVICE_NUM] [DATA_PATH] [DATA_NAME] [NUM_TR_EXAMPLES_PER_CLASS] [SAVE_PATH] "
+    echo "For example: bash scripts/run_train_gpu.sh 1 /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpt/1P_mini_1"
+    echo " ============bash scripts/run_train_gpu.sh 1 /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpt/1P_mini_5"
+    echo " ============bash scripts/run_train_gpu.sh 1 /home/mindspore/dataset/embeddings/ tieredImageNet 1 ./ckpt/1P_tiered_1"
+    echo " ============bash scripts/run_train_gpu.sh 1 /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/1P_tiered_5"  
+    echo "===================================================================================================================="
+    echo "Please run distributed training script as: "
+    echo "For example: bash scripts/run_train_gpu.sh 8 /home/mindspore/dataset/embeddings/ miniImageNet 1 ./ckpt/8P_mini_1 "
+    echo " ============bash scripts/run_train_gpu.sh 8 /home/mindspore/dataset/embeddings/ miniImageNet 5 ./ckpt/8P_mini_5"
+    echo " ============bash scripts/run_train_gpu.sh 8 /home/mindspore/dataset/embeddings/ tieredImageNet 1 ./ckpt/8P_tiered_1"
+    echo " ============bash scripts/run_train_gpu.sh 8 /home/mindspore/dataset/embeddings/ tieredImageNet 5 ./ckpt/8P_tiered_5"            
+    echo "===================================================================================================================="
+    exit 1;
+fi
+
 export  DEVICE_NUM=$1
 export  DEVICE_TARGET=GPU
 export  DATA_PATH=$2
diff --git a/research/cv/LEO/train.py b/research/cv/LEO/train.py
index 812955ac4..574e0a722 100644
--- a/research/cv/LEO/train.py
+++ b/research/cv/LEO/train.py
@@ -19,8 +19,10 @@ import model_utils.config as config
 import src.data as data
 import src.outerloop as outerloop
 from src.trainonestepcell import TrainOneStepCell
+import mindspore
+import mindspore.nn as nn
 from mindspore import context
-from mindspore import save_checkpoint, load_param_into_net, load_checkpoint
+from mindspore import save_checkpoint, load_param_into_net
 from mindspore.communication.management import init
 from mindspore.context import ParallelMode
 
@@ -29,12 +31,35 @@ os.environ['GLOG_v'] = "3"
 os.environ['GLOG_log_dir'] = '/var/log'
 
 
+def save_checkpoint_to_file(if_save_checkpoint, val_accs, best_acc, step, val_losses, init_config, train_outer_loop):
+    if if_save_checkpoint:
+        if not sum(val_accs) / len(val_accs) < best_acc:
+            best_acc = sum(val_accs) / len(val_accs)
+            model_name = '%dk_%4.4f_%4.4f_model.ckpt' % (
+                (step // 1000 + 1),
+                sum(val_losses) / len(val_losses),
+                sum(val_accs) / len(val_accs))
+
+            check_dir(init_config['save_path'])
+
+            if args.enable_modelarts:
+                save_checkpoint_path = '/cache/train_output/device_' + \
+                                       os.getenv('DEVICE_ID') + '/'
+                save_checkpoint_path = '/cache/train_output/'
+                if not os.path.exists(save_checkpoint_path):
+                    os.makedirs(save_checkpoint_path)
+                save_checkpoint(train_outer_loop, os.path.join(save_checkpoint_path, model_name))
+            else:
+                save_checkpoint(train_outer_loop, os.path.join(init_config['save_path'], model_name))
+            print('Saved checkpoint %s...' % model_name)
+
+
 def train_leo(init_config, inner_model_config, outer_model_config):
     inner_lr_init = inner_model_config['inner_lr_init']
     finetuning_lr_init = inner_model_config['finetuning_lr_init']
 
     total_train_steps = outer_model_config['total_steps']
-    val_every_step = 5000
+    val_every_step = 3000
     total_val_steps = 100
     if_save_checkpoint = True
     best_acc = 0
@@ -72,8 +97,11 @@ def train_leo(init_config, inner_model_config, outer_model_config):
         inner_step=inner_model_config['inner_unroll_length'],
         finetune_inner_step=inner_model_config['finetuning_unroll_length'], is_meta_training=True)
 
-    parm_dict = load_checkpoint('./resource/leo_ms_init.ckpt')
-    load_param_into_net(train_outer_loop, parm_dict)
+    if context.get_context("device_target") == "Ascend":
+        train_outer_loop.to_float(mindspore.float32)
+        for _, cell in train_outer_loop.cells_and_names():
+            if isinstance(cell, nn.Dense):
+                cell.to_float(mindspore.float16)
 
     train_net = TrainOneStepCell(train_outer_loop,
                                  outer_model_config['outer_lr'],
@@ -105,8 +133,8 @@ def train_leo(init_config, inner_model_config, outer_model_config):
                                                         train_net.group_params[0]['params'][1].T.asnumpy(),
                                                         val_acc.asnumpy(), now_t-old_t))
 
-        if step % val_every_step == 4999:
-            print('5000 step average time: %4.4f second...'%(sum_steptime/5000))
+        if step % val_every_step == 2999:
+            print('3000 step average time: %4.4f second...'%(sum_steptime/3000))
             sum_steptime = 0
 
             val_losses = []
@@ -128,18 +156,8 @@ def train_leo(init_config, inner_model_config, outer_model_config):
                   (sum(val_losses)/len(val_losses), sum(val_accs)/len(val_accs)))
             print('=' * 50)
 
-            if if_save_checkpoint:
-                if not sum(val_accs)/len(val_accs) < best_acc:
-                    best_acc = sum(val_accs)/len(val_accs)
-                    model_name = '%dk_%4.4f_%4.4f_model.ckpt' % (
-                        (step//1000+1),
-                        sum(val_losses)/len(val_losses),
-                        sum(val_accs)/len(val_accs))
-
-                    check_dir(init_config['save_path'])
-
-                    save_checkpoint(train_outer_loop, os.path.join(init_config['save_path'], model_name))
-                    print('Saved checkpoint %s...'%model_name)
+            save_checkpoint_to_file(if_save_checkpoint, val_accs, best_acc, step, val_losses,
+                                    init_config, train_outer_loop)
 
         if step == (total_train_steps-1):
             train_end = time.time()
@@ -166,6 +184,8 @@ if __name__ == '__main__':
     initConfig = config.get_init_config()
     inner_model_Config = config.get_inner_model_config()
     outer_model_Config = config.get_outer_model_config()
+    args = config.get_config(get_args=True)
+
 
     print("===============inner_model_config=================")
     for key, value in inner_model_Config.items():
@@ -175,10 +195,21 @@ if __name__ == '__main__':
         print(key+": "+str(value))
 
     context.set_context(mode=context.GRAPH_MODE, device_target=initConfig['device_target'])
-    if initConfig['device_num'] > 1:
+    if args.enable_modelarts:
+        import moxing as mox
+
+        mox.file.copy_parallel(
+            src_url=args.data_url, dst_url='/cache/dataset/device_' + os.getenv('DEVICE_ID'))
+        train_dataset_path = os.path.join('/cache/dataset/device_' + os.getenv('DEVICE_ID'), "embeddings")
+        initConfig['data_path'] = train_dataset_path
+
+    elif initConfig['device_num'] > 1:
         init('nccl')
         context.set_auto_parallel_context(device_num=initConfig['device_num'],
                                           parallel_mode=ParallelMode.DATA_PARALLEL,
                                           gradients_mean=True)
 
     train_leo(initConfig, inner_model_Config, outer_model_Config)
+    if args.enable_modelarts:
+        mox.file.copy_parallel(
+            src_url='/cache/train_output', dst_url=args.train_url)
-- 
GitLab