From f8585c86a00f629ebcc3a8c7002e9c735a0f316a Mon Sep 17 00:00:00 2001
From: yexijoe <yexijoe@163.com>
Date: Thu, 30 Sep 2021 13:03:27 +0800
Subject: [PATCH] add vit_base

---
 research/cv/vit_base/README_CN.md             | 313 ++++++++++++++++++
 .../vit_base/ascend310_infer/CMakeLists.txt   |  15 +
 research/cv/vit_base/ascend310_infer/build.sh |  23 ++
 .../cv/vit_base/ascend310_infer/inc/utils.h   |  32 ++
 .../cv/vit_base/ascend310_infer/src/main.cc   | 162 +++++++++
 .../cv/vit_base/ascend310_infer/src/utils.cc  | 159 +++++++++
 research/cv/vit_base/eval.py                  |  89 +++++
 research/cv/vit_base/export.py                |  62 ++++
 research/cv/vit_base/postprocess.py           |  45 +++
 research/cv/vit_base/preprocess.py            |  38 +++
 research/cv/vit_base/requirements.txt         |   1 +
 .../scripts/run_distribution_train_ascend.sh  |  43 +++
 research/cv/vit_base/scripts/run_infer_310.sh | 121 +++++++
 .../scripts/run_standalone_eval_ascend.sh     |  21 ++
 .../scripts/run_standalone_train_ascend.sh    |  21 ++
 research/cv/vit_base/src/config.py            |  53 +++
 research/cv/vit_base/src/dataset.py           |  80 +++++
 research/cv/vit_base/src/modeling_ms.py       | 221 +++++++++++++
 research/cv/vit_base/src/net_config.py        | 117 +++++++
 research/cv/vit_base/train.py                 | 211 ++++++++++++
 20 files changed, 1827 insertions(+)
 create mode 100644 research/cv/vit_base/README_CN.md
 create mode 100644 research/cv/vit_base/ascend310_infer/CMakeLists.txt
 create mode 100644 research/cv/vit_base/ascend310_infer/build.sh
 create mode 100644 research/cv/vit_base/ascend310_infer/inc/utils.h
 create mode 100644 research/cv/vit_base/ascend310_infer/src/main.cc
 create mode 100644 research/cv/vit_base/ascend310_infer/src/utils.cc
 create mode 100644 research/cv/vit_base/eval.py
 create mode 100644 research/cv/vit_base/export.py
 create mode 100644 research/cv/vit_base/postprocess.py
 create mode 100644 research/cv/vit_base/preprocess.py
 create mode 100644 research/cv/vit_base/requirements.txt
 create mode 100644 research/cv/vit_base/scripts/run_distribution_train_ascend.sh
 create mode 100644 research/cv/vit_base/scripts/run_infer_310.sh
 create mode 100644 research/cv/vit_base/scripts/run_standalone_eval_ascend.sh
 create mode 100644 research/cv/vit_base/scripts/run_standalone_train_ascend.sh
 create mode 100644 research/cv/vit_base/src/config.py
 create mode 100644 research/cv/vit_base/src/dataset.py
 create mode 100644 research/cv/vit_base/src/modeling_ms.py
 create mode 100644 research/cv/vit_base/src/net_config.py
 create mode 100644 research/cv/vit_base/train.py

diff --git a/research/cv/vit_base/README_CN.md b/research/cv/vit_base/README_CN.md
new file mode 100644
index 000000000..b9cf74c33
--- /dev/null
+++ b/research/cv/vit_base/README_CN.md
@@ -0,0 +1,313 @@
+# 鐩綍
+
+<!-- TOC -->
+
+- [鐩綍](#鐩綍)
+- [vit_base鎻忚堪](#vit_base鎻忚堪)
+- [妯″瀷鏋舵瀯](#妯″瀷鏋舵瀯)
+- [鏁版嵁闆哴(#鏁版嵁闆�)
+- [鐗规€(#鐗规€�)
+    - [娣峰悎绮惧害](#娣峰悎绮惧害)
+- [鐜瑕佹眰](#鐜瑕佹眰)
+- [蹇€熷叆闂╙(#蹇€熷叆闂�)
+- [鑴氭湰璇存槑](#鑴氭湰璇存槑)
+    - [鑴氭湰鍙婃牱渚嬩唬鐮乚(#鑴氭湰鍙婃牱渚嬩唬鐮�)
+    - [鑴氭湰鍙傛暟](#鑴氭湰鍙傛暟)
+    - [璁粌杩囩▼](#璁粌杩囩▼)
+        - [璁粌](#璁粌)
+        - [鍒嗗竷寮忚缁僝(#鍒嗗竷寮忚缁�)
+    - [璇勪及杩囩▼](#璇勪及杩囩▼)
+        - [璇勪及](#璇勪及)
+    - [瀵煎嚭杩囩▼](#瀵煎嚭杩囩▼)
+        - [瀵煎嚭](#瀵煎嚭)
+    - [鎺ㄧ悊杩囩▼](#鎺ㄧ悊杩囩▼)
+        - [鎺ㄧ悊](#鎺ㄧ悊)
+- [妯″瀷鎻忚堪](#妯″瀷鎻忚堪)
+    - [鎬ц兘](#鎬ц兘)
+        - [璇勪及鎬ц兘](#璇勪及鎬ц兘)
+            - [CIFAR-10涓婄殑vit_base](#cifar-10涓婄殑vit_base)
+        - [鎺ㄧ悊鎬ц兘](#鎺ㄧ悊鎬ц兘)
+            - [CIFAR-10涓婄殑vit_base](#cifar-10涓婄殑vit_base)
+- [ModelZoo涓婚〉](#modelzoo涓婚〉)
+
+<!-- /TOC -->
+
+# vit_base鎻忚堪
+
+Transformer鏋舵瀯宸插箍娉涘簲鐢ㄤ簬鑷劧璇█澶勭悊棰嗗煙銆傛湰妯″瀷鐨勪綔鑰呭彂鐜帮紝Vision Transformer锛圴iT锛夋ā鍨嬪湪璁$畻鏈鸿瑙夐鍩熶腑瀵笴NN鐨勪緷璧栦笉鏄繀闇€鐨勶紝鐩存帴灏嗗叾搴旂敤浜庡浘鍍忓潡搴忓垪鏉ヨ繘琛屽浘鍍忓垎绫绘椂锛屼篃鑳藉緱鍒板拰鐩墠鍗风Н缃戠粶鐩稿缇庣殑鍑嗙‘鐜囥€�
+
+[璁烘枃](https://arxiv.org/abs/2010.11929) 锛欴osovitskiy, A. , Beyer, L. , Kolesnikov, A. , Weissenborn, D. , & Houlsby, N.. (2020). An image is worth 16x16 words: transformers for image recognition at scale.
+
+# 妯″瀷鏋舵瀯
+
+vit_base鐨勬€讳綋缃戠粶鏋舵瀯濡備笅锛� [閾炬帴](https://arxiv.org/abs/2010.11929)
+
+# 鏁版嵁闆�
+
+浣跨敤鐨勬暟鎹泦锛歔CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>)
+
+- 鏁版嵁闆嗗ぇ灏忥細175M锛屽叡10涓被銆�6涓囧紶褰╄壊鍥惧儚
+    - 璁粌闆嗭細146M锛屽叡5涓囧紶鍥惧儚
+    - 娴嬭瘯闆嗭細29M锛屽叡1涓囧紶鍥惧儚
+- 鏁版嵁鏍煎紡锛氫簩杩涘埗鏂囦欢
+    - 娉細鏁版嵁灏嗗湪src/dataset.py涓鐞嗐€�
+
+# 鐗规€�
+
+## 娣峰悎绮惧害
+
+閲囩敤[娣峰悎绮惧害](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.3/enable_mixed_precision.html) 鐨勮缁冩柟娉曪紝浣跨敤鏀寔鍗曠簿搴﹀拰鍗婄簿搴︽暟鎹潵鎻愰珮娣卞害瀛︿範绁炵粡缃戠粶鐨勮缁冮€熷害锛屽悓鏃朵繚鎸佸崟绮惧害璁粌鎵€鑳借揪鍒扮殑缃戠粶绮惧害銆傛贩鍚堢簿搴﹁缁冩彁楂樿绠楅€熷害銆佸噺灏戝唴瀛樹娇鐢ㄧ殑鍚屾椂锛屾敮鎸佸湪鐗瑰畾纭欢涓婅缁冩洿澶х殑妯″瀷鎴栧疄鐜版洿澶ф壒娆$殑璁粌銆�
+
+# 鐜瑕佹眰
+
+- 纭欢锛圓scend锛�
+    - 浣跨敤Ascend鏉ユ惌寤虹‖浠剁幆澧冦€�
+- 妗嗘灦
+    - [MindSpore](https://www.mindspore.cn/install/en)
+- 濡傞渶鏌ョ湅璇︽儏锛岃鍙傝濡備笅璧勬簮锛�
+    - [MindSpore鏁欑▼](https://www.mindspore.cn/tutorials/zh-CN/r1.3/index.html)
+    - [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/r1.3/index.html)
+
+# 蹇€熷叆闂�
+
+閫氳繃瀹樻柟缃戠珯瀹夎MindSpore鍚庯紝鎮ㄥ彲浠ユ寜鐓у涓嬫楠よ繘琛岃缁冨拰璇勪及锛岀壒鍒湴锛岃繘琛岃缁冨墠闇€瑕佸厛涓嬭浇瀹樻柟鍩轰簬ImageNet21k鐨勯璁粌妯″瀷[ViT-B_16](https://console.cloud.google.com/storage/vit_models/) 锛屽苟灏嗗叾杞崲涓篗indSpore鏀寔鐨刢kpt鏍煎紡妯″瀷锛屽懡鍚嶄负"cifar10_pre_checkpoint_based_imagenet21k.ckpt"锛屽拰璁粌闆嗘祴璇曢泦鏁版嵁鏀句簬鍚屼竴绾х洰褰曚笅锛�
+
+- Ascend澶勭悊鍣ㄧ幆澧冭繍琛�
+
+  ```python
+  # 杩愯璁粌绀轰緥
+  python train.py --device_id=0 --dataset_name='cifar10' > train.log 2>&1 &
+  OR
+  bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [DATASET_NAME]
+
+  # 杩愯鍒嗗竷寮忚缁冪ず渚�
+  bash ./scripts/run_distribution_train_ascend.sh [RANK_TABLE] [DEVICE_NUM] [DEVICE_START] [DATASET_NAME]
+
+  # 杩愯璇勪及绀轰緥
+  python eval.py --checkpoint_path [CKPT_PATH] ./eval.log 2>&1 &
+  OR
+  bash ./scripts/run_standalone_eval_ascend.sh [CKPT_PATH]
+
+  # 杩愯鎺ㄧ悊绀轰緥
+  bash run_infer_310.sh ../vit_base.mindir Cifar10 /home/dataset/cifar-10-verify-bin/ 0
+  ```
+
+  瀵逛簬鍒嗗竷寮忚缁冿紝闇€瑕佹彁鍓嶅垱寤篔SON鏍煎紡鐨刪ccl閰嶇疆鏂囦欢銆�
+
+  璇烽伒寰互涓嬮摼鎺ヤ腑鐨勮鏄庯細
+
+ <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
+
+- 鍦� ModelArts 杩涜璁粌 (濡傛灉浣犳兂鍦╩odelarts涓婅繍琛岋紝鍙互鍙傝€冧互涓嬫枃妗� [modelarts](https://support.huaweicloud.com/modelarts/))
+
+    - 鍦� ModelArts 涓婁娇鐢ㄥ鍗¤缁� cifar10 鏁版嵁闆�
+
+      ```python
+      # (1) 鍦ㄧ綉椤典笂璁剧疆AI寮曟搸涓篗indSpore
+      # (2) 鍦ㄧ綉椤典笂璁剧疆 "ckpt_url=obs://path/pre_ckpt/"锛堥璁粌妯″瀷鍛藉悕涓�"cifar10_pre_checkpoint_based_imagenet21k.ckpt"锛�
+      #     鍦ㄧ綉椤典笂璁剧疆 "modelarts=True"
+      #     鍦ㄧ綉椤典笂璁剧疆 鍏朵粬鍙傛暟
+      # (3) 涓婁紶浣犵殑鏁版嵁闆嗗埌 S3 妗朵笂
+      # (4) 鍦ㄧ綉椤典笂璁剧疆浣犵殑浠g爜璺緞涓� "/path/vit_base"
+      # (5) 鍦ㄧ綉椤典笂璁剧疆鍚姩鏂囦欢涓� "train.py"
+      # (6) 鍦ㄧ綉椤典笂璁剧疆"璁粌鏁版嵁闆嗭紙濡�/dataset/cifar10/cifar-10-batches-bin/锛�"銆�"璁粌杈撳嚭鏂囦欢璺緞"銆�"浣滀笟鏃ュ織璺緞"绛�
+      # (7) 鍒涘缓璁粌浣滀笟
+      ```
+
+# 鑴氭湰璇存槑
+
+## 鑴氭湰鍙婃牱渚嬩唬鐮�
+
+```bash
+鈹溾攢鈹€ models
+    鈹溾攢鈹€ README.md                                  // 鎵€鏈夋ā鍨嬬浉鍏宠鏄�
+    鈹溾攢鈹€ vit_base
+        鈹溾攢鈹€ README_CN.md                           // vit_base鐩稿叧璇存槑
+        鈹溾攢鈹€ ascend310_infer                        // 瀹炵幇310鎺ㄧ悊婧愪唬鐮�
+        鈹溾攢鈹€ scripts
+        鈹�   鈹溾攢鈹€run_distribution_train_ascend.sh    // 鍒嗗竷寮忓埌Ascend鐨剆hell鑴氭湰
+        鈹�   鈹溾攢鈹€run_infer_310.sh                    // Ascend鎺ㄧ悊鐨剆hell鑴氭湰
+        鈹�   鈹溾攢鈹€run_standalone_eval_ascend.sh       // Ascend璇勪及鐨剆hell鑴氭湰
+        鈹�   鈹溾攢鈹€run_standalone_train_ascend.sh      // Ascend鍗曞崱璁粌鐨剆hell鑴氭湰
+        鈹溾攢鈹€ src
+        鈹�   鈹溾攢鈹€config.py                           // 鍙傛暟閰嶇疆
+        鈹�   鈹溾攢鈹€dataset.py                          // 鍒涘缓鏁版嵁闆�
+        鈹�   鈹溾攢鈹€modeling_ms.py                      // vit_base鏋舵瀯
+        鈹�   鈹溾攢鈹€net_config.py                       // 缁撴瀯鍙傛暟閰嶇疆
+        鈹溾攢鈹€ eval.py                                // 璇勪及鑴氭湰
+        鈹溾攢鈹€ export.py                              // 灏哻heckpoint鏂囦欢瀵煎嚭鍒癮ir/mindir
+        鈹溾攢鈹€ postprocess.py                         // 310鎺ㄧ悊鍚庡鐞嗚剼鏈�
+        鈹溾攢鈹€ preprocess.py                          // 310鎺ㄧ悊鍓嶅鐞嗚剼鏈�
+        鈹溾攢鈹€ train.py                               // 璁粌鑴氭湰
+```
+
+## 鑴氭湰鍙傛暟
+
+鍦╟onfig.py涓彲浠ュ悓鏃堕厤缃缁冨弬鏁板拰璇勪及鍙傛暟銆�
+
+- 閰嶇疆vit_base鍜孋IFAR-10鏁版嵁闆嗐€�
+
+  ```python
+  'name':'cifar10'         # 鏁版嵁闆�
+  'pre_trained':True       # 鏄惁鍩轰簬棰勮缁冩ā鍨嬭缁�
+  'num_classes':10         # 鏁版嵁闆嗙被鏁�
+  'lr_init':0.013          # 鍒濆瀛︿範鐜囷紝鍙屽崱骞惰璁粌
+  'batch_size':32          # 璁粌鎵规澶у皬
+  'epoch_size':60          # 鎬昏璁粌epoch鏁�
+  'momentum':0.9           # 鍔ㄩ噺
+  'weight_decay':1e-4      # 鏉冮噸琛板噺鍊�
+  'image_height':224       # 杈撳叆鍒版ā鍨嬬殑鍥惧儚楂樺害
+  'image_width':224        # 杈撳叆鍒版ā鍨嬬殑鍥惧儚瀹藉害
+  'data_path':'/dataset/cifar10/cifar-10-batches-bin/'     # 璁粌鏁版嵁闆嗙殑缁濆鍏ㄨ矾寰�
+  'val_data_path':'/dataset/cifar10/cifar-10-verify-bin/'  # 璇勪及鏁版嵁闆嗙殑缁濆鍏ㄨ矾寰�
+  'device_target':'Ascend' # 杩愯璁惧
+  'device_id':0            # 鐢ㄤ簬璁粌鎴栬瘎浼版暟鎹泦鐨勮澶嘔D锛岃繘琛屽垎甯冨紡璁粌鏃跺彲浠ュ拷鐣�
+  'keep_checkpoint_max':2  # 鏈€澶氫繚瀛�2涓猚kpt妯″瀷鏂囦欢
+  'checkpoint_path':'/dataset/cifar10_pre_checkpoint_based_imagenet21k.ckpt'  # 淇濆瓨棰勮缁冩ā鍨嬬殑缁濆鍏ㄨ矾寰�
+  # optimizer and lr related
+  'lr_scheduler':'cosine_annealing'
+  'T_max':50
+  ```
+
+鏇村閰嶇疆缁嗚妭璇峰弬鑰冭剼鏈琡config.py`銆�
+
+## 璁粌杩囩▼
+
+### 璁粌
+
+- Ascend澶勭悊鍣ㄧ幆澧冭繍琛�
+
+  ```bash
+  python train.py --device_id=0 --dataset_name='cifar10' > train.log 2>&1 &
+  OR
+  bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [DATASET_NAME]
+  ```
+
+  涓婅堪python鍛戒护灏嗗湪鍚庡彴杩愯锛屽彲浠ラ€氳繃鐢熸垚鐨則rain.log鏂囦欢鏌ョ湅缁撴灉銆�
+
+  璁粌缁撴潫鍚庯紝鍙互鍦ㄩ粯璁よ剼鏈枃浠跺す涓嬪緱鍒版崯澶卞€硷細
+
+  ```bash
+  Load pre_trained ckpt: ./cifar10_pre_checkpoint_based_imagenet21k.ckpt
+  epoch: 1 step: 1562, loss is 0.12886986
+  epoch time: 289458.121 ms, per step time: 185.312 ms
+  epoch: 2 step: 1562, loss is 0.15596801
+  epoch time: 245404.168 ms, per step time: 157.109 ms
+  {'acc': 0.9240785256410257}
+  epoch: 3 step: 1562, loss is 0.06133139
+  epoch time: 244538.410 ms, per step time: 156.555 ms
+  epoch: 4 step: 1562, loss is 0.28615832
+  epoch time: 245382.652 ms, per step time: 157.095 ms
+  {'acc': 0.9597355769230769}
+  ```
+
+### 鍒嗗竷寮忚缁�
+
+- Ascend澶勭悊鍣ㄧ幆澧冭繍琛�
+
+  ```bash
+  bash ./scripts/run_distribution_train_ascend.sh [RANK_TABLE] [DEVICE_NUM] [DEVICE_START] [DATASET_NAME]
+  ```
+
+  涓婅堪shell鑴氭湰灏嗗湪鍚庡彴杩愯鍒嗗竷璁粌銆�
+
+  璁粌缁撴潫鍚庯紝鍙互寰楀埌鎹熷け鍊硷細
+
+  ```bash
+  Load pre_trained ckpt: ./cifar10_pre_checkpoint_based_imagenet21k.ckpt
+  epoch: 1 step: 781, loss is 0.015172593
+  epoch time: 195952.289 ms, per step time: 250.899 ms
+  epoch: 2 step: 781, loss is 0.06709316
+  epoch time: 135894.327 ms, per step time: 174.000 ms
+  {'acc': 0.9853766025641025}
+  epoch: 3 step: 781, loss is 0.050968178
+  epoch time: 135056.020 ms, per step time: 172.927 ms
+  epoch: 4 step: 781, loss is 0.01949552
+  epoch time: 136084.816 ms, per step time: 174.244 ms
+  {'acc': 0.9854767628205128}
+  ```
+
+## 璇勪及杩囩▼
+
+### 璇勪及
+
+- 鍦ˋscend鐜杩愯鏃惰瘎浼癈IFAR-10鏁版嵁闆�
+
+  ```bash
+  python eval.py --checkpoint_path [CKPT_PATH] ./eval.log 2>&1 &
+  OR
+  bash ./scripts/run_standalone_eval_ascend.sh [CKPT_PATH]
+  ```
+
+## 瀵煎嚭杩囩▼
+
+### 瀵煎嚭
+
+灏哻heckpoint鏂囦欢瀵煎嚭鎴恗indir鏍煎紡妯″瀷銆�
+
+  ```shell
+  python export.py --ckpt_file [CKPT_FILE]
+  ```
+
+## 鎺ㄧ悊杩囩▼
+
+### 鎺ㄧ悊
+
+鍦ㄨ繘琛屾帹鐞嗕箣鍓嶆垜浠渶瑕佸厛瀵煎嚭妯″瀷銆俶indir鍙互鍦ㄤ换鎰忕幆澧冧笂瀵煎嚭锛宎ir妯″瀷鍙兘鍦ㄦ槆鑵�910鐜涓婂鍑恒€備互涓嬪睍绀轰簡浣跨敤mindir妯″瀷鎵ц鎺ㄧ悊鐨勭ず渚嬨€�
+
+- 鍦ㄦ槆鑵�310涓婁娇鐢–IFAR-10鏁版嵁闆嗚繘琛屾帹鐞�
+
+  鎵ц鎺ㄧ悊鐨勫懡浠ゅ涓嬫墍绀猴紝鍏朵腑'MINDIR_PATH'鏄痬indir鏂囦欢璺緞锛�'DATASET'鏄娇鐢ㄧ殑鎺ㄧ悊鏁版嵁闆嗗悕绉帮紝涓�'Cifar10'锛�'DATA_PATH'鏄帹鐞嗘暟鎹泦璺緞锛�'DEVICE_ID'鍙€夛紝榛樿鍊间负0銆�
+
+  ```shell
+  # Ascend310 inference
+  bash run_infer_310.sh [MINDIR_PATH] [DATASET] [DATA_PATH] [DEVICE_ID]
+  ```
+
+  鎺ㄧ悊鐨勭簿搴︾粨鏋滀繚瀛樺湪scripts鐩綍涓嬶紝鍦╝cc.log鏃ュ織鏂囦欢涓彲浠ユ壘鍒扮被浼间互涓嬬殑鍒嗙被鍑嗙‘鐜囩粨鏋溿€傛帹鐞嗙殑鎬ц兘缁撴灉淇濆瓨鍦╯cripts/time_Result鐩綍涓嬶紝鍦╰est_perform_static.txt鏂囦欢涓彲浠ユ壘鍒扮被浼间互涓嬬殑鎬ц兘缁撴灉銆�
+
+  ```shell
+  after allreduce eval: top1_correct=9854, tot=10000, acc=98.54%
+  NN inference cost average time: 52.2274 ms of infer_count 10000
+  ```
+
+# 妯″瀷鎻忚堪
+
+## 鎬ц兘
+
+### 璇勪及鎬ц兘
+
+#### CIFAR-10涓婄殑vit_base
+
+| 鍙傛暟                 | Ascend                                                      |
+| -------------------------- | ----------------------------------------------------------- |
+| 妯″瀷鐗堟湰              | vit_base                                                |
+| 璧勬簮                   | Ascend 910锛汣PU 2.60GHz锛�192鏍革紱鍐呭瓨 755G锛涚郴缁� Red Hat 8.3.1-5         |
+| 涓婁紶鏃ユ湡              | 2021-10-26                                 |
+| MindSpore鐗堟湰          | 1.3.0                                                 |
+| 鏁版嵁闆�                    | CIFAR-10                                                |
+| 璁粌鍙傛暟        | epoch=60, batch_size=32, lr_init=0.013锛堝弻鍗″苟琛岃缁冩椂锛�             |
+| 浼樺寲鍣�                  | Momentum                                                    |
+| 鎹熷け鍑芥暟              | Softmax浜ゅ弶鐔�                                       |
+| 杈撳嚭                    | 姒傜巼                                                 |
+| 鍒嗙被鍑嗙‘鐜�             | 鍙屽崱锛�98.99%               |
+| 閫熷害                      | 鍗曞崱锛�157姣/姝ワ紱鍏崱锛�174姣/姝�                        |
+| 鎬绘椂闀�                 | 鍙屽崱锛�2.48灏忔椂/60杞�                                             |
+
+### 鎺ㄧ悊鎬ц兘
+
+#### CIFAR-10涓婄殑vit_base
+
+| 鍙傛暟                 | Ascend                                                       |
+| -------------------------- | ----------------------------------------------------------- |
+| 妯″瀷鐗堟湰              | vit_base                                                |
+| 璧勬簮                   | Ascend 310               |
+| 涓婁紶鏃ユ湡              | 2021-10-26                                 |
+| MindSpore鐗堟湰          | 1.3.0                                                 |
+| 鏁版嵁闆�                    | CIFAR-10                                                |
+| 鍒嗙被鍑嗙‘鐜�             | 98.54%                       |
+| 閫熷害                      | NN inference cost average time: 52.2274 ms of infer_count 10000           |
+
+# ModelZoo涓婚〉  
+
+ 璇锋祻瑙堝畼缃慬涓婚〉](https://gitee.com/mindspore/models) 銆�
\ No newline at end of file
diff --git a/research/cv/vit_base/ascend310_infer/CMakeLists.txt b/research/cv/vit_base/ascend310_infer/CMakeLists.txt
new file mode 100644
index 000000000..d3fa58018
--- /dev/null
+++ b/research/cv/vit_base/ascend310_infer/CMakeLists.txt
@@ -0,0 +1,15 @@
+cmake_minimum_required(VERSION 3.14.1)
+project(Ascend310Infer)
+add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
+set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
+option(MINDSPORE_PATH "mindspore install path" "")
+include_directories(${MINDSPORE_PATH})
+include_directories(${MINDSPORE_PATH}/include)
+include_directories(${PROJECT_SRC_ROOT})
+find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
+file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
+find_package(gflags REQUIRED)
+
+add_executable(main src/main.cc src/utils.cc)
+target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
diff --git a/research/cv/vit_base/ascend310_infer/build.sh b/research/cv/vit_base/ascend310_infer/build.sh
new file mode 100644
index 000000000..770a8851e
--- /dev/null
+++ b/research/cv/vit_base/ascend310_infer/build.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ ! -d out ]; then
+  mkdir out
+fi
+cd out || exit
+cmake .. \
+    -DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
+make
diff --git a/research/cv/vit_base/ascend310_infer/inc/utils.h b/research/cv/vit_base/ascend310_infer/inc/utils.h
new file mode 100644
index 000000000..efebe03a8
--- /dev/null
+++ b/research/cv/vit_base/ascend310_infer/inc/utils.h
@@ -0,0 +1,32 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ * 
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * 
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_INFERENCE_UTILS_H_
+#define MINDSPORE_INFERENCE_UTILS_H_
+
+#include <sys/stat.h>
+#include <dirent.h>
+#include <vector>
+#include <string>
+#include <memory>
+#include "include/api/types.h"
+
+std::vector<std::string> GetAllFiles(std::string_view dirName);
+DIR *OpenDir(std::string_view dirName);
+std::string RealPath(std::string_view path);
+mindspore::MSTensor ReadFileToTensor(const std::string &file);
+int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
+#endif
diff --git a/research/cv/vit_base/ascend310_infer/src/main.cc b/research/cv/vit_base/ascend310_infer/src/main.cc
new file mode 100644
index 000000000..1b3387b1c
--- /dev/null
+++ b/research/cv/vit_base/ascend310_infer/src/main.cc
@@ -0,0 +1,162 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <sys/time.h>
+#include <gflags/gflags.h>
+#include <dirent.h>
+#include <iostream>
+#include <string>
+#include <algorithm>
+#include <iosfwd>
+#include <vector>
+#include <fstream>
+#include <sstream>
+
+#include "../inc/utils.h"
+#include "include/dataset/execute.h"
+#include "include/dataset/transforms.h"
+#include "include/dataset/vision.h"
+#include "include/dataset/vision_ascend.h"
+#include "include/api/types.h"
+#include "include/api/model.h"
+#include "include/api/serialization.h"
+#include "include/api/context.h"
+
+using mindspore::Serialization;
+using mindspore::Model;
+using mindspore::Context;
+using mindspore::Status;
+using mindspore::ModelType;
+using mindspore::Graph;
+using mindspore::GraphCell;
+using mindspore::kSuccess;
+using mindspore::MSTensor;
+using mindspore::DataType;
+using mindspore::dataset::Execute;
+using mindspore::dataset::TensorTransform;
+using mindspore::dataset::vision::Decode;
+using mindspore::dataset::vision::Resize;
+using mindspore::dataset::vision::CenterCrop;
+using mindspore::dataset::vision::Normalize;
+using mindspore::dataset::vision::HWC2CHW;
+
+using mindspore::dataset::transforms::TypeCast;
+
+
+DEFINE_string(model_path, "", "model path");
+DEFINE_string(dataset, "Cifar10", "dataset: ImageNet or Cifar10");
+DEFINE_string(dataset_path, ".", "dataset path");
+DEFINE_int32(device_id, 0, "device id");
+
+int main(int argc, char **argv) {
+    gflags::ParseCommandLineFlags(&argc, &argv, true);
+    if (RealPath(FLAGS_model_path).empty()) {
+        std::cout << "Invalid model" << std::endl;
+        return 1;
+    }
+
+    std::transform(FLAGS_dataset.begin(), FLAGS_dataset.end(), FLAGS_dataset.begin(), ::tolower);
+
+    auto context = std::make_shared<Context>();
+    auto ascend310_info = std::make_shared<mindspore::Ascend310DeviceInfo>();
+    ascend310_info->SetDeviceID(FLAGS_device_id);
+    context->MutableDeviceInfo().push_back(ascend310_info);
+
+    Graph graph;
+    Status ret = Serialization::Load(FLAGS_model_path, ModelType::kMindIR, &graph);
+    if (ret != kSuccess) {
+        std::cout << "Load model failed." << std::endl;
+        return 1;
+    }
+
+    Model model;
+    ret = model.Build(GraphCell(graph), context);
+    if (ret != kSuccess) {
+        std::cout << "ERROR: Build failed." << std::endl;
+        return 1;
+    }
+
+    std::vector<MSTensor> modelInputs = model.GetInputs();
+
+    auto all_files = GetAllFiles(FLAGS_dataset_path);
+    if (all_files.empty()) {
+        std::cout << "ERROR: no input data." << std::endl;
+        return 1;
+    }
+
+    auto decode = Decode();
+    auto resizeImageNet = Resize({256});
+    auto centerCrop = CenterCrop({224});
+    auto normalizeImageNet = Normalize({123.675, 116.28, 103.53}, {58.395, 57.12, 57.375});
+    auto hwc2chw = HWC2CHW();
+
+    mindspore::dataset::Execute transformImageNet({decode, resizeImageNet, centerCrop, normalizeImageNet, hwc2chw});
+
+    std::map<double, double> costTime_map;
+
+    size_t size = all_files.size();
+    for (size_t i = 0; i < size; ++i) {
+        struct timeval start;
+        struct timeval end;
+        double startTime_ms;
+        double endTime_ms;
+        std::vector<MSTensor> inputs;
+        std::vector<MSTensor> outputs;
+
+        std::cout << "Start predict input files:" << all_files[i] << std::endl;
+        mindspore::MSTensor image =  ReadFileToTensor(all_files[i]);
+
+        if (FLAGS_dataset.compare("imagenet") == 0) {
+            transformImageNet(image, &image);
+        }
+
+        inputs.emplace_back(modelInputs[0].Name(), modelInputs[0].DataType(), modelInputs[0].Shape(),
+                            image.Data().get(), image.DataSize());
+
+        gettimeofday(&start, NULL);
+        model.Predict(inputs, &outputs);
+        gettimeofday(&end, NULL);
+
+        startTime_ms = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
+        endTime_ms = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
+        costTime_map.insert(std::pair<double, double>(startTime_ms, endTime_ms));
+        int ret_ = WriteResult(all_files[i], outputs);
+        if (ret_ != kSuccess) {
+          std::cout << "write result failed." << std::endl;
+          return 1;
+        }
+    }
+    double average = 0.0;
+    int infer_cnt = 0;
+    for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
+        double diff = 0.0;
+        diff = iter->second - iter->first;
+        average += diff;
+        infer_cnt++;
+    }
+
+    average = average / infer_cnt;
+    std::stringstream timeCost;
+    timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << infer_cnt << std::endl;
+    std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << infer_cnt << std::endl;
+
+    std::string file_name = "./time_Result" + std::string("/test_perform_static.txt");
+    std::ofstream file_stream(file_name.c_str(), std::ios::trunc);
+    file_stream << timeCost.str();
+    file_stream.close();
+    costTime_map.clear();
+    return 0;
+}
diff --git a/research/cv/vit_base/ascend310_infer/src/utils.cc b/research/cv/vit_base/ascend310_infer/src/utils.cc
new file mode 100644
index 000000000..dcd00ac3b
--- /dev/null
+++ b/research/cv/vit_base/ascend310_infer/src/utils.cc
@@ -0,0 +1,159 @@
+/**
+ * Copyright 2021 Huawei Technologies Co., Ltd
+ * 
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * 
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "inc/utils.h"
+
+#include <fstream>
+#include <algorithm>
+#include <iostream>
+
+using mindspore::MSTensor;
+using mindspore::DataType;
+
+std::vector<std::string> GetAllFiles(std::string_view dirName) {
+    struct dirent *filename;
+    DIR *dir = OpenDir(dirName);
+    if (dir == nullptr) {
+        return {};
+    }
+    std::vector<std::string> dirs;
+    std::vector<std::string> files;
+    while ((filename = readdir(dir)) != nullptr) {
+        std::string dName = std::string(filename->d_name);
+        if (dName == "." || dName == "..") {
+            continue;
+        } else if (filename->d_type == DT_DIR) {
+            dirs.emplace_back(std::string(dirName) + "/" + filename->d_name);
+        } else if (filename->d_type == DT_REG) {
+            files.emplace_back(std::string(dirName) + "/" + filename->d_name);
+        } else {
+            continue;
+        }
+    }
+
+    for (auto d : dirs) {
+        dir = OpenDir(d);
+        while ((filename = readdir(dir)) != nullptr) {
+            std::string dName = std::string(filename->d_name);
+            if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
+                continue;
+            }
+            files.emplace_back(std::string(d) + "/" + filename->d_name);
+        }
+    }
+    std::sort(files.begin(), files.end());
+    for (auto &f : files) {
+        std::cout << "image file: " << f << std::endl;
+    }
+    return files;
+}
+
+int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
+    std::string homePath = "./result_Files";
+    for (size_t i = 0; i < outputs.size(); ++i) {
+        size_t outputSize;
+        std::shared_ptr<const void> netOutput;
+        netOutput = outputs[i].Data();
+        outputSize = outputs[i].DataSize();
+        int pos = imageFile.rfind('/');
+        std::string fileName(imageFile, pos + 1);
+        fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
+        std::string outFileName = homePath + "/" + fileName;
+        FILE * outputFile = fopen(outFileName.c_str(), "wb");
+        if (outputFile == nullptr) {
+          std::cout << "open result file" << outFileName << "failed" << std::endl;
+          return -1;
+        }
+        size_t size = fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
+        if (size != outputSize) {
+          fclose(outputFile);
+          outputFile = nullptr;
+          std::cout << "writer result file" << outFileName << "failed write size[" << size <<
+              "] is smaller than output size[" << outputSize << "], maybe the disk is full" << std::endl;
+          return -1;
+        }
+
+        fclose(outputFile);
+        outputFile = nullptr;
+    }
+    return 0;
+}
+
+mindspore::MSTensor ReadFileToTensor(const std::string &file) {
+  if (file.empty()) {
+    std::cout << "Pointer file is nullptr" << std::endl;
+    return mindspore::MSTensor();
+  }
+
+  std::ifstream ifs(file);
+  if (!ifs.good()) {
+    std::cout << "File: " << file << " is not exist" << std::endl;
+    return mindspore::MSTensor();
+  }
+
+  if (!ifs.is_open()) {
+    std::cout << "File: " << file << "open failed" << std::endl;
+    return mindspore::MSTensor();
+  }
+
+  ifs.seekg(0, std::ios::end);
+  size_t size = ifs.tellg();
+  mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
+
+  ifs.seekg(0, std::ios::beg);
+  ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
+  ifs.close();
+
+  return buffer;
+}
+
+
+DIR *OpenDir(std::string_view dirName) {
+    if (dirName.empty()) {
+        std::cout << " dirName is null ! " << std::endl;
+        return nullptr;
+    }
+    std::string realPath = RealPath(dirName);
+    struct stat s;
+    lstat(realPath.c_str(), &s);
+    if (!S_ISDIR(s.st_mode)) {
+        std::cout << "dirName is not a valid directory !" << std::endl;
+        return nullptr;
+    }
+    DIR *dir;
+    dir = opendir(realPath.c_str());
+    if (dir == nullptr) {
+        std::cout << "Can not open dir " << dirName << std::endl;
+        return nullptr;
+    }
+    std::cout << "Successfully opened the dir " << dirName << std::endl;
+    return dir;
+}
+
+std::string RealPath(std::string_view path) {
+    char realPathMem[PATH_MAX] = {0};
+    char *realPathRet = nullptr;
+    realPathRet = realpath(path.data(), realPathMem);
+
+    if (realPathRet == nullptr) {
+        std::cout << "File: " << path << " is not exist.";
+        return "";
+    }
+
+    std::string realPath(realPathMem);
+    std::cout << path << " realpath is: " << realPath << std::endl;
+    return realPath;
+}
diff --git a/research/cv/vit_base/eval.py b/research/cv/vit_base/eval.py
new file mode 100644
index 000000000..106c3edae
--- /dev/null
+++ b/research/cv/vit_base/eval.py
@@ -0,0 +1,89 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Process the test set with the .ckpt model in turn.
+"""
+import argparse
+import mindspore.nn as nn
+from mindspore import context
+from mindspore.train.model import Model
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.common import set_seed
+from mindspore import Tensor
+from mindspore.common import dtype as mstype
+from mindspore.nn.loss.loss import LossBase
+from mindspore.ops import functional as F
+from mindspore.ops import operations as P
+
+from src.config import cifar10_cfg
+from src.dataset import create_dataset_cifar10
+
+from src.modeling_ms import VisionTransformer
+import src.net_config as configs
+
+set_seed(1)
+
+parser = argparse.ArgumentParser(description='vit_base')
+parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['cifar10'],
+                    help='dataset name.')
+parser.add_argument('--sub_type', type=str, default='ViT-B_16',
+                    choices=['ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'ViT-H_14', 'testing'])
+parser.add_argument('--checkpoint_path', type=str, default='./ckpt_0', help='Checkpoint file path')
+parser.add_argument('--id', type=int, default=0, help='Device id')
+args_opt = parser.parse_args()
+
+
+class CrossEntropySmooth(LossBase):
+    """CrossEntropy"""
+    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
+        super(CrossEntropySmooth, self).__init__()
+        self.onehot = P.OneHot()
+        self.sparse = sparse
+        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
+        self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
+        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
+
+    def construct(self, logit, label):
+        if self.sparse:
+            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
+        loss_ = self.ce(logit, label)
+        return loss_
+
+
+if __name__ == '__main__':
+    CONFIGS = {'ViT-B_16': configs.get_b16_config,
+               'ViT-B_32': configs.get_b32_config,
+               'ViT-L_16': configs.get_l16_config,
+               'ViT-L_32': configs.get_l32_config,
+               'ViT-H_14': configs.get_h14_config,
+               'R50-ViT-B_16': configs.get_r50_b16_config,
+               'testing': configs.get_testing}
+    context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=args_opt.id)
+    if args_opt.dataset_name == "cifar10":
+        cfg = cifar10_cfg
+        net = VisionTransformer(CONFIGS[args_opt.sub_type], num_classes=cfg.num_classes)
+        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+        opt = nn.Momentum(net.trainable_params(), 0.01, cfg.momentum, weight_decay=cfg.weight_decay)
+        dataset = create_dataset_cifar10(cfg.val_data_path, 1, False)
+        param_dict = load_checkpoint(args_opt.checkpoint_path)
+        print("load checkpoint from [{}].".format(args_opt.checkpoint_path))
+        load_param_into_net(net, param_dict)
+        net.set_train(False)
+        model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
+    else:
+        raise ValueError("dataset is not support.")
+
+    acc = model.eval(dataset)
+    print(f"model's accuracy is {acc}")
diff --git a/research/cv/vit_base/export.py b/research/cv/vit_base/export.py
new file mode 100644
index 000000000..da344da4e
--- /dev/null
+++ b/research/cv/vit_base/export.py
@@ -0,0 +1,62 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+##############export checkpoint file into air, onnx or mindir model#################
+python export.py
+"""
+import argparse
+import numpy as np
+
+from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
+
+from src.modeling_ms import VisionTransformer
+import src.net_config as configs
+
+parser = argparse.ArgumentParser(description='vit_base export')
+parser.add_argument("--device_id", type=int, default=0, help="Device id")
+parser.add_argument('--sub_type', type=str, default='ViT-B_16',
+                    choices=['ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'ViT-H_14', 'testing'])
+parser.add_argument("--batch_size", type=int, default=1, help="batch size")
+parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
+parser.add_argument("--file_name", type=str, default="vit_base", help="output file name.")
+parser.add_argument('--width', type=int, default=224, help='input width')
+parser.add_argument('--height', type=int, default=224, help='input height')
+parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
+parser.add_argument("--device_target", type=str, default="Ascend",
+                    choices=["Ascend", "GPU", "CPU"], help="device target(default: Ascend)")
+args = parser.parse_args()
+
+context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
+if args.device_target == "Ascend":
+    context.set_context(device_id=args.device_id)
+
+if __name__ == '__main__':
+
+    CONFIGS = {'ViT-B_16': configs.get_b16_config,
+               'ViT-B_32': configs.get_b32_config,
+               'ViT-L_16': configs.get_l16_config,
+               'ViT-L_32': configs.get_l32_config,
+               'ViT-H_14': configs.get_h14_config,
+               'R50-ViT-B_16': configs.get_r50_b16_config,
+               'testing': configs.get_testing}
+    net = VisionTransformer(CONFIGS[args.sub_type], num_classes=10)
+
+    assert args.ckpt_file is not None, "checkpoint_path is None."
+
+    param_dict = load_checkpoint(args.ckpt_file)
+    load_param_into_net(net, param_dict)
+
+    input_arr = Tensor(np.zeros([args.batch_size, 3, args.height, args.width], np.float32))
+    export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
diff --git a/research/cv/vit_base/postprocess.py b/research/cv/vit_base/postprocess.py
new file mode 100644
index 000000000..3b4372ea3
--- /dev/null
+++ b/research/cv/vit_base/postprocess.py
@@ -0,0 +1,45 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""postprocess for 310 inference"""
+import os
+import argparse
+import numpy as np
+
+parser = argparse.ArgumentParser(description="postprocess")
+parser.add_argument("--result_path", type=str, required=True, help="result files path.")
+parser.add_argument("--label_file", type=str, required=True, help="label file path.")
+args = parser.parse_args()
+
+
+if __name__ == '__main__':
+    img_tot = 0
+    top1_correct = 0
+    result_shape = (1, 10)
+    files = os.listdir(args.result_path)
+    for file in files:
+        full_file_path = os.path.join(args.result_path, file)
+        if os.path.isfile(full_file_path):
+            result = np.fromfile(full_file_path, dtype=np.float32).reshape(result_shape)
+            label_path = os.path.join(args.label_file, file.split(".bin")[0][:-2] + ".bin")
+            gt_classes = np.fromfile(label_path, dtype=np.int32)
+
+            top1_output = np.argmax(result, (-1))
+
+            t1_correct = np.equal(top1_output, gt_classes).sum()
+            top1_correct += t1_correct
+            img_tot += 1
+
+    acc1 = 100.0 * top1_correct / img_tot
+    print('after allreduce eval: top1_correct={}, tot={}, acc={:.2f}%'.format(top1_correct, img_tot, acc1))
diff --git a/research/cv/vit_base/preprocess.py b/research/cv/vit_base/preprocess.py
new file mode 100644
index 000000000..0e72521ab
--- /dev/null
+++ b/research/cv/vit_base/preprocess.py
@@ -0,0 +1,38 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""preprocess"""
+import os
+import argparse
+from src.dataset import create_dataset_cifar10
+parser = argparse.ArgumentParser('preprocess')
+parser.add_argument('--data_path', type=str, default='', help='eval data dir')
+
+args = parser.parse_args()
+if __name__ == "__main__":
+    dataset = create_dataset_cifar10(args.data_path, 1, 1, False)
+    img_path = os.path.join('./preprocess_Result/', "img_data")
+    label_path = os.path.join('./preprocess_Result/', "label")
+    os.makedirs(img_path)
+    os.makedirs(label_path)
+    batch_size = 1
+    for idx, data in enumerate(dataset.create_dict_iterator(output_numpy=True, num_epochs=1)):
+        img_data = data["image"]
+        img_label = data["label"]
+        file_name = "vit_base_cifar10_" + str(batch_size) + "_" + str(idx) + ".bin"
+        img_file_path = os.path.join(img_path, file_name)
+        img_data.tofile(img_file_path)
+        label_file_path = os.path.join(label_path, file_name)
+        img_label.tofile(label_file_path)
+    print("=" * 20, "export bin files finished", "=" * 20)
diff --git a/research/cv/vit_base/requirements.txt b/research/cv/vit_base/requirements.txt
new file mode 100644
index 000000000..66799b763
--- /dev/null
+++ b/research/cv/vit_base/requirements.txt
@@ -0,0 +1 @@
+easydict
diff --git a/research/cv/vit_base/scripts/run_distribution_train_ascend.sh b/research/cv/vit_base/scripts/run_distribution_train_ascend.sh
new file mode 100644
index 000000000..958f9a074
--- /dev/null
+++ b/research/cv/vit_base/scripts/run_distribution_train_ascend.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [[ $# -gt 4 ]]; then 
+    echo "Usage: bash ./scripts/run_distribution_train_ascend.sh [RANK_TABLE] [DEVICE_NUM] [DEVICE_START] [DATASET_NAME]"
+exit 1
+fi
+
+ulimit -u unlimited
+export DEVICE_NUM=$2
+export RANK_SIZE=$3
+RANK_TABLE_FILE=$(realpath $1)
+export RANK_TABLE_FILE
+echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
+
+device_start=$3
+for((i=0; i<${DEVICE_NUM}; i++))
+do
+    export DEVICE_ID=$((device_start + i))
+    export RANK_ID=$i
+    rm -rf ./train_parallel$i
+    mkdir ./train_parallel$i
+    cp -r ./src ./train_parallel$i
+    cp ./train.py ./train_parallel$i
+    echo "start training for rank $RANK_ID, device $DEVICE_ID"
+    cd ./train_parallel$i ||exit
+    env > env.log
+    python train.py --device_id=$DEVICE_ID --dataset_name=$4 > log 2>&1 &
+    cd ..
+done
diff --git a/research/cv/vit_base/scripts/run_infer_310.sh b/research/cv/vit_base/scripts/run_infer_310.sh
new file mode 100644
index 000000000..7af917e9f
--- /dev/null
+++ b/research/cv/vit_base/scripts/run_infer_310.sh
@@ -0,0 +1,121 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [[ $# -lt 3 || $# -gt 4 ]]; then 
+    echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET] [DATA_PATH] [DEVICE_ID]
+    DEVICE_ID is optional, default value is zero"
+exit 1
+fi
+
+get_real_path(){
+  if [ "${1:0:1}" == "/" ]; then
+    echo "$1"
+  else
+    echo "$(realpath -m $PWD/$1)"
+  fi
+}
+
+typeset -l dataset
+model=$(get_real_path $1)
+dataset=$2
+data_path=$(get_real_path $3)
+
+device_id=0
+
+if [ $# == 4 ]; then
+    device_id=$4
+fi
+
+echo $model
+echo $dataset
+echo $data_path
+echo $device_id
+
+export ASCEND_HOME=/usr/local/Ascend/
+if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
+    export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
+    export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
+    export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
+    export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
+    export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
+else
+    export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
+    export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
+    export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
+    export ASCEND_OPP_PATH=$ASCEND_HOME/opp
+fi
+
+function compile_app()
+{
+    cd ../ascend310_infer || exit
+    if [ -f "Makefile" ]; then
+        make clean
+    fi
+    sh build.sh &> build.log
+
+    if [ $? -ne 0 ]; then
+        echo "compile app code failed"
+        exit 1
+    fi
+    cd - || exit
+}
+
+function preprocess_data()
+{
+    if [ -d preprocess_Result ]; then
+        rm -rf ./preprocess_Result
+    fi
+    mkdir preprocess_Result
+
+    python3.7 ../preprocess.py --data_path=$data_path #--output_path=./preprocess_Result
+}
+
+function infer()
+{
+    if [ -d result_Files ]; then
+        rm -rf ./result_Files
+    fi
+     if [ -d time_Result ]; then
+        rm -rf ./time_Result
+    fi
+    mkdir result_Files
+    mkdir time_Result
+    ../ascend310_infer/out/main --model_path=$model --dataset=$dataset --dataset_path=$data_path --device_id=$device_id &> infer.log
+
+    if [ $? -ne 0 ]; then
+        echo "execute inference failed"
+        exit 1
+    fi
+}
+
+function cal_acc()
+{
+    if [ "x${dataset}" == "xcifar10" ] || [ "x${dataset}" == "xCifar10" ]; then
+        python ../postprocess.py --label_file=./preprocess_Result/label --result_path=result_Files &> acc.log
+    fi
+    if [ $? -ne 0 ]; then
+        echo "calculate accuracy failed"
+        exit 1
+    fi
+}
+
+if [ "x${dataset}" == "xcifar10" ] || [ "x${dataset}" == "xCifar10" ]; then
+    preprocess_data
+    data_path=./preprocess_Result/img_data
+fi
+compile_app
+infer
+cal_acc
diff --git a/research/cv/vit_base/scripts/run_standalone_eval_ascend.sh b/research/cv/vit_base/scripts/run_standalone_eval_ascend.sh
new file mode 100644
index 000000000..f9a09f2b3
--- /dev/null
+++ b/research/cv/vit_base/scripts/run_standalone_eval_ascend.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "Usage: bash ./scripts/run_standalone_eval_ascend.sh [CKPT_PATH]"
+
+export CKPT=$1
+
+python eval.py --checkpoint_path $CKPT > ./eval.log 2>&1 &
diff --git a/research/cv/vit_base/scripts/run_standalone_train_ascend.sh b/research/cv/vit_base/scripts/run_standalone_train_ascend.sh
new file mode 100644
index 000000000..2c1af299f
--- /dev/null
+++ b/research/cv/vit_base/scripts/run_standalone_train_ascend.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "Usage: bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [DATASET_NAME]"
+
+export DEVICE_ID=$1
+
+python train.py --device_id=$DEVICE_ID --dataset_name=$2 > train.log 2>&1 &
diff --git a/research/cv/vit_base/src/config.py b/research/cv/vit_base/src/config.py
new file mode 100644
index 000000000..f52467abf
--- /dev/null
+++ b/research/cv/vit_base/src/config.py
@@ -0,0 +1,53 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+network config setting, will be used in main.py
+"""
+from easydict import EasyDict as edict
+
+cifar10_cfg = edict({
+    'name': 'cifar10',
+    'pre_trained': True,  # False
+    'num_classes': 10,
+    'lr_init': 0.013,  # 2P
+    'batch_size': 32,
+    'epoch_size': 60,
+    'momentum': 0.9,
+    'weight_decay': 1e-4,
+    'image_height': 224,
+    'image_width': 224,
+    'data_path': '/dataset/cifar10/cifar-10-batches-bin/',
+    'val_data_path': '/dataset/cifar10/cifar-10-verify-bin/',
+    'device_target': 'Ascend',
+    'device_id': 0,
+    'keep_checkpoint_max': 2,
+    'checkpoint_path': '/dataset/cifar10_pre_checkpoint_based_imagenet21k.ckpt',
+    'onnx_filename': 'vit_base',
+    'air_filename': 'vit_base',
+
+    # optimizer and lr related
+    'lr_scheduler': 'cosine_annealing',
+    'lr_epochs': [30, 60, 90, 120],
+    'lr_gamma': 0.3,
+    'eta_min': 0.0,
+    'T_max': 50,
+    'warmup_epochs': 0,
+
+    # loss related
+    'is_dynamic_loss_scale': 0,
+    'loss_scale': 1024,
+    'label_smooth_factor': 0.1,
+    'use_label_smooth': True,
+})
diff --git a/research/cv/vit_base/src/dataset.py b/research/cv/vit_base/src/dataset.py
new file mode 100644
index 000000000..aed90ec97
--- /dev/null
+++ b/research/cv/vit_base/src/dataset.py
@@ -0,0 +1,80 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Data operations, will be used in train.py and eval.py
+"""
+import os
+
+import mindspore.common.dtype as mstype
+import mindspore.dataset as ds
+import mindspore.dataset.transforms.c_transforms as C
+import mindspore.dataset.vision.c_transforms as vision
+from src.config import cifar10_cfg
+
+def create_dataset_cifar10(data_home, repeat_num=1, device_num=1, training=True):
+    """Data operations."""
+    if device_num > 1:
+        rank_size, rank_id = _get_rank_info()
+        data_set = ds.Cifar10Dataset(data_home, num_shards=rank_size, shard_id=rank_id, shuffle=True)
+    else:
+        data_set = ds.Cifar10Dataset(data_home, shuffle=False)
+
+    resize_height = cifar10_cfg.image_height
+    resize_width = cifar10_cfg.image_width
+
+    # define map operations
+    random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4))  # padding_mode default CONSTANT
+    random_horizontal_op = vision.RandomHorizontalFlip()
+    resize_op = vision.Resize((resize_height, resize_width))  # interpolation default BILINEAR
+    rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
+    normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
+    changeswap_op = vision.HWC2CHW()
+    type_cast_op = C.TypeCast(mstype.int32)
+
+    c_trans = []
+    if training:
+        c_trans = [random_crop_op, random_horizontal_op]
+    c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
+
+    # apply map operations on images
+    data_set = data_set.map(operations=type_cast_op, input_columns="label")
+    data_set = data_set.map(operations=c_trans, input_columns="image")
+
+    # apply batch operations
+    if training:
+        data_set = data_set.batch(batch_size=cifar10_cfg.batch_size, drop_remainder=True)
+    else:
+        data_set = data_set.batch(batch_size=1, drop_remainder=True)
+
+    # apply repeat operations
+    data_set = data_set.repeat(repeat_num)
+
+    return data_set
+
+
+def _get_rank_info():
+    """
+    get rank size and rank id
+    """
+    rank_size = int(os.environ.get("RANK_SIZE", 1))
+
+    if rank_size > 1:
+        from mindspore.communication.management import get_rank, get_group_size
+        rank_size = get_group_size()
+        rank_id = get_rank()
+    else:
+        rank_size = rank_id = None
+
+    return rank_size, rank_id
diff --git a/research/cv/vit_base/src/modeling_ms.py b/research/cv/vit_base/src/modeling_ms.py
new file mode 100644
index 000000000..dba23f9fc
--- /dev/null
+++ b/research/cv/vit_base/src/modeling_ms.py
@@ -0,0 +1,221 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+model.
+"""
+import copy
+
+import mindspore
+from mindspore import Parameter, Tensor
+import mindspore.nn as nn
+import mindspore.ops.operations as P
+
+
+def swish(x):
+    return x * P.Sigmoid()(x)
+
+
+ACT2FN = {"gelu": nn.GELU(), "relu": P.ReLU(), "swish": swish}
+
+
+class Attention(nn.Cell):
+    """Attention"""
+    def __init__(self, config):
+        super(Attention, self).__init__()
+        self.num_attention_heads = config.transformer_num_heads
+        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
+        self.attention_head_size2 = Tensor(config.hidden_size / self.num_attention_heads, mindspore.float32)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Dense(config.hidden_size, self.all_head_size)
+        self.key = nn.Dense(config.hidden_size, self.all_head_size)
+        self.value = nn.Dense(config.hidden_size, self.all_head_size)
+
+        self.out = nn.Dense(config.hidden_size, config.hidden_size)
+        self.attn_dropout = nn.Dropout(config.transformer_attention_dropout_rate)
+        self.proj_dropout = nn.Dropout(config.transformer_attention_dropout_rate)
+
+        self.softmax = nn.Softmax(axis=-1)
+
+    def transpose_for_scores(self, x):
+        """transpose_for_scores"""
+        new_x_shape = P.Shape()(x)[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = P.Reshape()(x, new_x_shape)
+        return P.Transpose()(x, (0, 2, 1, 3,))
+
+    def construct(self, hidden_states):
+        """construct"""
+        mixed_query_layer = self.query(hidden_states)
+        mixed_key_layer = self.key(hidden_states)
+        mixed_value_layer = self.value(hidden_states)
+
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+        key_layer = self.transpose_for_scores(mixed_key_layer)
+        value_layer = self.transpose_for_scores(mixed_value_layer)
+
+        attention_scores = mindspore.ops.matmul(query_layer, P.Transpose()(key_layer, (0, 1, 3, 2)))
+        attention_scores = attention_scores / P.Sqrt()(self.attention_head_size2)
+        attention_probs = self.softmax(attention_scores)
+        attention_probs = self.attn_dropout(attention_probs)
+
+        context_layer = mindspore.ops.matmul(attention_probs, value_layer)
+        context_layer = P.Transpose()(context_layer, (0, 2, 1, 3))
+        new_context_layer_shape = P.Shape()(context_layer)[:-2] + (self.all_head_size,)
+        context_layer = P.Reshape()(context_layer, new_context_layer_shape)
+        attention_output = self.out(context_layer)
+        attention_output = self.proj_dropout(attention_output)
+        return attention_output
+
+
+class Mlp(nn.Cell):
+    """Mlp"""
+    def __init__(self, config):
+        super(Mlp, self).__init__()
+        self.fc1 = nn.Dense(config.hidden_size, config.transformer_mlp_dim,
+                            weight_init='XavierUniform', bias_init='Normal')
+        self.fc2 = nn.Dense(config.transformer_mlp_dim, config.hidden_size,
+                            weight_init='XavierUniform', bias_init='Normal')
+        self.act_fn = ACT2FN["gelu"]
+        self.dropout = nn.Dropout(config.transformer_dropout_rate)
+
+    def construct(self, x):
+        """construct"""
+        x = self.fc1(x)
+        x = self.act_fn(x)
+        x = self.dropout(x)
+        x = self.fc2(x)
+        x = self.dropout(x)
+        return x
+
+
+class Embeddings(nn.Cell):
+    """Construct the embeddings from patch, position embeddings."""
+    def __init__(self, config, img_size, in_channels=3):
+        super(Embeddings, self).__init__()
+        self.hybrid = None
+
+        if config.patches_grid is not None:
+            grid_size = config.patches_grid
+            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
+            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
+            self.hybrid = True
+        else:
+            patch_size = config.patches_size
+            n_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
+            self.hybrid = False
+
+        if self.hybrid:
+            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
+                                         width_factor=config.resnet.width_factor)
+            in_channels = self.hybrid_model.width * 16
+        self.patch_embeddings = nn.Conv2d(in_channels=in_channels,
+                                          out_channels=config.hidden_size,
+                                          kernel_size=patch_size,
+                                          stride=patch_size, has_bias=True)
+        self.position_embeddings = Parameter(P.Zeros()((1, n_patches+1, config.hidden_size), mindspore.float32),
+                                             name="q1", requires_grad=True)
+        self.cls_token = Parameter(P.Zeros()((1, 1, config.hidden_size), mindspore.float32), name="q2",
+                                   requires_grad=True)
+
+        self.dropout = nn.Dropout(config.transformer_dropout_rate)
+
+    def construct(self, x):
+        """construct"""
+        B = x.shape[0]
+        cls_tokens = P.BroadcastTo((B, self.cls_token.shape[1], self.cls_token.shape[2]))(self.cls_token)
+
+        if self.hybrid:
+            x = self.hybrid_model(x)
+        x = self.patch_embeddings(x)
+        x = P.Reshape()(x, (x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
+        x = P.Transpose()(x, (0, 2, 1))
+        x = P.Concat(1)((cls_tokens, x))
+
+        embeddings = x + self.position_embeddings
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+
+class Block(nn.Cell):
+    """Block"""
+    def __init__(self, config):
+        super(Block, self).__init__()
+        self.hidden_size = config.hidden_size
+        self.attention_norm = nn.LayerNorm([config.hidden_size], epsilon=1e-6)
+        self.ffn_norm = nn.LayerNorm([config.hidden_size], epsilon=1e-6)
+        self.ffn = Mlp(config)
+        self.attn = Attention(config)
+
+    def construct(self, x):
+        """construct"""
+        h = x
+        x = self.attention_norm(x)
+        x = self.attn(x)
+        x = x + h
+
+        h = x
+        x = self.ffn_norm(x)
+        x = self.ffn(x)
+        x = x + h
+        return x
+
+
+class Encoder(nn.Cell):
+    """Encoder"""
+    def __init__(self, config):
+        super(Encoder, self).__init__()
+        self.layer = nn.CellList([])
+        self.encoder_norm = nn.LayerNorm([config.hidden_size], epsilon=1e-6)
+        for _ in range(config.transformer_num_layers):
+            layer = Block(config)
+            self.layer.append(copy.deepcopy(layer))
+
+    def construct(self, hidden_states):
+        """construct"""
+        for layer_block in self.layer:
+            hidden_states = layer_block(hidden_states)
+        encoded = self.encoder_norm(hidden_states)
+        return encoded
+
+
+class Transformer(nn.Cell):
+    """Transformer"""
+    def __init__(self, config, img_size):
+        super(Transformer, self).__init__()
+        self.embeddings = Embeddings(config, img_size=img_size)
+        self.encoder = Encoder(config)
+
+    def construct(self, input_ids):
+        """construct"""
+        embedding_output = self.embeddings(input_ids)
+        encoded = self.encoder(embedding_output)
+        return encoded
+
+
+class VisionTransformer(nn.Cell):
+    """VisionTransformer"""
+    def __init__(self, config, img_size=(224, 224), num_classes=21843):
+        super(VisionTransformer, self).__init__()
+        self.num_classes = num_classes
+        self.classifier = config.classifier
+
+        self.transformer = Transformer(config, img_size)
+        self.head = nn.Dense(config.hidden_size, num_classes)
+
+    def construct(self, x, labels=None):
+        """construct"""
+        x = self.transformer(x)
+        logits = self.head(x[:, 0])
+        return logits
diff --git a/research/cv/vit_base/src/net_config.py b/research/cv/vit_base/src/net_config.py
new file mode 100644
index 000000000..8c5ff1510
--- /dev/null
+++ b/research/cv/vit_base/src/net_config.py
@@ -0,0 +1,117 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Configurations.
+"""
+
+from easydict import EasyDict as edict
+
+
+# Returns a minimal configuration for testing.
+get_testing = edict({
+    'patches_grid': None,
+    'patches_size': 16,
+    'hidden_size': 1,
+    'transformer_mlp_dim': 1,
+    'transformer_num_heads': 1,
+    'transformer_num_layers': 1,
+    'transformer_attention_dropout_rate': 1.0,
+    'transformer_dropout_rate': 0.9,
+    'classifier': 'token',
+    'representation_size': None,
+})
+
+
+# Returns the ViT-B/16 configuration.
+get_b16_config = edict({
+    'patches_grid': None,
+    'patches_size': 16,
+    'hidden_size': 768,
+    'transformer_mlp_dim': 3072,
+    'transformer_num_heads': 12,
+    'transformer_num_layers': 12,
+    'transformer_attention_dropout_rate': 1.0,
+    'transformer_dropout_rate': 1.0,  # 0.9
+    'classifier': 'token',
+    'representation_size': None,
+})
+
+
+# Returns the Resnet50 + ViT-B/16 configuration.
+get_r50_b16_config = edict({
+    'patches_grid': 14,
+    'resnet_num_layers': (3, 4, 9),
+    'resnet_width_factor': 1,
+})
+
+
+# Returns the ViT-B/32 configuration.
+get_b32_config = edict({
+    'patches_grid': None,
+    'patches_size': 32,
+    'hidden_size': 768,
+    'transformer_mlp_dim': 3072,
+    'transformer_num_heads': 12,
+    'transformer_num_layers': 12,
+    'transformer_attention_dropout_rate': 1.0,
+    'transformer_dropout_rate': 0.9,
+    'classifier': 'token',
+    'representation_size': None,
+})
+
+
+# Returns the ViT-L/16 configuration.
+get_l16_config = edict({
+    'patches_grid': None,
+    'patches_size': 16,
+    'hidden_size': 1024,
+    'transformer_mlp_dim': 4096,
+    'transformer_num_heads': 16,
+    'transformer_num_layers': 24,
+    'transformer_attention_dropout_rate': 1.0,
+    'transformer_dropout_rate': 0.9,
+    'classifier': 'token',
+    'representation_size': None,
+})
+
+
+# Returns the ViT-L/32 configuration.
+get_l32_config = edict({
+    'patches_grid': None,
+    'patches_size': 32,
+    'hidden_size': 1024,
+    'transformer_mlp_dim': 4096,
+    'transformer_num_heads': 16,
+    'transformer_num_layers': 24,
+    'transformer_attention_dropout_rate': 1.0,
+    'transformer_dropout_rate': 0.9,
+    'classifier': 'token',
+    'representation_size': None,
+})
+
+
+# Returns the ViT-L/16 configuration.
+get_h14_config = edict({
+    'patches_grid': None,
+    'patches_size': 14,
+    'hidden_size': 1280,
+    'transformer_mlp_dim': 5120,
+    'transformer_num_heads': 16,
+    'transformer_num_layers': 32,
+    'transformer_attention_dropout_rate': 1.0,
+    'transformer_dropout_rate': 0.9,
+    'classifier': 'token',
+    'representation_size': None,
+})
diff --git a/research/cv/vit_base/train.py b/research/cv/vit_base/train.py
new file mode 100644
index 000000000..3a9972e49
--- /dev/null
+++ b/research/cv/vit_base/train.py
@@ -0,0 +1,211 @@
+# Copyright 2021 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Train.
+"""
+import argparse
+import os
+
+import math
+import numpy as np
+
+import mindspore.nn as nn
+from mindspore import Tensor
+from mindspore import context
+from mindspore.communication.management import init
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
+from mindspore.train.model import Model
+from mindspore.context import ParallelMode
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from mindspore.common import set_seed
+from mindspore.train.callback import Callback
+
+from src.config import cifar10_cfg
+from src.dataset import create_dataset_cifar10
+from src.modeling_ms import VisionTransformer
+import src.net_config as configs
+
+set_seed(2)
+
+def lr_steps_imagenet(_cfg, steps_per_epoch):
+    """lr step for imagenet"""
+    if _cfg.lr_scheduler == 'cosine_annealing':
+        _lr = warmup_cosine_annealing_lr(_cfg.lr_init,
+                                         steps_per_epoch,
+                                         _cfg.warmup_epochs,
+                                         _cfg.epoch_size,
+                                         _cfg.T_max,
+                                         _cfg.eta_min)
+    else:
+        raise NotImplementedError(_cfg.lr_scheduler)
+
+    return _lr
+
+
+def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
+    lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
+    lr1 = float(init_lr) + lr_inc * current_step
+    return lr1
+
+
+def warmup_cosine_annealing_lr(lr5, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
+    """ warmup cosine annealing lr"""
+    base_lr = lr5
+    warmup_init_lr = 0
+    total_steps = int(max_epoch * steps_per_epoch)
+    warmup_steps = int(warmup_epochs * steps_per_epoch)
+
+    lr_each_step = []
+    for i in range(total_steps):
+        last_epoch = i // steps_per_epoch
+        if i < warmup_steps:
+            lr5 = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
+        else:
+            lr5 = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
+        lr_each_step.append(lr5)
+
+    return np.array(lr_each_step).astype(np.float32)
+
+
+class EvalCallBack(Callback):
+    """EvalCallBack"""
+    def __init__(self, model0, eval_dataset, eval_per_epoch, epoch_per_eval0):
+        self.model = model0
+        self.eval_dataset = eval_dataset
+        self.eval_per_epoch = eval_per_epoch
+        self.epoch_per_eval = epoch_per_eval0
+
+    def epoch_end(self, run_context):
+        """epoch_end"""
+        cb_param = run_context.original_args()
+        cur_epoch = cb_param.cur_epoch_num
+        if cur_epoch % self.eval_per_epoch == 0:
+            acc = self.model.eval(self.eval_dataset)
+            self.epoch_per_eval["epoch"].append(cur_epoch)
+            self.epoch_per_eval["acc"].append(acc)
+            print(acc)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Classification')
+    parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['cifar10'],
+                        help='dataset name.')
+    parser.add_argument('--sub_type', type=str, default='ViT-B_16',
+                        choices=['ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'ViT-H_14', 'testing'])
+    parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
+    parser.add_argument('--device_start', type=int, default=0, help='start device id. (Default: 0)')
+    parser.add_argument('--data_url', default=None, help='Location of data.')
+    parser.add_argument('--train_url', default=None, help='Location of training outputs.')
+    parser.add_argument('--ckpt_url', default=None, help='Location of ckpt.')
+    parser.add_argument('--modelarts', default=False, help='Use ModelArts or not.')
+    args_opt = parser.parse_args()
+
+    if args_opt.modelarts:
+        import moxing as mox
+        local_data_path = '/cache/data'
+        local_ckpt_path = '/cache/data/pre_ckpt'
+
+    if args_opt.dataset_name == "cifar10":
+        cfg = cifar10_cfg
+    else:
+        raise ValueError("Unsupported dataset.")
+
+    # set context
+    device_target = cfg.device_target
+
+    context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
+    device_num = int(os.getenv('RANK_SIZE', '1'))
+
+    if device_target == "Ascend":
+        device_id = int(os.getenv('DEVICE_ID', '0'))
+        if args_opt.device_id is not None:
+            context.set_context(device_id=args_opt.device_id)
+        else:
+            context.set_context(device_id=cfg.device_id)
+
+        if device_num > 1:
+            if args_opt.modelarts:
+                context.set_context(device_id=int(os.getenv('DEVICE_ID')))
+            init(backend_name='hccl')
+            context.reset_auto_parallel_context()
+            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True)
+            if args_opt.modelarts:
+                local_data_path = os.path.join(local_data_path, str(device_id))
+    else:
+        raise ValueError("Unsupported platform.")
+
+    if args_opt.modelarts:
+        mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path)
+
+    if args_opt.dataset_name == "cifar10":
+        if args_opt.modelarts:
+            dataset = create_dataset_cifar10(local_data_path, 1, device_num)
+        else:
+            dataset = create_dataset_cifar10(cfg.data_path, 1, device_num)
+    else:
+        raise ValueError("Unsupported dataset.")
+
+    batch_num = dataset.get_dataset_size()
+
+    CONFIGS = {'ViT-B_16': configs.get_b16_config,
+               'ViT-B_32': configs.get_b32_config,
+               'ViT-L_16': configs.get_l16_config,
+               'ViT-L_32': configs.get_l32_config,
+               'ViT-H_14': configs.get_h14_config,
+               'R50-ViT-B_16': configs.get_r50_b16_config,
+               'testing': configs.get_testing}
+
+    net = VisionTransformer(CONFIGS[args_opt.sub_type], num_classes=cfg.num_classes)
+
+    if args_opt.modelarts:
+        mox.file.copy_parallel(src_url=args_opt.ckpt_url, dst_url=local_ckpt_path)
+
+    if cfg.pre_trained:
+        if args_opt.modelarts:
+            param_dict = load_checkpoint(os.path.join(local_ckpt_path, "cifar10_pre_checkpoint_based_imagenet21k.ckpt"))
+        else:
+            param_dict = load_checkpoint(cfg.checkpoint_path)
+        load_param_into_net(net, param_dict)
+        print("Load pre_trained ckpt: {}".format(cfg.checkpoint_path))
+
+    loss_scale_manager = None
+    if args_opt.dataset_name == 'cifar10':
+        lr = lr_steps_imagenet(cfg, batch_num)
+        opt = nn.Momentum(params=net.trainable_params(),
+                          learning_rate=Tensor(lr),
+                          momentum=cfg.momentum,
+                          weight_decay=cfg.weight_decay)
+        loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+
+    model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
+                  amp_level="O3", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager)
+
+    config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 2, keep_checkpoint_max=cfg.keep_checkpoint_max)
+    time_cb = TimeMonitor(data_size=batch_num)
+    ckpt_save_dir = "./ckpt/"
+    ckpoint_cb = ModelCheckpoint(prefix="train_vit_" + args_opt.dataset_name, directory=ckpt_save_dir,
+                                 config=config_ck)
+    loss_cb = LossMonitor()
+    if args_opt.modelarts:
+        cbs = [time_cb, ModelCheckpoint(prefix="train_vit_" + args_opt.dataset_name, config=config_ck), loss_cb]
+    else:
+        epoch_per_eval = {"epoch": [], "acc": []}
+        eval_cb = EvalCallBack(model, create_dataset_cifar10(cfg.val_data_path, 1, False), 2, epoch_per_eval)
+        cbs = [time_cb, ckpoint_cb, loss_cb, eval_cb]
+        if device_num > 1 and device_id != args_opt.device_start:
+            cbs = [time_cb, loss_cb]
+    model.train(cfg.epoch_size, dataset, callbacks=cbs)
+    print("train success")
-- 
GitLab