diff --git a/research/cv/ViG/README_CN.md b/research/cv/ViG/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..2091d5bdb266076ad6d433bdbe92ac1916ff9c00
--- /dev/null
+++ b/research/cv/ViG/README_CN.md
@@ -0,0 +1,215 @@
+# 目录
+
+<!-- TOC -->
+
+- [目录](#目录)
+- [描述](#描述)
+- [数据集](#数据集)
+- [特性](#特性)
+    - [混合精度](#混合精度)
+- [环境要求](#环境要求)
+- [脚本说明](#脚本说明)
+    - [脚本及样例代码](#脚本及样例代码)
+    - [脚本参数](#脚本参数)
+- [训练和测试](#训练和测试)
+    - [训练结果](#训练结果)
+        - [结果](#结果)
+    - [导出过程](#导出过程)
+        - [导出](#导出)
+- [ModelZoo主页](#modelzoo主页)
+
+<!-- /TOC -->
+
+# [描述](#目录)
+
+图神经网络(GNN)是一种最初用于图数据任务的神经网络。本文首次提出了基于纯GNN的模型来解决通用性计算机视觉问题,如图像识别、目标检测等。视觉GNN(即ViG)将图像视为一个由patch作为节点的图结构,使用GNN来处理该图结构,进行节点之间的信息交互和特征变换。通过堆叠ViG模块,作者建立了用于图像识别的ViG模型。
+
+论文:Kai Han, Yunhe Wang, Jianyuan Guo, Yehui Tang, Enhua Wu. Vision GNN: An Image is Worth Graph of Nodes. 2022. [paper link](https://arxiv.org/abs/2206.00272)
+
+# [数据集](#目录)
+
+使用的数据集:[ImageNet2012](http://www.image-net.org/)
+
+- 数据集大小:共1000个类、224*224彩色图像
+    - 训练集:共1,281,167张图像
+    - 测试集:共50,000张图像
+- 数据格式:JPEG
+    - 注:数据在dataset.py中处理。
+- 下载数据集,目录结构如下:
+
+ ```text
+└─dataset
+    ├─train                 # 训练数据集
+    └─val                   # 评估数据集
+ ```
+
+# [特性](#目录)
+
+## 混合精度
+
+采用[混合精度](https://www.mindspore.cn/tutorials/experts/zh-CN/master/others/mixed_precision.html)的训练方法,使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
+
+# [环境要求](#目录)
+
+- 硬件(Ascend)
+    - 使用Ascend来搭建硬件环境。
+- 框架
+    - [MindSpore](https://www.mindspore.cn/install/en)
+- 如需查看详情,请参见如下资源:
+    - [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/r1.3/index.html)
+    - [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/r1.3/index.html)
+
+# [脚本说明](#目录)
+
+## 脚本及样例代码
+
+```text
+├── ViG
+  ├── README_CN.md                        // ViG相关说明
+  ├── scripts
+      ├──run_standalone_train_ascend.sh   // 单卡Ascend910训练脚本
+      ├──run_distribute_train_ascend.sh   // 多卡Ascend910训练脚本
+      ├──run_eval_ascend.sh               // 测试脚本
+  ├── src
+      ├──configs                          // ViG的配置文件
+      ├──data                             // 数据集配置文件
+          ├──imagenet.py                  // imagenet配置文件
+          ├──augment                      // 数据增强函数文件
+          ┕──data_utils                   // modelarts运行时数据集复制函数文件
+  │   ├──models                           // 模型定义文件夹
+          ┕──ViG                          // ViG定义文件
+  │   ├──trainers                         // 自定义TrainOneStep文件
+  │   ├──tools                            // 工具文件夹
+          ├──callback.py                  // 自定义回调函数,训练结束测试
+          ├──cell.py                      // 一些关于cell的通用工具函数
+          ├──criterion.py                 // 关于损失函数的工具函数
+          ├──get_misc.py                  // 一些其他的工具函数
+          ├──optimizer.py                 // 关于优化器和参数的函数
+          ┕──schedulers.py                // 学习率衰减的工具函数
+  ├── train.py                            // 训练文件
+  ├── eval.py                             // 评估文件
+  ├── export.py                           // 导出模型文件
+  ├── postprocess.py                      // 推理计算精度文件
+  ├── preprocess.py                       // 推理预处理图片文件
+
+```
+
+## 脚本参数
+
+在vig_s_patch16_224.yaml中可以同时配置训练参数和评估参数。
+
+- 配置ViG和ImageNet-1k数据集。
+
+  ```python
+    # Architecture
+    arch: vig_s_patch16_224             # ViG结构选择
+    # ===== Dataset ===== #
+    data_url: ./data/imagenet           # 数据集地址
+    set: ImageNet                       # 数据集名字
+    num_classes: 1000                   # 数据集分类数目
+    mix_up: 0.8                         # MixUp数据增强参数
+    cutmix: 1.0                         # CutMix数据增强参数
+    auto_augment: rand-m9-mstd0.5-inc1  # AutoAugment参数
+    interpolation: bicubic              # 图像缩放插值方法
+    re_prob: 0.25                       # 数据增强参数
+    re_mode: pixel                      # 数据增强参数
+    re_count: 1                         # 数据增强参数
+    mixup_prob: 1.                      # 数据增强参数
+    switch_prob: 0.5                    # 数据增强参数
+    mixup_mode: batch                   # 数据增强参数
+    # ===== Learning Rate Policy ======== #
+    optimizer: adamw                    # 优化器类别
+    base_lr: 0.0005                     # 基础学习率
+    warmup_lr: 0.00000007               # 学习率热身初始学习率
+    min_lr: 0.000006                    # 最小学习率
+    lr_scheduler: cosine_lr             # 学习率衰减策略
+    warmup_length: 20                   # 学习率热身轮数
+    image_size: 224                     # 图像大小
+    # ===== Network training config ===== #
+    amp_level: O2                       # 混合精度策略
+    beta: [ 0.9, 0.999 ]                # adamw参数
+    clip_global_norm_value: 5.          # 全局梯度范数裁剪阈值
+    is_dynamic_loss_scale: True         # 是否使用动态缩放
+    epochs: 300                         # 训练轮数
+    label_smoothing: 0.1                # 标签平滑参数
+    weight_decay: 0.05                  # 权重衰减参数
+    momentum: 0.9                       # 优化器动量
+    batch_size: 128                     # 批大小
+    # ===== Hardware setup ===== #
+    num_parallel_workers: 16            # 数据预处理线程数
+    device_target: Ascend               # GPU或者Ascend
+  ```
+
+更多配置细节请参考脚本`vig_s_patch16_224.yaml`。 通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
+
+# [训练和测试](#目录)
+
+- Ascend处理器环境运行
+
+  ```bash
+  # 使用python启动单卡训练
+  python train.py --device_id 0 --device_target Ascend --vig_config ./src/configs/vig_s_patch16_224.yaml \
+  > train.log 2>&1 &
+
+  # 使用脚本启动单卡训练
+  bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH]
+
+  # 使用脚本启动多卡训练
+  bash ./scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [CONFIG_PATH]
+
+  # 使用python启动单卡运行评估示例
+  python eval.py --device_id 0 --device_target Ascend --vig_config ./src/configs/vig_s_patch16_224.yaml \
+  --pretrained ./ckpt_0/vig_s_patch16_224.ckpt > ./eval.log 2>&1 &
+
+  # 使用脚本启动单卡运行评估示例
+  bash ./scripts/run_eval_ascend.sh [DEVICE_ID] [CONFIG_PATH] [CHECKPOINT_PATH]
+  ```
+
+  对于分布式训练,需要提前创建JSON格式的hccl配置文件。
+
+  请遵循以下链接中的说明:
+
+[hccl工具](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)
+
+## 训练结果
+
+### 结果
+
+使用ImageNet-1k数据集进行训练和测试,可以找到类似以下的结果。
+
+  ```shell
+  # result
+  Top1 acc:  0.804
+  Top5 acc:  0.952
+  ```
+
+| 参数                 | Ascend                                                       |
+| -------------------------- | ----------------------------------------------------------- |
+|模型|ViG|
+| 模型版本              | vig_s_patch16_224                                                |
+| 资源                   | Ascend 910               |
+| 上传日期              | 2022-06-06                                 |
+| MindSpore版本          | 1.7.0                                                 |
+| 数据集                    | ImageNet-1k Train,共1,281,167张图像                                              |
+| 训练参数        | epoch=300, batch_size=128            |
+| 优化器                  | AdamWeightDecay                                                    |
+| 损失函数              | SoftTargetCrossEntropy                                       |
+| 损失| 0.9680 |
+| 输出                    | 概率                                                 |
+| 分类准确率             | 八卡:top1:80.4% top5:95.2%                   |
+| 速度                      | 八卡:1755毫秒/步                        |
+| 训练耗时          |212h40min(run on ModelArts)|
+
+## 导出过程
+
+### 导出
+
+  ```shell
+  python export.py --pretrained [CKPT_FILE] --vig_config [CONFIG_PATH] --device_target [DEVICE_TARGET]
+  ```
+
+导出的模型会以模型的结构名字命名并且保存在当前目录下
+
+# ModelZoo主页
+
+请浏览官网[主页](https://gitee.com/mindspore/models)
\ No newline at end of file
diff --git a/research/cv/ViG/eval.py b/research/cv/ViG/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..69964df77c43a944883aaf5e8ec2e19681e5ab40
--- /dev/null
+++ b/research/cv/ViG/eval.py
@@ -0,0 +1,70 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""eval"""
+
+from mindspore import Model
+from mindspore import context
+from mindspore import nn
+from mindspore.common import set_seed
+
+from src.args import args
+from src.tools.cell import cast_amp
+from src.tools.criterion import get_criterion, NetWithLoss
+from src.tools.get_misc import get_dataset, set_device, get_model, pretrained, get_train_one_step
+from src.tools.optimizer import get_optimizer
+
+set_seed(args.seed)
+
+
+def main():
+    mode = {
+        0: context.GRAPH_MODE,
+        1: context.PYNATIVE_MODE
+    }
+    context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
+    context.set_context(enable_graph_kernel=False)
+    set_device(args)
+
+    # get model
+    net = get_model(args)
+    cast_amp(net)
+    criterion = get_criterion(args)
+
+    net_with_loss = NetWithLoss(net, criterion)
+    if args.pretrained:
+        pretrained(args, net)
+
+    data = get_dataset(args, training=False)
+    batch_num = data.val_dataset.get_dataset_size()
+    optimizer = get_optimizer(args, net, batch_num)
+    # save a yaml file to read to record parameters
+
+    net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
+    eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
+    eval_indexes = [0, 1, 2]
+    eval_metrics = {'Loss': nn.Loss(),
+                    'Top1-Acc': nn.Top1CategoricalAccuracy(),
+                    'Top5-Acc': nn.Top5CategoricalAccuracy()}
+    model = Model(net_with_loss, metrics=eval_metrics,
+                  eval_network=eval_network,
+                  eval_indexes=eval_indexes)
+    print(f"=> begin eval")
+    results = model.eval(data.val_dataset)
+    print(f"=> eval results:{results}")
+    print(f"=> eval success")
+
+
+if __name__ == '__main__':
+    main()
diff --git a/research/cv/ViG/export.py b/research/cv/ViG/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7b990331a9e57eb3c01a0196f8a4904d03fcd36
--- /dev/null
+++ b/research/cv/ViG/export.py
@@ -0,0 +1,48 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+##############export checkpoint file into air, onnx or mindir model#################
+python export.py
+"""
+
+import numpy as np
+from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
+from mindspore import dtype as mstype
+
+from src.args import args
+from src.tools.cell import cast_amp
+from src.tools.criterion import get_criterion, NetWithLoss
+from src.tools.get_misc import get_model
+
+context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
+
+if args.device_target in ["Ascend", "GPU"]:
+    context.set_context(device_id=args.device_id)
+
+if __name__ == '__main__':
+    net = get_model(args)
+    criterion = get_criterion(args)
+    cast_amp(net)
+    net_with_loss = NetWithLoss(net, criterion)
+    assert args.pretrained is not None, "checkpoint_path is None."
+
+    param_dict = load_checkpoint(args.pretrained)
+    load_param_into_net(net, param_dict)
+
+    net.set_train(False)
+    net.to_float(mstype.float32)
+
+    input_arr = Tensor(np.zeros([1, 3, args.image_size, args.image_size], np.float32))
+    export(net, input_arr, file_name=args.arch, file_format=args.file_format)
diff --git a/research/cv/ViG/postprocess.py b/research/cv/ViG/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..65b78267de3ad7b456ac10c7b19e1c45b45c783d
--- /dev/null
+++ b/research/cv/ViG/postprocess.py
@@ -0,0 +1,50 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""postprocess for 310 inference"""
+import argparse
+import json
+import os
+
+import numpy as np
+from mindspore.nn import Top1CategoricalAccuracy, Top5CategoricalAccuracy
+
+parser = argparse.ArgumentParser(description="postprocess")
+parser.add_argument("--result_dir", type=str, default="./result_Files", help="result files path.")
+parser.add_argument('--dataset_name', type=str, choices=["imagenet2012"], default="imagenet2012")
+args = parser.parse_args()
+
+def calcul_acc(lab, preds):
+    return sum(1 for x, y in zip(lab, preds) if x == y) / len(lab)
+
+
+if __name__ == '__main__':
+    batch_size = 1
+    top1_acc = Top1CategoricalAccuracy()
+    rst_path = args.result_dir
+    label_list = []
+    pred_list = []
+    file_list = os.listdir(rst_path)
+    top5_acc = Top5CategoricalAccuracy()
+    with open('./preprocess_Result/imagenet_label.json', "r") as label:
+        labels = json.load(label)
+    for f in file_list:
+        label = f.split("_0.bin")[0] + ".JPEG"
+        label_list.append(labels[label])
+        pred = np.fromfile(os.path.join(rst_path, f), np.float32)
+        pred = pred.reshape(batch_size, int(pred.shape[0] / batch_size))
+        top1_acc.update(pred, [labels[label],])
+        top5_acc.update(pred, [labels[label],])
+    print("Top1 acc: ", top1_acc.eval())
+    print("Top5 acc: ", top5_acc.eval())
diff --git a/research/cv/ViG/preprocess.py b/research/cv/ViG/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..a124cdb0f8ad167a803b3f1cd209067907076ffb
--- /dev/null
+++ b/research/cv/ViG/preprocess.py
@@ -0,0 +1,46 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""preprocess"""
+import argparse
+import json
+import os
+
+parser = argparse.ArgumentParser('preprocess')
+parser.add_argument('--dataset_name', type=str, choices=["imagenet2012"], default="imagenet2012")
+parser.add_argument('--data_path', type=str, default='', help='eval data dir')
+def create_label(result_path, dir_path):
+    """
+    create_label
+    """
+    dirs = os.listdir(dir_path)
+    file_list = []
+    for file in dirs:
+        file_list.append(file)
+    file_list = sorted(file_list)
+    total = 0
+    img_label = {}
+    for i, file_dir in enumerate(file_list):
+        files = os.listdir(os.path.join(dir_path, file_dir))
+        for f in files:
+            img_label[f] = i
+        total += len(files)
+    json_file = os.path.join(result_path, "imagenet_label.json")
+    with open(json_file, "w+") as label:
+        json.dump(img_label, label)
+    print("[INFO] Completed! Total {} data.".format(total))
+
+args = parser.parse_args()
+if __name__ == "__main__":
+    create_label('./preprocess_Result/', args.data_path)
diff --git a/research/cv/ViG/requriments.txt b/research/cv/ViG/requriments.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/research/cv/ViG/scripts/run_distribute_train_ascend.sh b/research/cv/ViG/scripts/run_distribute_train_ascend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b80fa33066956bb43ecdc4c7f38cfeab9d3b5cca
--- /dev/null
+++ b/research/cv/ViG/scripts/run_distribute_train_ascend.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# -lt 2 ]
+then
+    echo "Usage: bash ./scripts/run_distribute_train_ascend.sh [RANK_TABLE_FILE] [CONFIG_PATH]"
+exit 1
+fi
+export RANK_TABLE_FILE=$1
+CONFIG_PATH=$2
+export RANK_SIZE=8
+export DEVICE_NUM=8
+
+
+cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
+echo "the number of logical core" $cores
+avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
+core_gap=`expr $avg_core_per_rank \- 1`
+echo "avg_core_per_rank" $avg_core_per_rank
+echo "core_gap" $core_gap
+for((i=0;i<RANK_SIZE;i++))
+do
+    start=`expr $i \* $avg_core_per_rank`
+    export DEVICE_ID=$i
+    export RANK_ID=$i
+    export DEPLOY_MODE=0
+    export GE_USE_STATIC_MEMORY=1
+    end=`expr $start \+ $core_gap`
+    cmdopt=$start"-"$end
+
+    rm -rf train_parallel$i
+    mkdir ./train_parallel$i
+    cp -r ./src ./train_parallel$i
+    cp  *.py ./train_parallel$i
+    cd ./train_parallel$i || exit
+    echo "start training for rank $i, device $DEVICE_ID rank_id $RANK_ID"
+    env > env.log
+    taskset -c $cmdopt python -u ../train.py \
+    --device_target Ascend \
+    --device_id $i \
+    --vig_config=$CONFIG_PATH > log.txt 2>&1 &
+    cd ../
+done
diff --git a/research/cv/ViG/scripts/run_eval_ascend.sh b/research/cv/ViG/scripts/run_eval_ascend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2e48fe46485adec122730c5ed52841aea52e56c6
--- /dev/null
+++ b/research/cv/ViG/scripts/run_eval_ascend.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ $# -lt 3 ]
+then
+    echo "Usage: bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH] [CHECKPOINT_PATH]"
+exit 1
+fi
+
+export DEVICE_ID=$1
+CONFIG_PATH=$2
+CHECKPOINT_PATH=$3
+export RANK_SIZE=1
+export DEVICE_NUM=1
+
+rm -rf evaluation_ascend
+mkdir ./evaluation_ascend
+cd ./evaluation_ascend || exit
+echo  "start training for device id $DEVICE_ID"
+env > env.log
+python ../eval.py --device_target=Ascend --device_id=$DEVICE_ID --vig_config=$CONFIG_PATH --pretrained=$CHECKPOINT_PATH > eval.log 2>&1 &
+cd ../
diff --git a/research/cv/ViG/scripts/run_standalone_train_ascend.sh b/research/cv/ViG/scripts/run_standalone_train_ascend.sh
new file mode 100644
index 0000000000000000000000000000000000000000..817aceb26b8876684b8cf5c27368780757f45d97
--- /dev/null
+++ b/research/cv/ViG/scripts/run_standalone_train_ascend.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ $# -lt 2 ]
+then
+    echo "Usage: bash ./scripts/run_standalone_train_ascend.sh [DEVICE_ID] [CONFIG_PATH]"
+exit 1
+fi
+
+export RANK_SIZE=1
+export DEVICE_NUM=1
+export DEVICE_ID=$1
+CONFIG_PATH=$2
+
+rm -rf train_standalone
+mkdir ./train_standalone
+cd ./train_standalone || exit
+echo  "start training for device id $DEVICE_ID"
+env > env.log
+python -u ../train.py \
+    --device_id=$DEVICE_ID \
+    --device_target="Ascend" \
+    --vig_config=$CONFIG_PATH > log.txt 2>&1 &
+cd ../
diff --git a/research/cv/ViG/src/args.py b/research/cv/ViG/src/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4e5d238230e9c4b146f43a09875ac86b2e97b76
--- /dev/null
+++ b/research/cv/ViG/src/args.py
@@ -0,0 +1,125 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""global args for Vision GNN (ViG)"""
+import argparse
+import ast
+import os
+import sys
+
+import yaml
+
+from src.configs import parser as _parser
+
+args = None
+
+
+def parse_arguments():
+    """parse_arguments"""
+    global args
+    parser = argparse.ArgumentParser(description="MindSpore ViG Training")
+
+    parser.add_argument("-a", "--arch", metavar="ARCH", default="ResNet18", help="model architecture")
+    parser.add_argument("--accumulation_step", default=1, type=int, help="accumulation step")
+    parser.add_argument("--amp_level", default="O2", choices=["O0", "O2", "O3"], help="AMP Level")
+    parser.add_argument("--batch_size", default=128, type=int, metavar="N",
+                        help="mini-batch size (default: 256), this is the total "
+                             "batch size of all GPUs on the current node when "
+                             "using Data Parallel or Distributed Data Parallel")
+    parser.add_argument("--beta", default=[0.9, 0.999], type=lambda x: [float(a) for a in x.split(",")],
+                        help="beta for optimizer")
+    parser.add_argument("--clip_global_norm_value", default=5., type=float, help="Clip grad value")
+    parser.add_argument('--data_url', default="./data", help='Location of data.')
+    parser.add_argument("--device_id", default=0, type=int, help="Device Id")
+    parser.add_argument("--device_num", default=1, type=int, help="device num")
+    parser.add_argument("--device_target", default="GPU", choices=["GPU", "Ascend", "CPU"], type=str)
+    parser.add_argument("--epochs", default=300, type=int, metavar="N", help="number of total epochs to run")
+    parser.add_argument("--eps", default=1e-8, type=float)
+    parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
+    parser.add_argument("--in_channel", default=3, type=int)
+    parser.add_argument("--is_dynamic_loss_scale", default=1, type=int, help="is_dynamic_loss_scale ")
+    parser.add_argument("--keep_checkpoint_max", default=20, type=int, help="keep checkpoint max num")
+    parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd")
+    parser.add_argument("--set", help="name of dataset", type=str, default="ImageNet")
+    parser.add_argument("--graph_mode", default=1, type=int, help="graph mode with 0, python with 1")
+    parser.add_argument("--mix_up", default=0., type=float, help="mix up")
+    parser.add_argument("--mlp_ratio", help="mlp ", default=4., type=float)
+    parser.add_argument("-j", "--num_parallel_workers", default=20, type=int, metavar="N",
+                        help="number of data loading workers (default: 20)")
+    parser.add_argument("--start_epoch", default=0, type=int, metavar="N",
+                        help="manual epoch number (useful on restarts)")
+    parser.add_argument("--warmup_length", default=0, type=int, help="Number of warmup iterations")
+    parser.add_argument("--warmup_lr", default=5e-7, type=float, help="warm up learning rate")
+    parser.add_argument("--wd", "--weight_decay", default=0.05, type=float, metavar="W",
+                        help="weight decay (default: 1e-4)", dest="weight_decay")
+    parser.add_argument("--loss_scale", default=1024, type=int, help="loss_scale")
+    parser.add_argument("--lr", "--learning_rate", default=2e-3, type=float, help="initial lr", dest="lr")
+    parser.add_argument("--lr_scheduler", default="cosine_annealing", help="Schedule for the learning rate.")
+    parser.add_argument("--lr_adjust", default=30, type=float, help="Interval to drop lr")
+    parser.add_argument("--lr_gamma", default=0.97, type=int, help="Multistep multiplier")
+    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
+    parser.add_argument("--num_classes", default=1000, type=int)
+    parser.add_argument("--pretrained", dest="pretrained", default=None, type=str, help="use pre-trained model")
+    parser.add_argument("--vig_config", help="Config file to use (see configs dir)", default=None, required=True)
+    parser.add_argument("--seed", default=0, type=int, help="seed for initializing training. ")
+    parser.add_argument("--save_every", default=2, type=int, help="Save every ___ epochs(default:2)")
+    parser.add_argument("--label_smoothing", type=float, help="Label smoothing to use, default 0.0", default=0.1)
+    parser.add_argument("--image_size", default=224, help="Image Size.", type=int)
+    parser.add_argument('--train_url', default="./", help='Location of training outputs.')
+    parser.add_argument("--run_modelarts", type=ast.literal_eval, default=False, help="Whether run on modelarts")
+    args = parser.parse_args()
+
+    # Allow for use from notebook without config file
+    if len(sys.argv) > 1:
+        get_config()
+
+
+def get_config():
+    """get_config"""
+    global args
+    override_args = _parser.argv_to_vars(sys.argv)
+    # load yaml file
+    if args.run_modelarts:
+        import moxing as mox
+        if not args.vig_config.startswith("obs:/"):
+            args.vig_config = "obs:/" + args.vig_config
+        with mox.file.File(args.vig_config, 'r') as f:
+            yaml_txt = f.read()
+    else:
+        yaml_txt = open(args.vig_config).read()
+
+    # override args
+    loaded_yaml = yaml.load(yaml_txt, Loader=yaml.FullLoader)
+
+    for v in override_args:
+        loaded_yaml[v] = getattr(args, v)
+
+    print(f"=> Reading YAML config from {args.vig_config}")
+
+    args.__dict__.update(loaded_yaml)
+    print(args)
+
+    if "DEVICE_NUM" not in os.environ.keys():
+        os.environ["DEVICE_NUM"] = str(args.device_num)
+        os.environ["RANK_SIZE"] = str(args.device_num)
+
+
+def run_args():
+    """run and get args"""
+    global args
+    if args is None:
+        parse_arguments()
+
+
+run_args()
diff --git a/research/cv/ViG/src/configs/parser.py b/research/cv/ViG/src/configs/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..75ddb83bfab6b502220250e7322d9e6e1ddc7372
--- /dev/null
+++ b/research/cv/ViG/src/configs/parser.py
@@ -0,0 +1,39 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""parser function"""
+USABLE_TYPES = set([float, int])
+
+
+def trim_preceding_hyphens(st):
+    i = 0
+    while st[i] == "-":
+        i += 1
+
+    return st[i:]
+
+
+def arg_to_varname(st: str):
+    st = trim_preceding_hyphens(st)
+    st = st.replace("-", "_")
+
+    return st.split("=")[0]
+
+
+def argv_to_vars(argv):
+    var_names = []
+    for arg in argv:
+        if arg.startswith("-") and arg_to_varname(arg) != "vig_config":
+            var_names.append(arg_to_varname(arg))
+    return var_names
diff --git a/research/cv/ViG/src/configs/vig_s_patch16_224.yaml b/research/cv/ViG/src/configs/vig_s_patch16_224.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5a31c2787cc08c6344b915ba8448a1abf3bfe848
--- /dev/null
+++ b/research/cv/ViG/src/configs/vig_s_patch16_224.yaml
@@ -0,0 +1,45 @@
+# Architecture
+arch: vig_s_patch16_224
+
+# ===== Dataset ===== #
+data_url: ../data/imagenet
+set: ImageNet
+num_classes: 1000
+mix_up: 0.8
+cutmix: 1.0
+auto_augment: rand-m9-mstd0.5-inc1
+interpolation: bicubic
+re_prob: 0.25
+re_mode: pixel
+re_count: 1
+mixup_prob: 1.
+switch_prob: 0.5
+mixup_mode: batch
+image_size: 224
+
+
+# ===== Learning Rate Policy ======== #
+optimizer: adamw
+base_lr: 0.002
+drop_path_rate: 0.1
+warmup_lr: 0.00000007
+min_lr: 0.000006
+lr_scheduler: cosine_lr
+warmup_length: 20
+
+
+# ===== Network training config ===== #
+amp_level: O2
+keep_bn_fp32: True
+beta: [ 0.9, 0.999 ]
+clip_global_norm_value: 5.
+is_dynamic_loss_scale: True
+epochs: 300
+label_smoothing: 0.1
+weight_decay: 0.05
+momentum: 0.9
+batch_size: 16
+
+# ===== Hardware setup ===== #
+num_parallel_workers: 8
+device_target: Ascend
\ No newline at end of file
diff --git a/research/cv/ViG/src/data/__init__.py b/research/cv/ViG/src/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c031d10103182a72f3b348677b378dbcfc2a72d
--- /dev/null
+++ b/research/cv/ViG/src/data/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init datasets"""
+from .imagenet import ImageNet
diff --git a/research/cv/ViG/src/data/augment/__init__.py b/research/cv/ViG/src/data/augment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a85287a9dcbd95414aa8ccba5e842439a3d0423
--- /dev/null
+++ b/research/cv/ViG/src/data/augment/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init augment"""
+from .auto_augment import _pil_interp, rand_augment_transform
+from .mixup import Mixup
+from .random_erasing import RandomErasing
diff --git a/research/cv/ViG/src/data/augment/auto_augment.py b/research/cv/ViG/src/data/augment/auto_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..5779a6ea54cb44ae2ac5433a4c09659c1697b6a0
--- /dev/null
+++ b/research/cv/ViG/src/data/augment/auto_augment.py
@@ -0,0 +1,896 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" AutoAugment, RandAugment, and AugMix for MindSpore
+
+This code implements the searched ImageNet policies with various tweaks and improvements and
+does not include any of the search code.
+
+AA and RA Implementation adapted from:
+    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
+
+AugMix adapted from:
+    https://github.com/google-research/augmix
+
+Papers:
+    AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 paper
+    Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 paper
+    RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 paper
+    AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 paper
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+import random
+import re
+
+import numpy as np
+import PIL
+from PIL import Image, ImageOps, ImageEnhance
+
+_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
+
+_FILL = (128, 128, 128)
+
+# This signifies the max integer that the controller RNN could predict for the
+# augmentation scheme.
+_MAX_LEVEL = 10.
+
+_HPARAMS_DEFAULT = dict(
+    translate_const=250,
+    img_mean=_FILL,
+)
+
+_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+
+
+def _pil_interp(method):
+    """Interpolation method selection"""
+    if method == 'bicubic':
+        func = Image.BICUBIC
+    elif method == 'lanczos':
+        func = Image.LANCZOS
+    elif method == 'hamming':
+        func = Image.HAMMING
+    else:
+        func = Image.BILINEAR
+    return func
+
+
+def _interpolation(kwargs):
+    """_interpolation"""
+    interpolation = kwargs.pop('resample', Image.BILINEAR)
+    interpolation = random.choice(interpolation) \
+        if isinstance(interpolation, (list, tuple)) else interpolation
+    return interpolation
+
+def _check_args_tf(kwargs):
+    """_check_args_tf"""
+    if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
+        kwargs.pop('fillcolor')
+    kwargs['resample'] = _interpolation(kwargs)
+
+
+def shear_x(img, factor, **kwargs):
+    """shear_x"""
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
+
+
+def shear_y(img, factor, **kwargs):
+    """shear_y"""
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
+
+
+def translate_x_rel(img, pct, **kwargs):
+    """translate_x_rel"""
+    pixels = pct * img.size[0]
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_rel(img, pct, **kwargs):
+    """translate_y_rel"""
+    pixels = pct * img.size[1]
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def translate_x_abs(img, pixels, **kwargs):
+    """translate_x_abs"""
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_abs(img, pixels, **kwargs):
+    """translate_y_abs"""
+    _check_args_tf(kwargs)
+    return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def rotate(img, degrees, **kwargs):
+    """rotate"""
+    _check_args_tf(kwargs)
+    if _PIL_VER >= (5, 2):
+        func = img.rotate(degrees, **kwargs)
+    elif _PIL_VER >= (5, 0):
+        w, h = img.size
+        post_trans = (0, 0)
+        rotn_center = (w / 2.0, h / 2.0)
+        angle = -math.radians(degrees)
+        matrix = [
+            round(math.cos(angle), 15),
+            round(math.sin(angle), 15),
+            0.0,
+            round(-math.sin(angle), 15),
+            round(math.cos(angle), 15),
+            0.0,
+        ]
+
+        def transform(x, y, matrix):
+            (a, b, c, d, e, f) = matrix
+            return a * x + b * y + c, d * x + e * y + f
+
+        matrix[2], matrix[5] = transform(
+            -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
+        )
+        matrix[2] += rotn_center[0]
+        matrix[5] += rotn_center[1]
+        func = img.transform(img.size, Image.AFFINE, matrix, **kwargs)
+    else:
+        func = img.rotate(degrees, resample=kwargs['resample'])
+    return func
+
+
+def auto_contrast(img, **__):
+    """auto_contrast"""
+    return ImageOps.autocontrast(img)
+
+
+def invert(img, **__):
+    """invert"""
+    return ImageOps.invert(img)
+
+
+def equalize(img, **__):
+    """equalize"""
+    return ImageOps.equalize(img)
+
+
+def solarize(img, thresh, **__):
+    """solarize"""
+    return ImageOps.solarize(img, thresh)
+
+
+def solarize_add(img, add, thresh=128, **__):
+    """solarize_add"""
+    lut = []
+    for i in range(256):
+        if i < thresh:
+            lut.append(min(255, i + add))
+        else:
+            lut.append(i)
+    if img.mode in ("L", "RGB"):
+        if img.mode == "RGB" and len(lut) == 256:
+            lut = lut + lut + lut
+        func = img.point(lut)
+    else:
+        func = img
+    return func
+
+
+def posterize(img, bits_to_keep, **__):
+    """posterize"""
+    if bits_to_keep >= 8:
+        func = img
+    else:
+        func = ImageOps.posterize(img, bits_to_keep)
+    return func
+
+
+def contrast(img, factor, **__):
+    """contrast"""
+    return ImageEnhance.Contrast(img).enhance(factor)
+
+
+def color(img, factor, **__):
+    """color"""
+    return ImageEnhance.Color(img).enhance(factor)
+
+
+def brightness(img, factor, **__):
+    """brightness"""
+    return ImageEnhance.Brightness(img).enhance(factor)
+
+
+def sharpness(img, factor, **__):
+    """sharpness"""
+    return ImageEnhance.Sharpness(img).enhance(factor)
+
+
+def _randomly_negate(v):
+    """With 50% prob, negate the value"""
+    return -v if random.random() > 0.5 else v
+
+
+def _rotate_level_to_arg(level, _hparams):
+    """_randomly_negate"""
+    # range [-30, 30]
+    level = (level / _MAX_LEVEL) * 30.
+    level = _randomly_negate(level)
+    return (level,)
+
+
+def _enhance_level_to_arg(level, _hparams):
+    """_enhance_level_to_arg"""
+    # range [0.1, 1.9]
+    return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
+
+
+def _enhance_increasing_level_to_arg(level, _hparams):
+    """_enhance_increasing_level_to_arg"""
+    # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
+    # range [0.1, 1.9]
+    level = (level / _MAX_LEVEL) * .9
+    level = 1.0 + _randomly_negate(level)
+    return (level,)
+
+
+def _shear_level_to_arg(level, _hparams):
+    """_shear_level_to_arg"""
+    # range [-0.3, 0.3]
+    level = (level / _MAX_LEVEL) * 0.3
+    level = _randomly_negate(level)
+    return (level,)
+
+
+def _translate_abs_level_to_arg(level, hparams):
+    """_translate_abs_level_to_arg"""
+    translate_const = hparams['translate_const']
+    level = (level / _MAX_LEVEL) * float(translate_const)
+    level = _randomly_negate(level)
+    return (level,)
+
+
+def _translate_rel_level_to_arg(level, hparams):
+    """_translate_rel_level_to_arg"""
+    # default range [-0.45, 0.45]
+    translate_pct = hparams.get('translate_pct', 0.45)
+    level = (level / _MAX_LEVEL) * translate_pct
+    level = _randomly_negate(level)
+    return (level,)
+
+
+def _posterize_level_to_arg(level, _hparams):
+    """_posterize_level_to_arg"""
+    # As per Tensorflow TPU EfficientNet impl
+    # range [0, 4], 'keep 0 up to 4 MSB of original image'
+    # intensity/severity of augmentation decreases with level
+    return (int((level / _MAX_LEVEL) * 4),)
+
+
+def _posterize_increasing_level_to_arg(level, hparams):
+    """_posterize_increasing_level_to_arg"""
+    # As per Tensorflow models research and UDA impl
+    # range [4, 0], 'keep 4 down to 0 MSB of original image',
+    # intensity/severity of augmentation increases with level
+    return (4 - _posterize_level_to_arg(level, hparams)[0],)
+
+
+def _posterize_original_level_to_arg(level, _hparams):
+    """_posterize_original_level_to_arg"""
+    # As per original AutoAugment paper description
+    # range [4, 8], 'keep 4 up to 8 MSB of image'
+    # intensity/severity of augmentation decreases with level
+    return (int((level / _MAX_LEVEL) * 4) + 4,)
+
+
+def _solarize_level_to_arg(level, _hparams):
+    """_solarize_level_to_arg"""
+    # range [0, 256]
+    # intensity/severity of augmentation decreases with level
+    return (int((level / _MAX_LEVEL) * 256),)
+
+
+def _solarize_increasing_level_to_arg(level, _hparams):
+    """_solarize_increasing_level_to_arg"""
+    # range [0, 256]
+    # intensity/severity of augmentation increases with level
+    return (256 - _solarize_level_to_arg(level, _hparams)[0],)
+
+
+def _solarize_add_level_to_arg(level, _hparams):
+    """_solarize_add_level_to_arg"""
+    # range [0, 110]
+    return (int((level / _MAX_LEVEL) * 110),)
+
+
+LEVEL_TO_ARG = {
+    'AutoContrast': None,
+    'Equalize': None,
+    'Invert': None,
+    'Rotate': _rotate_level_to_arg,
+    # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
+    'Posterize': _posterize_level_to_arg,
+    'PosterizeIncreasing': _posterize_increasing_level_to_arg,
+    'PosterizeOriginal': _posterize_original_level_to_arg,
+    'Solarize': _solarize_level_to_arg,
+    'SolarizeIncreasing': _solarize_increasing_level_to_arg,
+    'SolarizeAdd': _solarize_add_level_to_arg,
+    'Color': _enhance_level_to_arg,
+    'ColorIncreasing': _enhance_increasing_level_to_arg,
+    'Contrast': _enhance_level_to_arg,
+    'ContrastIncreasing': _enhance_increasing_level_to_arg,
+    'Brightness': _enhance_level_to_arg,
+    'BrightnessIncreasing': _enhance_increasing_level_to_arg,
+    'Sharpness': _enhance_level_to_arg,
+    'SharpnessIncreasing': _enhance_increasing_level_to_arg,
+    'ShearX': _shear_level_to_arg,
+    'ShearY': _shear_level_to_arg,
+    'TranslateX': _translate_abs_level_to_arg,
+    'TranslateY': _translate_abs_level_to_arg,
+    'TranslateXRel': _translate_rel_level_to_arg,
+    'TranslateYRel': _translate_rel_level_to_arg,
+}
+
+NAME_TO_OP = {
+    'AutoContrast': auto_contrast,
+    'Equalize': equalize,
+    'Invert': invert,
+    'Rotate': rotate,
+    'Posterize': posterize,
+    'PosterizeIncreasing': posterize,
+    'PosterizeOriginal': posterize,
+    'Solarize': solarize,
+    'SolarizeIncreasing': solarize,
+    'SolarizeAdd': solarize_add,
+    'Color': color,
+    'ColorIncreasing': color,
+    'Contrast': contrast,
+    'ContrastIncreasing': contrast,
+    'Brightness': brightness,
+    'BrightnessIncreasing': brightness,
+    'Sharpness': sharpness,
+    'SharpnessIncreasing': sharpness,
+    'ShearX': shear_x,
+    'ShearY': shear_y,
+    'TranslateX': translate_x_abs,
+    'TranslateY': translate_y_abs,
+    'TranslateXRel': translate_x_rel,
+    'TranslateYRel': translate_y_rel,
+}
+
+
+class AugmentOp:
+    """AugmentOp"""
+
+    def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
+        hparams = hparams or _HPARAMS_DEFAULT
+        self.aug_fn = NAME_TO_OP[name]
+        self.level_fn = LEVEL_TO_ARG[name]
+        self.prob = prob
+        self.magnitude = magnitude
+        self.hparams = hparams.copy()
+        self.kwargs = dict(
+            fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
+            resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
+        )
+
+        # If magnitude_std is > 0, we introduce some randomness
+        # in the usually fixed policy and sample magnitude from a normal distribution
+        # with mean `magnitude` and std-dev of `magnitude_std`.
+        # NOTE This is my own hack, being tested, not in papers or reference impls.
+        # If magnitude_std is inf, we sample magnitude from a uniform distribution
+        self.magnitude_std = self.hparams.get('magnitude_std', 0)
+
+    def __call__(self, img):
+        """apply augment"""
+        if self.prob < 1.0 and random.random() > self.prob:
+            return img
+        magnitude = self.magnitude
+        if self.magnitude_std:
+            if self.magnitude_std == float('inf'):
+                magnitude = random.uniform(0, magnitude)
+            elif self.magnitude_std > 0:
+                magnitude = random.gauss(magnitude, self.magnitude_std)
+        magnitude = min(_MAX_LEVEL, max(0, magnitude))  # clip to valid range
+        level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
+        return self.aug_fn(img, *level_args, **self.kwargs)
+
+
+def auto_augment_policy_v0(hparams):
+    """auto_augment_policy_v0"""
+    # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
+    policy = [
+        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+        [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+        [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],  # This results in black image with Tpu posterize
+        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy_v0r(hparams):
+    """auto_augment_policy_v0r"""
+    # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
+    # in Google research implementation (number of bits discarded increases with magnitude)
+    policy = [
+        [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+        [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+        [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+        [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+        [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+        [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+        [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+        [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+        [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+        [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+        [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+        [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+        [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
+        [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+        [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+        [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+        [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+        [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+        [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+        [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+        [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+        [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+        [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
+        [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+        [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy_original(hparams):
+    """auto_augment_policy_original"""
+    # ImageNet policy from https://arxiv.org/abs/1805.09501 paper
+    policy = [
+        [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+        [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+        [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
+        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+        [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
+        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy_originalr(hparams):
+    """auto_augment_policy_originalr"""
+    # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
+    policy = [
+        [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+        [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+        [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+        [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
+        [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+        [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
+        [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+        [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+        [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+        [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+        [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+        [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+        [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+        [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+        [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+        [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+        [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+        [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+    ]
+    pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+    return pc
+
+
+def auto_augment_policy(name='v0', hparams=None):
+    """auto_augment_policy"""
+    hparams = hparams or _HPARAMS_DEFAULT
+    if name == 'original':
+        func = auto_augment_policy_original(hparams)
+    elif name == 'originalr':
+        func = auto_augment_policy_originalr(hparams)
+    elif name == 'v0':
+        func = auto_augment_policy_v0(hparams)
+    elif name == 'v0r':
+        func = auto_augment_policy_v0r(hparams)
+    else:
+        assert False, 'Unknown AA policy (%s)' % name
+    return func
+
+class AutoAugment:
+    """AutoAugment"""
+    def __init__(self, policy):
+        self.policy = policy
+
+    def __call__(self, img):
+        """apply autoaugment"""
+        sub_policy = random.choice(self.policy)
+        for op in sub_policy:
+            img = op(img)
+        return img
+
+
+def auto_augment_transform(config_str, hparams):
+    """
+    Create a AutoAugment transform
+
+    :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
+    dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
+    The remaining sections, not order specific determine
+        'mstd' -  float std deviation of magnitude noise applied
+    Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
+
+    :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
+
+    :return: A MindSpore compatible Transform
+    """
+    config = config_str.split('-')
+    policy_name = config[0]
+    config = config[1:]
+    for c in config:
+        cs = re.split(r'(\d.*)', c)
+        if len(cs) < 2:
+            continue
+        key, val = cs[:2]
+        if key == 'mstd':
+            # noise param injected via hparams for now
+            hparams.setdefault('magnitude_std', float(val))
+        else:
+            assert False, 'Unknown AutoAugment config section'
+    aa_policy = auto_augment_policy(policy_name, hparams=hparams)
+    return AutoAugment(aa_policy)
+
+
+_RAND_TRANSFORMS = [
+    'AutoContrast',
+    'Equalize',
+    'Invert',
+    'Rotate',
+    'Posterize',
+    'Solarize',
+    'SolarizeAdd',
+    'Color',
+    'Contrast',
+    'Brightness',
+    'Sharpness',
+    'ShearX',
+    'ShearY',
+    'TranslateXRel',
+    'TranslateYRel',
+    # 'Cutout'  # NOTE I've implement this as random erasing separately
+]
+
+_RAND_INCREASING_TRANSFORMS = [
+    'AutoContrast',
+    'Equalize',
+    'Invert',
+    'Rotate',
+    'PosterizeIncreasing',
+    'SolarizeIncreasing',
+    'SolarizeAdd',
+    'ColorIncreasing',
+    'ContrastIncreasing',
+    'BrightnessIncreasing',
+    'SharpnessIncreasing',
+    'ShearX',
+    'ShearY',
+    'TranslateXRel',
+    'TranslateYRel',
+    # 'Cutout'  # NOTE I've implement this as random erasing separately
+]
+
+# These experimental weights are based loosely on the relative improvements mentioned in paper.
+# They may not result in increased performance, but could likely be tuned to so.
+_RAND_CHOICE_WEIGHTS_0 = {
+    'Rotate': 0.3,
+    'ShearX': 0.2,
+    'ShearY': 0.2,
+    'TranslateXRel': 0.1,
+    'TranslateYRel': 0.1,
+    'Color': .025,
+    'Sharpness': 0.025,
+    'AutoContrast': 0.025,
+    'Solarize': .005,
+    'SolarizeAdd': .005,
+    'Contrast': .005,
+    'Brightness': .005,
+    'Equalize': .005,
+    'Posterize': 0,
+    'Invert': 0,
+}
+
+
+def _select_rand_weights(weight_idx=0, transforms=None):
+    """_select_rand_weights"""
+    transforms = transforms or _RAND_TRANSFORMS
+    assert weight_idx == 0  # only one set of weights currently
+    rand_weights = _RAND_CHOICE_WEIGHTS_0
+    probs = [rand_weights[k] for k in transforms]
+    probs /= np.sum(probs)
+    return probs
+
+
+def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
+    """rand_augment_ops"""
+    hparams = hparams or _HPARAMS_DEFAULT
+    transforms = transforms or _RAND_TRANSFORMS
+    return [AugmentOp(
+        name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class RandAugment:
+    """RandAugment"""
+    def __init__(self, ops, num_layers=2, choice_weights=None):
+        self.ops = ops
+        self.num_layers = num_layers
+        self.choice_weights = choice_weights
+
+    def __call__(self, img):
+        # no replacement when using weighted choice
+        ops = np.random.choice(
+            self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
+        for op in ops:
+            img = op(img)
+        return img
+
+
+def rand_augment_transform(config_str, hparams):
+    """
+    Create a RandAugment transform
+
+    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+    sections, not order specific determine
+        'm' - integer magnitude of rand augment
+        'n' - integer num layers (number of transform ops selected per image)
+        'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
+        'mstd' -  float std deviation of magnitude noise applied
+        'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
+    Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
+    'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
+
+    :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
+
+    :return: A MindSpore compatible Transform
+    """
+    magnitude = _MAX_LEVEL  # default to _MAX_LEVEL for magnitude (currently 10)
+    num_layers = 2  # default to 2 ops per image
+    weight_idx = None  # default to no probability weights for op choice
+    transforms = _RAND_TRANSFORMS
+    config = config_str.split('-')
+    assert config[0] == 'rand'
+    # [rand, m9, mstd0.5, inc1]
+    config = config[1:]
+    for c in config:
+        cs = re.split(r'(\d.*)', c)
+        if len(cs) < 2:
+            continue
+        key, val = cs[:2]
+        if key == 'mstd':
+            # noise param injected via hparams for now
+            hparams.setdefault('magnitude_std', float(val))
+        elif key == 'inc':
+            if bool(val):
+                transforms = _RAND_INCREASING_TRANSFORMS
+        elif key == 'm':
+            magnitude = int(val)
+        elif key == 'n':
+            num_layers = int(val)
+        elif key == 'w':
+            weight_idx = int(val)
+        else:
+            assert False, 'Unknown RandAugment config section'
+    ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
+    choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
+    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
+
+
+_AUGMIX_TRANSFORMS = [
+    'AutoContrast',
+    'ColorIncreasing',  # not in paper
+    'ContrastIncreasing',  # not in paper
+    'BrightnessIncreasing',  # not in paper
+    'SharpnessIncreasing',  # not in paper
+    'Equalize',
+    'Rotate',
+    'PosterizeIncreasing',
+    'SolarizeIncreasing',
+    'ShearX',
+    'ShearY',
+    'TranslateXRel',
+    'TranslateYRel',
+]
+
+
+def augmix_ops(magnitude=10, hparams=None, transforms=None):
+    """augmix_ops"""
+    hparams = hparams or _HPARAMS_DEFAULT
+    transforms = transforms or _AUGMIX_TRANSFORMS
+    return [AugmentOp(
+        name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class AugMixAugment:
+    """ AugMix Transform
+    Adapted and improved from impl here (https://github.com/google-research/augmix/blob/master/imagenet.py)
+    From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
+    (https://arxiv.org/abs/1912.02781)
+    """
+
+    def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
+        self.ops = ops
+        self.alpha = alpha
+        self.width = width
+        self.depth = depth
+        self.blended = blended  # blended mode is faster but not well tested
+
+    def _calc_blended_weights(self, ws, m):
+        """_calc_blended_weights"""
+        ws = ws * m
+        cump = 1.
+        rws = []
+        for w in ws[::-1]:
+            alpha = w / cump
+            cump *= (1 - alpha)
+            rws.append(alpha)
+        return np.array(rws[::-1], dtype=np.float32)
+
+    def _apply_blended(self, img, mixing_weights, m):
+        """_apply_blended"""
+        # This is my first crack and implementing a slightly faster mixed augmentation. Instead
+        # of accumulating the mix for each chain in a Numpy array and then blending with original,
+        # it recomputes the blending coefficients and applies one PIL image blend per chain.
+        # TODO the results appear in the right ballpark but they differ by more than rounding.
+        img_orig = img.copy()
+        ws = self._calc_blended_weights(mixing_weights, m)
+        for w in ws:
+            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+            ops = np.random.choice(self.ops, depth, replace=True)
+            img_aug = img_orig  # no ops are in-place, deep copy not necessary
+            for op in ops:
+                img_aug = op(img_aug)
+            img = Image.blend(img, img_aug, w)
+        return img
+
+    def _apply_basic(self, img, mixing_weights, m):
+        """_apply_basic"""
+        # This is a literal adaptation of the paper/official implementation without normalizations and
+        # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
+        # typical augmentation transforms, could use a GPU / Kornia implementation.
+        img_shape = img.size[0], img.size[1], len(img.getbands())
+        mixed = np.zeros(img_shape, dtype=np.float32)
+        for mw in mixing_weights:
+            depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+            ops = np.random.choice(self.ops, depth, replace=True)
+            img_aug = img  # no ops are in-place, deep copy not necessary
+            for op in ops:
+                img_aug = op(img_aug)
+            mixed += mw * np.asarray(img_aug, dtype=np.float32)
+        np.clip(mixed, 0, 255., out=mixed)
+        mixed = Image.fromarray(mixed.astype(np.uint8))
+        return Image.blend(img, mixed, m)
+
+    def __call__(self, img):
+        """AugMixAugment apply"""
+        mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
+        m = np.float32(np.random.beta(self.alpha, self.alpha))
+        if self.blended:
+            mixed = self._apply_blended(img, mixing_weights, m)
+        else:
+            mixed = self._apply_basic(img, mixing_weights, m)
+        return mixed
+
+
+def augment_and_mix_transform(config_str, hparams):
+    """ Create AugMix MindSpore transform
+
+    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+    sections, not order specific determine
+        'm' - integer magnitude (severity) of augmentation mix (default: 3)
+        'w' - integer width of augmentation chain (default: 3)
+        'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
+        'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
+        'mstd' -  float std deviation of magnitude noise applied (default: 0)
+    Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
+
+    :param hparams: Other hparams (kwargs) for the Augmentation transforms
+
+    :return: A MindSpore compatible Transform
+    """
+    magnitude = 3
+    width = 3
+    depth = -1
+    alpha = 1.
+    blended = False
+    hparams['magnitude_std'] = float('inf')
+    config = config_str.split('-')
+    assert config[0] == 'augmix'
+    config = config[1:]
+    for c in config:
+        cs = re.split(r'(\d.*)', c)
+        if len(cs) < 2:
+            continue
+        key, val = cs[:2]
+        if key == 'mstd':
+            # noise param injected via hparams for now
+            hparams.setdefault('magnitude_std', float(val))
+        elif key == 'm':
+            magnitude = int(val)
+        elif key == 'w':
+            width = int(val)
+        elif key == 'd':
+            depth = int(val)
+        elif key == 'a':
+            alpha = float(val)
+        elif key == 'b':
+            blended = bool(val)
+        else:
+            assert False, 'Unknown AugMix config section'
+    ops = augmix_ops(magnitude=magnitude, hparams=hparams)
+    return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
diff --git a/research/cv/ViG/src/data/augment/mixup.py b/research/cv/ViG/src/data/augment/mixup.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d29897dfb1ce18264cd4e405dd27563e2517a11
--- /dev/null
+++ b/research/cv/ViG/src/data/augment/mixup.py
@@ -0,0 +1,247 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" Mixup and Cutmix
+
+Papers:
+mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
+
+CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
+
+Code Reference:
+CutMix: https://github.com/clovaai/CutMix-PyTorch
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import numpy as np
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import ops as P
+
+
+def one_hot(x, num_classes, on_value=1., off_value=0.):
+    """one hot to label"""
+    x = x.reshape(-1)
+    x = np.eye(num_classes)[x]
+    x = np.clip(x, a_min=off_value, a_max=on_value, dtype=np.float32)
+    return x
+
+
+def mixup_target(target, num_classes, lam=1., smoothing=0.0):
+    """mixup_target"""
+    off_value = smoothing / num_classes
+    on_value = 1. - smoothing + off_value
+    y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
+    y2 = one_hot(np.flip(target, axis=0), num_classes, on_value=on_value, off_value=off_value)
+    return y1 * lam + y2 * (1. - lam)
+
+
+def rand_bbox(img_shape, lam, margin=0., count=None):
+    """ Standard CutMix bounding-box
+    Generates a random square bbox based on lambda value. This impl includes
+    support for enforcing a border margin as percent of bbox dimensions.
+
+    Args:
+        img_shape (tuple): Image shape as tuple
+        lam (float): Cutmix lambda value
+        margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
+        count (int): Number of bbox to generate
+    """
+    ratio = np.sqrt(1 - lam)
+    img_h, img_w = img_shape[-2:]
+    cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
+    margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
+    cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
+    cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
+    yl = np.clip(cy - cut_h // 2, 0, img_h)
+    yh = np.clip(cy + cut_h // 2, 0, img_h)
+    xl = np.clip(cx - cut_w // 2, 0, img_w)
+    xh = np.clip(cx + cut_w // 2, 0, img_w)
+    return yl, yh, xl, xh
+
+
+def rand_bbox_minmax(img_shape, minmax, count=None):
+    """ Min-Max CutMix bounding-box
+    Inspired by Darknet cutmix impl, generates a random rectangular bbox
+    based on min/max percent values applied to each dimension of the input image.
+
+    Typical defaults for minmax are usually in the  .2-.3 for min and .8-.9 range for max.
+
+    Args:
+        img_shape (tuple): Image shape as tuple
+        minmax (tuple or list): Min and max bbox ratios (as percent of image size)
+        count (int): Number of bbox to generate
+    """
+    assert len(minmax) == 2
+    img_h, img_w = img_shape[-2:]
+    cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
+    cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
+    yl = np.random.randint(0, img_h - cut_h, size=count)
+    xl = np.random.randint(0, img_w - cut_w, size=count)
+    yu = yl + cut_h
+    xu = xl + cut_w
+    return yl, yu, xl, xu
+
+
+def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
+    """ Generate bbox and apply lambda correction.
+    """
+    if ratio_minmax is not None:
+        yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
+    else:
+        yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
+    if correct_lam or ratio_minmax is not None:
+        bbox_area = (yu - yl) * (xu - xl)
+        lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
+    return (yl, yu, xl, xu), lam
+
+
+class Mixup:
+    """ Mixup/Cutmix that applies different params to each element or whole batch
+
+    Args:
+        mixup_alpha (float): mixup alpha value, mixup is active if > 0.
+        cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
+        cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
+        prob (float): probability of applying mixup or cutmix per batch or element
+        switch_prob (float): probability of switching to cutmix instead of mixup when both are active
+        mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
+        correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
+        label_smoothing (float): apply label smoothing to the mixed target tensor
+        num_classes (int): number of classes for target
+    """
+
+    def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
+                 mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
+        self.mixup_alpha = mixup_alpha
+        self.cutmix_alpha = cutmix_alpha
+        self.cutmix_minmax = cutmix_minmax
+        if self.cutmix_minmax is not None:
+            assert len(self.cutmix_minmax) == 2
+            # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
+            self.cutmix_alpha = 1.0
+        self.mix_prob = prob
+        self.switch_prob = switch_prob
+        self.label_smoothing = label_smoothing
+        self.num_classes = num_classes
+        self.mode = mode
+        self.correct_lam = correct_lam  # correct lambda based on clipped area for cutmix
+        self.mixup_enabled = True  # set to false to disable mixing (intended tp be set by train loop)
+
+    def _params_per_elem(self, batch_size):
+        """_params_per_elem"""
+        lam = np.ones(batch_size, dtype=np.float32)
+        use_cutmix = np.zeros(batch_size, dtype=np.bool)
+        if self.mixup_enabled:
+            if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+                use_cutmix = np.random.rand(batch_size) < self.switch_prob
+                lam_mix = np.where(
+                    use_cutmix,
+                    np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
+                    np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
+            elif self.mixup_alpha > 0.:
+                lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
+            elif self.cutmix_alpha > 0.:
+                use_cutmix = np.ones(batch_size, dtype=np.bool)
+                lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
+            else:
+                assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+            lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
+        return lam, use_cutmix
+
+    def _params_per_batch(self):
+        """_params_per_batch"""
+        lam = 1.
+        use_cutmix = False
+        if self.mixup_enabled and np.random.rand() < self.mix_prob:
+            if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+                use_cutmix = np.random.rand() < self.switch_prob
+                lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
+                    np.random.beta(self.mixup_alpha, self.mixup_alpha)
+            elif self.mixup_alpha > 0.:
+                lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
+            elif self.cutmix_alpha > 0.:
+                use_cutmix = True
+                lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
+            else:
+                assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+            lam = float(lam_mix)
+        return lam, use_cutmix
+
+    def _mix_elem(self, x):
+        """_mix_elem"""
+        batch_size = len(x)
+        lam_batch, use_cutmix = self._params_per_elem(batch_size)
+        x_orig = x.clone()  # need to keep an unmodified original for mixing source
+        for i in range(batch_size):
+            j = batch_size - i - 1
+            lam = lam_batch[i]
+            if lam != 1.:
+                if use_cutmix[i]:
+                    (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+                        x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+                    x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
+                    lam_batch[i] = lam
+                else:
+                    x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+        return P.ExpandDims()(Tensor(lam_batch, dtype=mstype.float32), 1)
+
+    def _mix_pair(self, x):
+        """_mix_pair"""
+        batch_size = len(x)
+        lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
+        x_orig = x.clone()  # need to keep an unmodified original for mixing source
+        for i in range(batch_size // 2):
+            j = batch_size - i - 1
+            lam = lam_batch[i]
+            if lam != 1.:
+                if use_cutmix[i]:
+                    (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+                        x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+                    x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
+                    x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
+                    lam_batch[i] = lam
+                else:
+                    x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+                    x[j] = x[j] * lam + x_orig[i] * (1 - lam)
+        lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
+        return P.ExpandDims()(Tensor(lam_batch, dtype=mstype.float32), 1)
+
+    def _mix_batch(self, x):
+        """_mix_batch"""
+        lam, use_cutmix = self._params_per_batch()
+        if lam == 1.:
+            return 1.
+        if use_cutmix:
+            (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+                x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+            x[:, :, yl:yh, xl:xh] = np.flip(x, axis=0)[:, :, yl:yh, xl:xh]
+        else:
+            x_flipped = np.flip(x, axis=0) * (1. - lam)
+            x *= lam
+            x += x_flipped
+        return lam
+
+    def __call__(self, x, target):
+        """Mixup apply"""
+        # the same to image, label
+        assert len(x) % 2 == 0, 'Batch size should be even when using this'
+        if self.mode == 'elem':
+            lam = self._mix_elem(x)
+        elif self.mode == 'pair':
+            lam = self._mix_pair(x)
+        else:
+            lam = self._mix_batch(x)
+        target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
+        return x.astype(np.float32), target.astype(np.float32)
diff --git a/research/cv/ViG/src/data/augment/random_erasing.py b/research/cv/ViG/src/data/augment/random_erasing.py
new file mode 100644
index 0000000000000000000000000000000000000000..6430b302ed7875920d32f734d984a3bb577e2a0a
--- /dev/null
+++ b/research/cv/ViG/src/data/augment/random_erasing.py
@@ -0,0 +1,113 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" Random Erasing (Cutout)
+
+Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
+Copyright Zhun Zhong & Liang Zheng
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+import random
+
+import numpy as np
+
+
+def _get_pixels(per_pixel, rand_color, patch_size, dtype=np.float32):
+    """_get_pixels"""
+    if per_pixel:
+        func = np.random.normal(size=patch_size).astype(dtype)
+    elif rand_color:
+        func = np.random.normal(size=(patch_size[0], 1, 1)).astype(dtype)
+    else:
+        func = np.zeros((patch_size[0], 1, 1), dtype=dtype)
+    return func
+
+
+class RandomErasing:
+    """ Randomly selects a rectangle region in an image and erases its pixels.
+        'Random Erasing Data Augmentation' by Zhong et al.
+        See https://arxiv.org/pdf/1708.04896.pdf
+
+        This variant of RandomErasing is intended to be applied to either a batch
+        or single image tensor after it has been normalized by dataset mean and std.
+    Args:
+         probability: Probability that the Random Erasing operation will be performed.
+         min_area: Minimum percentage of erased area wrt input image area.
+         max_area: Maximum percentage of erased area wrt input image area.
+         min_aspect: Minimum aspect ratio of erased area.
+         mode: pixel color mode, one of 'const', 'rand', or 'pixel'
+            'const' - erase block is constant color of 0 for all channels
+            'rand'  - erase block is same per-channel random (normal) color
+            'pixel' - erase block is per-pixel random (normal) color
+        max_count: maximum number of erasing blocks per image, area per box is scaled by count.
+            per-image count is randomly chosen between 1 and this value.
+    """
+
+    def __init__(self, probability=0.5, min_area=0.02, max_area=1 / 3, min_aspect=0.3,
+                 max_aspect=None, mode='const', min_count=1, max_count=None, num_splits=0):
+        self.probability = probability
+        self.min_area = min_area
+        self.max_area = max_area
+        max_aspect = max_aspect or 1 / min_aspect
+        self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
+        self.min_count = min_count
+        self.max_count = max_count or min_count
+        self.num_splits = num_splits
+        mode = mode.lower()
+        self.rand_color = False
+        self.per_pixel = False
+        if mode == 'rand':
+            self.rand_color = True  # per block random normal
+        elif mode == 'pixel':
+            self.per_pixel = True  # per pixel random normal
+        else:
+            assert not mode or mode == 'const'
+
+    def _erase(self, img, chan, img_h, img_w, dtype):
+        """_erase"""
+        if random.random() > self.probability:
+            pass
+        else:
+            area = img_h * img_w
+            count = self.min_count if self.min_count == self.max_count else \
+                random.randint(self.min_count, self.max_count)
+            for _ in range(count):
+                for _ in range(10):
+                    target_area = random.uniform(self.min_area, self.max_area) * area / count
+                    aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+                    h = int(round(math.sqrt(target_area * aspect_ratio)))
+                    w = int(round(math.sqrt(target_area / aspect_ratio)))
+                    if w < img_w and h < img_h:
+                        top = random.randint(0, img_h - h)
+                        left = random.randint(0, img_w - w)
+                        img[:, top:top + h, left:left + w] = _get_pixels(
+                            self.per_pixel, self.rand_color, (chan, h, w),
+                            dtype=dtype)
+                        break
+        return img
+
+    def __call__(self, x):
+        """RandomErasing apply"""
+        if len(x.shape) == 3:
+            output = self._erase(x, *x.shape, x.dtype)
+        else:
+            output = np.zeros_like(x)
+            batch_size, chan, img_h, img_w = x.shape
+            # skip first slice of batch if num_splits is set (for clean portion of samples)
+            batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
+            for i in range(batch_start, batch_size):
+                output[i] = self._erase(x[i], chan, img_h, img_w, x.dtype)
+        return output
diff --git a/research/cv/ViG/src/data/data_utils/__init__.py b/research/cv/ViG/src/data/data_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/research/cv/ViG/src/data/data_utils/moxing_adapter.py b/research/cv/ViG/src/data/data_utils/moxing_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3f1652470395edf49a6d6d0c84f4c9d2f3f3b55
--- /dev/null
+++ b/research/cv/ViG/src/data/data_utils/moxing_adapter.py
@@ -0,0 +1,72 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Moxing adapter for ModelArts"""
+
+import os
+
+_global_sync_count = 0
+
+
+def get_device_id():
+    device_id = os.getenv('DEVICE_ID', '0')
+    return int(device_id)
+
+
+def get_device_num():
+    device_num = os.getenv('RANK_SIZE', '1')
+    return int(device_num)
+
+
+def get_rank_id():
+    global_rank_id = os.getenv('RANK_ID', '0')
+    return int(global_rank_id)
+
+
+def get_job_id():
+    job_id = os.getenv('JOB_ID')
+    job_id = job_id if job_id != "" else "default"
+    return job_id
+
+
+def sync_data(from_path, to_path, threads=16):
+    """
+    Download data from remote obs to local directory if the first url is remote url and the second one is local path
+    Upload data from local directory to remote obs in contrast.
+    """
+    import moxing as mox
+    import time
+    global _global_sync_count
+    sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
+    _global_sync_count += 1
+
+    # Each server contains 8 devices as most.
+    if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
+        print("from path: ", from_path)
+        print("to path: ", to_path)
+        mox.file.copy_parallel(from_path, to_path, threads=threads)
+        print("===finish data synchronization===")
+        try:
+            os.mknod(sync_lock)
+        except IOError:
+            pass
+        print("===save flag===")
+
+    while True:
+        if os.path.exists(sync_lock):
+            break
+        time.sleep(1)
+
+    print("Finish sync data from {} to {}.".format(from_path, to_path))
diff --git a/research/cv/ViG/src/data/imagenet.py b/research/cv/ViG/src/data/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e4ad3790da0789afa2272b7ae551f3ad33292c0
--- /dev/null
+++ b/research/cv/ViG/src/data/imagenet.py
@@ -0,0 +1,160 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Data operations, will be used in train.py and eval.py
+"""
+import os
+
+import mindspore.common.dtype as mstype
+import mindspore.dataset as ds
+import mindspore.dataset.transforms.c_transforms as C
+import mindspore.dataset.vision.c_transforms as vision
+import mindspore.dataset.vision.py_transforms as py_vision
+from mindspore.dataset.vision.utils import Inter
+
+from src.data.augment.auto_augment import _pil_interp, rand_augment_transform
+from src.data.augment.mixup import Mixup
+from src.data.augment.random_erasing import RandomErasing
+from .data_utils.moxing_adapter import sync_data
+
+
+class ImageNet:
+    """ImageNet Define"""
+
+    def __init__(self, args, training=True):
+        if args.run_modelarts:
+            print('Download data.')
+            local_data_path = '/cache/data'
+            sync_data(args.data_url, local_data_path, threads=128)
+            print('Create train and evaluate dataset.')
+            train_dir = os.path.join(local_data_path, "train")
+            val_ir = os.path.join(local_data_path, "val")
+            self.train_dataset = create_dataset_imagenet(train_dir, training=True, args=args)
+            self.val_dataset = create_dataset_imagenet(val_ir, training=False, args=args)
+        else:
+            train_dir = os.path.join(args.data_url, "train")
+            val_ir = os.path.join(args.data_url, "val")
+            if training:
+                self.train_dataset = create_dataset_imagenet(train_dir, training=True, args=args)
+            self.val_dataset = create_dataset_imagenet(val_ir, training=False, args=args)
+
+
+def create_dataset_imagenet(dataset_dir, args, repeat_num=1, training=True):
+    """
+    create a train or eval imagenet2012 dataset for TNT
+
+    Args:
+        dataset_dir(string): the path of dataset.
+        do_train(bool): whether dataset is used for train or eval.
+        repeat_num(int): the repeat times of dataset. Default: 1
+
+    Returns:
+        dataset
+    """
+
+    device_num, rank_id = _get_rank_info()
+    shuffle = bool(training)
+    if device_num == 1 or not training:
+        data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers,
+                                         shuffle=shuffle)
+    else:
+        data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers, shuffle=shuffle,
+                                         num_shards=device_num, shard_id=rank_id)
+
+    image_size = args.image_size
+
+    # define map operations
+    # BICUBIC: 3
+
+    if training:
+        mean = [0.485, 0.456, 0.406]
+        std = [0.229, 0.224, 0.225]
+        aa_params = dict(
+            translate_const=int(image_size * 0.45),
+            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
+        )
+        interpolation = args.interpolation
+        auto_augment = args.auto_augment
+        assert auto_augment.startswith('rand')
+        aa_params['interpolation'] = _pil_interp(interpolation)
+
+        transform_img = [
+            vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(3 / 4, 4 / 3),
+                                          interpolation=Inter.BICUBIC),
+            vision.RandomHorizontalFlip(prob=0.5),
+            py_vision.ToPIL()
+        ]
+        transform_img += [rand_augment_transform(auto_augment, aa_params)]
+        transform_img += [
+            py_vision.ToTensor(),
+            py_vision.Normalize(mean=mean, std=std),
+            RandomErasing(args.re_prob, mode=args.re_mode, max_count=args.re_count)
+        ]
+    else:
+        mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
+        std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
+        # test transform complete
+        transform_img = [
+            vision.Decode(),
+            vision.Resize(int(256 / 224 * image_size), interpolation=Inter.BICUBIC),
+            vision.CenterCrop(image_size),
+            vision.Normalize(mean=mean, std=std),
+            vision.HWC2CHW()
+        ]
+
+    transform_label = C.TypeCast(mstype.int32)
+
+    data_set = data_set.map(input_columns="image", num_parallel_workers=args.num_parallel_workers,
+                            operations=transform_img)
+    data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
+                            operations=transform_label)
+    if (args.mix_up > 0. or args.cutmix > 0.)  and not training:
+        # if use mixup and not training(False), one hot val data label
+        one_hot = C.OneHot(num_classes=args.num_classes)
+        data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
+                                operations=one_hot)
+    # apply batch operations
+    data_set = data_set.batch(args.batch_size, drop_remainder=True,
+                              num_parallel_workers=args.num_parallel_workers)
+
+    if (args.mix_up > 0. or args.cutmix > 0.) and training:
+        mixup_fn = Mixup(
+            mixup_alpha=args.mix_up, cutmix_alpha=args.cutmix, cutmix_minmax=None,
+            prob=args.mixup_prob, switch_prob=args.switch_prob, mode=args.mixup_mode,
+            label_smoothing=args.label_smoothing, num_classes=args.num_classes)
+
+        data_set = data_set.map(operations=mixup_fn, input_columns=["image", "label"],
+                                num_parallel_workers=args.num_parallel_workers)
+
+    # apply dataset repeat operation
+    data_set = data_set.repeat(repeat_num)
+
+    return data_set
+
+
+def _get_rank_info():
+    """
+    get rank size and rank id
+    """
+    rank_size = int(os.environ.get("RANK_SIZE", 1))
+
+    if rank_size > 1:
+        from mindspore.communication.management import get_rank, get_group_size
+        rank_size = get_group_size()
+        rank_id = get_rank()
+    else:
+        rank_size = rank_id = None
+
+    return rank_size, rank_id
diff --git a/research/cv/ViG/src/models/__init__.py b/research/cv/ViG/src/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca0ae33e78d46095804b11f691e0f6f1810b3293
--- /dev/null
+++ b/research/cv/ViG/src/models/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init model"""
+from .vig import vig_ti_patch16_224
+from .vig import vig_s_patch16_224
+from .vig import vig_b_patch16_224
+
+__all__ = [
+    "vig_ti_patch16_224",
+    "vig_s_patch16_224",
+    "vig_b_patch16_224",
+]
diff --git a/research/cv/ViG/src/models/vig/__init__.py b/research/cv/ViG/src/models/vig/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc45c92aa413607a179f1a2e1d64954c91fa4401
--- /dev/null
+++ b/research/cv/ViG/src/models/vig/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""import vig models"""
+from .vig import vig_ti_patch16_224
+from .vig import vig_s_patch16_224
+from .vig import vig_b_patch16_224
diff --git a/research/cv/ViG/src/models/vig/gcn_lib.py b/research/cv/ViG/src/models/vig/gcn_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..41d8aac743b28a4996053f8c2c72f5239d9b2cf6
--- /dev/null
+++ b/research/cv/ViG/src/models/vig/gcn_lib.py
@@ -0,0 +1,82 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"graph conv functions for vig"
+import numpy as np
+import mindspore.nn as nn
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import ops
+
+
+def pairwise_distance(x):
+    """
+    Compute pairwise distance.
+    """
+    x_inner = -2 * ops.matmul(x, ops.Transpose()(x, (0, 2, 1)))
+    x_square = ops.ReduceSum(True)(ops.mul(x, x), axis=-1)
+    return x_square + x_inner + ops.Transpose()(x_square, (0, 2, 1))
+
+
+def dense_knn_matrix(x, k=16):
+    """Get kNN based on the pairwise distance.
+    """
+    x = ops.Transpose()(x, (0, 2, 1, 3)).squeeze(-1)
+    batch_size, n_points, _ = x.shape
+    dist = pairwise_distance(x)
+    _, nn_idx = ops.TopK()(-dist, k)
+    center_idx = Tensor(np.arange(0, n_points), mstype.int32)
+    center_idx = ops.Tile()(center_idx, (batch_size, k, 1))
+    center_idx = ops.Transpose()(center_idx, (0, 2, 1))
+    return ops.Stack(axis=0)((nn_idx, center_idx))
+
+
+def batched_index_select(x, idx):
+    """fetches neighbors features from a given neighbor idx.
+    """
+    batch_size, num_dims, num_vertices = x.shape[:3]
+    k = idx.shape[-1]
+    idx_base = Tensor(np.arange(0, batch_size), mstype.int32).view(-1, 1, 1) * num_vertices
+    idx = idx + idx_base
+    idx = idx.view(-1)
+
+    x = ops.Transpose()(x, (0, 2, 1, 3))
+    feature = x.view(batch_size * num_vertices, -1)[idx, :]
+    feature = feature.view(batch_size, num_vertices, k, num_dims)
+    feature = ops.Transpose()(feature, (0, 3, 1, 2))
+    return feature
+
+
+class MRGraphConv2d(nn.Cell):
+    """Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751).
+    """
+
+    def __init__(self, in_channels, out_channels, k=9, dilation=1, bias=True):
+        super(MRGraphConv2d, self).__init__()
+        self.k = k
+        self.dilation = dilation
+        self.nn = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1, group=4, has_bias=bias)
+
+    def construct(self, x):
+        b, c, h, w = x.shape
+        x = x.view(b, c, -1, 1)
+
+        edge_index = dense_knn_matrix(x, self.k)
+        edge_index = edge_index[:, :, :, ::self.dilation]
+
+        x_i = batched_index_select(x, edge_index[1])
+        x_j = batched_index_select(x, edge_index[0])
+        x_j = ops.ReduceMax(True)(x_j - x_i, -1)
+        x = ops.Concat(axis=1)([x, x_j])
+        return self.nn(x).view(b, -1, h, w)
diff --git a/research/cv/ViG/src/models/vig/misc.py b/research/cv/ViG/src/models/vig/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..95b4007cc652365e5fbdfaf1cb7fb54b3bbda2e8
--- /dev/null
+++ b/research/cv/ViG/src/models/vig/misc.py
@@ -0,0 +1,84 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"misc functions for vig"
+import collections.abc
+from itertools import repeat
+
+import numpy as np
+import mindspore.nn as nn
+from mindspore import Tensor
+from mindspore import dtype as mstype
+from mindspore import ops
+from scipy.stats import truncnorm
+
+
+def trunc_array(shape, sigma=0.02):
+    """output truncnormal array in shape"""
+    return truncnorm.rvs(-2, 2, loc=0, scale=sigma, size=shape, random_state=None)
+
+
+def _ntuple(n):
+    "get _ntuple"
+
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+
+    return parse
+
+
+to_2tuple = _ntuple(2)
+
+
+class Identity(nn.Cell):
+    """Identity"""
+
+    def construct(self, x):
+        return x
+
+
+class DropPath(nn.Cell):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob, ndim):
+        super(DropPath, self).__init__()
+        self.drop = nn.Dropout(keep_prob=1 - drop_prob)
+        shape = (1,) + (1,) * (ndim + 1)
+        self.ndim = ndim
+        self.mask = Tensor(np.ones(shape), dtype=mstype.float32)
+
+    def construct(self, x):
+        if not self.training:
+            return x
+        mask = ops.Tile()(self.mask, (x.shape[0],) + (1,) * (self.ndim + 1))
+        out = self.drop(mask)
+        out = out * x
+        return out
+
+
+class DropPath1D(DropPath):
+    """DropPath1D"""
+
+    def __init__(self, drop_prob):
+        super(DropPath1D, self).__init__(drop_prob=drop_prob, ndim=1)
+
+
+class DropPath2D(DropPath):
+    """DropPath2D"""
+
+    def __init__(self, drop_prob):
+        super(DropPath2D, self).__init__(drop_prob=drop_prob, ndim=2)
diff --git a/research/cv/ViG/src/models/vig/vig.py b/research/cv/ViG/src/models/vig/vig.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeda35dae0c1f22126333a1694177c9466eaa370
--- /dev/null
+++ b/research/cv/ViG/src/models/vig/vig.py
@@ -0,0 +1,248 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Vision GNN (ViG)"""
+import numpy as np
+import mindspore.nn as nn
+import mindspore.ops.operations as P
+from mindspore import Parameter
+from mindspore import Tensor
+from mindspore import dtype as mstype
+
+from .misc import DropPath2D, Identity, trunc_array
+from .gcn_lib import MRGraphConv2d
+
+
+class Grapher(nn.Cell):
+    """Grapher"""
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_path=0.,
+                 k=9, dilation=1):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.SequentialCell([
+            nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, has_bias=False),
+            nn.BatchNorm2d(in_features),
+        ])
+        self.graph_conv = nn.SequentialCell([
+            MRGraphConv2d(in_features, hidden_features, k=k, dilation=dilation),
+            nn.BatchNorm2d(hidden_features),
+            nn.GELU(),
+        ])
+        self.fc2 = nn.SequentialCell([
+            nn.Conv2d(in_channels=hidden_features, out_channels=out_features, kernel_size=1, has_bias=False),
+            nn.BatchNorm2d(out_features),
+        ])
+        self.drop_path = DropPath2D(drop_path) if drop_path > 0. else Identity()
+
+    def construct(self, x):
+        shortcut = x
+        x = self.fc1(x)
+        x = self.graph_conv(x)
+        x = self.fc2(x)
+        x = shortcut + self.drop_path(x)
+        return x
+
+
+class Mlp(nn.Cell):
+    """Mlp"""
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_path=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.SequentialCell([
+            nn.Conv2d(in_channels=in_features, out_channels=hidden_features, kernel_size=1, has_bias=False),
+            nn.BatchNorm2d(hidden_features),
+            nn.GELU(),
+        ])
+        self.fc2 = nn.SequentialCell([
+            nn.Conv2d(in_channels=hidden_features, out_channels=out_features, kernel_size=1, has_bias=False),
+            nn.BatchNorm2d(out_features),
+        ])
+        self.drop_path = DropPath2D(drop_path) if drop_path > 0. else Identity()
+
+    def construct(self, x):
+        shortcut = x
+        x = self.fc1(x)
+        x = self.fc2(x)
+        x = shortcut + self.drop_path(x)
+        return x
+
+
+class Block(nn.Cell):
+    """ ViG Block"""
+
+    def __init__(self, dim, mlp_ratio=4., drop_path=0., act_layer=nn.GELU, k=9, dilation=1):
+        super().__init__()
+        self.grapher = Grapher(dim, hidden_features=int(dim * 2), out_features=dim,
+                               act_layer=act_layer, drop_path=drop_path, k=k, dilation=dilation)
+        self.mlp = Mlp(dim, hidden_features=int(dim * mlp_ratio), out_features=dim,
+                       act_layer=act_layer, drop_path=drop_path)
+
+    def construct(self, x):
+        x = self.grapher(x)
+        x = self.mlp(x)
+        return x
+
+
+class PatchEmbed(nn.Cell):
+    """ Image to Visual Embeddings
+    """
+
+    def __init__(self, dim=768):
+        super().__init__()
+        self.conv1 = nn.SequentialCell([
+            nn.Conv2d(in_channels=3, out_channels=dim//8, kernel_size=3, stride=2,
+                      pad_mode='pad', padding=1, has_bias=False),
+            nn.BatchNorm2d(dim//8),
+            nn.GELU(),
+        ])
+        self.conv2 = nn.SequentialCell([
+            nn.Conv2d(in_channels=dim//8, out_channels=dim//4, kernel_size=3, stride=2,
+                      pad_mode='pad', padding=1, has_bias=False),
+            nn.BatchNorm2d(dim//4),
+            nn.GELU(),
+        ])
+        self.conv3 = nn.SequentialCell([
+            nn.Conv2d(in_channels=dim//4, out_channels=dim//2, kernel_size=3, stride=2,
+                      pad_mode='pad', padding=1, has_bias=False),
+            nn.BatchNorm2d(dim//2),
+            nn.GELU(),
+        ])
+        self.conv4 = nn.SequentialCell([
+            nn.Conv2d(in_channels=dim//2, out_channels=dim, kernel_size=3, stride=2,
+                      pad_mode='pad', padding=1, has_bias=False),
+            nn.BatchNorm2d(dim),
+            nn.GELU(),
+        ])
+        self.conv5 = nn.SequentialCell([
+            nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, stride=1,
+                      pad_mode='pad', padding=1, has_bias=False),
+            nn.BatchNorm2d(dim),
+        ])
+
+    def construct(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = self.conv3(x)
+        x = self.conv4(x)
+        x = self.conv5(x)
+        return x
+
+
+class ViG(nn.Cell):
+    """ ViG (Visioin GNN)
+    """
+
+    def __init__(self, num_classes=1000, dim=768, depth=12, mlp_ratio=4., drop_path_rate=0., k=9, **kwargs):
+        super().__init__()
+        self.num_classes = num_classes
+        self.dim = dim
+
+        self.patch_embed = PatchEmbed(dim)
+
+        self.pos_embed = Parameter(Tensor(trunc_array([1, dim, 14, 14]), dtype=mstype.float32),
+                                   name="pos_embed")
+
+        dpr = [x for x in np.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        num_knn = [int(x) for x in np.linspace(k, 2 * k, depth)]  # number of knn's k
+        max_dilation = 196 // max(num_knn)
+        blocks = []
+        for i in range(depth):
+            blocks.append(Block(
+                dim, mlp_ratio=mlp_ratio, drop_path=dpr[i], k=num_knn[i], dilation=min(max_dilation, i//4+1)))
+        self.blocks = nn.CellList(blocks)
+
+        # Classifier head
+        self.head = nn.SequentialCell([
+            nn.Conv2d(in_channels=dim, out_channels=1024, kernel_size=1, has_bias=True),
+            nn.BatchNorm2d(1024),
+            nn.GELU(),
+            nn.Conv2d(in_channels=1024, out_channels=num_classes, kernel_size=1, has_bias=True),
+        ])
+
+        self.init_weights()
+        print("================================success================================")
+
+    def init_weights(self):
+        """init_weights"""
+        for _, m in self.cells_and_names():
+            if isinstance(m, (nn.Conv2d)):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
+                                                          m.weight.data.shape).astype("float32")))
+                if m.bias is not None:
+                    m.bias.set_data(
+                        Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.gamma.set_data(
+                    Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
+                m.beta.set_data(
+                    Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
+            elif isinstance(m, nn.Dense):
+                m.weight.set_data(Tensor(np.random.normal(
+                    0, 0.01, m.weight.data.shape).astype("float32")))
+                if m.bias is not None:
+                    m.bias.set_data(
+                        Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
+
+    def forward_features(self, x):
+        """ViG forward_features"""
+        x = self.patch_embed(x) + self.pos_embed
+
+        for blk in self.blocks:
+            x = blk(x)
+
+        x = P.ReduceMean(True)(x, [2, 3])
+        return x
+
+    def construct(self, x):
+        x = self.forward_features(x)
+        x = self.head(x).squeeze(-1).squeeze(-1)
+        return x
+
+
+def vig_ti_patch16_224(args):
+    """vig_ti_patch16_224"""
+    num_classes = args.num_classes
+    dim = 192
+    depth = 12
+    mlp_ratio = 4
+    drop_path_rate = args.drop_path_rate
+    model = ViG(num_classes, dim, depth, mlp_ratio, drop_path_rate, k=9)
+    return model
+
+
+def vig_s_patch16_224(args):
+    """vig_s_patch16_224"""
+    num_classes = args.num_classes
+    dim = 320
+    depth = 16
+    mlp_ratio = 4
+    drop_path_rate = args.drop_path_rate
+    model = ViG(num_classes, dim, depth, mlp_ratio, drop_path_rate, k=9)
+    return model
+
+
+def vig_b_patch16_224(args):
+    """vig_b_patch16_224"""
+    num_classes = args.num_classes
+    dim = 640
+    depth = 16
+    mlp_ratio = 4
+    drop_path_rate = args.drop_path_rate
+    model = ViG(num_classes, dim, depth, mlp_ratio, drop_path_rate, k=9)
+    return model
diff --git a/research/cv/ViG/src/tools/__init__.py b/research/cv/ViG/src/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/research/cv/ViG/src/tools/callback.py b/research/cv/ViG/src/tools/callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ce83c3d1dcac3293acc958f7d9decfc9984cd23
--- /dev/null
+++ b/research/cv/ViG/src/tools/callback.py
@@ -0,0 +1,48 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""callback function"""
+
+from mindspore.train.callback import Callback
+
+from src.args import args
+
+
+class EvaluateCallBack(Callback):
+    """EvaluateCallBack"""
+
+    def __init__(self, model, eval_dataset, src_url, train_url, save_freq=50):
+        super(EvaluateCallBack, self).__init__()
+        self.model = model
+        self.eval_dataset = eval_dataset
+        self.src_url = src_url
+        self.train_url = train_url
+        self.save_freq = save_freq
+        self.best_acc = 0.
+
+    def epoch_end(self, run_context):
+        """
+            Test when epoch end, save best model with best.ckpt.
+        """
+        cb_params = run_context.original_args()
+        cur_epoch_num = cb_params.cur_epoch_num
+        result = self.model.eval(self.eval_dataset)
+        if result["acc"] > self.best_acc:
+            self.best_acc = result["acc"]
+        print("epoch: %s acc: %s, best acc is %s" %
+              (cb_params.cur_epoch_num, result["acc"], self.best_acc), flush=True)
+        if args.run_modelarts:
+            import moxing as mox
+            if cur_epoch_num % self.save_freq == 0:
+                mox.file.copy_parallel(src_url=self.src_url, dst_url=self.train_url)
diff --git a/research/cv/ViG/src/tools/cell.py b/research/cv/ViG/src/tools/cell.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a7ee3252063987778f7c2d5a3394d74e8a31c3b
--- /dev/null
+++ b/research/cv/ViG/src/tools/cell.py
@@ -0,0 +1,56 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Functions of cells"""
+import mindspore.nn as nn
+from mindspore import dtype as mstype
+from mindspore.ops import functional as F
+
+from src.args import args
+
+
+class OutputTo16(nn.Cell):
+    "Wrap cell for amp. Cast network output back to float16"
+
+    def __init__(self, op):
+        super(OutputTo16, self).__init__(auto_prefix=False)
+        self._op = op
+
+    def construct(self, x):
+        return F.cast(self._op(x), mstype.float16)
+
+
+def do_keep_fp16(network, cell_types):
+    """Cast cell to fp32 if cell in cell_types"""
+    for _, cell in network.cells_and_names():
+        if isinstance(cell, cell_types):
+            cell.to_float(mstype.float16)
+
+
+def cast_amp(net):
+    """cast network amp_level"""
+    if args.amp_level == "O2":
+        cell_types = (nn.Dense,)
+        print(f"=> using amp_level {args.amp_level}\n"
+              f"=> change {args.arch}'s {cell_types}to fp16")
+        do_keep_fp16(net, cell_types)
+    elif args.amp_level == "O3":
+        print(f"=> using amp_level {args.amp_level}\n"
+              f"=> change {args.arch} to fp16")
+        net.to_float(mstype.float16)
+    else:
+        print(f"=> using amp_level {args.amp_level}")
+        args.loss_scale = 1.
+        args.is_dynamic_loss_scale = 0
+        print(f"=> When amp_level is O0, using fixed loss_scale with {args.loss_scale}")
diff --git a/research/cv/ViG/src/tools/criterion.py b/research/cv/ViG/src/tools/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eef93734155608378ba8ac43e8ed71de7fde0b3
--- /dev/null
+++ b/research/cv/ViG/src/tools/criterion.py
@@ -0,0 +1,93 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""functions of criterion"""
+import mindspore.nn as nn
+from mindspore import Tensor
+from mindspore import ops
+from mindspore.common import dtype as mstype
+from mindspore.nn.loss.loss import LossBase
+from mindspore.ops import functional as F
+from mindspore.ops import operations as P
+
+
+class SoftTargetCrossEntropy(LossBase):
+    """SoftTargetCrossEntropy for MixUp Augment"""
+
+    def __init__(self):
+        super(SoftTargetCrossEntropy, self).__init__()
+        self.mean_ops = P.ReduceMean(keep_dims=False)
+        self.sum_ops = P.ReduceSum(keep_dims=False)
+        self.log_softmax = P.LogSoftmax()
+
+    def construct(self, logit, label):
+        logit = P.Cast()(logit, mstype.float32)
+        label = P.Cast()(label, mstype.float32)
+        loss = self.sum_ops(-label * self.log_softmax(logit), -1)
+        return self.mean_ops(loss)
+
+
+class CrossEntropySmooth(LossBase):
+    """CrossEntropy"""
+
+    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
+        super(CrossEntropySmooth, self).__init__()
+        self.onehot = P.OneHot()
+        self.sparse = sparse
+        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
+        self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
+        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
+        self.cast = ops.Cast()
+
+    def construct(self, logit, label):
+        if self.sparse:
+            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
+        loss2 = self.ce(logit, label)
+        return loss2
+
+
+def get_criterion(args):
+    """Get loss function from args.label_smooth and args.mix_up"""
+    assert args.label_smoothing >= 0. and args.label_smoothing <= 1.
+
+    if args.mix_up > 0. or args.cutmix > 0.:
+        print(25 * "=" + "Using MixBatch" + 25 * "=")
+        # smoothing is handled with mixup label transform
+        criterion = SoftTargetCrossEntropy()
+    elif args.label_smoothing > 0.:
+        print(25 * "=" + "Using label smoothing" + 25 * "=")
+        criterion = CrossEntropySmooth(sparse=True, reduction="mean",
+                                       smooth_factor=args.label_smoothing,
+                                       num_classes=args.num_classes)
+    else:
+        print(25 * "=" + "Using Simple CE" + 25 * "=")
+        criterion = CrossEntropySmooth(sparse=True, reduction="mean", num_classes=args.num_classes)
+
+    return criterion
+
+
+class NetWithLoss(nn.Cell):
+    """
+       NetWithLoss: Only support Network with Classfication
+    """
+
+    def __init__(self, model, criterion):
+        super(NetWithLoss, self).__init__()
+        self.model = model
+        self.criterion = criterion
+
+    def construct(self, data, label):
+        predict = self.model(data)
+        loss = self.criterion(predict, label)
+        return loss
diff --git a/research/cv/ViG/src/tools/get_misc.py b/research/cv/ViG/src/tools/get_misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..cde4ff72f197cfeb813a6748f1676735ab90b1a0
--- /dev/null
+++ b/research/cv/ViG/src/tools/get_misc.py
@@ -0,0 +1,120 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""misc functions for program"""
+import os
+
+from mindspore import context
+from mindspore import nn
+from mindspore.communication.management import init, get_rank
+from mindspore.context import ParallelMode
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+
+from src import models, data
+from src.data.data_utils.moxing_adapter import sync_data
+from src.trainers import TrainClipGrad
+
+
+def set_device(args):
+    """Set device and ParallelMode(if device_num > 1)"""
+    rank = 0
+    # set context and device
+    device_target = args.device_target
+    device_num = int(os.environ.get("DEVICE_NUM", 1))
+
+    if device_target == "Ascend":
+        if device_num > 1:
+            context.set_context(device_id=int(os.environ["DEVICE_ID"]))
+            init(backend_name='hccl')
+            context.reset_auto_parallel_context()
+            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True)
+
+            rank = get_rank()
+        else:
+            context.set_context(device_id=args.device_id)
+    elif device_target == "GPU":
+        if device_num > 1:
+            init(backend_name='nccl')
+            context.reset_auto_parallel_context()
+            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True)
+            rank = get_rank()
+        else:
+            context.set_context(device_id=args.device_id)
+    else:
+        raise ValueError("Unsupported platform.")
+
+    return rank
+
+
+def get_dataset(args, training=True):
+    """"Get model according to args.set"""
+    print(f"=> Getting {args.set} dataset")
+    dataset = getattr(data, args.set)(args, training)
+
+    return dataset
+
+
+def get_model(args):
+    """"Get model according to args.arch"""
+    print("==> Creating model '{}'".format(args.arch))
+    model = models.__dict__[args.arch](args)
+
+    return model
+
+
+def pretrained(args, model):
+    """"Load pretrained weights if args.pretrained is given"""
+    if args.run_modelarts:
+        print('Download data.')
+        local_data_path = '/cache/weight'
+        name = args.pretrained.split('/')[-1]
+        path = f"/".join(args.pretrained.split("/")[:-1])
+        sync_data(path, local_data_path, threads=128)
+        args.pretrained = os.path.join(local_data_path, name)
+        print("=> loading pretrained weights from '{}'".format(args.pretrained))
+        param_dict = load_checkpoint(args.pretrained)
+        for key, value in param_dict.copy().items():
+            if 'head' in key:
+                if value.shape[0] != args.num_classes:
+                    print(f'==> removing {key} with shape {value.shape}')
+                    param_dict.pop(key)
+        load_param_into_net(model, param_dict)
+    elif os.path.isfile(args.pretrained):
+        print("=> loading pretrained weights from '{}'".format(args.pretrained))
+        param_dict = load_checkpoint(args.pretrained)
+        for key, value in param_dict.copy().items():
+            if 'head' in key:
+                if value.shape[0] != args.num_classes:
+                    print(f'==> removing {key} with shape {value.shape}')
+                    param_dict.pop(key)
+        load_param_into_net(model, param_dict)
+    else:
+        print("=> no pretrained weights found at '{}'".format(args.pretrained))
+
+
+def get_train_one_step(args, net_with_loss, optimizer):
+    """get_train_one_step cell"""
+    if args.is_dynamic_loss_scale:
+        print(f"=> Using DynamicLossScaleUpdateCell")
+        scale_sense = nn.wrap.loss_scale.DynamicLossScaleUpdateCell(loss_scale_value=2 ** 24, scale_factor=2,
+                                                                    scale_window=2000)
+    else:
+        print(f"=> Using FixedLossScaleUpdateCell, loss_scale_value:{args.loss_scale}")
+        scale_sense = nn.wrap.FixedLossScaleUpdateCell(loss_scale_value=args.loss_scale)
+    net_with_loss = TrainClipGrad(net_with_loss, optimizer, scale_sense=scale_sense,
+                                  clip_global_norm_value=args.clip_global_norm_value,
+                                  use_global_norm=True)
+    return net_with_loss
diff --git a/research/cv/ViG/src/tools/optimizer.py b/research/cv/ViG/src/tools/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..14c5519984935c6bd8f89d6e1d7c4ffea6855dcd
--- /dev/null
+++ b/research/cv/ViG/src/tools/optimizer.py
@@ -0,0 +1,84 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Functions of optimizer"""
+import os
+
+from mindspore.nn.optim import AdamWeightDecay
+from mindspore.nn.optim.momentum import Momentum
+
+from .schedulers import get_policy
+
+
+def get_learning_rate(args, batch_num):
+    """Get learning rate"""
+    return get_policy(args.lr_scheduler)(args, batch_num)
+
+
+def get_optimizer(args, model, batch_num):
+    """Get optimizer for training"""
+    print(f"=> When using train_wrapper, using optimizer {args.optimizer}")
+    args.start_epoch = int(args.start_epoch)
+    optim_type = args.optimizer.lower()
+    params = get_param_groups(model)
+    learning_rate = get_learning_rate(args, batch_num)
+    step = int(args.start_epoch * batch_num)
+    accumulation_step = int(args.accumulation_step)
+    learning_rate = learning_rate[step::accumulation_step]
+    train_step = len(learning_rate)
+    print(f"=> Get LR from epoch: {args.start_epoch}\n"
+          f"=> Start step: {step}\n"
+          f"=> Total step: {train_step}\n"
+          f"=> Accumulation step:{accumulation_step}")
+    learning_rate = learning_rate * args.batch_size * int(os.getenv("DEVICE_NUM", args.device_num)) / 512.
+    if accumulation_step > 1:
+        learning_rate = learning_rate * accumulation_step
+
+    if optim_type == "momentum":
+        optim = Momentum(
+            params=params,
+            learning_rate=learning_rate,
+            momentum=args.momentum,
+            weight_decay=args.weight_decay
+        )
+    elif optim_type == "adamw":
+        optim = AdamWeightDecay(
+            params=params,
+            learning_rate=learning_rate,
+            beta1=args.beta[0],
+            beta2=args.beta[1],
+            eps=args.eps,
+            weight_decay=args.weight_decay
+        )
+    else:
+        raise ValueError(f"optimizer {optim_type} is not supported")
+
+    return optim
+
+
+def get_param_groups(network):
+    """ get param groups """
+    decay_params = []
+    no_decay_params = []
+    for x in network.trainable_params():
+        parameter_name = x.name
+        if parameter_name.endswith(".weight"):
+            # Dense or Conv's weight using weight decay
+            decay_params.append(x)
+        else:
+            # all bias not using weight decay
+            # bn weight bias not using weight decay, be carefully for now x not include LN
+            no_decay_params.append(x)
+
+    return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
diff --git a/research/cv/ViG/src/tools/schedulers.py b/research/cv/ViG/src/tools/schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b67307679b062b9db822442d7602da7b4dc9618
--- /dev/null
+++ b/research/cv/ViG/src/tools/schedulers.py
@@ -0,0 +1,112 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""LearningRate scheduler functions"""
+import numpy as np
+
+__all__ = ["multistep_lr", "cosine_lr", "constant_lr", "get_policy", "exp_lr"]
+
+
+def get_policy(name):
+    """get lr policy from name"""
+    if name is None:
+        return constant_lr
+
+    out_dict = {
+        "constant_lr": constant_lr,
+        "cosine_lr": cosine_lr,
+        "multistep_lr": multistep_lr,
+        "exp_lr": exp_lr,
+    }
+
+    return out_dict[name]
+
+
+def constant_lr(args, batch_num):
+    """Get constant lr"""
+    learning_rate = []
+
+    def _lr_adjuster(epoch):
+        if epoch < args.warmup_length:
+            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
+        else:
+            lr = args.base_lr
+
+        return lr
+
+    for epoch in range(args.epochs):
+        for batch in range(batch_num):
+            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
+    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
+    return learning_rate
+
+
+def exp_lr(args, batch_num):
+    """Get exp lr """
+    learning_rate = []
+
+    def _lr_adjuster(epoch):
+        if epoch < args.warmup_length:
+            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
+        else:
+            lr = args.base_lr * args.lr_gamma ** epoch
+
+        return lr
+
+    for epoch in range(args.epochs):
+        for batch in range(batch_num):
+            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
+    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
+    return learning_rate
+
+
+def cosine_lr(args, batch_num):
+    """Get cosine lr"""
+    learning_rate = []
+
+    def _lr_adjuster(epoch):
+        if epoch < args.warmup_length:
+            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
+        else:
+            e = epoch - args.warmup_length
+            es = args.epochs - args.warmup_length
+            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * args.base_lr
+
+        return lr
+
+    for epoch in range(args.epochs):
+        for batch in range(batch_num):
+            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
+    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
+    return learning_rate
+
+
+def multistep_lr(args, batch_num):
+    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
+    learning_rate = []
+
+    def _lr_adjuster(epoch):
+        lr = args.base_lr * (args.lr_gamma ** (epoch / args.lr_adjust))
+        return lr
+
+    for epoch in range(args.epochs):
+        for batch in range(batch_num):
+            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
+    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
+    return learning_rate
+
+
+def _warmup_lr(warmup_lr, base_lr, warmup_length, epoch):
+    """Linear warmup"""
+    return epoch / warmup_length * (base_lr - warmup_lr) + warmup_lr
diff --git a/research/cv/ViG/src/trainers/__init__.py b/research/cv/ViG/src/trainers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5fecee77553f3cdcb049cc30b7a014f59fca3b4
--- /dev/null
+++ b/research/cv/ViG/src/trainers/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""init train one step"""
+from .train_one_step_with_scale_and_clip_global_norm \
+    import TrainOneStepWithLossScaleCellGlobalNormClip as TrainClipGrad
diff --git a/research/cv/ViG/src/trainers/train_one_step_with_scale_and_clip_global_norm.py b/research/cv/ViG/src/trainers/train_one_step_with_scale_and_clip_global_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f7b6aef92438e192560c2a29bf48641aedf4b31
--- /dev/null
+++ b/research/cv/ViG/src/trainers/train_one_step_with_scale_and_clip_global_norm.py
@@ -0,0 +1,87 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""TrainOneStepWithLossScaleCellGlobalNormClip"""
+import mindspore.nn as nn
+from mindspore.common import RowTensor
+from mindspore.ops import composite as C
+from mindspore.ops import functional as F
+from mindspore.ops import operations as P
+
+_grad_scale = C.MultitypeFuncGraph("grad_scale")
+reciprocal = P.Reciprocal()
+
+
+@_grad_scale.register("Tensor", "Tensor")
+def tensor_grad_scale(scale, grad):
+    return grad * F.cast(reciprocal(scale), F.dtype(grad))
+
+
+@_grad_scale.register("Tensor", "RowTensor")
+def tensor_grad_scale_row_tensor(scale, grad):
+    return RowTensor(grad.indices,
+                     grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
+                     grad.dense_shape)
+
+
+_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
+grad_overflow = P.FloatStatus()
+
+
+class TrainOneStepWithLossScaleCellGlobalNormClip(nn.TrainOneStepWithLossScaleCell):
+    """
+    Encapsulation class of SSD network training.
+
+    Append an optimizer to the training network after that the construct
+    function can be called to create the backward graph.
+
+    Args:
+        network (Cell): The training network. Note that loss function should have been added.
+        optimizer (Optimizer): Optimizer for updating the weights.
+        sens (Number): The adjust parameter. Default: 1.0.
+        use_global_nrom(bool): Whether apply global norm before optimizer. Default: False
+    """
+
+    def __init__(self, network, optimizer,
+                 scale_sense=1.0, use_global_norm=True,
+                 clip_global_norm_value=1.0):
+        super(TrainOneStepWithLossScaleCellGlobalNormClip, self).__init__(network, optimizer, scale_sense)
+        self.use_global_norm = use_global_norm
+        self.clip_global_norm_value = clip_global_norm_value
+        self.print = P.Print()
+
+    def construct(self, *inputs):
+        """construct"""
+        weights = self.weights
+        loss = self.network(*inputs)
+        scaling_sens = self.scale_sense
+
+        status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
+
+        scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
+        grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
+        grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
+        # apply grad reducer on grads
+        grads = self.grad_reducer(grads)
+        # get the overflow buffer
+        cond = self.get_overflow_status(status, grads)
+        overflow = self.process_loss_scale(cond)
+        # if there is no overflow, do optimize
+        if not overflow:
+            if self.use_global_norm:
+                grads = C.clip_by_global_norm(grads, clip_norm=self.clip_global_norm_value)
+            self.optimizer(grads)
+        else:
+            self.print("=============Over Flow, skipping=============")
+        return loss
diff --git a/research/cv/ViG/train.py b/research/cv/ViG/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc98369bd35679bfc7f29ce73a597114dccefe98
--- /dev/null
+++ b/research/cv/ViG/train.py
@@ -0,0 +1,92 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""train"""
+import os
+
+from mindspore import Model
+from mindspore import context
+from mindspore import nn
+from mindspore.common import set_seed
+from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
+
+from src.args import args
+from src.tools.callback import EvaluateCallBack
+from src.tools.cell import cast_amp
+from src.tools.criterion import get_criterion, NetWithLoss
+from src.tools.get_misc import get_dataset, set_device, get_model, pretrained, get_train_one_step
+from src.tools.optimizer import get_optimizer
+
+
+def main():
+    set_seed(args.seed)
+    mode = {
+        0: context.GRAPH_MODE,
+        1: context.PYNATIVE_MODE
+    }
+    context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
+    context.set_context(enable_graph_kernel=False)
+    if args.device_target == "Ascend":
+        context.set_context(enable_auto_mixed_precision=True)
+    rank = set_device(args)
+
+    # get model and cast amp_level
+    net = get_model(args)
+    cast_amp(net)
+    criterion = get_criterion(args)
+    net_with_loss = NetWithLoss(net, criterion)
+    if args.pretrained:
+        pretrained(args, net)
+
+    data = get_dataset(args)
+    batch_num = data.train_dataset.get_dataset_size()
+    optimizer = get_optimizer(args, net, batch_num)
+    # save a yaml file to read to record parameters
+
+    net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
+
+    eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
+    eval_indexes = [0, 1, 2]
+    model = Model(net_with_loss, metrics={"acc", "loss"},
+                  eval_network=eval_network,
+                  eval_indexes=eval_indexes)
+
+    config_ck = CheckpointConfig(save_checkpoint_steps=data.train_dataset.get_dataset_size(),
+                                 keep_checkpoint_max=args.save_every)
+    time_cb = TimeMonitor(data_size=data.train_dataset.get_dataset_size())
+
+    ckpt_save_dir = "./ckpt_" + str(rank)
+    if args.run_modelarts:
+        ckpt_save_dir = "/cache/ckpt_" + str(rank)
+
+    ckpoint_cb = ModelCheckpoint(prefix=args.arch + str(rank), directory=ckpt_save_dir,
+                                 config=config_ck)
+    loss_cb = LossMonitor()
+    eval_cb = EvaluateCallBack(model, eval_dataset=data.val_dataset, src_url=ckpt_save_dir,
+                               train_url=os.path.join(args.train_url, "ckpt_" + str(rank)),
+                               save_freq=args.save_every)
+
+    print("begin train")
+    model.train(int(args.epochs - args.start_epoch), data.train_dataset,
+                callbacks=[time_cb, ckpoint_cb, loss_cb, eval_cb],
+                dataset_sink_mode=True)
+    print("train success")
+
+    if args.run_modelarts:
+        import moxing as mox
+        mox.file.copy_parallel(src_url=ckpt_save_dir, dst_url=os.path.join(args.train_url, "ckpt_" + str(rank)))
+
+
+if __name__ == '__main__':
+    main()