diff --git a/research/cv/cait/README_CN.md b/research/cv/cait/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..6bc7238087894511e460afa2c04883299dcf57bc
--- /dev/null
+++ b/research/cv/cait/README_CN.md
@@ -0,0 +1,214 @@
+# 鐩綍
+
+<!-- TOC -->
+
+- [鐩綍](#鐩綍)
+- [CAIT鎻忚堪](#CAIT鎻忚堪)
+- [鏁版嵁闆哴(#鏁版嵁闆�)
+- [鐗规€(#鐗规€�)
+    - [娣峰悎绮惧害](#娣峰悎绮惧害)
+- [鐜瑕佹眰](#鐜瑕佹眰)
+- [鑴氭湰璇存槑](#鑴氭湰璇存槑)
+    - [鑴氭湰鍙婃牱渚嬩唬鐮乚(#鑴氭湰鍙婃牱渚嬩唬鐮�)
+    - [鑴氭湰鍙傛暟](#鑴氭湰鍙傛暟)
+- [璁粌鍜屾祴璇昡(#璁粌鍜屾祴璇�)
+    - [瀵煎嚭杩囩▼](#瀵煎嚭杩囩▼)
+        - [瀵煎嚭](#瀵煎嚭)
+    - [鎺ㄧ悊杩囩▼](#鎺ㄧ悊杩囩▼)
+        - [鎺ㄧ悊](#鎺ㄧ悊)
+- [妯″瀷鎻忚堪](#妯″瀷鎻忚堪)
+    - [鎬ц兘](#鎬ц兘)
+        - [璇勪及鎬ц兘](#璇勪及鎬ц兘)
+            - [ImageNet-1k涓婄殑cait](#imagenet-1k涓婄殑cait)
+- [ModelZoo涓婚〉](#modelzoo涓婚〉)
+
+<!-- /TOC -->
+
+# [CAIT鎻忚堪](#鐩綍)
+
+# [鏁版嵁闆哴(#鐩綍)
+
+浣跨敤鐨勬暟鎹泦锛歔ImageNet2012](http://www.image-net.org/)
+
+- 鏁版嵁闆嗗ぇ灏忥細鍏�1000涓被銆�224*224褰╄壊鍥惧儚
+    - 璁粌闆嗭細鍏�1,281,167寮犲浘鍍�
+    - 娴嬭瘯闆嗭細鍏�50,000寮犲浘鍍�
+- 鏁版嵁鏍煎紡锛欽PEG
+    - 娉細鏁版嵁鍦╠ataset.py涓鐞嗐€�
+- 涓嬭浇鏁版嵁闆嗭紝鐩綍缁撴瀯濡備笅锛�
+
+ ```text
+鈹斺攢dataset
+    鈹溾攢train                 # 璁粌鏁版嵁闆�
+    鈹斺攢val                   # 璇勪及鏁版嵁闆�
+```
+
+# [鐗规€(#鐩綍)
+
+Transformer鏈€杩戝凡杩涜浜嗗ぇ瑙勬ā鍥惧儚鍒嗙被锛岃幏寰椾簡寰堥珮鐨勫垎鏁帮紝杩欏姩鎽囦簡鍗风Н绁炵粡缃戠粶鐨勯暱鏈熼湼涓诲湴浣嶃€備絾鏄紝鍒扮洰鍓嶄负姝紝瀵瑰浘鍍廡ransformer鐨勪紭鍖栬繕寰堝皯杩涜鐮旂┒銆傚湪杩欓」宸ヤ綔涓紝浣滆€呬负鍥惧儚鍒嗙被寤虹珛鍜屼紭鍖栦簡鏇存繁鐨凾ransformer缃戠粶銆� 鐗瑰埆鏄紝浣滆€呯爺绌朵簡杩欑涓撶敤Transformer鐨勬灦鏋勫拰浼樺寲涔嬮棿鐨勭浉浜掍綔鐢ㄣ€�
+浣滆€呰繘琛屼簡涓ゆTransformer浣撶郴缁撴瀯鏇存敼锛屼粠鑰屾樉钁楁彁楂樹簡娣卞害Transformer鐨勭簿搴︺€�
+
+## 娣峰悎绮惧害
+
+閲囩敤[娣峰悎绮惧害](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.6/enable_mixed_precision.html?highlight=%E6%B7%B7%E5%90%88%E7%B2%BE%E5%BA%A6)鐨勮缁冩柟娉曪紝浣跨敤鏀寔鍗曠簿搴﹀拰鍗婄簿搴︽暟鎹潵鎻愰珮娣卞害瀛︿範绁炵粡缃戠粶鐨勮缁冮€熷害锛屽悓鏃朵繚鎸佸崟绮惧害璁粌鎵€鑳借揪鍒扮殑缃戠粶绮惧害銆傛贩鍚堢簿搴﹁缁冩彁楂樿绠楅€熷害銆佸噺灏戝唴瀛樹娇鐢ㄧ殑鍚屾椂锛屾敮鎸佸湪鐗瑰畾纭欢涓婅缁冩洿澶х殑妯″瀷鎴栧疄鐜版洿澶ф壒娆$殑璁粌銆�
+
+# [鐜瑕佹眰](#鐩綍)
+
+- 纭欢锛圓scend锛�
+    - 浣跨敤Ascend鏉ユ惌寤虹‖浠剁幆澧冦€�
+- 妗嗘灦
+    - [MindSpore](https://www.mindspore.cn/install/en)
+- 濡傞渶鏌ョ湅璇︽儏锛岃鍙傝濡備笅璧勬簮锛�
+    - [MindSpore鏁欑▼](https://www.mindspore.cn/tutorials/zh-CN/r1.3/index.html)
+    - [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/r1.3/index.html)
+
+# [鑴氭湰璇存槑](#鐩綍)
+
+## 鑴氭湰鍙婃牱渚嬩唬鐮�
+
+```text
+鈹溾攢鈹€ CAIT
+  鈹溾攢鈹€ README_CN.md                        // CAIT鐩稿叧璇存槑
+  鈹溾攢鈹€ ascend310_infer                     // Ascend310鎺ㄧ悊闇€瑕佺殑鏂囦欢
+  鈹溾攢鈹€ scripts
+      鈹溾攢鈹€run_standalone_train_ascend.sh   // 鍗曞崱Ascend910璁粌鑴氭湰
+      鈹溾攢鈹€run_distribute_train_ascend.sh   // 澶氬崱Ascend910璁粌鑴氭湰
+      鈹溾攢鈹€run_eval_ascend.sh               // 娴嬭瘯鑴氭湰
+      鈹溾攢鈹€run_infer_310.sh                 // 310鎺ㄧ悊鑴氭湰
+  鈹溾攢鈹€ src
+      鈹溾攢鈹€configs                          // CAIT鐨勯厤缃枃浠�
+      鈹溾攢鈹€data                             // 鏁版嵁闆嗛厤缃枃浠�
+          鈹溾攢鈹€imagenet.py                  // imagenet閰嶇疆鏂囦欢
+          鈹溾攢鈹€augment                      // 鏁版嵁澧炲己鍑芥暟鏂囦欢
+          鈹曗攢鈹€data_utils                   // modelarts杩愯鏃舵暟鎹泦澶嶅埗鍑芥暟鏂囦欢
+  鈹�   鈹溾攢鈹€models                           // 妯″瀷瀹氫箟鏂囦欢澶�
+          鈹曗攢鈹€cait                         // CAIT瀹氫箟鏂囦欢
+  鈹�   鈹溾攢鈹€trainers                         // 鑷畾涔塗rainOneStep鏂囦欢
+  鈹�   鈹溾攢鈹€tools                            // 宸ュ叿鏂囦欢澶�
+          鈹溾攢鈹€callback.py                  // 鑷畾涔夊洖璋冨嚱鏁帮紝璁粌缁撴潫娴嬭瘯
+          鈹溾攢鈹€cell.py                      // 涓€浜涘叧浜巆ell鐨勯€氱敤宸ュ叿鍑芥暟
+          鈹溾攢鈹€criterion.py                 // 鍏充簬鎹熷け鍑芥暟鐨勫伐鍏峰嚱鏁�
+          鈹溾攢鈹€get_misc.py                  // 涓€浜涘叾浠栫殑宸ュ叿鍑芥暟
+          鈹溾攢鈹€optimizer.py                 // 鍏充簬浼樺寲鍣ㄥ拰鍙傛暟鐨勫嚱鏁�
+          鈹曗攢鈹€schedulers.py                // 瀛︿範鐜囪“鍑忕殑宸ュ叿鍑芥暟
+  鈹溾攢鈹€ train.py                            // 璁粌鏂囦欢
+  鈹溾攢鈹€ eval.py                             // 璇勪及鏂囦欢
+  鈹溾攢鈹€ export.py                           // 瀵煎嚭妯″瀷鏂囦欢
+  鈹溾攢鈹€ postprocess.py                      // 鎺ㄧ悊璁$畻绮惧害鏂囦欢
+  鈹溾攢鈹€ preprocess.py                       // 鎺ㄧ悊棰勫鐞嗗浘鐗囨枃浠�
+
+```
+
+## 鑴氭湰鍙傛暟
+
+鍦╟onfig.py涓彲浠ュ悓鏃堕厤缃缁冨弬鏁板拰璇勪及鍙傛暟銆�
+
+- 閰嶇疆CAIT鍜孖mageNet-1k鏁版嵁闆嗐€�
+
+  ```python
+    # Architecture
+    arch: cait_XXS24_224                  # CAIT缁撴瀯閫夋嫨
+    # ===== Dataset ===== #
+    data_url: ./data/imagenet           # 鏁版嵁闆嗗湴鍧€
+    set: ImageNet                       # 鏁版嵁闆嗗悕瀛�
+    num_classes: 1000                   # 鏁版嵁闆嗗垎绫绘暟鐩�
+    mix_up: 0.8                         # MixUp鏁版嵁澧炲己鍙傛暟
+    cutmix: 1.0                         # CutMix鏁版嵁澧炲己鍙傛暟
+    auto_augment: rand-m9-mstd0.5-inc1  # AutoAugment鍙傛暟
+    interpolation: bicubic              # 鍥惧儚缂╂斁鎻掑€兼柟娉�
+    re_prob: 0.25                       # 鏁版嵁澧炲己鍙傛暟
+    re_mode: pixel                      # 鏁版嵁澧炲己鍙傛暟
+    re_count: 1                         # 鏁版嵁澧炲己鍙傛暟
+    mixup_prob: 1.                      # 鏁版嵁澧炲己鍙傛暟
+    switch_prob: 0.5                    # 鏁版嵁澧炲己鍙傛暟
+    mixup_mode: batch                   # 鏁版嵁澧炲己鍙傛暟
+    # ===== Learning Rate Policy ======== #
+    optimizer: adamw                    # 浼樺寲鍣ㄧ被鍒�
+    base_lr: 0.0005                     # 鍩虹瀛︿範鐜�
+    warmup_lr: 0.00000007               # 瀛︿範鐜囩儹韬垵濮嬪涔犵巼
+    min_lr: 0.000006                    # 鏈€灏忓涔犵巼
+    lr_scheduler: cosine_lr             # 瀛︿範鐜囪“鍑忕瓥鐣�
+    warmup_length: 5                    # 瀛︿範鐜囩儹韬疆鏁�
+    nonlinearity: GELU                  # 婵€娲诲嚱鏁扮被鍒�
+    # ===== Network training config ===== #
+    amp_level: O2                       # 娣峰悎绮惧害绛栫暐
+    keep_bn_fp32: True                  # 淇濇寔bn鏄痜p32
+    beta: [ 0.9, 0.999 ]                # adamw鍙傛暟
+    clip_global_norm_value: 5.          # 鍏ㄥ眬姊害鑼冩暟瑁佸壀闃堝€�
+    is_dynamic_loss_scale: True         # 鏄惁浣跨敤鍔ㄦ€佺缉鏀�
+    epochs: 400                         # 璁粌杞暟
+    label_smoothing: 0.1                # 鏍囩骞虫粦鍙傛暟
+    loss_scale: 1024                    # 鎹熷け缂╂斁
+    weight_decay: 0.05                  # 鏉冮噸琛板噺鍙傛暟
+    momentum: 0.9                       # 浼樺寲鍣ㄥ姩閲�
+    batch_size: 64                      # 鎵规
+    # ===== Hardware setup ===== #
+    num_parallel_workers: 16            # 鏁版嵁棰勫鐞嗙嚎绋嬫暟
+    device_target: GPU                  # GPU鎴栬€匒scend
+    # ===== Model config ===== #
+    drop_path_rate: 0.05                # drop_path姒傜巼
+    image_size: 224                     # 鍥惧儚澶у皬
+  ```
+
+鏇村閰嶇疆缁嗚妭璇峰弬鑰冭剼鏈琡config.py`銆� 閫氳繃瀹樻柟缃戠珯瀹夎MindSpore鍚庯紝鎮ㄥ彲浠ユ寜鐓у涓嬫楠よ繘琛岃缁冨拰璇勪及锛�
+
+# [璁粌鍜屾祴璇昡(#鐩綍)
+
+- GPU
+
+    ```bash
+    # 浣跨敤python鍚姩鍗曞崱璁粌
+    python train.py --device_id 0 --device_target GPU --cait_config ./src/configs/cait_XXS24_224.yaml > train.log 2>&1 &
+
+    # 浣跨敤鑴氭湰鍚姩鍗曞崱璁粌
+    bash ./scripts/run_standalone_train_gpu.sh [DEVICE_ID] [CONFIG_PATH]
+
+    # 浣跨敤鑴氭湰鍚姩澶氬崱璁粌
+    bash ./scripts/run_distribute_train_gpu.sh [DEVICE_NUM] [CUDA_VISIBLE_DEVICES] [CONFIG_PATH]
+
+    # 浣跨敤python鍚姩鍗曞崱杩愯璇勪及绀轰緥
+    python eval.py --device_id 0 --device_target GPU --cait_config ./src/configs/cait_XXS24_224.yaml \
+    --pretrained ./ckpt_0/cait_XXS24_224.ckpt > ./eval.log 2>&1 &
+
+    # 浣跨敤鑴氭湰鍚姩鍗曞崱杩愯璇勪及绀轰緥
+    bash ./scripts/run_eval_gpu.sh [DEVICE_ID] [CONFIG_PATH] [CHECKPOINT_PATH]
+    ```
+
+## 瀵煎嚭杩囩▼
+
+### 瀵煎嚭
+
+  ```shell
+  python export.py --pretrained [CKPT_FILE] --cait_config [CONFIG_PATH] --device_target [DEVICE_TARGET] --file_format [FILE_FORMAT]
+  ```
+
+瀵煎嚭鐨勬ā鍨嬩細浠ユā鍨嬬殑缁撴瀯鍚嶅瓧鍛藉悕骞朵笖淇濆瓨鍦ㄥ綋鍓嶇洰褰曚笅
+
+# [妯″瀷鎻忚堪](#鐩綍)
+
+## 鎬ц兘
+
+### 璇勪及鎬ц兘
+
+#### ImageNet-1k涓婄殑cait
+
+| 鍙傛暟                 | GPU RTX3090                                                       |
+| -------------------------- | ----------------------------------------------------------- |
+|妯″瀷|CAIT|
+| 妯″瀷鐗堟湰              | cait_XXS24_224                                                |
+| 璧勬簮                   | GPU               |
+| 涓婁紶鏃ユ湡              | 2021-12-9                                 |
+| MindSpore鐗堟湰          | 1.5.0                                                 |
+| 鏁版嵁闆�                    | ImageNet-1k Train锛屽叡1,281,167寮犲浘鍍�                                              |
+| 璁粌鍙傛暟        | epoch=400, batch_size=64            |
+| 浼樺寲鍣�                  | AdamWeightDecay                                                    |
+| 鎹熷け鍑芥暟              | SoftTargetCrossEntropy                                       |
+| 鎹熷け| 1.002|
+| 杈撳嚭                    | 姒傜巼                                                 |
+| 鍒嗙被鍑嗙‘鐜�             | 鍏崱锛歵op1:77.13% top5:93.61%                   |
+| 閫熷害                      | 鍏崱锛�748 ms姣/姝�                        |
+| 璁粌鑰楁椂          |绾�230h|
+
+# ModelZoo涓婚〉
+
+璇锋祻瑙堝畼缃慬涓婚〉](https://gitee.com/mindspore/models)
\ No newline at end of file
diff --git a/research/cv/cait/__init__.py b/research/cv/cait/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/research/cv/cait/ascend310_infer/CMakeLists.txt b/research/cv/cait/ascend310_infer/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..23a3bce938d0e9f295c4a1bd12544bdea68752d1
--- /dev/null
+++ b/research/cv/cait/ascend310_infer/CMakeLists.txt
@@ -0,0 +1,16 @@
+cmake_minimum_required(VERSION 3.14.1)
+project(Ascend310Infer)
+add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
+set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
+option(MINDSPORE_PATH "mindspore install path" "")
+include_directories(${MINDSPORE_PATH})
+include_directories(${MINDSPORE_PATH}/include)
+include_directories(${PROJECT_SRC_ROOT})
+find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
+file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
+find_package(gflags REQUIRED)
+
+add_executable(main src/main.cc src/utils.cc)
+target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
+find_package(gflags REQUIRED)
diff --git a/research/cv/cait/ascend310_infer/build.sh b/research/cv/cait/ascend310_infer/build.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e429f3859b6cfd1f606608c9465ec90cecb20535
--- /dev/null
+++ b/research/cv/cait/ascend310_infer/build.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ -d out ]; then
+    rm -rf out
+fi
+
+mkdir out
+cd out || exit
+
+if [ -f "Makefile" ]; then
+  make clean
+fi
+
+cmake .. \
+    -DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
+make
diff --git a/research/cv/cait/ascend310_infer/inc/utils.h b/research/cv/cait/ascend310_infer/inc/utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..058358228a1615f478e800d43bcf9197edb7bd47
--- /dev/null
+++ b/research/cv/cait/ascend310_infer/inc/utils.h
@@ -0,0 +1,35 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ * 
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * 
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_INFERENCE_UTILS_H_
+#define MINDSPORE_INFERENCE_UTILS_H_
+
+#include <sys/stat.h>
+#include <dirent.h>
+#include <vector>
+#include <string>
+#include <memory>
+#include "include/api/types.h"
+
+std::vector<std::string> GetAllFiles(std::string_view dirName);
+DIR *OpenDir(std::string_view dirName);
+std::string RealPath(std::string_view path);
+mindspore::MSTensor ReadFileToTensor(const std::string &file);
+int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
+std::vector<std::string> GetAllFiles(std::string dir_name);
+std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name);
+
+#endif
diff --git a/research/cv/cait/ascend310_infer/src/main.cc b/research/cv/cait/ascend310_infer/src/main.cc
new file mode 100644
index 0000000000000000000000000000000000000000..55b3aa8259b1979230e36b926d592ca3a6543849
--- /dev/null
+++ b/research/cv/cait/ascend310_infer/src/main.cc
@@ -0,0 +1,161 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ * 
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * 
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <sys/time.h>
+#include <gflags/gflags.h>
+#include <dirent.h>
+#include <iostream>
+#include <string>
+#include <algorithm>
+#include <iosfwd>
+#include <vector>
+#include <fstream>
+#include <sstream>
+
+#include "include/api/model.h"
+#include "include/api/context.h"
+#include "include/api/types.h"
+#include "include/api/serialization.h"
+#include "include/dataset/vision_ascend.h"
+#include "include/dataset/execute.h"
+#include "include/dataset/transforms.h"
+#include "include/dataset/vision.h"
+#include "inc/utils.h"
+
+using mindspore::Context;
+using mindspore::Serialization;
+using mindspore::Model;
+using mindspore::Status;
+using mindspore::ModelType;
+using mindspore::GraphCell;
+using mindspore::kSuccess;
+using mindspore::MSTensor;
+using mindspore::dataset::Execute;
+using mindspore::dataset::vision::Decode;
+using mindspore::dataset::vision::Resize;
+using mindspore::dataset::vision::CenterCrop;
+using mindspore::dataset::vision::Normalize;
+using mindspore::dataset::vision::HWC2CHW;
+using mindspore::dataset::InterpolationMode;
+
+
+DEFINE_string(mindir_path, "", "mindir path");
+DEFINE_string(dataset_name, "imagenet2012", "imagenet2012");
+DEFINE_string(input0_path, ".", "input0 path");
+DEFINE_int32(device_id, 0, "device id");
+
+int load_model(Model *model, std::vector<MSTensor> *model_inputs, std::string mindir_path, int device_id) {
+  if (RealPath(mindir_path).empty()) {
+    std::cout << "Invalid mindir" << std::endl;
+    return 1;
+  }
+
+  auto context = std::make_shared<Context>();
+  auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
+  ascend310->SetDeviceID(device_id);
+  context->MutableDeviceInfo().push_back(ascend310);
+  mindspore::Graph graph;
+  Serialization::Load(mindir_path, ModelType::kMindIR, &graph);
+
+  Status ret = model->Build(GraphCell(graph), context);
+  if (ret != kSuccess) {
+    std::cout << "ERROR: Build failed." << std::endl;
+    return 1;
+  }
+
+  *model_inputs = model->GetInputs();
+  if (model_inputs->empty()) {
+    std::cout << "Invalid model, inputs is empty." << std::endl;
+    return 1;
+  }
+  return 0;
+}
+
+int main(int argc, char **argv) {
+  gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+  Model model;
+  std::vector<MSTensor> model_inputs;
+  load_model(&model, &model_inputs, FLAGS_mindir_path, FLAGS_device_id);
+
+  std::map<double, double> costTime_map;
+  struct timeval start = {0};
+  struct timeval end = {0};
+
+  if (FLAGS_dataset_name != "imagenet2012") {
+    std::cout << "ERROR: only support imagenet2012 dataset." << std::endl;
+    return 1;
+  } else {
+    auto input0_files = GetAllInputData(FLAGS_input0_path);
+    if (input0_files.empty()) {
+      std::cout << "ERROR: no input data." << std::endl;
+      return 1;
+    }
+    auto decode = Decode();
+    auto resize = Resize({256, 256}, InterpolationMode::kCubicPil);
+    auto centercrop = CenterCrop({224, 224});
+    auto normalize = Normalize({0.485 * 255, 0.456 * 255, 0.406 * 255}, {0.229 * 255, 0.224 * 255, 0.225 * 255});
+    auto hwc2chw = HWC2CHW();
+    Execute SingleOp({decode, resize, centercrop, normalize, hwc2chw});
+    size_t size = input0_files.size();
+    for (size_t i = 0; i < size; ++i) {
+      for (size_t j = 0; j < input0_files[i].size(); ++j) {
+        std::vector<MSTensor> inputs;
+        std::vector<MSTensor> outputs;
+        std::cout << "Start predict input files:" << input0_files[i][j] <<std::endl;
+        auto imgDvpp = std::make_shared<MSTensor>();
+        SingleOp(ReadFileToTensor(input0_files[i][j]), imgDvpp.get());
+        inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
+                            imgDvpp->Data().get(), imgDvpp->DataSize());
+        gettimeofday(&start, nullptr);
+        Status ret = model.Predict(inputs, &outputs);
+        gettimeofday(&end, nullptr);
+        if (ret != kSuccess) {
+          std::cout << "Predict " << input0_files[i][j] << " failed." << std::endl;
+          return 1;
+        }
+        double startTimeMs;
+        double endTimeMs;
+        startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
+        endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
+        costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
+        int rst = WriteResult(input0_files[i][j], outputs);
+        if (rst != 0) {
+          std::cout << "write result failed." << std::endl;
+          return rst;
+        }
+      }
+    }
+  }
+  double average = 0.0;
+  int inferCount = 0;
+
+  for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
+    double diff = 0.0;
+    diff = iter->second - iter->first;
+    average += diff;
+    inferCount++;
+  }
+  average = average / inferCount;
+  std::stringstream timeCost;
+  timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
+  std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
+  std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
+  std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
+  fileStream << timeCost.str();
+  fileStream.close();
+  costTime_map.clear();
+  return 0;
+}
diff --git a/research/cv/cait/ascend310_infer/src/utils.cc b/research/cv/cait/ascend310_infer/src/utils.cc
new file mode 100644
index 0000000000000000000000000000000000000000..537a5ea72b2ac56b302b16493ca11238ba608e89
--- /dev/null
+++ b/research/cv/cait/ascend310_infer/src/utils.cc
@@ -0,0 +1,196 @@
+/**
+ * Copyright 2022 Huawei Technologies Co., Ltd
+ * 
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * 
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <fstream>
+#include <algorithm>
+#include <iostream>
+#include "inc/utils.h"
+
+using mindspore::MSTensor;
+using mindspore::DataType;
+
+std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name) {
+  std::vector<std::vector<std::string>> ret;
+
+  DIR *dir = OpenDir(dir_name);
+  if (dir == nullptr) {
+    return {};
+  }
+  struct dirent *filename;
+  std::vector<std::string> sub_dirs;
+  while ((filename = readdir(dir)) != nullptr) {
+    std::string d_name = std::string(filename->d_name);
+    // get rid of "." and ".."
+    if (d_name == "." || d_name == ".." || d_name.empty()) {
+      continue;
+    }
+    std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name);
+    struct stat s;
+    lstat(dir_path.c_str(), &s);
+    if (!S_ISDIR(s.st_mode)) {
+      continue;
+    }
+
+    sub_dirs.emplace_back(dir_path);
+  }
+  std::sort(sub_dirs.begin(), sub_dirs.end());
+
+  (void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret),
+                       [](const std::string &d) { return GetAllFiles(d); });
+
+  return ret;
+}
+
+
+std::vector<std::string> GetAllFiles(std::string dir_name) {
+  struct dirent *filename;
+  DIR *dir = OpenDir(dir_name);
+  if (dir == nullptr) {
+    return {};
+  }
+
+  std::vector<std::string> res;
+  while ((filename = readdir(dir)) != nullptr) {
+    std::string d_name = std::string(filename->d_name);
+    if (d_name == "." || d_name == ".." || d_name.size() <= 3) {
+      continue;
+    }
+    res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
+  }
+  std::sort(res.begin(), res.end());
+
+  return res;
+}
+
+
+std::vector<std::string> GetAllFiles(std::string_view dirName) {
+  struct dirent *filename;
+  DIR *dir = OpenDir(dirName);
+  if (dir == nullptr) {
+    return {};
+  }
+  std::vector<std::string> res;
+  while ((filename = readdir(dir)) != nullptr) {
+    std::string dName = std::string(filename->d_name);
+    if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
+      continue;
+    }
+    res.emplace_back(std::string(dirName) + "/" + filename->d_name);
+  }
+  std::sort(res.begin(), res.end());
+  for (auto &f : res) {
+    std::cout << "image file: " << f << std::endl;
+  }
+  return res;
+}
+
+
+int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
+  std::string homePath = "./result_Files";
+  const int INVALID_POINTER = -1;
+  const int ERROR = -2;
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    size_t outputSize;
+    std::shared_ptr<const void> netOutput;
+    netOutput = outputs[i].Data();
+    outputSize = outputs[i].DataSize();
+    int pos = imageFile.rfind('/');
+    std::string fileName(imageFile, pos + 1);
+    fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
+    std::string outFileName = homePath + "/" + fileName;
+    FILE *outputFile = fopen(outFileName.c_str(), "wb");
+    if (outputFile == nullptr) {
+        std::cout << "open result file " << outFileName << " failed" << std::endl;
+        return INVALID_POINTER;
+    }
+    size_t size = fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
+    if (size != outputSize) {
+        fclose(outputFile);
+        outputFile = nullptr;
+        std::cout << "write result file " << outFileName << " failed, write size[" << size <<
+            "] is smaller than output size[" << outputSize << "], maybe the disk is full." << std::endl;
+        return ERROR;
+    }
+    fclose(outputFile);
+    outputFile = nullptr;
+  }
+  return 0;
+}
+
+mindspore::MSTensor ReadFileToTensor(const std::string &file) {
+  if (file.empty()) {
+    std::cout << "Pointer file is nullptr" << std::endl;
+    return mindspore::MSTensor();
+  }
+
+  std::ifstream ifs(file);
+  if (!ifs.good()) {
+    std::cout << "File: " << file << " is not exist" << std::endl;
+    return mindspore::MSTensor();
+  }
+
+  if (!ifs.is_open()) {
+    std::cout << "File: " << file << "open failed" << std::endl;
+    return mindspore::MSTensor();
+  }
+
+  ifs.seekg(0, std::ios::end);
+  size_t size = ifs.tellg();
+  mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
+
+  ifs.seekg(0, std::ios::beg);
+  ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
+  ifs.close();
+
+  return buffer;
+}
+
+
+DIR *OpenDir(std::string_view dirName) {
+  if (dirName.empty()) {
+    std::cout << " dirName is null ! " << std::endl;
+    return nullptr;
+  }
+  std::string realPath = RealPath(dirName);
+  struct stat s;
+  lstat(realPath.c_str(), &s);
+  if (!S_ISDIR(s.st_mode)) {
+    std::cout << "dirName is not a valid directory !" << std::endl;
+    return nullptr;
+  }
+  DIR *dir;
+  dir = opendir(realPath.c_str());
+  if (dir == nullptr) {
+    std::cout << "Can not open dir " << dirName << std::endl;
+    return nullptr;
+  }
+  std::cout << "Successfully opened the dir " << dirName << std::endl;
+  return dir;
+}
+
+std::string RealPath(std::string_view path) {
+  char realPathMem[PATH_MAX] = {0};
+  char *realPathRet = nullptr;
+  realPathRet = realpath(path.data(), realPathMem);
+  if (realPathRet == nullptr) {
+    std::cout << "File: " << path << " is not exist.";
+    return "";
+  }
+
+  std::string realPath(realPathMem);
+  std::cout << path << " realpath is: " << realPath << std::endl;
+  return realPath;
+}
diff --git a/research/cv/cait/eval.py b/research/cv/cait/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..196004fd42e595085962f0cafe124773c2b4838d
--- /dev/null
+++ b/research/cv/cait/eval.py
@@ -0,0 +1,72 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""eval"""
+
+from mindspore import Model
+from mindspore import context
+from mindspore import nn
+from mindspore.common import set_seed
+
+from src.args import args
+from src.tools.cell import cast_amp
+from src.tools.criterion import get_criterion, NetWithLoss
+from src.tools.get_misc import get_dataset, set_device, get_model, pretrained, get_train_one_step
+from src.tools.optimizer import get_optimizer
+
+set_seed(args.seed)
+
+
+def main():
+    mode = {
+        0: context.GRAPH_MODE,
+        1: context.PYNATIVE_MODE
+    }
+    context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
+    context.set_context(enable_graph_kernel=False)
+    if args.device_target == "Ascend":
+        context.set_context(enable_auto_mixed_precision=True)
+    set_device(args)
+
+    # get model
+    net = get_model(args)
+    cast_amp(net)
+    criterion = get_criterion(args)
+    if args.pretrained:
+        pretrained(args, net)
+    net_with_loss = NetWithLoss(net, criterion)
+
+    data = get_dataset(args, training=False)
+    batch_num = data.val_dataset.get_dataset_size()
+    optimizer = get_optimizer(args, net, batch_num)
+    # save a yaml file to read to record parameters
+
+    net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
+    eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
+    eval_indexes = [0, 1, 2]
+    eval_metrics = {'Loss': nn.Loss(),
+                    'Top1-Acc': nn.Top1CategoricalAccuracy(),
+                    'Top5-Acc': nn.Top5CategoricalAccuracy()}
+    model = Model(net_with_loss, metrics=eval_metrics,
+                  eval_network=eval_network,
+                  eval_indexes=eval_indexes)
+
+    print(f"=> begin eval")
+    results = model.eval(data.val_dataset)
+    print(f"=> eval results:{results}")
+    print(f"=> eval success")
+
+
+if __name__ == '__main__':
+    main()
diff --git a/research/cv/cait/export.py b/research/cv/cait/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..767d2b0734dd1f43f59e2355bdbd17b613e84f9d
--- /dev/null
+++ b/research/cv/cait/export.py
@@ -0,0 +1,48 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+##############export checkpoint file into air or mindir model#################
+python export.py
+"""
+
+import numpy as np
+from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
+from mindspore import dtype as mstype
+
+from src.args import args
+from src.tools.cell import cast_amp
+from src.tools.criterion import get_criterion, NetWithLoss
+from src.tools.get_misc import get_model
+
+context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
+
+if args.device_target in ["Ascend", "GPU"]:
+    context.set_context(device_id=args.device_id)
+
+if __name__ == '__main__':
+    net = get_model(args)
+    criterion = get_criterion(args)
+    cast_amp(net)
+    net_with_loss = NetWithLoss(net, criterion)
+    assert args.pretrained is not None, "checkpoint_path is None."
+
+    param_dict = load_checkpoint(args.pretrained)
+    load_param_into_net(net, param_dict)
+
+    net.set_train(False)
+    net.to_float(mstype.float32)
+
+    input_arr = Tensor(np.zeros([1, 3, args.image_size, args.image_size], np.float32))
+    export(net, input_arr, file_name=args.arch, file_format=args.file_format)
diff --git a/research/cv/cait/postprocess.py b/research/cv/cait/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..65b78267de3ad7b456ac10c7b19e1c45b45c783d
--- /dev/null
+++ b/research/cv/cait/postprocess.py
@@ -0,0 +1,50 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""postprocess for 310 inference"""
+import argparse
+import json
+import os
+
+import numpy as np
+from mindspore.nn import Top1CategoricalAccuracy, Top5CategoricalAccuracy
+
+parser = argparse.ArgumentParser(description="postprocess")
+parser.add_argument("--result_dir", type=str, default="./result_Files", help="result files path.")
+parser.add_argument('--dataset_name', type=str, choices=["imagenet2012"], default="imagenet2012")
+args = parser.parse_args()
+
+def calcul_acc(lab, preds):
+    return sum(1 for x, y in zip(lab, preds) if x == y) / len(lab)
+
+
+if __name__ == '__main__':
+    batch_size = 1
+    top1_acc = Top1CategoricalAccuracy()
+    rst_path = args.result_dir
+    label_list = []
+    pred_list = []
+    file_list = os.listdir(rst_path)
+    top5_acc = Top5CategoricalAccuracy()
+    with open('./preprocess_Result/imagenet_label.json', "r") as label:
+        labels = json.load(label)
+    for f in file_list:
+        label = f.split("_0.bin")[0] + ".JPEG"
+        label_list.append(labels[label])
+        pred = np.fromfile(os.path.join(rst_path, f), np.float32)
+        pred = pred.reshape(batch_size, int(pred.shape[0] / batch_size))
+        top1_acc.update(pred, [labels[label],])
+        top5_acc.update(pred, [labels[label],])
+    print("Top1 acc: ", top1_acc.eval())
+    print("Top5 acc: ", top5_acc.eval())
diff --git a/research/cv/cait/preprocess.py b/research/cv/cait/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..a124cdb0f8ad167a803b3f1cd209067907076ffb
--- /dev/null
+++ b/research/cv/cait/preprocess.py
@@ -0,0 +1,46 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""preprocess"""
+import argparse
+import json
+import os
+
+parser = argparse.ArgumentParser('preprocess')
+parser.add_argument('--dataset_name', type=str, choices=["imagenet2012"], default="imagenet2012")
+parser.add_argument('--data_path', type=str, default='', help='eval data dir')
+def create_label(result_path, dir_path):
+    """
+    create_label
+    """
+    dirs = os.listdir(dir_path)
+    file_list = []
+    for file in dirs:
+        file_list.append(file)
+    file_list = sorted(file_list)
+    total = 0
+    img_label = {}
+    for i, file_dir in enumerate(file_list):
+        files = os.listdir(os.path.join(dir_path, file_dir))
+        for f in files:
+            img_label[f] = i
+        total += len(files)
+    json_file = os.path.join(result_path, "imagenet_label.json")
+    with open(json_file, "w+") as label:
+        json.dump(img_label, label)
+    print("[INFO] Completed! Total {} data.".format(total))
+
+args = parser.parse_args()
+if __name__ == "__main__":
+    create_label('./preprocess_Result/', args.data_path)
diff --git a/research/cv/cait/requirements.txt b/research/cv/cait/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/research/cv/cait/scripts/run_distribute_train_ascend.sh b/research/cv/cait/scripts/run_distribute_train_ascend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5eb2906ff0f708c7a0a0bf67fef7041cfb18d9de
--- /dev/null
+++ b/research/cv/cait/scripts/run_distribute_train_ascend.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# -lt 2 ]
+then
+    echo "Usage: bash ./scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [CONFIG_PATH]"
+exit 1
+fi
+export RANK_TABLE_FILE=$1
+CONFIG_PATH=$2
+export RANK_SIZE=8
+export DEVICE_NUM=8
+
+
+cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
+echo "the number of logical core" $cores
+avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
+core_gap=`expr $avg_core_per_rank \- 1`
+echo "avg_core_per_rank" $avg_core_per_rank
+echo "core_gap" $core_gap
+for((i=0;i<RANK_SIZE;i++))
+do
+    start=`expr $i \* $avg_core_per_rank`
+    export DEVICE_ID=$i
+    export RANK_ID=$i
+    export DEPLOY_MODE=0
+    export GE_USE_STATIC_MEMORY=1
+    end=`expr $start \+ $core_gap`
+    cmdopt=$start"-"$end
+
+    rm -rf train_parallel$i
+    mkdir ./train_parallel$i
+    cp -r ./src ./train_parallel$i
+    cp  *.py ./train_parallel$i
+    cd ./train_parallel$i || exit
+    echo "start training for rank $i, device $DEVICE_ID rank_id $RANK_ID"
+    env > env.log
+    taskset -c $cmdopt python -u ../train.py \
+    --device_target Ascend \
+    --device_id $i \
+    --cait_config=$CONFIG_PATH > log.txt 2>&1 &
+    cd ../
+done
diff --git a/research/cv/cait/scripts/run_distribute_train_gpu.sh b/research/cv/cait/scripts/run_distribute_train_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..dbf9bb4135e388849d7019da075984ce9643c1f8
--- /dev/null
+++ b/research/cv/cait/scripts/run_distribute_train_gpu.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ $# -lt 3 ]
+then
+    echo "Usage: bash ./scripts/run_distributed_train_gpu.sh [DEVICE_NUM] [CUDA_VISIBLE_DEVICES] [CONFIG_PATH]"
+exit 1
+fi
+
+export RANK_SIZE=$1
+export DEVICE_NUM=$1
+export CUDA_VISIBLE_DEVICES=$2
+CONFIG_PATH=$3
+
+mpirun -n ${DEVICE_NUM} --allow-run-as-root --output-filename log_output \
+                      --merge-stderr-to-stdout python train.py \
+                      --device_target=GPU \
+                      --cait_config=$CONFIG_PATH > log.txt 2>&1 &
+
diff --git a/research/cv/cait/scripts/run_eval_ascend.sh b/research/cv/cait/scripts/run_eval_ascend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..89c10dea60f2cbe99dc0fdb41c223576b88474ec
--- /dev/null
+++ b/research/cv/cait/scripts/run_eval_ascend.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ $# -lt 3 ]
+then
+    echo "Usage: bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH] [CHECKPOINT_PATH]"
+exit 1
+fi
+
+export DEVICE_ID=$1
+CONFIG_PATH=$2
+CHECKPOINT_PATH=$3
+export RANK_SIZE=1
+export DEVICE_NUM=1
+
+rm -rf evaluation_ascend
+mkdir ./evaluation_ascend
+cd ./evaluation_ascend || exit
+echo  "start training for device id $DEVICE_ID"
+env > env.log
+python ../eval.py --device_target=Ascend --device_id=$DEVICE_ID --cait_config=$CONFIG_PATH \
+                  --pretrained=$CHECKPOINT_PATH > eval.log 2>&1 &
+cd ../
diff --git a/research/cv/cait/scripts/run_eval_gpu.sh b/research/cv/cait/scripts/run_eval_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2f97b058a1992e920358718f596cda0c856a8bad
--- /dev/null
+++ b/research/cv/cait/scripts/run_eval_gpu.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ $# -lt 3 ]
+then
+    echo "Usage: bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH] [CHECKPOINT_PATH]"
+exit 1
+fi
+
+export DEVICE_ID=$1
+CONFIG_PATH=$2
+CHECKPOINT_PATH=$3
+export RANK_SIZE=1
+export DEVICE_NUM=1
+
+rm -rf evaluation_ascend
+mkdir ./evaluation_ascend
+cd ./evaluation_ascend || exit
+echo  "start training for device id $DEVICE_ID"
+env > env.log
+python ../eval.py --device_target=GPU --device_id=$DEVICE_ID --cait_config=$CONFIG_PATH \
+                  --pretrained=$CHECKPOINT_PATH > eval.log 2>&1 &
+cd ../
diff --git a/research/cv/cait/scripts/run_infer_310.sh b/research/cv/cait/scripts/run_infer_310.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f691a82f37101c6b996bb665a025897de21f21da
--- /dev/null
+++ b/research/cv/cait/scripts/run_infer_310.sh
@@ -0,0 +1,118 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [[ $# -lt 3 || $# -gt 4 ]]; then
+    echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET_NAME] [DATASET_PATH] [DEVICE_ID]
+    DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
+exit 1
+fi
+
+get_real_path(){
+    if [ "${1:0:1}" == "/" ]; then
+        echo "$1"
+    else
+        echo "$(realpath -m $PWD/$1)"
+    fi
+}
+model=$(get_real_path $1)
+if [ $2 == 'imagenet2012' ]; then
+  dataset_name=$2
+else
+  echo "DATASET_NAME should be 'imagenet2012'"
+  exit 1
+fi
+
+dataset_path=$(get_real_path $3)
+
+device_id=0
+if [ $# == 4 ]; then
+    device_id=$4
+fi
+
+echo "mindir name: "$model
+echo "dataset name: "$dataset_name
+echo "dataset path: "$dataset_path
+echo "device id: "$device_id
+
+export ASCEND_HOME=/usr/local/Ascend/
+if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
+    export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
+    export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
+    export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
+    export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
+    export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
+else
+    export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
+    export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
+    export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
+    export ASCEND_OPP_PATH=$ASCEND_HOME/opp
+fi
+function preprocess_data()
+{
+    if [ -d preprocess_Result ]; then
+        rm -rf ./preprocess_Result
+    fi
+    mkdir preprocess_Result
+    python ../preprocess.py --dataset_name=$dataset_name --data_path=$dataset_path
+}
+
+function compile_app()
+{
+    cd ../ascend310_infer/ || exit
+    bash build.sh &> build.log
+}
+
+function infer()
+{
+    cd - || exit
+    if [ -d result_Files ]; then
+        rm -rf ./result_Files
+    fi
+    if [ -d time_Result ]; then
+        rm -rf ./time_Result
+    fi
+    mkdir result_Files
+    mkdir time_Result
+
+    ../ascend310_infer/out/main --mindir_path=$model --dataset_name=$dataset_name --input0_path=$dataset_path .\
+                                --device_id=$device_id  &> infer.log
+}
+
+function cal_acc()
+{
+    python ../postprocess.py --dataset_name=$dataset_name  &> acc.log
+}
+
+preprocess_data
+if [ $? -ne 0 ]; then
+    echo "preprocess dataset failed"
+    exit 1
+fi
+compile_app
+if [ $? -ne 0 ]; then
+    echo "compile app code failed"
+    exit 1
+fi
+infer
+if [ $? -ne 0 ]; then
+    echo " execute inference failed"
+    exit 1
+fi
+cal_acc
+if [ $? -ne 0 ]; then
+    echo "calculate accuracy failed"
+    exit 1
+fi
diff --git a/research/cv/cait/scripts/run_standalone_train_ascend.sh b/research/cv/cait/scripts/run_standalone_train_ascend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ab28bf1f32fdb51348ce138ea5c5ff895f786908
--- /dev/null
+++ b/research/cv/cait/scripts/run_standalone_train_ascend.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ $# -lt 2 ]
+then
+    echo "Usage: bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH]"
+exit 1
+fi
+
+export RANK_SIZE=1
+export DEVICE_NUM=1
+export DEVICE_ID=$1
+CONFIG_PATH=$2
+
+rm -rf train_standalone
+mkdir ./train_standalone
+cd ./train_standalone || exit
+echo  "start training for device id $DEVICE_ID"
+env > env.log
+python -u ../train.py \
+    --device_id=$DEVICE_ID \
+    --device_target="Ascend" \
+    --cait_config=$CONFIG_PATH > log.txt 2>&1 &
+cd ../
diff --git a/research/cv/cait/scripts/run_standalone_train_gpu.sh b/research/cv/cait/scripts/run_standalone_train_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b0efb2a3462551900de002e4e42ac5715b53240b
--- /dev/null
+++ b/research/cv/cait/scripts/run_standalone_train_gpu.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ $# -lt 2 ]
+then
+    echo "Usage: bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH]"
+exit 1
+fi
+
+export RANK_SIZE=1
+export DEVICE_NUM=1
+export DEVICE_ID=$1
+CONFIG_PATH=$2
+
+rm -rf train_standalone
+mkdir ./train_standalone
+cd ./train_standalone || exit
+echo  "start training for device id $DEVICE_ID"
+env > env.log
+python -u ../train.py \
+    --device_id=$DEVICE_ID \
+    --device_target="GPU" \
+    --cait_config=$CONFIG_PATH > log.txt 2>&1 &
+cd ../
diff --git a/research/cv/cait/src/__init__.py b/research/cv/cait/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/research/cv/cait/src/args.py b/research/cv/cait/src/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fbe02a7a3dceef9380c1b03c6f69aae6974f6b4
--- /dev/null
+++ b/research/cv/cait/src/args.py
@@ -0,0 +1,128 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""global args for CAIT"""
+import argparse
+import ast
+import os
+import sys
+
+import yaml
+
+from src.configs import parser as _parser
+
+args = None
+
+
+def parse_arguments():
+    """parse_arguments"""
+    global args
+    parser = argparse.ArgumentParser(description="MindSpore CAIT Training")
+
+    parser.add_argument("-a", "--arch", metavar="ARCH", default="cait_XXS24_224", help="model architecture")
+    parser.add_argument("--accumulation_step", default=1, type=int, help="accumulation step")
+    parser.add_argument("--amp_level", default="O1", choices=["O0", "O1", "O2", "O3"], help="AMP Level")
+    parser.add_argument("-b", "--batch_size", default=64, type=int, metavar="N",
+                        help="mini-batch size (default: 256), this is the total "
+                             "batch size of all GPUs on the current node when "
+                             "using Data Parallel or Distributed Data Parallel")
+    parser.add_argument("--beta", default=[0.9, 0.999], type=lambda x: [float(a) for a in x.split(",")],
+                        help="beta for optimizer")
+    parser.add_argument("--crop_pct", default=0.875, type=float, help="Crop Pct")
+    parser.add_argument("--clip_global_norm", default=False, type=ast.literal_eval, help="clip global norm")
+    parser.add_argument('--data_url', default="./data", help='Location of data.')
+    parser.add_argument('--clip_global_norm_value', default=5., type=float, help='clip_global_norm_value.')
+    parser.add_argument("--device_id", default=0, type=int, help="Device Id")
+    parser.add_argument("--device_num", default=1, type=int, help="device num")
+    parser.add_argument("--curr_epoch", default=0, type=int, help="curr epoch")
+    parser.add_argument("--device_target", default="GPU", choices=["GPU", "Ascend", "CPU"], type=str)
+    parser.add_argument("--epochs", default=300, type=int, metavar="N", help="number of total epochs to run")
+    parser.add_argument("--eps", default=1e-8, type=float)
+    parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
+    parser.add_argument("--in_channel", default=3, type=int)
+    parser.add_argument("--is_dynamic_loss_scale", default=1, type=int, help="is_dynamic_loss_scale ")
+    parser.add_argument("--keep_checkpoint_max", default=20, type=int, help="keep checkpoint max num")
+    parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd")
+    parser.add_argument("--set", help="name of dataset", type=str, default="ImageNet")
+    parser.add_argument("--graph_mode", default=0, type=int, help="graph mode with 0, python with 1")
+    parser.add_argument("--mix_up", default=0.8, type=float, help="mix up")
+    parser.add_argument("--mixup_off_epoch", default=0., type=int, help="mix_up off epoch")
+    parser.add_argument("--interpolation", default="bicubic", type=str)
+    parser.add_argument("-j", "--num_parallel_workers", default=20, type=int, metavar="N",
+                        help="number of data loading workers (default: 20)")
+    parser.add_argument("--start_epoch", default=0, type=int, metavar="N",
+                        help="manual epoch number (useful on restarts)")
+    parser.add_argument("--warmup_length", default=0, type=int, help="Number of warmup iterations")
+    parser.add_argument("--warmup_lr", default=5e-7, type=float, help="warm up learning rate")
+    parser.add_argument("--wd", "--weight_decay", default=0.05, type=float, metavar="W",
+                        help="weight decay (default: 1e-4)", dest="weight_decay")
+    parser.add_argument("--loss_scale", default=1024, type=int, help="loss_scale")
+    parser.add_argument("--base_lr", "--learning_rate", default=5e-4, type=float, help="initial lr", dest="base_lr")
+    parser.add_argument("--lr_scheduler", default="cosine_annealing", help="Schedule for the learning rate.")
+    parser.add_argument("--lr_adjust", default=30, type=float, help="Interval to drop lr")
+    parser.add_argument("--lr_gamma", default=0.97, type=int, help="Multistep multiplier")
+    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
+    parser.add_argument("--num_classes", default=1000, type=int)
+    parser.add_argument("--pretrained", dest="pretrained", default=None, type=str, help="use pre-trained model")
+    parser.add_argument("--cait_config", help="Config file to use (see configs dir)", default=None, required=True)
+    parser.add_argument("--seed", default=0, type=int, help="seed for initializing training. ")
+    parser.add_argument("--save_every", default=2, type=int, help="save_every:50")
+    parser.add_argument("--label_smoothing", type=float, help="Label smoothing to use, default 0.0", default=0.1)
+    parser.add_argument("--image_size", default=224, help="Image Size.", type=int)
+    parser.add_argument('--train_url', default="./", help='Location of training outputs.')
+    parser.add_argument("--run_modelarts", type=ast.literal_eval, default=False, help="Whether run on modelarts")
+    args = parser.parse_args()
+
+    # Allow for use from notebook without config file
+    if len(sys.argv) > 1:
+        get_config()
+
+
+def get_config():
+    """get_config"""
+    global args
+    override_args = _parser.argv_to_vars(sys.argv)
+
+    print(f"=> Reading YAML config from {args.cait_config}")
+    # load yaml file
+    if args.run_modelarts:
+        import moxing as mox
+        if not args.cait_config.startswith("obs:/"):
+            args.cait_config = "obs:/" + args.cait_config
+        with mox.file.File(args.cait_config, 'r') as f:
+            yaml_txt = f.read()
+    else:
+        yaml_txt = open(args.cait_config).read()
+
+    # override args
+    loaded_yaml = yaml.load(yaml_txt, Loader=yaml.FullLoader)
+    for v in override_args:
+        loaded_yaml[v] = getattr(args, v)
+
+    args.__dict__.update(loaded_yaml)
+    print(args)
+
+    if "DEVICE_NUM" not in os.environ.keys():
+        os.environ["DEVICE_NUM"] = str(args.device_num)
+        os.environ["RANK_SIZE"] = str(args.device_num)
+
+
+def run_args():
+    """run and get args"""
+    global args
+    if args is None:
+        parse_arguments()
+
+
+run_args()
diff --git a/research/cv/cait/src/configs/cait_XXS24_224.yaml b/research/cv/cait/src/configs/cait_XXS24_224.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..885eb6339868f1c47100dcca2b183c5ff2a12c48
--- /dev/null
+++ b/research/cv/cait/src/configs/cait_XXS24_224.yaml
@@ -0,0 +1,49 @@
+# Architecture Top1-77.6%
+arch: cait_XXS24_224
+
+# ===== Dataset ===== #
+data_url: ./data/imagenet
+set: ImageNet
+num_classes: 1000
+mix_up: 0.8
+cutmix: 1.0
+auto_augment: rand-m9-mstd0.5-inc1
+interpolation: bicubic
+re_prob: 0.25
+re_mode: pixel
+re_count: 1
+mixup_prob: 1.
+switch_prob: 0.5
+mixup_mode: batch
+
+
+# ===== Learning Rate Policy ======== #
+optimizer: adamw
+base_lr: 0.0005
+warmup_lr: 0.00000007
+min_lr: 0.000006
+lr_scheduler: cosine_lr
+warmup_length: 5
+nonlinearity: GELU
+
+
+# ===== Network training config ===== #
+amp_level: O2
+keep_bn_fp32: True
+beta: [ 0.9, 0.999 ]
+clip_global_norm_value: 5.
+is_dynamic_loss_scale: True
+epochs: 400
+label_smoothing: 0.1
+loss_scale: 1024
+weight_decay: 0.05
+momentum: 0.9
+batch_size: 128
+
+# ===== Hardware setup ===== #
+num_parallel_workers: 16
+device_target: GPU
+
+# ===== Model config ===== #
+drop_path_rate: 0.05
+image_size: 224
\ No newline at end of file
diff --git a/research/cv/cait/src/configs/parser.py b/research/cv/cait/src/configs/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..437ae943bbd1f6dd0657d186e3135c9850c52dbf
--- /dev/null
+++ b/research/cv/cait/src/configs/parser.py
@@ -0,0 +1,39 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""parser function"""
+USABLE_TYPES = set([float, int])
+
+
+def trim_preceding_hyphens(st):
+    i = 0
+    while st[i] == "-":
+        i += 1
+
+    return st[i:]
+
+
+def arg_to_varname(st: str):
+    st = trim_preceding_hyphens(st)
+    st = st.replace("-", "_")
+
+    return st.split("=")[0]
+
+
+def argv_to_vars(argv):
+    var_names = []
+    for arg in argv:
+        if arg.startswith("-") and arg_to_varname(arg) != "cait_config":
+            var_names.append(arg_to_varname(arg))
+    return var_names
diff --git a/research/cv/cait/src/data/__init__.py b/research/cv/cait/src/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c031d10103182a72f3b348677b378dbcfc2a72d
--- /dev/null
+++ b/research/cv/cait/src/data/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init datasets"""
+from .imagenet import ImageNet
diff --git a/research/cv/cait/src/data/augment/__init__.py b/research/cv/cait/src/data/augment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a85287a9dcbd95414aa8ccba5e842439a3d0423
--- /dev/null
+++ b/research/cv/cait/src/data/augment/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init augment"""
+from .auto_augment import _pil_interp, rand_augment_transform
+from .mixup import Mixup
+from .random_erasing import RandomErasing
diff --git a/research/cv/cait/src/data/augment/auto_augment.py b/research/cv/cait/src/data/augment/auto_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..f65a1d81574ef523769151c5306ea40ef715a58b
--- /dev/null
+++ b/research/cv/cait/src/data/augment/auto_augment.py
@@ -0,0 +1,894 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" AutoAugment, RandAugment, and AugMix for MindSpore
+
+This code implements the searched ImageNet policies with various tweaks and improvements and
+does not include any of the search code.
+
+AA and RA Implementation adapted from:
+    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
+
+AugMix adapted from:
+    https://github.com/google-research/augmix
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+import random
+import re
+
+import numpy as np
+import PIL
+from PIL import Image, ImageOps, ImageEnhance
+
+_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
+
+_FILL = (128, 128, 128)
+
+# This signifies the max integer that the controller RNN could predict for the
+# augmentation scheme.
+_MAX_LEVEL = 10.
+
+_HPARAMS_DEFAULT = dict(
+    translate_const=250,
+    img_mean=_FILL,
+)
+
+_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+
+
+def _pil_interp(method):
+    """Interpolation method selection"""
+    if method == 'bicubic':
+        func = Image.BICUBIC
+    elif method == 'lanczos':
+        func = Image.LANCZOS
+    elif method == 'hamming':
+        func = Image.HAMMING
+    else:
+        func = Image.BILINEAR
+    return func
+
+
+def _interpolation(kwargs):
+    """_interpolation"""
+    interpolation = kwargs.pop('resample', Image.BILINEAR)
+    interpolation = random.choice(interpolation) \
+        if isinstance(interpolation, (list, tuple)) else interpolation
+    return interpolation
+
+def _check_args_tf(kwargs):
+    """_check_args_tf"""
+    if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
+        kwargs.pop('fillcolor')
+    kwargs['resample'] = _interpolation(kwargs)
+
+
+def shear_x(img, factor, **kwargs):
+    """shear_x"""
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
+
+
+def shear_y(img, factor, **kwargs):
+    """shear_y"""
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
+
+
+def translate_x_rel(img, pct, **kwargs):
+    """translate_x_rel"""
+    pixels = pct * img.size[0]
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_rel(img, pct, **kwargs):
+    """translate_y_rel"""
+    pixels = pct * img.size[1]
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def translate_x_abs(img, pixels, **kwargs):
+    """translate_x_abs"""
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_abs(img, pixels, **kwargs):
+    """translate_y_abs"""
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def rotate(img, degrees, **kwargs):
+    """rotate"""
+    _check_args_tf(kwargs)
+    if _PIL_VER >= (5, 2):
+        func = img.rotate(degrees, **kwargs)
+    elif _PIL_VER >= (5, 0):
+        w, h = img.size
+        post_trans = (0, 0)
+        rotn_center = (w / 2.0, h / 2.0)
+        angle = -math.radians(degrees)
+        matrix = [
+            round(math.cos(angle), 15),
+            round(math.sin(angle), 15),
+            0.0,
+            round(-math.sin(angle), 15),
+            round(math.cos(angle), 15),
+            0.0,
+        ]
+
+        def transform(x, y, matrix):
+            (a, b, c, d, e, f) = matrix
+            return a * x + b * y + c, d * x + e * y + f
+
+        matrix[2], matrix[5] = transform(
+            -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
+        )
+        matrix[2] += rotn_center[0]
+        matrix[5] += rotn_center[1]
+        func = img.transform(img.size, Image.AFFINE, matrix, **kwargs)
+    else:
+        func = img.rotate(degrees, resample=kwargs['resample'])
+    return func
+
+
+def auto_contrast(img, **__):
+    """auto_contrast"""
+    return ImageOps.autocontrast(img)
+
+
+def invert(img, **__):
+    """invert"""
+    return ImageOps.invert(img)
+
+
+def equalize(img, **__):
+    """equalize"""
+    return ImageOps.equalize(img)
+
+
+def solarize(img, thresh, **__):
+    """solarize"""
+    return ImageOps.solarize(img, thresh)
+
+
+def solarize_add(img, add, thresh=128, **__):
+    """solarize_add"""
+    lut = []
+    for i in range(256):
+        if i < thresh:
+            lut.append(min(255, i + add))
+        else:
+            lut.append(i)
+    if img.mode in ("L", "RGB"):
+        if img.mode == "RGB" and len(lut) == 256:
+            lut = lut + lut + lut
+        func = img.point(lut)
+    else:
+        func = img
+    return func
+
+
+def posterize(img, bits_to_keep, **__):
+    """posterize"""
+    if bits_to_keep >= 8:
+        func = img
+    else:
+        func = ImageOps.posterize(img, bits_to_keep)
+    return func
+
+
+def contrast(img, factor, **__):
+    """contrast"""
+    return ImageEnhance.Contrast(img).enhance(factor)
+
+
+def color(img, factor, **__):
+    """color"""
+    return ImageEnhance.Color(img).enhance(factor)
+
+
+def brightness(img, factor, **__):
+    """brightness"""
+    return ImageEnhance.Brightness(img).enhance(factor)
+
+
+def sharpness(img, factor, **__):
+    """sharpness"""
+    return ImageEnhance.Sharpness(img).enhance(factor)
+
+
+def _randomly_negate(v):
+    """With 50% prob, negate the value"""
+    return -v if random.random() > 0.5 else v
+
+
+def _rotate_level_to_arg(level, _hparams):
+    """_randomly_negate"""
+    # range [-30, 30]
+    level = (level / _MAX_LEVEL) * 30.
+    level = _randomly_negate(level)
+    return (level,)
+
+
+def _enhance_level_to_arg(level, _hparams):
+    """_enhance_level_to_arg"""
+    # range [0.1, 1.9]
+    return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
+
+
+def _enhance_increasing_level_to_arg(level, _hparams):
+    """_enhance_increasing_level_to_arg"""
+    # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
+    # range [0.1, 1.9]
+    level = (level / _MAX_LEVEL) * .9
+    level = 1.0 + _randomly_negate(level)
+    return (level,)
+
+
+def _shear_level_to_arg(level, _hparams):
+    """_shear_level_to_arg"""
+    # range [-0.3, 0.3]
+    level = (level / _MAX_LEVEL) * 0.3
+    level = _randomly_negate(level)
+    return (level,)
+
+
+def _translate_abs_level_to_arg(level, hparams):
+    """_translate_abs_level_to_arg"""
+    translate_const = hparams['translate_const']
+    level = (level / _MAX_LEVEL) * float(translate_const)
+    level = _randomly_negate(level)
+    return (level,)
+
+
+def _translate_rel_level_to_arg(level, hparams):
+    """_translate_rel_level_to_arg"""
+    # default range [-0.45, 0.45]
+    translate_pct = hparams.get('translate_pct', 0.45)
+    level = (level / _MAX_LEVEL) * translate_pct
+    level = _randomly_negate(level)
+    return (level,)
+
+
+def _posterize_level_to_arg(level, _hparams):
+    """_posterize_level_to_arg"""
+    # As per Tensorflow TPU EfficientNet impl
+    # range [0, 4], 'keep 0 up to 4 MSB of original image'
+    # intensity/severity of augmentation decreases with level
+    return (int((level / _MAX_LEVEL) * 4),)
+
+
+def _posterize_increasing_level_to_arg(level, hparams):
+    """_posterize_increasing_level_to_arg"""
+    # As per Tensorflow models research and UDA impl
+    # range [4, 0], 'keep 4 down to 0 MSB of original image',
+    # intensity/severity of augmentation increases with level
+    return (4 - _posterize_level_to_arg(level, hparams)[0],)
+
+
+def _posterize_original_level_to_arg(level, _hparams):
+    """_posterize_original_level_to_arg"""
+    # As per original AutoAugment paper description
+    # range [4, 8], 'keep 4 up to 8 MSB of image'
+    # intensity/severity of augmentation decreases with level
+    return (int((level / _MAX_LEVEL) * 4) + 4,)
+
+
+def _solarize_level_to_arg(level, _hparams):
+    """_solarize_level_to_arg"""
+    # range [0, 256]
+    # intensity/severity of augmentation decreases with level
+    return (int((level / _MAX_LEVEL) * 256),)
+
+
+def _solarize_increasing_level_to_arg(level, _hparams):
+    """_solarize_increasing_level_to_arg"""
+    # range [0, 256]
+    # intensity/severity of augmentation increases with level
+    return (256 - _solarize_level_to_arg(level, _hparams)[0],)
+
+
+def _solarize_add_level_to_arg(level, _hparams):
+    """_solarize_add_level_to_arg"""
+    # range [0, 110]
+    return (int((level / _MAX_LEVEL) * 110),)
+
+
+LEVEL_TO_ARG = {
+    'AutoContrast': None,
+    'Equalize': None,
+    'Invert': None,
+    'Rotate': _rotate_level_to_arg,
+    # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
+    'Posterize': _posterize_level_to_arg,
+    'PosterizeIncreasing': _posterize_increasing_level_to_arg,
+    'PosterizeOriginal': _posterize_original_level_to_arg,
+    'Solarize': _solarize_level_to_arg,
+    'SolarizeIncreasing': _solarize_increasing_level_to_arg,
+    'SolarizeAdd': _solarize_add_level_to_arg,
+    'Color': _enhance_level_to_arg,
+    'ColorIncreasing': _enhance_increasing_level_to_arg,
+    'Contrast': _enhance_level_to_arg,
+    'ContrastIncreasing': _enhance_increasing_level_to_arg,
+    'Brightness': _enhance_level_to_arg,
+    'BrightnessIncreasing': _enhance_increasing_level_to_arg,
+    'Sharpness': _enhance_level_to_arg,
+    'SharpnessIncreasing': _enhance_increasing_level_to_arg,
+    'ShearX': _shear_level_to_arg,
+    'ShearY': _shear_level_to_arg,
+    'TranslateX': _translate_abs_level_to_arg,
+    'TranslateY': _translate_abs_level_to_arg,
+    'TranslateXRel': _translate_rel_level_to_arg,
+    'TranslateYRel': _translate_rel_level_to_arg,
+}
+
+NAME_TO_OP = {
+    'AutoContrast': auto_contrast,
+    'Equalize': equalize,
+    'Invert': invert,
+    'Rotate': rotate,
+    'Posterize': posterize,
+    'PosterizeIncreasing': posterize,
+    'PosterizeOriginal': posterize,
+    'Solarize': solarize,
+    'SolarizeIncreasing': solarize,
+    'SolarizeAdd': solarize_add,
+    'Color': color,
+    'ColorIncreasing': color,
+    'Contrast': contrast,
+    'ContrastIncreasing': contrast,
+    'Brightness': brightness,
+    'BrightnessIncreasing': brightness,
+    'Sharpness': sharpness,
+    'SharpnessIncreasing': sharpness,
+    'ShearX': shear_x,
+    'ShearY': shear_y,
+    'TranslateX': translate_x_abs,
+    'TranslateY': translate_y_abs,
+    'TranslateXRel': translate_x_rel,
+    'TranslateYRel': translate_y_rel,
+}
+
+
+class AugmentOp:
+    """AugmentOp"""
+
+    def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
+        hparams = hparams or _HPARAMS_DEFAULT
+        self.aug_fn = NAME_TO_OP[name]
+        self.level_fn = LEVEL_TO_ARG[name]
+        self.prob = prob
+        self.magnitude = magnitude
+        self.hparams = hparams.copy()
+        self.kwargs = dict(
+            fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
+            resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
+        )
+
+        # If magnitude_std is > 0, we introduce some randomness
+        # in the usually fixed policy and sample magnitude from a normal distribution
+        # with mean `magnitude` and std-dev of `magnitude_std`.
+        # NOTE This is my own hack, being tested, not in papers or reference impls.
+        # If magnitude_std is inf, we sample magnitude from a uniform distribution
+        self.magnitude_std = self.hparams.get('magnitude_std', 0)
+
+    def __call__(self, img):
+        """apply augment"""
+        if self.prob < 1.0 and random.random() > self.prob:
+            return img
+        magnitude = self.magnitude
+        if self.magnitude_std:
+            if self.magnitude_std == float('inf'):
+                magnitude = random.uniform(0, magnitude)
+            elif self.magnitude_std > 0:
+                magnitude = random.gauss(magnitude, self.magnitude_std)
+        magnitude = min(_MAX_LEVEL, max(0, magnitude))  # clip to valid range
+        level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
+        return self.aug_fn(img, *level_args, **self.kwargs)
+
+
+def auto_augment_policy_v0(hparams):
+    """auto_augment_policy_v0"""
+    # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
+    policy = [
+        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+        [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+        [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],  # This results in black image with Tpu posterize
+        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy_v0r(hparams):
+    """auto_augment_policy_v0r"""
+    # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
+    # in Google research implementation (number of bits discarded increases with magnitude)
+    policy = [
+        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+        [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
+        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+        [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
+        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy_original(hparams):
+    """auto_augment_policy_original"""
+    # ImageNet policy from
+    # https://arxiv.org/abs/1805.09501
+
+    policy = [
+        [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+        [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+        [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
+        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+        [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
+        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy_originalr(hparams):
+    """auto_augment_policy_originalr"""
+    # ImageNet policy from
+    # https://arxiv.org/abs/1805.09501 with research posterize variation
+    policy = [
+        [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+        [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+        [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
+        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+        [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
+        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy(name='v0', hparams=None):
+    """auto_augment_policy"""
+    hparams = hparams or _HPARAMS_DEFAULT
+    if name == 'original':
+        func = auto_augment_policy_original(hparams)
+    elif name == 'originalr':
+        func = auto_augment_policy_originalr(hparams)
+    elif name == 'v0':
+        func = auto_augment_policy_v0(hparams)
+    elif name == 'v0r':
+        func = auto_augment_policy_v0r(hparams)
+    else:
+        assert False, 'Unknown AA policy (%s)' % name
+    return func
+
+class AutoAugment:
+    """AutoAugment"""
+    def __init__(self, policy):
+        self.policy = policy
+
+    def __call__(self, img):
+        """apply autoaugment"""
+        sub_policy = random.choice(self.policy)
+        for op in sub_policy:
+            img = op(img)
+        return img
+
+
+def auto_augment_transform(config_str, hparams):
+    """
+    Create a AutoAugment transform
+
+    :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
+    dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
+    The remaining sections, not order specific determine
+        'mstd' -  float std deviation of magnitude noise applied
+    Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
+
+    :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
+
+    :return: A MindSpore compatible Transform
+    """
+    config = config_str.split('-')
+    policy_name = config[0]
+    config = config[1:]
+    for c in config:
+        cs = re.split(r'(\d.*)', c)
+        if len(cs) < 2:
+            continue
+        key, val = cs[:2]
+        if key == 'mstd':
+            # noise param injected via hparams for now
+            hparams.setdefault('magnitude_std', float(val))
+        else:
+            assert False, 'Unknown AutoAugment config section'
+    aa_policy = auto_augment_policy(policy_name, hparams=hparams)
+    return AutoAugment(aa_policy)
+
+
+_RAND_TRANSFORMS = [
+    'AutoContrast',
+    'Equalize',
+    'Invert',
+    'Rotate',
+    'Posterize',
+    'Solarize',
+    'SolarizeAdd',
+    'Color',
+    'Contrast',
+    'Brightness',
+    'Sharpness',
+    'ShearX',
+    'ShearY',
+    'TranslateXRel',
+    'TranslateYRel',
+    # 'Cutout'  # NOTE I've implement this as random erasing separately
+]
+
+_RAND_INCREASING_TRANSFORMS = [
+    'AutoContrast',
+    'Equalize',
+    'Invert',
+    'Rotate',
+    'PosterizeIncreasing',
+    'SolarizeIncreasing',
+    'SolarizeAdd',
+    'ColorIncreasing',
+    'ContrastIncreasing',
+    'BrightnessIncreasing',
+    'SharpnessIncreasing',
+    'ShearX',
+    'ShearY',
+    'TranslateXRel',
+    'TranslateYRel',
+    # 'Cutout'  # NOTE I've implement this as random erasing separately
+]
+
+# These experimental weights are based loosely on the relative improvements mentioned in paper.
+# They may not result in increased performance, but could likely be tuned to so.
+_RAND_CHOICE_WEIGHTS_0 = {
+    'Rotate': 0.3,
+    'ShearX': 0.2,
+    'ShearY': 0.2,
+    'TranslateXRel': 0.1,
+    'TranslateYRel': 0.1,
+    'Color': .025,
+    'Sharpness': 0.025,
+    'AutoContrast': 0.025,
+    'Solarize': .005,
+    'SolarizeAdd': .005,
+    'Contrast': .005,
+    'Brightness': .005,
+    'Equalize': .005,
+    'Posterize': 0,
+    'Invert': 0,
+}
+
+
+def _select_rand_weights(weight_idx=0, transforms=None):
+    """_select_rand_weights"""
+    transforms = transforms or _RAND_TRANSFORMS
+    assert weight_idx == 0  # only one set of weights currently
+    rand_weights = _RAND_CHOICE_WEIGHTS_0
+    probs = [rand_weights[k] for k in transforms]
+    probs /= np.sum(probs)
+    return probs
+
+
+def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
+    """rand_augment_ops"""
+    hparams = hparams or _HPARAMS_DEFAULT
+    transforms = transforms or _RAND_TRANSFORMS
+    return [AugmentOp(
+        name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class RandAugment:
+    """RandAugment"""
+    def __init__(self, ops, num_layers=2, choice_weights=None):
+        self.ops = ops
+        self.num_layers = num_layers
+        self.choice_weights = choice_weights
+
+    def __call__(self, img):
+        # no replacement when using weighted choice
+        ops = np.random.choice(
+            self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
+        for op in ops:
+            img = op(img)
+        return img
+
+
+def rand_augment_transform(config_str, hparams):
+    """
+    Create a RandAugment transform
+
+    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+    sections, not order specific determine
+        'm' - integer magnitude of rand augment
+        'n' - integer num layers (number of transform ops selected per image)
+        'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
+        'mstd' -  float std deviation of magnitude noise applied
+        'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
+    Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
+    'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
+
+    :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
+
+    :return: A MindSpore compatible Transform
+    """
+    magnitude = _MAX_LEVEL  # default to _MAX_LEVEL for magnitude (currently 10)
+    num_layers = 2  # default to 2 ops per image
+    weight_idx = None  # default to no probability weights for op choice
+    transforms = _RAND_TRANSFORMS
+    config = config_str.split('-')
+    assert config[0] == 'rand'
+    # [rand, m9, mstd0.5, inc1]
+    config = config[1:]
+    for c in config:
+        cs = re.split(r'(\d.*)', c)
+        if len(cs) < 2:
+            continue
+        key, val = cs[:2]
+        if key == 'mstd':
+            # noise param injected via hparams for now
+            hparams.setdefault('magnitude_std', float(val))
+        elif key == 'inc':
+            if bool(val):
+                transforms = _RAND_INCREASING_TRANSFORMS
+        elif key == 'm':
+            magnitude = int(val)
+        elif key == 'n':
+            num_layers = int(val)
+        elif key == 'w':
+            weight_idx = int(val)
+        else:
+            assert False, 'Unknown RandAugment config section'
+    ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
+    choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
+    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
+
+
+_AUGMIX_TRANSFORMS = [
+    'AutoContrast',
+    'ColorIncreasing',  # not in paper
+    'ContrastIncreasing',  # not in paper
+    'BrightnessIncreasing',  # not in paper
+    'SharpnessIncreasing',  # not in paper
+    'Equalize',
+    'Rotate',
+    'PosterizeIncreasing',
+    'SolarizeIncreasing',
+    'ShearX',
+    'ShearY',
+    'TranslateXRel',
+    'TranslateYRel',
+]
+
+
+def augmix_ops(magnitude=10, hparams=None, transforms=None):
+    """augmix_ops"""
+    hparams = hparams or _HPARAMS_DEFAULT
+    transforms = transforms or _AUGMIX_TRANSFORMS
+    return [AugmentOp(
+        name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class AugMixAugment:
+    """ AugMix Transform
+    Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
+
+    From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
+    https://arxiv.org/abs/1912.02781
+    """
+
+    def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
+        self.ops = ops
+        self.alpha = alpha
+        self.width = width
+        self.depth = depth
+        self.blended = blended  # blended mode is faster but not well tested
+
+    def _calc_blended_weights(self, ws, m):
+        """_calc_blended_weights"""
+        ws = ws * m
+        cump = 1.
+        rws = []
+        for w in ws[::-1]:
+            alpha = w / cump
+            cump *= (1 - alpha)
+            rws.append(alpha)
+        return np.array(rws[::-1], dtype=np.float32)
+
+    def _apply_blended(self, img, mixing_weights, m):
+        """_apply_blended"""
+        # This is my first crack and implementing a slightly faster mixed augmentation. Instead
+        # of accumulating the mix for each chain in a Numpy array and then blending with original,
+        # it recomputes the blending coefficients and applies one PIL image blend per chain.
+        # TODO the results appear in the right ballpark but they differ by more than rounding.
+        img_orig = img.copy()
+        ws = self._calc_blended_weights(mixing_weights, m)
+        for w in ws:
+            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+            ops = np.random.choice(self.ops, depth, replace=True)
+            img_aug = img_orig  # no ops are in-place, deep copy not necessary
+            for op in ops:
+                img_aug = op(img_aug)
+            img = Image.blend(img, img_aug, w)
+        return img
+
+    def _apply_basic(self, img, mixing_weights, m):
+        """_apply_basic"""
+        # This is a literal adaptation of the paper/official implementation without normalizations and
+        # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
+        # typical augmentation transforms, could use a GPU / Kornia implementation.
+        img_shape = img.size[0], img.size[1], len(img.getbands())
+        mixed = np.zeros(img_shape, dtype=np.float32)
+        for mw in mixing_weights:
+            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+            ops = np.random.choice(self.ops, depth, replace=True)
+            img_aug = img  # no ops are in-place, deep copy not necessary
+            for op in ops:
+                img_aug = op(img_aug)
+            mixed += mw * np.asarray(img_aug, dtype=np.float32)
+        np.clip(mixed, 0, 255., out=mixed)
+        mixed = Image.fromarray(mixed.astype(np.uint8))
+        return Image.blend(img, mixed, m)
+
+    def __call__(self, img):
+        """AugMixAugment apply"""
+        mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
+        m = np.float32(np.random.beta(self.alpha, self.alpha))
+        if self.blended:
+            mixed = self._apply_blended(img, mixing_weights, m)
+        else:
+            mixed = self._apply_basic(img, mixing_weights, m)
+        return mixed
+
+
+def augment_and_mix_transform(config_str, hparams):
+    """ Create AugMix MindSpore transform
+
+    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+    sections, not order specific determine
+        'm' - integer magnitude (severity) of augmentation mix (default: 3)
+        'w' - integer width of augmentation chain (default: 3)
+        'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
+        'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
+        'mstd' -  float std deviation of magnitude noise applied (default: 0)
+    Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
+
+    :param hparams: Other hparams (kwargs) for the Augmentation transforms
+
+    :return: A MindSpore compatible Transform
+    """
+    magnitude = 3
+    width = 3
+    depth = -1
+    alpha = 1.
+    blended = False
+    hparams['magnitude_std'] = float('inf')
+    config = config_str.split('-')
+    assert config[0] == 'augmix'
+    config = config[1:]
+    for c in config:
+        cs = re.split(r'(\d.*)', c)
+        if len(cs) < 2:
+            continue
+        key, val = cs[:2]
+        if key == 'mstd':
+            # noise param injected via hparams for now
+            hparams.setdefault('magnitude_std', float(val))
+        elif key == 'm':
+            magnitude = int(val)
+        elif key == 'w':
+            width = int(val)
+        elif key == 'd':
+            depth = int(val)
+        elif key == 'a':
+            alpha = float(val)
+        elif key == 'b':
+            blended = bool(val)
+        else:
+            assert False, 'Unknown AugMix config section'
+    ops = augmix_ops(magnitude=magnitude, hparams=hparams)
+    return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
diff --git a/research/cv/cait/src/data/augment/mixup.py b/research/cv/cait/src/data/augment/mixup.py
new file mode 100644
index 0000000000000000000000000000000000000000..00d0e6fe856de5babeee29ede677f8835336d34c
--- /dev/null
+++ b/research/cv/cait/src/data/augment/mixup.py
@@ -0,0 +1,255 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" Mixup and Cutmix
+
+Papers:
+mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
+
+CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
+
+Code Reference:
+CutMix: https://github.com/clovaai/CutMix-PyTorch
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import numpy as np
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import ops as P
+
+
+def one_hot(x, num_classes, on_value=1., off_value=0.):
+    """one_hot"""
+    x = x.reshape(-1)
+    x = np.eye(num_classes)[x]
+    x = np.clip(x, a_min=off_value, a_max=on_value, dtype=np.float32)
+    return x
+
+
+def mixup_target(target, num_classes, lam=1., smoothing=0.0):
+    """mixup_target"""
+    off_value = smoothing / num_classes
+    on_value = 1. - smoothing + off_value
+    y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
+    y2 = one_hot(np.flip(target, axis=0), num_classes, on_value=on_value, off_value=off_value)
+    return y1 * lam + y2 * (1. - lam)
+
+
+def rand_bbox(img_shape, lam, margin=0., count=None):
+    """ Standard CutMix bounding-box
+    Generates a random square bbox based on lambda value. This impl includes
+    support for enforcing a border margin as percent of bbox dimensions.
+
+    Args:
+        img_shape (tuple): Image shape as tuple
+        lam (float): Cutmix lambda value
+        margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
+        count (int): Number of bbox to generate
+    """
+    ratio = np.sqrt(1 - lam)
+    img_h, img_w = img_shape[-2:]
+    cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
+    margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
+    cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
+    cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
+    yl = np.clip(cy - cut_h // 2, 0, img_h)
+    yh = np.clip(cy + cut_h // 2, 0, img_h)
+    xl = np.clip(cx - cut_w // 2, 0, img_w)
+    xh = np.clip(cx + cut_w // 2, 0, img_w)
+    return yl, yh, xl, xh
+
+
+def rand_bbox_minmax(img_shape, minmax, count=None):
+    """ Min-Max CutMix bounding-box
+    Inspired by Darknet cutmix impl, generates a random rectangular bbox
+    based on min/max percent values applied to each dimension of the input image.
+
+    Typical defaults for minmax are usually in the  .2-.3 for min and .8-.9 range for max.
+
+    Args:
+        img_shape (tuple): Image shape as tuple
+        minmax (tuple or list): Min and max bbox ratios (as percent of image size)
+        count (int): Number of bbox to generate
+    """
+    assert len(minmax) == 2
+    img_h, img_w = img_shape[-2:]
+    cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
+    cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
+    yl = np.random.randint(0, img_h - cut_h, size=count)
+    xl = np.random.randint(0, img_w - cut_w, size=count)
+    yu = yl + cut_h
+    xu = xl + cut_w
+    return yl, yu, xl, xu
+
+
+def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
+    """ Generate bbox and apply lambda correction.
+    """
+    if ratio_minmax is not None:
+        yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
+    else:
+        yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
+    if correct_lam or ratio_minmax is not None:
+        bbox_area = (yu - yl) * (xu - xl)
+        lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
+    return (yl, yu, xl, xu), lam
+
+
+class Mixup:
+    """ Mixup/Cutmix that applies different params to each element or whole batch
+
+    Args:
+        mixup_alpha (float): mixup alpha value, mixup is active if > 0.
+        cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
+        cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
+        prob (float): probability of applying mixup or cutmix per batch or element
+        switch_prob (float): probability of switching to cutmix instead of mixup when both are active
+        mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
+        correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
+        label_smoothing (float): apply label smoothing to the mixed target tensor
+        num_classes (int): number of classes for target
+    """
+
+    def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
+                 mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000, mix_steps=0.):
+        self.mixup_alpha = mixup_alpha
+        self.cutmix_alpha = cutmix_alpha
+        self.cutmix_minmax = cutmix_minmax
+        if self.cutmix_minmax is not None:
+            assert len(self.cutmix_minmax) == 2
+            # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
+            self.cutmix_alpha = 1.0
+        self.mix_prob = prob
+        self.switch_prob = switch_prob
+        self.label_smoothing = label_smoothing
+        self.num_classes = num_classes
+        self.mode = mode
+        self.correct_lam = correct_lam  # correct lambda based on clipped area for cutmix
+        self.mixup_enabled = True  # set to false to disable mixing (intended tp be set by train loop)
+        self.mix_steps = int(mix_steps)
+        print(f"==> self.mix_steps = {mix_steps}")
+        self.mix_step = 0.
+        self.print = P.Print()
+
+    def _params_per_elem(self, batch_size):
+        """_params_per_elem"""
+        lam = np.ones(batch_size, dtype=np.float32)
+        use_cutmix = np.zeros(batch_size, dtype=np.bool)
+        if self.mixup_enabled:
+            if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+                use_cutmix = np.random.rand(batch_size) < self.switch_prob
+                lam_mix = np.where(
+                    use_cutmix,
+                    np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
+                    np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
+            elif self.mixup_alpha > 0.:
+                lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
+            elif self.cutmix_alpha > 0.:
+                use_cutmix = np.ones(batch_size, dtype=np.bool)
+                lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
+            else:
+                assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+            lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
+        return lam, use_cutmix
+
+    def _params_per_batch(self):
+        """_params_per_batch"""
+        lam = 1.
+        use_cutmix = False
+        if self.mixup_enabled and np.random.rand() < self.mix_prob:
+            if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+                use_cutmix = np.random.rand() < self.switch_prob
+                lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
+                    np.random.beta(self.mixup_alpha, self.mixup_alpha)
+            elif self.mixup_alpha > 0.:
+                lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
+            elif self.cutmix_alpha > 0.:
+                use_cutmix = True
+                lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
+            else:
+                assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+            lam = float(lam_mix)
+        return lam, use_cutmix
+
+    def _mix_elem(self, x):
+        """_mix_elem"""
+        batch_size = len(x)
+        lam_batch, use_cutmix = self._params_per_elem(batch_size)
+        x_orig = x.clone()  # need to keep an unmodified original for mixing source
+        for i in range(batch_size):
+            j = batch_size - i - 1
+            lam = lam_batch[i]
+            if lam != 1.:
+                if use_cutmix[i]:
+                    (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+                        x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+                    x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
+                    lam_batch[i] = lam
+                else:
+                    x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+        return P.ExpandDims()(Tensor(lam_batch, dtype=mstype.float32), 1)
+
+    def _mix_pair(self, x):
+        """_mix_pair"""
+        batch_size = len(x)
+        lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
+        x_orig = x.clone()  # need to keep an unmodified original for mixing source
+        for i in range(batch_size // 2):
+            j = batch_size - i - 1
+            lam = lam_batch[i]
+            if lam != 1.:
+                if use_cutmix[i]:
+                    (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+                        x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+                    x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
+                    x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
+                    lam_batch[i] = lam
+                else:
+                    x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+                    x[j] = x[j] * lam + x_orig[i] * (1 - lam)
+        lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
+        return P.ExpandDims()(Tensor(lam_batch, dtype=mstype.float32), 1)
+
+    def _mix_batch(self, x):
+        """_mix_batch"""
+        lam, use_cutmix = self._params_per_batch()
+        if lam == 1.:
+            return 1.
+        if use_cutmix:
+            (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+                x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+            x[:, :, yl:yh, xl:xh] = np.flip(x, axis=0)[:, :, yl:yh, xl:xh]
+        else:
+            x_flipped = np.flip(x, axis=0) * (1. - lam)
+            x *= lam
+            x += x_flipped
+        return lam
+
+    def __call__(self, x, target):
+        """mix target"""
+        # the same to image, label
+        if self.mix_steps != 0 and self.mixup_enabled:
+            self.mix_step += 1
+            if self.mix_step == self.mix_steps:
+                self.mixup_enabled = False
+        assert len(x) % 2 == 0, 'Batch size should be even when using this'
+        if self.mode == 'elem':
+            lam = self._mix_elem(x)
+        elif self.mode == 'pair':
+            lam = self._mix_pair(x)
+        else:
+            lam = self._mix_batch(x)
+        target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
+        return x.astype(np.float32), target.astype(np.float32)
diff --git a/research/cv/cait/src/data/augment/random_erasing.py b/research/cv/cait/src/data/augment/random_erasing.py
new file mode 100644
index 0000000000000000000000000000000000000000..6430b302ed7875920d32f734d984a3bb577e2a0a
--- /dev/null
+++ b/research/cv/cait/src/data/augment/random_erasing.py
@@ -0,0 +1,113 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" Random Erasing (Cutout)
+
+Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
+Copyright Zhun Zhong & Liang Zheng
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+import random
+
+import numpy as np
+
+
+def _get_pixels(per_pixel, rand_color, patch_size, dtype=np.float32):
+    """_get_pixels"""
+    if per_pixel:
+        func = np.random.normal(size=patch_size).astype(dtype)
+    elif rand_color:
+        func = np.random.normal(size=(patch_size[0], 1, 1)).astype(dtype)
+    else:
+        func = np.zeros((patch_size[0], 1, 1), dtype=dtype)
+    return func
+
+
+class RandomErasing:
+    """ Randomly selects a rectangle region in an image and erases its pixels.
+        'Random Erasing Data Augmentation' by Zhong et al.
+        See https://arxiv.org/pdf/1708.04896.pdf
+
+        This variant of RandomErasing is intended to be applied to either a batch
+        or single image tensor after it has been normalized by dataset mean and std.
+    Args:
+         probability: Probability that the Random Erasing operation will be performed.
+         min_area: Minimum percentage of erased area wrt input image area.
+         max_area: Maximum percentage of erased area wrt input image area.
+         min_aspect: Minimum aspect ratio of erased area.
+         mode: pixel color mode, one of 'const', 'rand', or 'pixel'
+            'const' - erase block is constant color of 0 for all channels
+            'rand'  - erase block is same per-channel random (normal) color
+            'pixel' - erase block is per-pixel random (normal) color
+        max_count: maximum number of erasing blocks per image, area per box is scaled by count.
+            per-image count is randomly chosen between 1 and this value.
+    """
+
+    def __init__(self, probability=0.5, min_area=0.02, max_area=1 / 3, min_aspect=0.3,
+                 max_aspect=None, mode='const', min_count=1, max_count=None, num_splits=0):
+        self.probability = probability
+        self.min_area = min_area
+        self.max_area = max_area
+        max_aspect = max_aspect or 1 / min_aspect
+        self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
+        self.min_count = min_count
+        self.max_count = max_count or min_count
+        self.num_splits = num_splits
+        mode = mode.lower()
+        self.rand_color = False
+        self.per_pixel = False
+        if mode == 'rand':
+            self.rand_color = True  # per block random normal
+        elif mode == 'pixel':
+            self.per_pixel = True  # per pixel random normal
+        else:
+            assert not mode or mode == 'const'
+
+    def _erase(self, img, chan, img_h, img_w, dtype):
+        """_erase"""
+        if random.random() > self.probability:
+            pass
+        else:
+            area = img_h * img_w
+            count = self.min_count if self.min_count == self.max_count else \
+                random.randint(self.min_count, self.max_count)
+            for _ in range(count):
+                for _ in range(10):
+                    target_area = random.uniform(self.min_area, self.max_area) * area / count
+                    aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+                    h = int(round(math.sqrt(target_area * aspect_ratio)))
+                    w = int(round(math.sqrt(target_area / aspect_ratio)))
+                    if w < img_w and h < img_h:
+                        top = random.randint(0, img_h - h)
+                        left = random.randint(0, img_w - w)
+                        img[:, top:top + h, left:left + w] = _get_pixels(
+                            self.per_pixel, self.rand_color, (chan, h, w),
+                            dtype=dtype)
+                        break
+        return img
+
+    def __call__(self, x):
+        """RandomErasing apply"""
+        if len(x.shape) == 3:
+            output = self._erase(x, *x.shape, x.dtype)
+        else:
+            output = np.zeros_like(x)
+            batch_size, chan, img_h, img_w = x.shape
+            # skip first slice of batch if num_splits is set (for clean portion of samples)
+            batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
+            for i in range(batch_start, batch_size):
+                output[i] = self._erase(x[i], chan, img_h, img_w, x.dtype)
+        return output
diff --git a/research/cv/cait/src/data/augment/transforms.py b/research/cv/cait/src/data/augment/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a920a70e3ba8d8d57107f7d4fba928c8b97cd17
--- /dev/null
+++ b/research/cv/cait/src/data/augment/transforms.py
@@ -0,0 +1,270 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""image transformer"""
+import math
+import random
+import warnings
+
+import numpy as np
+from PIL import Image
+from PIL import ImageEnhance, ImageOps
+
+
+class ShearX:
+    def __init__(self, fillcolor=(128, 128, 128)):
+        self.fillcolor = fillcolor
+
+    def __call__(self, x, magnitude):
+        return x.transform(
+            x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
+            Image.BICUBIC, fillcolor=self.fillcolor)
+
+
+class ShearY:
+    def __init__(self, fillcolor=(128, 128, 128)):
+        self.fillcolor = fillcolor
+
+    def __call__(self, x, magnitude):
+        return x.transform(
+            x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
+            Image.BICUBIC, fillcolor=self.fillcolor)
+
+
+class TranslateX:
+    def __init__(self, fillcolor=(128, 128, 128)):
+        self.fillcolor = fillcolor
+
+    def __call__(self, x, magnitude):
+        return x.transform(
+            x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0),
+            fillcolor=self.fillcolor)
+
+
+class TranslateY:
+    def __init__(self, fillcolor=(128, 128, 128)):
+        self.fillcolor = fillcolor
+
+    def __call__(self, x, magnitude):
+        return x.transform(
+            x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])),
+            fillcolor=self.fillcolor)
+
+
+class Rotate:
+    # from https://stackoverflow.com/questions/
+    # 5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
+    def __call__(self, x, magnitude):
+        rot = x.convert("RGBA").rotate(magnitude)
+        return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(x.mode)
+
+
+class Color:
+    def __call__(self, x, magnitude):
+        return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1]))
+
+
+class Posterize:
+    def __call__(self, x, magnitude):
+        return ImageOps.posterize(x, magnitude)
+
+
+class Solarize:
+    def __call__(self, x, magnitude):
+        return ImageOps.solarize(x, magnitude)
+
+
+class Contrast:
+    def __call__(self, x, magnitude):
+        return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1]))
+
+
+class Sharpness:
+    def __call__(self, x, magnitude):
+        return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1]))
+
+
+class Brightness:
+    def __call__(self, x, magnitude):
+        return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1]))
+
+
+class AutoContrast:
+    def __call__(self, x, magnitude):
+        return ImageOps.autocontrast(x)
+
+
+class Equalize:
+    def __call__(self, x, magnitude):
+        return ImageOps.equalize(x)
+
+
+class Invert:
+    def __call__(self, x, magnitude):
+        return ImageOps.invert(x)
+
+
+class ToNumpy:
+
+    def __call__(self, pil_img):
+        np_img = np.array(pil_img, dtype=np.uint8)
+        if np_img.ndim < 3:
+            np_img = np.expand_dims(np_img, axis=-1)
+        np_img = np.rollaxis(np_img, 2)  # HWC to CHW
+        return np_img
+
+
+_pil_interpolation_to_str = {
+    Image.NEAREST: 'PIL.Image.NEAREST',
+    Image.BILINEAR: 'PIL.Image.BILINEAR',
+    Image.BICUBIC: 'PIL.Image.BICUBIC',
+    Image.LANCZOS: 'PIL.Image.LANCZOS',
+    Image.HAMMING: 'PIL.Image.HAMMING',
+    Image.BOX: 'PIL.Image.BOX',
+}
+
+
+def _pil_interp(method):
+    """_pil_interp"""
+    if method == 'bicubic':
+        output = Image.BICUBIC
+    elif method == 'lanczos':
+        output = Image.LANCZOS
+    elif method == 'hamming':
+        output = Image.HAMMING
+    else:
+        # default bilinear, do we want to allow nearest?
+        output = Image.BILINEAR
+    return output
+
+
+_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+
+
+class Resize:
+    """Resize"""
+
+    def __init__(self, size, interpolation='bilinear'):
+        if isinstance(size, (list, tuple)):
+            self.size = tuple(size)
+        else:
+            self.size = (size, size)
+        self.interpolation = _pil_interp(interpolation)
+
+    def __call__(self, img):
+        img = img.resize(self.size, self.interpolation)
+        return img
+
+
+class RandomResizedCropAndInterpolation:
+    """Crop the given PIL Image to random size and aspect ratio with random interpolation.
+
+    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+    is finally resized to given size.
+    This is popularly used to train the Inception networks.
+
+    Args:
+        size: expected output size of each edge
+        scale: range of size of the origin size cropped
+        ratio: range of aspect ratio of the origin aspect ratio cropped
+        interpolation: Default: PIL.Image.BILINEAR
+    """
+
+    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
+                 interpolation='bilinear'):
+        if isinstance(size, (list, tuple)):
+            self.size = tuple(size)
+        else:
+            self.size = (size, size)
+        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+            warnings.warn("range should be of kind (min, max)")
+
+        if interpolation == 'random':
+            self.interpolation = _RANDOM_INTERPOLATION
+        else:
+            self.interpolation = _pil_interp(interpolation)
+        self.scale = scale
+        self.ratio = ratio
+
+    @staticmethod
+    def get_params(img, scale, ratio):
+        """Get parameters for ``crop`` for a random sized crop.
+
+        Args:
+            img (PIL Image): Image to be cropped.
+            scale (tuple): range of size of the origin size cropped
+            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+
+        Returns:
+            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+                sized crop.
+        """
+        area = img.size[0] * img.size[1]
+
+        for _ in range(10):
+            target_area = random.uniform(*scale) * area
+            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+            aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+            w = int(round(math.sqrt(target_area * aspect_ratio)))
+            h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if w <= img.size[0] and h <= img.size[1]:
+                i = random.randint(0, img.size[1] - h)
+                j = random.randint(0, img.size[0] - w)
+                return j, i, w, h
+
+        # Fallback to central crop
+        in_ratio = img.size[0] / img.size[1]
+        if in_ratio < min(ratio):
+            w = img.size[0]
+            h = int(round(w / min(ratio)))
+        elif in_ratio > max(ratio):
+            h = img.size[1]
+            w = int(round(h * max(ratio)))
+        else:  # whole image
+            w = img.size[0]
+            h = img.size[1]
+        i = (img.size[1] - h) // 2
+        j = (img.size[0] - w) // 2
+        return j, i, w, h
+
+    def __call__(self, img):
+        """
+        Args:
+            img (PIL Image): Image to be cropped and resized.
+
+        Returns:
+            PIL Image: Randomly cropped and resized image.
+        """
+        left, top, width, height = self.get_params(img, self.scale, self.ratio)
+        if isinstance(self.interpolation, (tuple, list)):
+            interpolation = random.choice(self.interpolation)
+        else:
+            interpolation = self.interpolation
+        img = img.crop((left, top, left + width, top + height))
+        img = img.resize(self.size, interpolation)
+        return img
+
+    def __repr__(self):
+        if isinstance(self.interpolation, (tuple, list)):
+            interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
+        else:
+            interpolate_str = _pil_interpolation_to_str[self.interpolation]
+        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+        format_string += ', interpolation={0})'.format(interpolate_str)
+        return format_string
diff --git a/research/cv/cait/src/data/data_utils/__init__.py b/research/cv/cait/src/data/data_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/research/cv/cait/src/data/data_utils/moxing_adapter.py b/research/cv/cait/src/data/data_utils/moxing_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3f1652470395edf49a6d6d0c84f4c9d2f3f3b55
--- /dev/null
+++ b/research/cv/cait/src/data/data_utils/moxing_adapter.py
@@ -0,0 +1,72 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Moxing adapter for ModelArts"""
+
+import os
+
+_global_sync_count = 0
+
+
+def get_device_id():
+    device_id = os.getenv('DEVICE_ID', '0')
+    return int(device_id)
+
+
+def get_device_num():
+    device_num = os.getenv('RANK_SIZE', '1')
+    return int(device_num)
+
+
+def get_rank_id():
+    global_rank_id = os.getenv('RANK_ID', '0')
+    return int(global_rank_id)
+
+
+def get_job_id():
+    job_id = os.getenv('JOB_ID')
+    job_id = job_id if job_id != "" else "default"
+    return job_id
+
+
+def sync_data(from_path, to_path, threads=16):
+    """
+    Download data from remote obs to local directory if the first url is remote url and the second one is local path
+    Upload data from local directory to remote obs in contrast.
+    """
+    import moxing as mox
+    import time
+    global _global_sync_count
+    sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
+    _global_sync_count += 1
+
+    # Each server contains 8 devices as most.
+    if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
+        print("from path: ", from_path)
+        print("to path: ", to_path)
+        mox.file.copy_parallel(from_path, to_path, threads=threads)
+        print("===finish data synchronization===")
+        try:
+            os.mknod(sync_lock)
+        except IOError:
+            pass
+        print("===save flag===")
+
+    while True:
+        if os.path.exists(sync_lock):
+            break
+        time.sleep(1)
+
+    print("Finish sync data from {} to {}.".format(from_path, to_path))
diff --git a/research/cv/cait/src/data/imagenet.py b/research/cv/cait/src/data/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1539daa2c63b26d34e290c4abf9f9a5606c5602f
--- /dev/null
+++ b/research/cv/cait/src/data/imagenet.py
@@ -0,0 +1,160 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Data operations, will be used in train.py and eval.py
+"""
+import os
+
+import mindspore.common.dtype as mstype
+import mindspore.dataset as ds
+import mindspore.dataset.transforms.c_transforms as C
+import mindspore.dataset.vision.c_transforms as vision
+import mindspore.dataset.vision.py_transforms as py_vision
+
+from src.data.augment.auto_augment import rand_augment_transform
+from src.data.augment.mixup import Mixup
+from src.data.augment.random_erasing import RandomErasing
+from src.data.augment.transforms import RandomResizedCropAndInterpolation, Resize, _pil_interp
+from .data_utils.moxing_adapter import sync_data
+
+
+class ImageNet:
+    """ImageNet Define"""
+
+    def __init__(self, args, training=True):
+        if args.run_modelarts:
+            print('Download data.')
+            local_data_path = '/cache/data'
+            sync_data(args.data_url, local_data_path, threads=128)
+            print('Create train and evaluate dataset.')
+            train_dir = os.path.join(local_data_path, "train")
+            val_ir = os.path.join(local_data_path, "val")
+            self.train_dataset = create_dataset_imagenet(train_dir, training=True, args=args)
+            self.val_dataset = create_dataset_imagenet(val_ir, training=False, args=args)
+        else:
+            train_dir = os.path.join(args.data_url, "train")
+            val_ir = os.path.join(args.data_url, "val")
+            if training:
+                self.train_dataset = create_dataset_imagenet(train_dir, training=True, args=args)
+            self.val_dataset = create_dataset_imagenet(val_ir, training=False, args=args)
+
+
+def create_dataset_imagenet(dataset_dir, args, repeat_num=1, training=True):
+    """
+    create a train or eval imagenet2012 dataset for SwinTransformer
+
+    Args:
+        dataset_dir(string): the path of dataset.
+        do_train(bool): whether dataset is used for train or eval.
+        repeat_num(int): the repeat times of dataset. Default: 1
+
+    Returns:
+        dataset
+    """
+
+    device_num, rank_id = _get_rank_info()
+    shuffle = bool(training)
+    if device_num == 1 or not training:
+        data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers,
+                                         shuffle=shuffle)
+    else:
+        data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers, shuffle=shuffle,
+                                         num_shards=device_num, shard_id=rank_id)
+
+    image_size = args.image_size
+
+    # define map operations
+    # BICUBIC: 3
+    mean = [0.485, 0.456, 0.406]
+    std = [0.229, 0.224, 0.225]
+    if training:
+        aa_params = dict(
+            translate_const=int(image_size * 0.45),
+            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
+        )
+        interpolation = args.interpolation
+        auto_augment = args.auto_augment
+        if interpolation != "random":
+            aa_params["interpolation"] = _pil_interp(interpolation)
+        assert auto_augment.startswith('rand')
+        transform_img = [
+            vision.Decode(),
+            py_vision.ToPIL(),
+            RandomResizedCropAndInterpolation(size=args.image_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
+                                              interpolation=interpolation),
+            py_vision.RandomHorizontalFlip(prob=0.5),
+        ]
+        transform_img += [rand_augment_transform(auto_augment, aa_params)]
+        transform_img += [
+            py_vision.ToTensor(),
+            py_vision.Normalize(mean=mean, std=std)]
+        if args.re_prob > 0.:
+            transform_img += [RandomErasing(args.re_prob, mode=args.re_mode, max_count=args.re_count)]
+    else:
+        # test transform complete
+        transform_img = [
+            vision.Decode(),
+            py_vision.ToPIL(),
+            Resize(int(args.image_size / args.crop_pct), interpolation="bicubic"),
+            py_vision.CenterCrop(image_size),
+            py_vision.ToTensor(),
+            py_vision.Normalize(mean=mean, std=std)
+        ]
+
+    transform_label = C.TypeCast(mstype.int32)
+
+    data_set = data_set.map(input_columns="image", num_parallel_workers=args.num_parallel_workers,
+                            operations=transform_img)
+    data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
+                            operations=transform_label)
+    if (args.mix_up or args.cut_mix) and not training:
+        # if use mixup and not training(False), one hot val data label
+        one_hot = C.OneHot(num_classes=args.num_classes)
+        data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
+                                operations=one_hot)
+    # apply batch operations
+    data_set = data_set.batch(args.batch_size, drop_remainder=True,
+                              num_parallel_workers=args.num_parallel_workers)
+
+    if (args.mix_up or args.cut_mix) and training:
+        mixup_fn = Mixup(
+            mixup_alpha=args.mix_up, cutmix_alpha=args.cutmix, cutmix_minmax=None,
+            prob=args.mixup_prob, switch_prob=args.switch_prob, mode=args.mixup_mode,
+            label_smoothing=args.label_smoothing, num_classes=args.num_classes,
+            mix_steps=args.mixup_off_epoch * data_set.get_dataset_size())
+
+        data_set = data_set.map(operations=mixup_fn, input_columns=["image", "label"],
+                                num_parallel_workers=args.num_parallel_workers)
+
+    # apply dataset repeat operation
+    data_set = data_set.repeat(repeat_num)
+
+    return data_set
+
+
+def _get_rank_info():
+    """
+    get rank size and rank id
+    """
+    rank_size = int(os.environ.get("RANK_SIZE", 1))
+
+    if rank_size > 1:
+        from mindspore.communication.management import get_rank, get_group_size
+        rank_size = get_group_size()
+        rank_id = get_rank()
+    else:
+        rank_size = rank_id = None
+
+    return rank_size, rank_id
diff --git a/research/cv/cait/src/models/__init__.py b/research/cv/cait/src/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0e4da858b6405b562ad537641d60361c3f47b6b
--- /dev/null
+++ b/research/cv/cait/src/models/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init models"""
+from .cait import cait_XXS24_224
+
+__all__ = [
+    "cait_XXS24_224"
+]
diff --git a/research/cv/cait/src/models/cait/__init__.py b/research/cv/cait/src/models/cait/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1202b133456dd6e9c9ee2205da5384396183e3fa
--- /dev/null
+++ b/research/cv/cait/src/models/cait/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from .cait_models import cait_XXS24_224
+
+__all__ = [
+    'cait_XXS24_224',
+
+]
diff --git a/research/cv/cait/src/models/cait/cait_models.py b/research/cv/cait/src/models/cait/cait_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..8adece04fe06f7746e055d0a48a85c557caa9bba
--- /dev/null
+++ b/research/cv/cait/src/models/cait/cait_models.py
@@ -0,0 +1,326 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import numpy as np
+import mindspore.common.initializer as weight_init
+import mindspore.nn as nn
+import mindspore.ops.operations as P
+from mindspore import Parameter
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import ops
+
+from src.models.cait.misc import to_2tuple, Identity, DropPath1D
+
+
+class PatchEmbed(nn.Cell):
+    """ 2D Image to Patch Embedding
+    """
+    def __init__(self, image_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+        super().__init__()
+        image_size = to_2tuple(image_size)
+        patch_size = to_2tuple(patch_size)
+        self.in_chans = in_chans
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+        self.num_patches = self.grid_size[0] * self.grid_size[1]
+        self.flatten = flatten
+        self.proj = nn.Dense(in_chans * patch_size[0] * patch_size[1], embed_dim, has_bias=False)
+        self.norm = norm_layer((embed_dim,), epsilon=1e-8) if norm_layer else Identity()
+
+    def construct(self, x):
+        B, C, H, W = x.shape
+        x = ops.Reshape()(x, (B, C, H // self.patch_size[0], self.patch_size[0], W // self.patch_size[1],
+                              self.patch_size[1]))
+        x = ops.Transpose()(x, (0, 2, 4, 1, 3, 5))
+        x = ops.Reshape()(x, (B, self.num_patches, -1))
+        x = self.proj(x)
+        x = self.norm(x)
+        return x
+
+
+class Mlp(nn.Cell):
+    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+    """
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Dense(in_channels=in_features, out_channels=hidden_features, has_bias=False)
+        self.act = act_layer()
+        self.fc2 = nn.Dense(in_channels=hidden_features, out_channels=out_features, has_bias=False)
+        self.drop = nn.Dropout(keep_prob=1.0 - drop)
+
+    def construct(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Class_Attention(nn.Cell):
+    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+    # with slight modifications to do CA
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.q = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
+        self.k = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
+        self.v = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
+        self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_drop)
+        self.proj = nn.Dense(in_channels=dim, out_channels=dim)
+        self.proj_drop = nn.Dropout(keep_prob=1.0 - proj_drop)
+        self.softmax = nn.Softmax(axis=-1)
+
+    def construct(self, x):
+        B, N, C = x.shape
+        q = P.Reshape()(self.q(x[:, 0]), (B, 1, self.num_heads, C // self.num_heads))
+        q = P.Transpose()(q, (0, 2, 1, 3))
+        k = P.Reshape()(self.k(x), (B, N, self.num_heads, C // self.num_heads))
+        k = P.Transpose()(k, (0, 2, 3, 1))
+        q = q * self.scale
+        v = P.Reshape()(self.v(x), (B, N, self.num_heads, C // self.num_heads))
+        v = P.Transpose()(v, (0, 2, 1, 3))
+
+        attn = ops.BatchMatMul()(q, k)
+        attn = self.softmax(attn)
+        attn = self.attn_drop(attn)
+
+        x_cls = ops.BatchMatMul()(attn, v)
+        x_cls = P.Transpose()(x_cls, (0, 2, 1, 3))
+        x_cls = P.Reshape()(x_cls, (B, 1, C))
+        x_cls = self.proj(x_cls)
+        x_cls = self.proj_drop(x_cls)
+
+        return x_cls
+
+
+class LayerScale_Block_CA(nn.Cell):
+    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+    # with slight modifications to add CA and LayerScale
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Class_Attention,
+                 Mlp_block=Mlp, init_values=1e-4):
+        super().__init__()
+        self.norm1 = norm_layer((dim,), epsilon=1e-8)
+        self.attn = Attention_block(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        self.drop_path = DropPath1D(drop_path) if drop_path > 0. else Identity()
+        self.norm2 = norm_layer((dim,), epsilon=1e-8)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+        self.gamma_1 = Parameter(Tensor(init_values * np.ones([1, 1, dim]), mstype.float32))
+        self.gamma_2 = Parameter(Tensor(init_values * np.ones([1, 1, dim]), mstype.float32))
+
+    def construct(self, x, x_cls):
+        u = P.Concat(1)((x_cls, x))
+        x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u)))
+
+        x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls)))
+        return x_cls
+
+
+class Attention_talking_head(nn.Cell):
+    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+    # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+        super().__init__()
+
+        self.num_heads = num_heads
+
+        head_dim = dim // num_heads
+
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.q = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
+        self.k = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
+        self.v = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
+
+        self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_drop)
+
+        self.proj = nn.Dense(in_channels=dim, out_channels=dim, has_bias=False)
+
+        self.proj_l = nn.Dense(in_channels=num_heads, out_channels=num_heads, has_bias=False)
+        self.proj_w = nn.Dense(in_channels=num_heads, out_channels=num_heads, has_bias=False)
+
+        self.proj_drop = nn.Dropout(keep_prob=1.0 - proj_drop)
+        self.softmax = nn.Softmax(axis=-1)
+
+    def construct(self, x):
+        B_, N, C = x.shape
+        q = ops.Reshape()(self.q(x), (B_, N, self.num_heads, C // self.num_heads))
+        q = ops.Transpose()(q, (0, 2, 1, 3)) * self.scale
+        k = ops.Reshape()(self.k(x), (B_, N, self.num_heads, C // self.num_heads))
+        k = ops.Transpose()(k, (0, 2, 3, 1))
+        v = ops.Reshape()(self.v(x), (B_, N, self.num_heads, C // self.num_heads))
+        v = ops.Transpose()(v, (0, 2, 1, 3))
+        attn = ops.BatchMatMul()(q, k)
+        attn = P.Transpose()(self.proj_l(P.Transpose()(attn, (0, 2, 3, 1,))), (0, 3, 1, 2))
+        attn = self.softmax(attn)
+        attn = P.Transpose()(self.proj_w(P.Transpose()(attn, (0, 2, 3, 1,))), (0, 3, 1, 2))
+        attn = self.attn_drop(attn)
+        x = P.Transpose()(ops.BatchMatMul()(attn, v), (0, 2, 1, 3))
+        x = P.Reshape()(x, (B_, N, C))
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class LayerScale_Block(nn.Cell):
+    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+    # with slight modifications to add layerScale
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=Attention_talking_head,
+                 Mlp_block=Mlp, init_values=1e-4):
+        super().__init__()
+        self.norm1 = norm_layer((dim,), epsilon=1e-8)
+        self.attn = Attention_block(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        self.drop_path = DropPath1D(drop_path) if drop_path > 0. else Identity()
+        self.norm2 = norm_layer((dim,), epsilon=1e-8)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+        self.gamma_1 = Parameter(Tensor(init_values * np.ones([1, 1, dim]), mstype.float32))
+        self.gamma_2 = Parameter(Tensor(init_values * np.ones([1, 1, dim]), mstype.float32))
+
+    def construct(self, x):
+        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
+        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+        return x
+
+
+class cait_models(nn.Cell):
+    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+    # with slight modifications to adapt to our cait models
+    def __init__(self, image_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+                 drop_path_rate=0., norm_layer=nn.LayerNorm,
+                 block_layers=LayerScale_Block,
+                 block_layers_token=LayerScale_Block_CA,
+                 Patch_layer=PatchEmbed, act_layer=nn.GELU,
+                 Attention_block=Attention_talking_head, Mlp_block=Mlp,
+                 init_scale=1e-4,
+                 Attention_block_token_only=Class_Attention,
+                 Mlp_block_token_only=Mlp,
+                 depth_token_only=2,
+                 mlp_ratio_clstk=4.0):
+        super().__init__()
+
+        self.num_classes = num_classes
+        self.num_features = self.embed_dim = embed_dim
+
+        self.patch_embed = Patch_layer(
+            image_size=image_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = Parameter(Tensor(np.zeros([1, 1, embed_dim]), mstype.float32))
+        self.pos_embed = Parameter(Tensor(np.zeros([1, num_patches, embed_dim]), mstype.float32))
+        self.pos_drop = nn.Dropout(keep_prob=1.0 - drop_rate)
+
+        dpr = [drop_path_rate for _ in range(depth)]
+        self.blocks = nn.CellList([
+            block_layers(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+                act_layer=act_layer, Attention_block=Attention_block, Mlp_block=Mlp_block, init_values=init_scale)
+            for i in range(depth)])
+
+        self.blocks_token_only = nn.CellList([
+            block_layers_token(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
+                act_layer=act_layer, Attention_block=Attention_block_token_only,
+                Mlp_block=Mlp_block_token_only, init_values=init_scale)
+            for _ in range(depth_token_only)])
+
+        self.norm = norm_layer((embed_dim,), epsilon=1e-8)
+        channel_mask = np.zeros([1, num_patches + 1, 1])
+        channel_mask[0] = 1
+        self.channel_mask = Tensor(channel_mask, mstype.float32)
+
+        self.head = nn.Dense(in_channels=embed_dim, out_channels=num_classes, has_bias=False) if num_classes > 0 else \
+            Identity()
+
+        self.pos_embed.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
+                                                        self.pos_embed.shape,
+                                                        self.pos_embed.dtype))
+        self.cls_token.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
+                                                        self.cls_token.shape,
+                                                        self.cls_token.dtype))
+        self.init_weights()
+
+    def init_weights(self):
+        """init_weights"""
+        for _, cell in self.cells_and_names():
+            if isinstance(cell, nn.Dense):
+                cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
+                                                             cell.weight.shape,
+                                                             cell.weight.dtype))
+                if isinstance(cell, nn.Dense) and cell.bias is not None:
+                    cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
+                                                               cell.bias.shape,
+                                                               cell.bias.dtype))
+            elif isinstance(cell, nn.LayerNorm):
+                cell.gamma.set_data(weight_init.initializer(weight_init.One(),
+                                                            cell.gamma.shape,
+                                                            cell.gamma.dtype))
+                cell.beta.set_data(weight_init.initializer(weight_init.Zero(),
+                                                           cell.beta.shape,
+                                                           cell.beta.dtype))
+
+    def forward_features(self, x):
+        B = x.shape[0]
+        x = self.patch_embed(x)
+        cls_tokens = P.Tile()(self.cls_token, (B, 1, 1))
+        x = x + self.pos_embed
+        x = self.pos_drop(x)
+
+        for blk in self.blocks:
+            x = blk(x)
+
+        for blk in self.blocks_token_only:
+            cls_tokens = blk(x, cls_tokens)
+        x = P.Concat(1)((cls_tokens, x))
+
+        x = self.norm(x)
+        return x[:, 0]
+
+    def construct(self, x):
+        x = self.forward_features(x)
+
+        x = self.head(x)
+
+        return x
+
+
+def cait_XXS24_224(args):
+    num_classes = args.num_classes
+    drop_path_rate = args.drop_path_rate
+    image_size = args.image_size
+    assert image_size == 224
+    model = cait_models(
+        image_size=image_size, patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=False,
+        norm_layer=nn.LayerNorm,
+        init_scale=1e-5,
+        depth_token_only=2, num_classes=num_classes, drop_path_rate=drop_path_rate)
+
+    return model
diff --git a/research/cv/cait/src/models/cait/misc.py b/research/cv/cait/src/models/cait/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8d708fcf709f8c67c5e821b4990c4f8fb6e3a26
--- /dev/null
+++ b/research/cv/cait/src/models/cait/misc.py
@@ -0,0 +1,67 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Misc function for cait"""
+import collections.abc
+from itertools import repeat
+
+import numpy as np
+import mindspore.nn as nn
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import ops
+
+
+def _ntuple(n):
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+
+    return parse
+
+
+to_2tuple = _ntuple(2)
+
+
+class Identity(nn.Cell):
+    """Identity"""
+
+    def construct(self, x):
+        return x
+
+
+class DropPath(nn.Cell):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob, ndim):
+        super(DropPath, self).__init__()
+        self.drop = nn.Dropout(keep_prob=1 - drop_prob)
+        shape = (1,) + (1,) * (ndim + 1)
+        self.ndim = ndim
+        self.mask = Tensor(np.ones(shape), dtype=mstype.float32)
+
+    def construct(self, x):
+        if not self.training:
+            return x
+        mask = ops.Tile()(self.mask, (x.shape[0],) + (1,) * (self.ndim + 1))
+        out = self.drop(mask)
+        out = out * x
+        return out
+
+
+class DropPath1D(DropPath):
+    def __init__(self, drop_prob):
+        super(DropPath1D, self).__init__(drop_prob=drop_prob, ndim=1)
diff --git a/research/cv/cait/src/tools/__init__.py b/research/cv/cait/src/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/research/cv/cait/src/tools/callback.py b/research/cv/cait/src/tools/callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ce83c3d1dcac3293acc958f7d9decfc9984cd23
--- /dev/null
+++ b/research/cv/cait/src/tools/callback.py
@@ -0,0 +1,48 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""callback function"""
+
+from mindspore.train.callback import Callback
+
+from src.args import args
+
+
+class EvaluateCallBack(Callback):
+    """EvaluateCallBack"""
+
+    def __init__(self, model, eval_dataset, src_url, train_url, save_freq=50):
+        super(EvaluateCallBack, self).__init__()
+        self.model = model
+        self.eval_dataset = eval_dataset
+        self.src_url = src_url
+        self.train_url = train_url
+        self.save_freq = save_freq
+        self.best_acc = 0.
+
+    def epoch_end(self, run_context):
+        """
+            Test when epoch end, save best model with best.ckpt.
+        """
+        cb_params = run_context.original_args()
+        cur_epoch_num = cb_params.cur_epoch_num
+        result = self.model.eval(self.eval_dataset)
+        if result["acc"] > self.best_acc:
+            self.best_acc = result["acc"]
+        print("epoch: %s acc: %s, best acc is %s" %
+              (cb_params.cur_epoch_num, result["acc"], self.best_acc), flush=True)
+        if args.run_modelarts:
+            import moxing as mox
+            if cur_epoch_num % self.save_freq == 0:
+                mox.file.copy_parallel(src_url=self.src_url, dst_url=self.train_url)
diff --git a/research/cv/cait/src/tools/cell.py b/research/cv/cait/src/tools/cell.py
new file mode 100644
index 0000000000000000000000000000000000000000..21cec9575493496523813b3ce13e7d338fa90a54
--- /dev/null
+++ b/research/cv/cait/src/tools/cell.py
@@ -0,0 +1,60 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Functions of cells"""
+import mindspore.nn as nn
+from mindspore import dtype as mstype
+from mindspore.ops import functional as F
+
+from src.args import args
+
+
+class OutputTo16(nn.Cell):
+    "Wrap cell for amp. Cast network output back to float16"
+
+    def __init__(self, op):
+        super(OutputTo16, self).__init__(auto_prefix=False)
+        self._op = op
+
+    def construct(self, x):
+        return F.cast(self._op(x), mstype.float16)
+
+
+def do_keep_fp32(network, cell_types):
+    """Cast cell to fp32 if cell in cell_types"""
+    for name, cell in network.cells_and_names():
+        if isinstance(cell, cell_types):
+            cell.to_float(mstype.float32)
+            print(f'cast {name} to fp32')
+
+
+def cast_amp(net):
+    """cast network amp_level"""
+    assert args.amp_level in ("O0", "O1", "O2", "O3")
+    if args.amp_level == "O2":
+        cell_types = (nn.LayerNorm, nn.Softmax, nn.BatchNorm2d, nn.BatchNorm1d, nn.GELU, nn.Conv2d, nn.MaxPool2d,)
+        print(f"=> using amp_level {args.amp_level}\n"
+              f"=> change {args.arch}to fp16")
+        net.to_float(mstype.float16)
+        do_keep_fp32(net, cell_types)
+        print(f"cast {cell_types} to fp32 back")
+    elif args.amp_level == "O3":
+        print(f"=> using amp_level {args.amp_level}\n"
+              f"=> change {args.arch} to fp16")
+        net.to_float(mstype.float16)
+    else:
+        print(f"=> using amp_level {args.amp_level}")
+        args.loss_scale = 1.
+        args.is_dynamic_loss_scale = 0
+        print(f"=> When amp_level is O0, using fixed loss_scale with {args.loss_scale}")
diff --git a/research/cv/cait/src/tools/criterion.py b/research/cv/cait/src/tools/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa92b6199185b39e612a041a40b85755115a6714
--- /dev/null
+++ b/research/cv/cait/src/tools/criterion.py
@@ -0,0 +1,95 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""functions of criterion"""
+import mindspore.nn as nn
+from mindspore import Tensor
+from mindspore import ops
+from mindspore.common import dtype as mstype
+from mindspore.nn.loss.loss import _Loss
+from mindspore.ops import functional as F
+from mindspore.ops import operations as P
+
+
+class SoftTargetCrossEntropy(_Loss):
+    """SoftTargetCrossEntropy for MixUp Augment"""
+
+    def __init__(self):
+        super(SoftTargetCrossEntropy, self).__init__()
+        self.mean_ops = P.ReduceMean(keep_dims=False)
+        self.sum_ops = P.ReduceSum(keep_dims=False)
+        self.log_softmax = P.LogSoftmax()
+
+    def construct(self, logit, label):
+        logit = P.Cast()(logit, mstype.float32)
+        label = P.Cast()(label, mstype.float32)
+        loss = self.sum_ops(-label * self.log_softmax(logit), -1)
+        return self.mean_ops(loss)
+
+
+class CrossEntropySmooth(_Loss):
+    """CrossEntropy"""
+
+    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
+        super(CrossEntropySmooth, self).__init__()
+        self.onehot = P.OneHot()
+        self.sparse = sparse
+        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
+        self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
+        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
+        self.cast = ops.Cast()
+
+    def construct(self, logit, label):
+        if self.sparse:
+            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
+        logit = P.Cast()(logit, mstype.float32)
+        label = P.Cast()(label, mstype.float32)
+        loss2 = self.ce(logit, label)
+        return loss2
+
+
+def get_criterion(args):
+    """Get loss function from args.label_smooth and args.mix_up"""
+    assert args.label_smoothing >= 0. and args.label_smoothing <= 1.
+
+    if args.mix_up > 0. or args.cutmix > 0.:
+        print(25 * "=" + "Using MixBatch" + 25 * "=")
+        # smoothing is handled with mixup label transform
+        criterion = SoftTargetCrossEntropy()
+    elif args.label_smoothing > 0.:
+        print(25 * "=" + "Using label smoothing" + 25 * "=")
+        criterion = CrossEntropySmooth(sparse=True, reduction="mean",
+                                       smooth_factor=args.label_smoothing,
+                                       num_classes=args.num_classes)
+    else:
+        print(25 * "=" + "Using Simple CE" + 25 * "=")
+        criterion = CrossEntropySmooth(sparse=True, reduction="mean", num_classes=args.num_classes)
+
+    return criterion
+
+
+class NetWithLoss(nn.Cell):
+    """
+       NetWithLoss: Only support Network with Classfication
+    """
+
+    def __init__(self, model, criterion):
+        super(NetWithLoss, self).__init__()
+        self.model = model
+        self.criterion = criterion
+
+    def construct(self, data, label):
+        predict = self.model(data)
+        loss = self.criterion(predict, label)
+        return loss
diff --git a/research/cv/cait/src/tools/get_misc.py b/research/cv/cait/src/tools/get_misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba19dc0dfc60cf9c4495f6ee1de2cebe5a5d14c1
--- /dev/null
+++ b/research/cv/cait/src/tools/get_misc.py
@@ -0,0 +1,120 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""misc functions for program"""
+import os
+
+from mindspore import context
+from mindspore import nn
+from mindspore.communication.management import init, get_rank
+from mindspore.context import ParallelMode
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+
+from src import models, data
+from src.data.data_utils.moxing_adapter import sync_data
+from src.trainers import TrainClipGrad
+
+
+def set_device(args):
+    """Set device and ParallelMode(if device_num > 1)"""
+    rank = 0
+    # set context and device
+    device_target = args.device_target
+    device_num = int(os.environ.get("DEVICE_NUM", 1))
+
+    if device_target == "Ascend":
+        if device_num > 1:
+            context.set_context(device_id=int(os.environ["DEVICE_ID"]))
+            init(backend_name='hccl')
+            context.reset_auto_parallel_context()
+            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True)
+            rank = get_rank()
+        else:
+            context.set_context(device_id=args.device_id)
+    elif device_target == "GPU":
+        if device_num > 1:
+            init(backend_name='nccl')
+            context.reset_auto_parallel_context()
+            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True)
+            rank = get_rank()
+        else:
+            context.set_context(device_id=args.device_id)
+    else:
+        raise ValueError("Unsupported platform.")
+
+    return rank
+
+
+def get_dataset(args, training=True):
+    """"Get model according to args.set"""
+    print(f"=> Getting {args.set} dataset")
+    dataset = getattr(data, args.set)(args, training)
+
+    return dataset
+
+
+def get_model(args):
+    """"Get model according to args.arch"""
+    print("==> Creating model '{}'".format(args.arch))
+    model = models.__dict__[args.arch](args=args)
+
+    return model
+
+
+def pretrained(args, model):
+    """"Load pretrained weights if args.pretrained is given"""
+    if args.run_modelarts:
+        print('Download data.')
+        local_data_path = '/cache/weight'
+        name = args.pretrained.split('/')[-1]
+        path = f"/".join(args.pretrained.split("/")[:-1])
+        sync_data(path, local_data_path, threads=128)
+        args.pretrained = os.path.join(local_data_path, name)
+        print("=> loading pretrained weights from '{}'".format(args.pretrained))
+        param_dict = load_checkpoint(args.pretrained)
+        for key, value in param_dict.copy().items():
+            if 'head' in key:
+                if value.shape[0] != args.num_classes:
+                    print(f'==> removing {key} with shape {value.shape}')
+                    param_dict.pop(key)
+        load_param_into_net(model, param_dict)
+    elif os.path.isfile(args.pretrained):
+        print("=> loading pretrained weights from '{}'".format(args.pretrained))
+        param_dict = load_checkpoint(args.pretrained)
+        for key, value in param_dict.copy().items():
+            if 'head' in key:
+                if value.shape[0] != args.num_classes:
+                    print(f'==> removing {key} with shape {value.shape}')
+                    param_dict.pop(key)
+        load_param_into_net(model, param_dict)
+    else:
+        print("=> no pretrained weights found at '{}'".format(args.pretrained))
+
+
+def get_train_one_step(args, net_with_loss, optimizer):
+    """get_train_one_step cell"""
+    if args.is_dynamic_loss_scale:
+        print(f"=> Using DynamicLossScaleUpdateCell")
+        scale_sense = nn.wrap.loss_scale.DynamicLossScaleUpdateCell(loss_scale_value=2 ** 12, scale_factor=2,
+                                                                    scale_window=1000)
+    else:
+        print(f"=> Using FixedLossScaleUpdateCell, loss_scale_value:{args.loss_scale}")
+        scale_sense = nn.wrap.FixedLossScaleUpdateCell(loss_scale_value=args.loss_scale)
+    net_with_loss = TrainClipGrad(net_with_loss, optimizer, scale_sense=scale_sense,
+                                  clip_global_norm_value=args.clip_global_norm_value,
+                                  clip_global_norm=args.clip_global_norm)
+    print("clip_global_norm", args.clip_global_norm)
+    return net_with_loss
diff --git a/research/cv/cait/src/tools/optimizer.py b/research/cv/cait/src/tools/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa5c05e2792b9b33703e0661ed11a6520a5d6c1d
--- /dev/null
+++ b/research/cv/cait/src/tools/optimizer.py
@@ -0,0 +1,84 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Functions of optimizer"""
+
+from mindspore.nn.optim import AdamWeightDecay
+from mindspore.nn.optim.momentum import Momentum
+
+from .schedulers import get_policy
+
+
+def get_learning_rate(args, batch_num):
+    """Get learning rate"""
+    return get_policy(args.lr_scheduler)(args, batch_num)
+
+
+def get_optimizer(args, model, batch_num):
+    """Get optimizer for training"""
+    print(f"=> When using train_wrapper, using optimizer {args.optimizer}")
+    args.start_epoch = int(args.start_epoch)
+    optim_type = args.optimizer.lower()
+    params = get_param_groups(model)
+    learning_rate = get_learning_rate(args, batch_num)
+    step = int(args.start_epoch * batch_num)
+    accumulation_step = int(args.accumulation_step)
+    train_step = len(learning_rate)
+    print(f"=> Get LR from epoch: {args.start_epoch}\n"
+          f"=> Start step: {step}\n"
+          f"=> Total step: {train_step}\n"
+          f"=> Accumulation step:{accumulation_step}")
+    if accumulation_step > 1:
+        learning_rate = learning_rate * accumulation_step
+
+    if optim_type == "momentum":
+        optim = Momentum(
+            params=params,
+            learning_rate=learning_rate,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay
+        )
+    elif optim_type == "adamw":
+        optim = AdamWeightDecay(
+            params=params,
+            learning_rate=learning_rate,
+            beta1=args.beta[0],
+            beta2=args.beta[1],
+            eps=args.eps,
+            weight_decay=args.weight_decay
+        )
+    else:
+        raise ValueError(f"optimizer {optim_type} is not supported")
+
+    return optim
+
+
+def get_param_groups(network):
+    """ get param groups """
+    decay_params = []
+    no_decay_params = []
+    for x in network.trainable_params():
+        parameter_name = x.name
+        if parameter_name.endswith(".gamma") or parameter_name.endswith(".beta") or \
+                parameter_name.endswith(".bias"):
+            # Dense or Conv's weight using weight decay
+            print(f"=> no decay {parameter_name}")
+            no_decay_params.append(x)
+        else:
+            print(f"=> decay {parameter_name}")
+            # all bias not using weight decay
+            # bn weight bias not using weight decay, be carefully for now x not include LN
+            decay_params.append(x)
+
+    return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
diff --git a/research/cv/cait/src/tools/schedulers.py b/research/cv/cait/src/tools/schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f3150f0eb753e4ed091f95921f2f44106ae4a20
--- /dev/null
+++ b/research/cv/cait/src/tools/schedulers.py
@@ -0,0 +1,113 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""LearningRate scheduler functions"""
+import numpy as np
+
+__all__ = ["multistep_lr", "cosine_lr", "constant_lr", "get_policy", "exp_lr"]
+
+
+def get_policy(name):
+    """get lr policy from name"""
+    if name is None:
+        return constant_lr
+
+    out_dict = {
+        "constant_lr": constant_lr,
+        "cosine_lr": cosine_lr,
+        "multistep_lr": multistep_lr,
+        "exp_lr": exp_lr,
+    }
+
+    return out_dict[name]
+
+
+def constant_lr(args, batch_num):
+    """Get constant lr"""
+    learning_rate = []
+
+    def _lr_adjuster(epoch):
+        if epoch < args.warmup_length:
+            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
+        else:
+            lr = args.base_lr
+
+        return lr
+
+    for epoch in range(args.epochs):
+        for batch in range(batch_num):
+            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
+    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
+    return learning_rate
+
+
+def exp_lr(args, batch_num):
+    """Get exp lr """
+    learning_rate = []
+
+    def _lr_adjuster(epoch):
+        if epoch < args.warmup_length:
+            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
+        else:
+            lr = args.base_lr * args.lr_gamma ** epoch
+
+        return lr
+
+    for epoch in range(args.epochs):
+        for batch in range(batch_num):
+            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
+    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
+    return learning_rate
+
+
+def cosine_lr(args, batch_num):
+    """Get cosine lr"""
+    learning_rate = []
+
+    def _lr_adjuster(epoch):
+        if epoch < args.warmup_length:
+            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
+        else:
+            e = epoch - args.warmup_length
+            es = args.epochs - args.warmup_length
+            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * (args.base_lr - args.min_lr) + args.min_lr
+
+        return lr
+
+    for epoch in range(args.epochs):
+        for batch in range(batch_num):
+            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
+    learning_rate = np.array(learning_rate)
+    # learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
+    return learning_rate
+
+
+def multistep_lr(args, batch_num):
+    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
+    learning_rate = []
+
+    def _lr_adjuster(epoch):
+        lr = args.base_lr * (args.lr_gamma ** (epoch / args.lr_adjust))
+        return lr
+
+    for epoch in range(args.epochs):
+        for batch in range(batch_num):
+            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
+    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
+    return learning_rate
+
+
+def _warmup_lr(warmup_lr, base_lr, warmup_length, epoch):
+    """Linear warmup"""
+    return epoch / warmup_length * (base_lr - warmup_lr) + warmup_lr
diff --git a/research/cv/cait/src/trainers/__init__.py b/research/cv/cait/src/trainers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5fecee77553f3cdcb049cc30b7a014f59fca3b4
--- /dev/null
+++ b/research/cv/cait/src/trainers/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init train one step"""
+from .train_one_step_with_scale_and_clip_global_norm \
+    import TrainOneStepWithLossScaleCellGlobalNormClip as TrainClipGrad
diff --git a/research/cv/cait/src/trainers/train_one_step_with_scale_and_clip_global_norm.py b/research/cv/cait/src/trainers/train_one_step_with_scale_and_clip_global_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ecf72a704d31ce910cd49498513dd3dfd3a0b5c
--- /dev/null
+++ b/research/cv/cait/src/trainers/train_one_step_with_scale_and_clip_global_norm.py
@@ -0,0 +1,87 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""TrainOneStepWithLossScaleCellGlobalNormClip"""
+import mindspore.nn as nn
+from mindspore.common import RowTensor
+from mindspore.ops import composite as C
+from mindspore.ops import functional as F
+from mindspore.ops import operations as P
+
+_grad_scale = C.MultitypeFuncGraph("grad_scale")
+reciprocal = P.Reciprocal()
+
+
+@_grad_scale.register("Tensor", "Tensor")
+def tensor_grad_scale(scale, grad):
+    return grad * F.cast(reciprocal(scale), F.dtype(grad))
+
+
+@_grad_scale.register("Tensor", "RowTensor")
+def tensor_grad_scale_row_tensor(scale, grad):
+    return RowTensor(grad.indices,
+                     grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
+                     grad.dense_shape)
+
+
+_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
+grad_overflow = P.FloatStatus()
+
+
+class TrainOneStepWithLossScaleCellGlobalNormClip(nn.TrainOneStepWithLossScaleCell):
+    """
+    Encapsulation class of SSD network training.
+
+    Append an optimizer to the training network after that the construct
+    function can be called to create the backward graph.
+
+    Args:
+        network (Cell): The training network. Note that loss function should have been added.
+        optimizer (Optimizer): Optimizer for updating the weights.
+        sens (Number): The adjust parameter. Default: 1.0.
+        use_global_nrom(bool): Whether apply global norm before optimizer. Default: False
+    """
+
+    def __init__(self, network, optimizer,
+                 scale_sense=1.0, clip_global_norm=True,
+                 clip_global_norm_value=1.0):
+        super(TrainOneStepWithLossScaleCellGlobalNormClip, self).__init__(network, optimizer, scale_sense)
+        self.clip_global_norm = clip_global_norm
+        self.clip_global_norm_value = clip_global_norm_value
+        self.print = P.Print()
+
+    def construct(self, *inputs):
+        """construct"""
+        weights = self.weights
+        loss = self.network(*inputs)
+        scaling_sens = self.scale_sense
+
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
+
+        scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
+        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
+        grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
+        # apply grad reducer on grads
+        grads = self.grad_reducer(grads)
+        # get the overflow buffer
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
+        # if there is no overflow, do optimize
+        if not overflow:
+            if self.clip_global_norm:
+                grads = C.clip_by_global_norm(grads, clip_norm=self.clip_global_norm_value)
+            loss = F.depend(loss, self.optimizer(grads))
+        else:
+            self.print("=============Over Flow, skipping=============")
+        return loss
diff --git a/research/cv/cait/train.py b/research/cv/cait/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d58cb557880ca5b186018c283ec2f18f15ed9f9
--- /dev/null
+++ b/research/cv/cait/train.py
@@ -0,0 +1,107 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+train
+code:
+    https://github.com/facebookresearch/deit/blob/main/cait_models.py
+
+paper:
+    https://arxiv.org/abs/2103.17239
+
+Acc: ImageNet1k-77.6%
+"""
+import os
+
+import numpy as np
+from mindspore import Model
+from mindspore import context
+from mindspore import nn
+from mindspore.common import set_seed
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
+
+from src.args import args
+from src.tools.callback import EvaluateCallBack
+from src.tools.cell import cast_amp
+from src.tools.criterion import get_criterion, NetWithLoss
+from src.tools.get_misc import get_dataset, set_device, get_model, pretrained, get_train_one_step
+from src.tools.optimizer import get_optimizer
+
+
+def main():
+    set_seed(args.seed)
+    mode = {
+        0: context.GRAPH_MODE,
+        1: context.PYNATIVE_MODE
+    }
+    context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
+    context.set_context(enable_graph_kernel=False)
+    if args.device_target == "Ascend":
+        context.set_context(enable_auto_mixed_precision=True)
+    rank = set_device(args)
+
+    # get model and cast amp_level
+    net = get_model(args)
+    params_num = 0
+    for param in net.trainable_params():
+        params_num += np.prod(param.shape)
+    print(f"=> params_num: {params_num}")
+
+    cast_amp(net)
+    criterion = get_criterion(args)
+    net_with_loss = NetWithLoss(net, criterion)
+    if args.pretrained:
+        pretrained(args, net)
+
+    data = get_dataset(args)
+    batch_num = data.train_dataset.get_dataset_size()
+    optimizer = get_optimizer(args, net, batch_num)
+    # save a yaml file to read to record parameters
+
+    net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
+
+    eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
+    eval_indexes = [0, 1, 2]
+    model = Model(net_with_loss, metrics={"acc", "loss"},
+                  eval_network=eval_network,
+                  eval_indexes=eval_indexes)
+
+    config_ck = CheckpointConfig(save_checkpoint_steps=data.train_dataset.get_dataset_size(),
+                                 keep_checkpoint_max=args.keep_checkpoint_max)
+    time_cb = TimeMonitor(data_size=data.train_dataset.get_dataset_size())
+
+    ckpt_save_dir = "./ckpt_" + str(rank)
+    if args.run_modelarts:
+        ckpt_save_dir = "/cache/ckpt_" + str(rank)
+
+    ckpoint_cb = ModelCheckpoint(prefix=args.arch + str(rank), directory=ckpt_save_dir,
+                                 config=config_ck)
+    loss_cb = LossMonitor()
+    eval_cb = EvaluateCallBack(model, eval_dataset=data.val_dataset, src_url=ckpt_save_dir,
+                               train_url=os.path.join(args.train_url, "ckpt_" + str(rank)),
+                               save_freq=args.save_every)
+
+    print("begin train")
+    model.train(int(args.epochs - args.start_epoch), data.train_dataset,
+                callbacks=[time_cb, ckpoint_cb, loss_cb, eval_cb],
+                dataset_sink_mode=True)
+    print("train success")
+
+    if args.run_modelarts:
+        import moxing as mox
+        mox.file.copy_parallel(src_url=ckpt_save_dir, dst_url=os.path.join(args.train_url, "ckpt_" + str(rank)))
+
+
+if __name__ == '__main__':
+    main()