diff --git a/research/cv/CMT/README.md b/research/cv/CMT/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f308db188bb9351dda3e35f61b718497fc76d13c
--- /dev/null
+++ b/research/cv/CMT/README.md
@@ -0,0 +1,91 @@
+# Contents
+
+- [Contents](#contents)
+    - [CMT Description](#cmt-description)
+    - [Model architecture](#model-architecture)
+    - [Dataset](#dataset)
+    - [Environment Requirements](#environment-requirements)
+    - [Script description](#script-description)
+        - [Script and sample code](#script-and-sample-code)
+    - [Eval process](#eval-process)
+        - [Usage](#usage)
+        - [Launch](#launch)
+        - [Result](#result)
+    - [Description of Random Situation](#description-of-random-situation)
+    - [ModelZoo Homepage](#modelzoo-homepage)
+
+## [CMT Description](#contents)
+
+  This paper aims to develop a network that can outperform not only the canonical transformers, but also the high-performance convolutional models. We propose a new transformer based hybrid network by taking advantage of transformers to capture long-range dependencies, and of CNNs to model local features. Furthermore, we scale it to obtain a family of models, called CMTs, obtaining much better accuracy and efficiency than previous convolution and transformer based models.
+
+[Paper](https://arxiv.org/pdf/2107.06263.pdf): Jianyuan Guo, Kai Han, Han Wu, Chang Xu, Yehui Tang, Chunjing Xu, Yunhe Wang. CMT: Convolutional Neural Networks Meet Vision Transformers. Accepted in CVPR 2022.
+
+## [Model architecture](#contents)
+
+A block of CMT is shown below:
+
+![image-20211026160438718](./fig/CMT.PNG)
+
+## [Dataset](#contents)
+
+Dataset used: [ImageNet2012]
+
+- Dataset size 224*224 colorful images in 1000 classes
+    - Train:1,281,167 images  
+    - Test: 50,000 images
+- Data format:jpeg
+    - Note:Data will be processed in dataset.py
+
+## [Environment Requirements](#contents)
+
+- Hardware(Ascend/GPU)
+    - Prepare hardware environment with Ascend or GPU.
+- Framework
+    - [MindSpore](https://www.mindspore.cn/install/en)
+- For more information, please check the resources below£º
+    - [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
+    - [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
+
+## [Script description](#contents)
+
+### [Script and sample code](#contents)
+
+```bash
+CMT
+├── eval.py        # inference entry
+├── fig
+│   └── CMT.PNG    # the illustration of CMT network
+├── readme.md      # Readme
+└── src
+    ├── dataset.py # dataset loader
+    └── cmt.py     # CMT network
+```
+
+## [Eval process](#contents)
+
+### Usage
+
+After installing MindSpore via the official website, you can start evaluation as follows:
+
+### Launch
+
+```bash
+# CMT infer example
+  GPU: python eval.py --model cmt --dataset_path dataset_path --platform GPU --checkpoint_path [CHECKPOINT_PATH]
+```
+
+> checkpoint can be downloaded at https://download.mindspore.cn/model_zoo/.
+
+### Result
+
+```bash
+result: {'acc': 0.832} ckpt= ./cmt_s_ms.ckpt
+```
+
+## [Description of Random Situation](#contents)
+
+In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py.
+
+## [ModelZoo Homepage](#contents)
+
+Please check the official [homepage](https://gitee.com/mindspore/models).
\ No newline at end of file
diff --git a/research/cv/CMT/eval.py b/research/cv/CMT/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..229138a80317c3af1c59cdc855841e778728c927
--- /dev/null
+++ b/research/cv/CMT/eval.py
@@ -0,0 +1,63 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+eval.
+"""
+import os
+import argparse
+from mindspore import context
+from mindspore import nn
+from mindspore.train.model import Model
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from src.dataset import create_dataset
+from src.cmt import cmt_s
+
+parser = argparse.ArgumentParser(description='Image classification')
+parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
+parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
+parser.add_argument('--platform', type=str, default='Ascend', help='run platform')
+parser.add_argument('--model', type=str, default='cmt', help='eval model')
+args_opt = parser.parse_args()
+
+
+if __name__ == '__main__':
+    config_platform = None
+    if args_opt.platform == "Ascend":
+        device_id = int(os.getenv('DEVICE_ID'))
+        context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
+                            device_id=device_id, save_graphs=False)
+    elif args_opt.platform == "GPU":
+        context.set_context(mode=context.GRAPH_MODE,
+                            device_target=args_opt.platform, save_graphs=False)
+    else:
+        raise ValueError("Unsupported platform.")
+
+    if args_opt.model == 'cmt':
+        net = cmt_s()
+    else:
+        raise ValueError("Unsupported model.")
+
+    if args_opt.checkpoint_path:
+        param_dict = load_checkpoint(args_opt.checkpoint_path)
+        load_param_into_net(net, param_dict)
+    net.set_train(False)
+
+    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+
+    dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=128)
+
+    model = Model(net, loss_fn=loss, metrics={'acc'})
+    res = model.eval(dataset, dataset_sink_mode=False)
+    print("result:", res, "ckpt=", args_opt.checkpoint_path)
diff --git a/research/cv/CMT/fig/CMT.PNG b/research/cv/CMT/fig/CMT.PNG
new file mode 100644
index 0000000000000000000000000000000000000000..d08dd3095c353d08c838c596e313e417e0237c5b
Binary files /dev/null and b/research/cv/CMT/fig/CMT.PNG differ
diff --git a/research/cv/CMT/scripts/run_cmt_eval.sh b/research/cv/CMT/scripts/run_cmt_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5b8013332f1848e188e7c431d126085b6f6977a6
--- /dev/null
+++ b/research/cv/CMT/scripts/run_cmt_eval.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ $# -lt 3 ]
+then
+    echo "Usage: bash ./scripts/run_cmt_eval.sh [DATA_PATH] [PLATFORM] [CHECKPOINT_PATH]"
+exit 1
+fi
+
+DATA_PATH=$1
+PLATFORM=$2
+CHECKPOINT_PATH=$3
+
+rm -rf evaluation
+mkdir ./evaluation
+cd ./evaluation || exit
+echo  "start training for device id $DEVICE_ID"
+env > env.log
+python eval.py --model cmt --dataset_path=$DATA_PATH --platform=$PLATFORM --checkpoint_path=$CHECKPOINT_PATH > eval.log 2>&1 &
+cd ../
\ No newline at end of file
diff --git a/research/cv/CMT/src/cmt.py b/research/cv/CMT/src/cmt.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2419680b698dd7c17015103b1646d4e9eb5f17e
--- /dev/null
+++ b/research/cv/CMT/src/cmt.py
@@ -0,0 +1,444 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import collections.abc
+from itertools import repeat
+import mindspore
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore import Parameter
+from mindspore.ops import operations as P
+import mindspore.common.initializer as weight_init
+import numpy as np
+
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .96, 'interpolation': 'bicubic',
+        'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'classifier': 'head',
+        **kwargs
+    }
+
+
+def to_2tuple(x):
+    if isinstance(x, collections.abc.Iterable):
+        return x
+    return tuple(repeat(x, 2))
+
+
+class DropPath(nn.Cell):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None, seed=0):
+        super(DropPath, self).__init__()
+        self.keep_prob = 1 - drop_prob
+        seed = min(seed, 0)  # always be 0
+        # seed must be 0, if set to other value, it's not rand for multiple call
+        self.rand = P.UniformReal(seed=seed)
+        self.shape = P.Shape()
+        self.floor = P.Floor()
+
+    def construct(self, x):
+        if self.training:
+            x_shape = self.shape(x)  # B N C
+            random_tensor = self.rand((x_shape[0], 1, 1))
+            random_tensor = random_tensor + self.keep_prob
+            random_tensor = self.floor(random_tensor)
+            x = x / self.keep_prob
+            x = x * random_tensor
+        return x
+
+
+def swish(x):
+    return x * P.Sigmoid()(x)
+
+
+class Mlp(nn.Cell):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.conv1 = nn.SequentialCell([
+            nn.Conv2d(in_features, hidden_features, 1, 1, has_bias=True),
+            nn.GELU(),
+            nn.BatchNorm2d(hidden_features),
+        ])
+        self.proj = nn.Conv2d(
+            hidden_features, hidden_features, 3, 1, pad_mode='pad', padding=1, group=hidden_features, has_bias=True)
+        self.proj_act = nn.GELU()
+        self.proj_bn = nn.BatchNorm2d(hidden_features)
+        self.conv2 = nn.SequentialCell([
+            nn.Conv2d(hidden_features, out_features, 1, 1, has_bias=True),
+            nn.BatchNorm2d(out_features),
+        ])
+        self.drop = nn.Dropout(1. - drop)
+
+    def construct(self, x, H, W):
+        B, _, C = x.shape
+        x = ops.Transpose()(x, (0, 2, 1)).reshape(B, C, H, W)
+        x = self.conv1(x)
+        x = self.drop(x)
+        x = self.proj(x) + x
+        x = self.proj_act(x)
+        x = self.proj_bn(x)
+        x = self.conv2(x)
+        x = ops.Transpose()(x.reshape(B, C, -1), (0, 2, 1))
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Cell):
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
+                 attn_drop=0., proj_drop=0., qk_ratio=1, sr_ratio=1):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+        self.qk_dim = dim // qk_ratio
+
+        self.q = nn.Dense(dim, self.qk_dim, has_bias=qkv_bias)
+        self.k = nn.Dense(dim, self.qk_dim, has_bias=qkv_bias)
+        self.v = nn.Dense(dim, dim, has_bias=qkv_bias)
+        self.attn_drop = nn.Dropout(1. - attn_drop)
+        self.proj = nn.Dense(dim, dim)
+        self.proj_drop = nn.Dropout(1. - proj_drop)
+
+        self.sr_ratio = sr_ratio
+        # Exactly same as PVTv1
+        if self.sr_ratio > 1:
+            self.sr = nn.SequentialCell([
+                nn.Conv2d(dim, dim, kernel_size=sr_ratio,
+                          stride=sr_ratio, group=dim, has_bias=True),
+                nn.BatchNorm2d(dim),
+            ])
+
+        self.softmax = nn.Softmax(axis=-1)
+
+    def construct(self, x, H, W, relative_pos):
+        B, N, C = x.shape
+        q = self.q(x).reshape(B, N, self.num_heads,
+                              self.qk_dim // self.num_heads)
+        q = ops.Transpose()(q, (0, 2, 1, 3))
+
+        if self.sr_ratio > 1:
+            x_ = ops.Transpose()(x, (0, 2, 1)).reshape(B, C, H, W)
+            x_ = self.sr(x_).reshape(B, C, -1)
+            x_ = ops.Transpose()(x_, (0, 2, 1))
+            k = self.k(x_).reshape(B, -1, self.num_heads,
+                                   self.qk_dim // self.num_heads)
+            k = ops.Transpose()(k, (0, 2, 1, 3))
+            v = self.v(x_).reshape(B, -1, self.num_heads, C // self.num_heads)
+            v = ops.Transpose()(v, (0, 2, 1, 3))
+        else:
+            k = self.k(x).reshape(B, N, self.num_heads,
+                                  self.qk_dim // self.num_heads)
+            k = ops.Transpose()(k, (0, 2, 1, 3))
+            v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads)
+            v = ops.Transpose()(v, (0, 2, 1, 3))
+
+        attn = mindspore.ops.matmul(q, ops.Transpose()(
+            k, (0, 1, 3, 2))) * self.scale + relative_pos
+
+        attn = self.softmax(attn)
+        attn = self.attn_drop(attn)
+        x = mindspore.ops.matmul(attn, v)
+        x = ops.Transpose()(x, (0, 2, 1, 3)).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Cell):
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, qk_ratio=1, sr_ratio=1):
+        super().__init__()
+        self.norm1 = norm_layer([dim])
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+            attn_drop=attn_drop, proj_drop=drop, qk_ratio=qk_ratio, sr_ratio=sr_ratio)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(
+            drop_path) if drop_path > 0. else ops.Identity()
+        self.norm2 = norm_layer([dim])
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
+                       act_layer=act_layer, drop=drop)
+        self.proj = nn.Conv2d(dim, dim, 3, 1, pad_mode='pad',
+                              padding=1, group=dim, has_bias=True)
+
+    def construct(self, x, H, W, relative_pos):
+        B, _, C = x.shape
+        cnn_feat = ops.Transpose()(x, (0, 2, 1)).reshape(B, C, H, W)
+        x = self.proj(cnn_feat) + cnn_feat
+        x = ops.Transpose()(x.reshape(B, C, H*W), (0, 2, 1))
+        x = x + self.drop_path(self.attn(self.norm1(x), H, W, relative_pos))
+        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
+        return x
+
+
+class PatchEmbed(nn.Cell):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        num_patches = (img_size[1] // patch_size[1]) * \
+            (img_size[0] // patch_size[0])
+
+        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
+            f"img_size {img_size} should be divided by patch_size {patch_size}."
+
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
+        self.norm = nn.LayerNorm([embed_dim])
+
+    def construct(self, x):
+        _, _, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        assert H == self.img_size[0] and W == self.img_size[1], \
+            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x)
+        _B, _C, _H, _W = x.shape
+        x = ops.Transpose()(x.reshape(_B, _C, _H*_W), (0, 2, 1))
+        x = self.norm(x)
+
+        H, W = H // self.patch_size[0], W // self.patch_size[1]
+        return x, (H, W)
+
+
+class CMT(nn.Cell):
+    def __init__(self, img_size=224, in_chans=3, num_classes=1000, embed_dims=None, stem_channel=16,
+                 fc_dim=1280, num_heads=None, mlp_ratios=None, qkv_bias=True, qk_scale=None,
+                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
+                 depths=None, qk_ratio=1, sr_ratios=None, dp=0.1):
+        super().__init__()
+        self.num_classes = num_classes
+        self.num_features = self.embed_dim = embed_dims[-1]
+        norm_layer = norm_layer or nn.LayerNorm
+
+        self.stem_conv1 = nn.Conv2d(
+            3, stem_channel, kernel_size=3, stride=2, pad_mode='pad', padding=1, has_bias=True)
+        self.stem_relu1 = nn.GELU()
+        self.stem_norm1 = nn.BatchNorm2d(stem_channel)
+
+        self.stem_conv2 = nn.Conv2d(
+            stem_channel, stem_channel, kernel_size=3, stride=1, pad_mode='pad', padding=1, has_bias=True)
+        self.stem_relu2 = nn.GELU()
+        self.stem_norm2 = nn.BatchNorm2d(stem_channel)
+
+        self.stem_conv3 = nn.Conv2d(
+            stem_channel, stem_channel, kernel_size=3, stride=1, pad_mode='pad', padding=1, has_bias=True)
+        self.stem_relu3 = nn.GELU()
+        self.stem_norm3 = nn.BatchNorm2d(stem_channel)
+
+        self.patch_embed_a = PatchEmbed(
+            img_size=img_size//2, patch_size=2, in_chans=stem_channel, embed_dim=embed_dims[0])
+        self.patch_embed_b = PatchEmbed(
+            img_size=img_size//4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
+        self.patch_embed_c = PatchEmbed(
+            img_size=img_size//8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
+        self.patch_embed_d = PatchEmbed(
+            img_size=img_size//16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
+
+        self.relative_pos_a = P.Zeros()(
+            (num_heads[0], self.patch_embed_a.num_patches, self.patch_embed_a.num_patches//sr_ratios[0]//sr_ratios[0]),
+            mindspore.float32)
+        self.relative_pos_a = Parameter(self.relative_pos_a)
+        self.relative_pos_b = P.Zeros()(
+            (num_heads[1], self.patch_embed_b.num_patches, self.patch_embed_b.num_patches//sr_ratios[1]//sr_ratios[1]),
+            mindspore.float32)
+        self.relative_pos_b = Parameter(self.relative_pos_b)
+        self.relative_pos_c = P.Zeros()(
+            (num_heads[2], self.patch_embed_c.num_patches, self.patch_embed_c.num_patches//sr_ratios[2]//sr_ratios[2]),
+            mindspore.float32)
+        self.relative_pos_c = Parameter(self.relative_pos_c)
+        self.relative_pos_d = P.Zeros()(
+            (num_heads[3], self.patch_embed_d.num_patches, self.patch_embed_d.num_patches//sr_ratios[3]//sr_ratios[3]),
+            mindspore.float32)
+        self.relative_pos_d = Parameter(self.relative_pos_d)
+
+        # stochastic depth decay rule
+        dpr = [x.item() for x in np.linspace(0, drop_path_rate, sum(depths))]
+        cur = 0
+        self.blocks_a = nn.CellList([
+            Block(
+                dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
+                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[
+                    cur+i],
+                norm_layer=norm_layer, qk_ratio=qk_ratio, sr_ratio=sr_ratios[0])
+            for i in range(depths[0])])
+        cur += depths[0]
+        self.blocks_b = nn.CellList([
+            Block(
+                dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias,
+                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[
+                    cur+i],
+                norm_layer=norm_layer, qk_ratio=qk_ratio, sr_ratio=sr_ratios[1])
+            for i in range(depths[1])])
+        cur += depths[1]
+        self.blocks_c = nn.CellList([
+            Block(
+                dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias,
+                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[
+                    cur+i],
+                norm_layer=norm_layer, qk_ratio=qk_ratio, sr_ratio=sr_ratios[2])
+            for i in range(depths[2])])
+        cur += depths[2]
+        self.blocks_d = nn.CellList([
+            Block(
+                dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias,
+                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[
+                    cur+i],
+                norm_layer=norm_layer, qk_ratio=qk_ratio, sr_ratio=sr_ratios[3])
+            for i in range(depths[3])])
+
+        # Classifier head
+        self._fc = nn.Conv2d(
+            embed_dims[-1], fc_dim, kernel_size=1, has_bias=True)
+        self._bn = nn.BatchNorm2d(fc_dim)
+        self._drop = nn.Dropout(1. - dp)
+        self.head = nn.Dense(
+            fc_dim, num_classes) if num_classes > 0 else ops.Identity()
+
+    def _init_weights(self):
+        for _, cell in self.cells_and_names():
+            if isinstance(cell, nn.Dense):
+                cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
+                                                             cell.weight.shape,
+                                                             cell.weight.dtype))
+                if isinstance(cell, nn.Dense) and cell.bias is not None:
+                    cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
+                                                               cell.bias.shape,
+                                                               cell.bias.dtype))
+            elif isinstance(cell, (nn.LayerNorm, nn.BatchNorm2d)):
+                cell.gamma.set_data(weight_init.initializer(weight_init.One(),
+                                                            cell.gamma.shape,
+                                                            cell.gamma.dtype))
+                cell.beta.set_data(weight_init.initializer(weight_init.Zero(),
+                                                           cell.beta.shape,
+                                                           cell.beta.dtype))
+
+    def forward_features(self, x):
+        B = x.shape[0]
+        x = self.stem_conv1(x)
+        x = self.stem_relu1(x)
+        x = self.stem_norm1(x)
+
+        x = self.stem_conv2(x)
+        x = self.stem_relu2(x)
+        x = self.stem_norm2(x)
+
+        x = self.stem_conv3(x)
+        x = self.stem_relu3(x)
+        x = self.stem_norm3(x)
+
+        x, (H, W) = self.patch_embed_a(x)
+        for _, blk in enumerate(self.blocks_a):
+            x = blk(x, H, W, self.relative_pos_a)
+
+        x = ops.Transpose()(x.reshape(B, H, W, -1), (0, 3, 1, 2))
+        x, (H, W) = self.patch_embed_b(x)
+        for _, blk in enumerate(self.blocks_b):
+            x = blk(x, H, W, self.relative_pos_b)
+
+        x = ops.Transpose()(x.reshape(B, H, W, -1), (0, 3, 1, 2))
+        x, (H, W) = self.patch_embed_c(x)
+        for _, blk in enumerate(self.blocks_c):
+            x = blk(x, H, W, self.relative_pos_c)
+
+        x = ops.Transpose()(x.reshape(B, H, W, -1), (0, 3, 1, 2))
+        x, (H, W) = self.patch_embed_d(x)
+        for _, blk in enumerate(self.blocks_d):
+            x = blk(x, H, W, self.relative_pos_d)
+
+        B, _, C = x.shape
+
+        x = self._fc(ops.Transpose()(x, (0, 2, 1)).reshape(B, C, H, W))
+        x = self._bn(x)
+        x = swish(x)
+        x = ops.AdaptiveAvgPool2D(output_size=1)(x).squeeze(2).squeeze(2)
+        x = self._drop(x)
+        return x
+
+    def construct(self, x):
+        x = self.forward_features(x)
+        x = self.head(x)
+        return x
+
+
+def _create_cmt_model(pretrained=False, distilled=False, **kwargs):
+    default_cfg = _cfg()
+    default_num_classes = default_cfg['num_classes']
+    default_img_size = default_cfg['input_size'][-1]
+
+    num_classes = kwargs.pop('num_classes', default_num_classes)
+    img_size = kwargs.pop('img_size', default_img_size)
+
+    model = CMT(img_size=img_size, num_classes=num_classes, **kwargs)
+    model.default_cfg = default_cfg
+
+    return model
+
+
+def cmt_ti(pretrained=False, **kwargs):
+    """
+    CMT-Tiny
+    """
+    model_kwargs = dict(
+        embed_dims=[46, 92, 184, 368], num_heads=[1, 2, 4, 8], mlp_ratios=[3.6, 3.6, 3.6, 3.6],
+        depths=[2, 2, 10, 2], sr_ratios=[8, 4, 2, 1], qkv_bias=True, **kwargs)
+    model = _create_cmt_model(pretrained=pretrained, **model_kwargs)
+    return model
+
+
+def cmt_xs(pretrained=False, **kwargs):
+    """
+    CMT-XS: dim x 0.9, depth x 0.8, input 192
+    """
+    model_kwargs = dict(
+        qkv_bias=True, embed_dims=[52, 104, 208, 416], stem_channel=16, num_heads=[1, 2, 4, 8],
+        depths=[3, 3, 12, 3], mlp_ratios=[3.77, 3.77, 3.77, 3.77], qk_ratio=1, sr_ratios=[8, 4, 2, 1], **kwargs)
+    model = _create_cmt_model(pretrained=pretrained, **model_kwargs)
+    return model
+
+
+def cmt_s(pretrained=False, **kwargs):
+    """
+    CMT-Small
+    """
+    model_kwargs = dict(
+        qkv_bias=True, embed_dims=[64, 128, 256, 512], stem_channel=32, num_heads=[1, 2, 4, 8],
+        depths=[3, 3, 16, 3], mlp_ratios=[4, 4, 4, 4], qk_ratio=1, sr_ratios=[8, 4, 2, 1], **kwargs)
+    model = _create_cmt_model(pretrained=pretrained, **model_kwargs)
+    return model
+
+
+def cmt_b(pretrained=False, **kwargs):
+    """
+    CMT-Base
+    """
+    model_kwargs = dict(
+        qkv_bias=True, embed_dims=[76, 152, 304, 608], stem_channel=38, num_heads=[1, 2, 4, 8],
+        depths=[4, 4, 20, 4], mlp_ratios=[4, 4, 4, 4], qk_ratio=1, sr_ratios=[8, 4, 2, 1], dp=0.3, **kwargs)
+    model = _create_cmt_model(pretrained=pretrained, **model_kwargs)
+    return model
diff --git a/research/cv/CMT/src/dataset.py b/research/cv/CMT/src/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27c6ecb2d2ece99900b800efad9b7879ec0b80b
--- /dev/null
+++ b/research/cv/CMT/src/dataset.py
@@ -0,0 +1,79 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Data operations, will be used in train.py and eval.py"""
+import os
+import mindspore.common.dtype as mstype
+import mindspore.dataset.engine as de
+import mindspore.dataset.transforms.c_transforms as C2
+import mindspore.dataset.vision.py_transforms as pytrans
+
+from mindspore.dataset.transforms.py_transforms import Compose
+import mindspore.dataset.vision.c_transforms as C
+
+
+def create_dataset(dataset_path, do_train, repeat_num=1, infer_910=True, device_id=0, batch_size=128):
+    """
+    create a train or eval dataset
+
+    Args:
+        batch_size:
+        device_id:
+        infer_910:
+        dataset_path(string): the path of dataset.
+        do_train(bool): whether dataset is used for train or eval.
+        rank (int): The shard ID within num_shards (default=None).
+        group_size (int): Number of shards that the dataset should be divided into (default=None).
+        repeat_num(int): the repeat times of dataset. Default: 1.
+
+    Returns:
+        dataset
+    """
+
+    if not do_train:
+        dataset_path = os.path.join(dataset_path, 'val')
+    else:
+        dataset_path = os.path.join(dataset_path, 'train')
+
+    ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=1, shard_id=0)
+
+    mean = [0.485, 0.456, 0.406]
+    std = [0.229, 0.224, 0.225]
+    # define map operations
+    if do_train:
+        trans = [
+            C.RandomCropDecodeResize(224),
+            C.RandomHorizontalFlip(prob=0.5),
+            C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
+        ]
+    else:
+        trans = [
+            pytrans.Decode(),
+            pytrans.Resize(235),
+            pytrans.CenterCrop(224)
+        ]
+    trans += [
+        pytrans.ToTensor(),
+        pytrans.Normalize(mean=mean, std=std),
+    ]
+    trans = Compose(trans)
+
+    type_cast_op = C2.TypeCast(mstype.int32)
+    ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
+    ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
+
+    # apply batch operations
+    ds = ds.batch(batch_size, drop_remainder=True, num_parallel_workers=8)
+    ds = ds.repeat(repeat_num)
+    return ds
diff --git a/research/cv/HireMLP/README.md b/research/cv/HireMLP/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a651bf846d9a382e4d4f82d9d9eea0ce5c3ef662
--- /dev/null
+++ b/research/cv/HireMLP/README.md
@@ -0,0 +1,91 @@
+# Contents
+
+- [Contents](#contents)
+    - [HireMLP Description](#hiremlp-description)
+    - [Model architecture](#model-architecture)
+    - [Dataset](#dataset)
+    - [Environment Requirements](#environment-requirements)
+    - [Script description](#script-description)
+        - [Script and sample code](#script-and-sample-code)
+    - [Eval process](#eval-process)
+        - [Usage](#usage)
+        - [Launch](#launch)
+        - [Result](#result)
+    - [Description of Random Situation](#description-of-random-situation)
+    - [ModelZoo Homepage](#modelzoo-homepage)
+
+## [HireMLP Description](#contents)
+
+  This paper presents Hire-MLP, a simple yet competitive vision MLP architecture via Hierarchical rearrangement, which contains two levels of rearrangements. Specifically, the innerregion rearrangement is proposed to capture local information inside a spatial region, and the cross-region rearrangement is proposed to enable information communication between different regions and capture global context by circular shifting all tokens along spatial directions.
+
+[Paper](https://arxiv.org/pdf/2108.13341.pdf): Jianyuan Guo, Yehui Tang, Kai Han, Xinghao Chen, Han Wu, Chao Xu, Chang Xu, Yunhe Wang. Hire-MLP: Vision MLP via Hierarchical Rearrangement. Accepted in CVPR 2022.
+
+## [Model architecture](#contents)
+
+A block of HireMLP is shown below:
+
+![image-20211026160438718](./fig/HireMLP.PNG)
+
+## [Dataset](#contents)
+
+Dataset used: [ImageNet2012]
+
+- Dataset size 224*224 colorful images in 1000 classes
+    - Train:1,281,167 images  
+    - Test: 50,000 images
+- Data format:jpeg
+    - Note:Data will be processed in dataset.py
+
+## [Environment Requirements](#contents)
+
+- Hardware(Ascend/GPU)
+    - Prepare hardware environment with Ascend or GPU.
+- Framework
+    - [MindSpore](https://www.mindspore.cn/install/en)
+- For more information, please check the resources below£º
+    - [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
+    - [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
+
+## [Script description](#contents)
+
+### [Script and sample code](#contents)
+
+```bash
+HireMLP
+├── eval.py          # inference entry
+├── fig
+│   └── HireMLP.PNG  # the illustration of HireMLP network
+├── readme.md        # Readme
+└── src
+    ├── dataset.py   # dataset loader
+    └── hire_mlp.py  # HireMLP network
+```
+
+## [Eval process](#contents)
+
+### Usage
+
+After installing MindSpore via the official website, you can start evaluation as follows:
+
+### Launch
+
+```bash
+# HireMLP infer example
+  GPU: python eval.py --dataset_path dataset --platform GPU --checkpoint_path [CHECKPOINT_PATH]
+```
+
+> checkpoint can be downloaded at https://download.mindspore.cn/model_zoo/.
+
+### Result
+
+```bash
+result: {'acc': 0.788} ckpt= ./hire_tiny_ms.ckpt
+```
+
+## [Description of Random Situation](#contents)
+
+In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py.
+
+## [ModelZoo Homepage](#contents)
+
+Please check the official [homepage](https://gitee.com/mindspore/models).
\ No newline at end of file
diff --git a/research/cv/HireMLP/eval.py b/research/cv/HireMLP/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5b0c49210a434986767eb34a5730a0de978517c
--- /dev/null
+++ b/research/cv/HireMLP/eval.py
@@ -0,0 +1,60 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+eval.
+"""
+import os
+import argparse
+from mindspore import context
+from mindspore import nn
+from mindspore.train.model import Model
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+from src.dataset import create_dataset
+from src.hire_mlp import hire_mlp_tiny
+
+parser = argparse.ArgumentParser(description='Image classification')
+parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
+parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
+parser.add_argument('--platform', type=str, default='Ascend', help='run platform')
+args_opt = parser.parse_args()
+
+
+if __name__ == '__main__':
+    config_platform = None
+    if args_opt.platform == "Ascend":
+        device_id = int(os.getenv('DEVICE_ID'))
+        context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
+                            device_id=device_id, save_graphs=False)
+    elif args_opt.platform == "GPU":
+        context.set_context(mode=context.GRAPH_MODE,
+                            device_target=args_opt.platform, save_graphs=False)
+    else:
+        raise ValueError("Unsupported platform.")
+
+    # Hire-Tiny for example
+    net = hire_mlp_tiny()
+
+    if args_opt.checkpoint_path:
+        param_dict = load_checkpoint(args_opt.checkpoint_path)
+        load_param_into_net(net, param_dict)
+    net.set_train(False)
+
+    loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
+
+    dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=128)
+
+    model = Model(net, loss_fn=loss, metrics={'acc'})
+    res = model.eval(dataset, dataset_sink_mode=False)
+    print("result:", res, "ckpt=", args_opt.checkpoint_path)
diff --git a/research/cv/HireMLP/fig/HireMLP.PNG b/research/cv/HireMLP/fig/HireMLP.PNG
new file mode 100644
index 0000000000000000000000000000000000000000..08349b71d82a7fc3031db00868a7c828e2a73ee7
Binary files /dev/null and b/research/cv/HireMLP/fig/HireMLP.PNG differ
diff --git a/research/cv/HireMLP/scripts/run_hire_mlp_eval.sh b/research/cv/HireMLP/scripts/run_hire_mlp_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1d72eefbac3c1286c061e9be14b857c71c13bd76
--- /dev/null
+++ b/research/cv/HireMLP/scripts/run_hire_mlp_eval.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [ $# -lt 3 ]
+then
+    echo "Usage: bash ./scripts/run_hire_mlp_eval.sh [DATA_PATH] [PLATFORM] [CHECKPOINT_PATH]"
+exit 1
+fi
+
+DATA_PATH=$1
+PLATFORM=$2
+CHECKPOINT_PATH=$3
+
+rm -rf evaluation
+mkdir ./evaluation
+cd ./evaluation || exit
+echo  "start training for device id $DEVICE_ID"
+env > env.log
+python eval.py --dataset_path=$DATA_PATH --platform=$PLATFORM --checkpoint_path=$CHECKPOINT_PATH > eval.log 2>&1 &
+cd ../
\ No newline at end of file
diff --git a/research/cv/HireMLP/src/dataset.py b/research/cv/HireMLP/src/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27c6ecb2d2ece99900b800efad9b7879ec0b80b
--- /dev/null
+++ b/research/cv/HireMLP/src/dataset.py
@@ -0,0 +1,79 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Data operations, will be used in train.py and eval.py"""
+import os
+import mindspore.common.dtype as mstype
+import mindspore.dataset.engine as de
+import mindspore.dataset.transforms.c_transforms as C2
+import mindspore.dataset.vision.py_transforms as pytrans
+
+from mindspore.dataset.transforms.py_transforms import Compose
+import mindspore.dataset.vision.c_transforms as C
+
+
+def create_dataset(dataset_path, do_train, repeat_num=1, infer_910=True, device_id=0, batch_size=128):
+    """
+    create a train or eval dataset
+
+    Args:
+        batch_size:
+        device_id:
+        infer_910:
+        dataset_path(string): the path of dataset.
+        do_train(bool): whether dataset is used for train or eval.
+        rank (int): The shard ID within num_shards (default=None).
+        group_size (int): Number of shards that the dataset should be divided into (default=None).
+        repeat_num(int): the repeat times of dataset. Default: 1.
+
+    Returns:
+        dataset
+    """
+
+    if not do_train:
+        dataset_path = os.path.join(dataset_path, 'val')
+    else:
+        dataset_path = os.path.join(dataset_path, 'train')
+
+    ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=1, shard_id=0)
+
+    mean = [0.485, 0.456, 0.406]
+    std = [0.229, 0.224, 0.225]
+    # define map operations
+    if do_train:
+        trans = [
+            C.RandomCropDecodeResize(224),
+            C.RandomHorizontalFlip(prob=0.5),
+            C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
+        ]
+    else:
+        trans = [
+            pytrans.Decode(),
+            pytrans.Resize(235),
+            pytrans.CenterCrop(224)
+        ]
+    trans += [
+        pytrans.ToTensor(),
+        pytrans.Normalize(mean=mean, std=std),
+    ]
+    trans = Compose(trans)
+
+    type_cast_op = C2.TypeCast(mstype.int32)
+    ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
+    ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
+
+    # apply batch operations
+    ds = ds.batch(batch_size, drop_remainder=True, num_parallel_workers=8)
+    ds = ds.repeat(repeat_num)
+    return ds
diff --git a/research/cv/HireMLP/src/hire_mlp.py b/research/cv/HireMLP/src/hire_mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6a0948a275c22329904ad4584a3e223a475843
--- /dev/null
+++ b/research/cv/HireMLP/src/hire_mlp.py
@@ -0,0 +1,403 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import collections.abc
+from itertools import repeat
+import mindspore.nn as nn
+import mindspore.ops as ops
+from mindspore.ops import operations as P
+import mindspore.common.initializer as weight_init
+
+
+def to_2tuple(x):
+    if isinstance(x, collections.abc.Iterable):
+        return x
+    return tuple(repeat(x, 2))
+
+
+class DropPath(nn.Cell):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None, seed=0):
+        super(DropPath, self).__init__()
+        self.keep_prob = 1 - drop_prob
+        seed = min(seed, 0)                  # always be 0
+        # seed must be 0, if set to other value, it's not rand for multiple call
+        self.rand = P.UniformReal(seed=seed)
+        self.shape = P.Shape()
+        self.floor = P.Floor()
+
+    def construct(self, x):
+        if self.training:
+            x_shape = self.shape(x)
+            random_tensor = self.rand((x_shape[0], 1, 1))
+            random_tensor = random_tensor + self.keep_prob
+            random_tensor = self.floor(random_tensor)
+            x = x / self.keep_prob
+            x = x * random_tensor
+        return x
+
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .96, 'interpolation': 'bicubic',
+        'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'classifier': 'head',
+        **kwargs
+    }
+
+
+class Mlp(nn.Cell):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.act = act_layer()
+        self.drop = nn.Dropout(1. - drop)
+        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, has_bias=True)
+        self.fc2 = nn.Conv2d(
+            hidden_features, out_features, 1, 1, has_bias=True)
+
+    def construct(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class HireMLP(nn.Cell):
+    def __init__(self, dim, attn_drop=0., proj_drop=0., pixel=2, step=1, step_pad_mode='c', pixel_pad_mode='c'):
+        super().__init__()
+        self.pixel = pixel
+        self.step = step
+        self.step_pad_mode = step_pad_mode
+        self.pixel_pad_mode = pixel_pad_mode
+        print('pixel: {} pad mode: {} step: {} pad mode: {}'.format(
+            pixel, pixel_pad_mode, step, step_pad_mode))
+
+        self.mlp_h1 = nn.Conv2d(dim*pixel, dim//2, 1, has_bias=False)
+        self.mlp_h1_norm = nn.BatchNorm2d(dim//2)
+        self.mlp_h2 = nn.Conv2d(dim//2, dim*pixel, 1, has_bias=True)
+        self.mlp_w1 = nn.Conv2d(dim*pixel, dim//2, 1, has_bias=False)
+        self.mlp_w1_norm = nn.BatchNorm2d(dim//2)
+        self.mlp_w2 = nn.Conv2d(dim//2, dim*pixel, 1, has_bias=True)
+        self.mlp_c = nn.Conv2d(dim, dim, 1, has_bias=True)
+
+        self.act = nn.ReLU()
+
+        self.reweight = Mlp(dim, dim // 4, dim * 3)
+
+        self.proj = nn.Conv2d(dim, dim, 1, has_bias=True)
+        self.proj_drop = nn.Dropout(1. - proj_drop)
+
+    def construct(self, x):
+        """
+        h: H x W x C -> H/pixel x W x C*pixel
+        w: H x W x C -> H x W/pixel x C*pixel
+        """
+
+        B, C, H, W = x.shape
+
+        pad_h, pad_w = (
+            self.pixel - H % self.pixel) % self.pixel, (self.pixel - W % self.pixel) % self.pixel
+        h, w = x.copy(), x.copy()
+
+        if self.step:
+            if self.step_pad_mode == 'c':
+                if self.step > 0:
+                    h_slice = ops.Slice()(h, (0, 0, 0, 0), (B, C, self.step, W))
+                    h = ops.Concat(axis=2)((h, h_slice))
+                    h = ops.Slice()(h, (0, 0, self.step, 0), (B, C, H, W))
+                    w_slice = ops.Slice()(w, (0, 0, 0, 0), (B, C, H, self.step))
+                    w = ops.Concat(axis=3)((w, w_slice))
+                    w = ops.Slice()(w, (0, 0, 0, self.step), (B, C, H, W))
+            else:
+                raise NotImplementedError("Invalid pad mode.")
+
+        if self.pixel_pad_mode == '0':
+            h = nn.Pad(paddings=((0, 0), (0, 0), (0, pad_h), (0, 0)), mode='CONSTANT')(h)
+            w = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, pad_w)), mode='CONSTANT')(w)
+        elif self.pixel_pad_mode == 'c':
+            if pad_h > 0:
+                h_slice = ops.Slice()(h, (0, 0, 0, 0), (B, C, pad_h, W))
+                h = ops.Concat(axis=2)((h, h_slice))
+            if pad_w > 0:
+                w_slice = ops.Slice()(w, (0, 0, 0, 0), (B, C, H, pad_w))
+                w = ops.Concat(axis=3)((w, w_slice))
+        else:
+            raise NotImplementedError("Invalid pad mode.")
+
+        h = (ops.Transpose()(h.reshape(B, C, (H + pad_h) // self.pixel, self.pixel, W), (0, 1, 3, 2, 4))).reshape(
+            B, C*self.pixel, (H + pad_h) // self.pixel, W)
+        w = (ops.Transpose()(w.reshape(B, C, H, (W + pad_w) // self.pixel, self.pixel), (0, 1, 4, 2, 3))).reshape(
+            B, C*self.pixel, H, (W + pad_w) // self.pixel)
+
+        h = self.mlp_h1(h)
+        h = self.mlp_h1_norm(h)
+        h = self.act(h)
+        h = self.mlp_h2(h)
+
+        w = self.mlp_w1(w)
+        w = self.mlp_w1_norm(w)
+        w = self.act(w)
+        w = self.mlp_w2(w)
+
+        h = (ops.Transpose()(h.reshape(B, C, self.pixel, (H + pad_h) // self.pixel, W), (0, 1, 3, 2, 4))).reshape(
+            B, C, H + pad_h, W)
+        w = (ops.Transpose()(w.reshape(B, C, self.pixel, H, (W + pad_w) // self.pixel), (0, 1, 3, 4, 2))).reshape(
+            B, C, H, W + pad_w)
+
+        h = ops.Slice()(h, (0, 0, 0, 0), (B, C, H, W))
+        w = ops.Slice()(w, (0, 0, 0, 0), (B, C, H, W))
+
+        if self.step and self.step_pad_mode == 'c':
+            _, _, H_, W_ = h.shape
+            h_slice = ops.Slice()(h, (0, 0, H_-self.step, 0), (B, C, self.step, W))
+            h = ops.Concat(axis=2)((h_slice, h))
+            h = ops.Slice()(h, (0, 0, 0, 0), (B, C, H, W))
+            w_slice = ops.Slice()(w, (0, 0, 0, W_-self.step), (B, C, H, self.step))
+            w = ops.Concat(axis=3)((w_slice, w))
+            w = ops.Slice()(w, (0, 0, 0, 0), (B, C, H, W))
+
+        c = self.mlp_c(x)
+
+        a = ops.AdaptiveAvgPool2D(output_size=(1, 1))(h + w + c)
+        a = ops.ExpandDims()(ops.ExpandDims()(
+            ops.Softmax(axis=0)(ops.Transpose()(self.reweight(a).reshape(B, C, 3), (2, 0, 1))), -1), -1)
+
+        x = h * a[0] + w * a[1] + c * a[2]
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+
+        return x
+
+
+class HireBlock(nn.Cell):
+
+    def __init__(self, dim, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.,
+                 pixel=2, step=1, step_pad_mode='c', pixel_pad_mode='c'):
+        super().__init__()
+        self.norm1 = nn.BatchNorm2d(dim)
+        self.attn = HireMLP(dim, attn_drop=attn_drop, pixel=pixel, step=step,
+                            step_pad_mode=step_pad_mode, pixel_pad_mode=pixel_pad_mode)
+        self.norm2 = nn.BatchNorm2d(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim,
+                       hidden_features=mlp_hidden_dim, drop=drop)
+        self.drop_path = DropPath(
+            drop_path) if drop_path > 0. else ops.Identity()
+
+    def construct(self, x):
+        x = x + self.drop_path(self.attn(self.norm1(x)))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbedOverlapping(nn.Cell):
+    def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d,
+                 groups=1, use_norm=True):
+        super().__init__()
+        patch_size = to_2tuple(patch_size)
+        stride = to_2tuple(stride)
+        self.patch_size = patch_size
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=patch_size, stride=stride,
+            padding=(padding, padding, padding, padding), group=groups, pad_mode='pad', has_bias=True)
+        self.norm = norm_layer(embed_dim) if use_norm else ops.Identity()
+        self.act = nn.ReLU()
+
+    def construct(self, x):
+        x = self.proj(x)
+        x = self.norm(x)
+        x = self.act(x)
+        return x
+
+
+class Downsample(nn.Cell):
+    def __init__(self, in_embed_dim, out_embed_dim, patch_size, norm_layer=nn.BatchNorm2d, use_norm=True):
+        super().__init__()
+        assert patch_size == 2, patch_size
+        self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=(
+            3, 3), stride=(2, 2), padding=1, pad_mode='pad', has_bias=True)
+        self.norm = norm_layer(
+            out_embed_dim) if use_norm else ops.Identity()
+        self.act = nn.ReLU()
+
+    def construct(self, x):
+        x = self.proj(x)
+        x = self.norm(x)
+        x = self.act(x)
+        return x
+
+
+def basic_blocks(dim, index, layers, mlp_ratio=4., attn_drop=0., drop_path_rate=0., pixel=2, step_stride=2,
+                 step_dilation=1, step_pad_mode='c', pixel_pad_mode='c', **kwargs):
+    blocks = []
+    for block_idx in range(layers[index]):
+        block_dpr = drop_path_rate * \
+            (block_idx + sum(layers[:index])) / (sum(layers) - 1)
+        blocks.append(HireBlock(
+            dim, mlp_ratio=mlp_ratio, attn_drop=attn_drop, drop_path=block_dpr, pixel=pixel,
+            step=(block_idx % step_stride) * step_dilation, step_pad_mode=step_pad_mode, pixel_pad_mode=pixel_pad_mode))
+    blocks = nn.SequentialCell(*blocks)
+    return blocks
+
+
+class HireMLPNet(nn.Cell):
+    def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+                 embed_dims=None, mlp_ratios=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
+                 pixel=None, step_stride=None, step_dilation=None,
+                 step_pad_mode='c', pixel_pad_mode='c'):
+        super().__init__()
+        self.print = ops.Print()
+
+        self.num_classes = num_classes
+
+        self.patch_embed = PatchEmbedOverlapping(
+            patch_size=7, stride=4, padding=2, in_chans=3, embed_dim=embed_dims[0])
+
+        network = []
+        for i in range(len(layers)):
+            stage = basic_blocks(
+                embed_dims[i], i, layers, mlp_ratio=mlp_ratios[i],
+                attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate, pixel=pixel[i],
+                step_stride=step_stride[i], step_dilation=step_dilation[i],
+                step_pad_mode=step_pad_mode, pixel_pad_mode=pixel_pad_mode)
+            network.append(stage)
+            if i >= len(layers) - 1:
+                break
+            network.append(Downsample(embed_dims[i], embed_dims[i+1], 2))
+
+        self.network = nn.SequentialCell(network)
+
+        self.norm = nn.BatchNorm2d(embed_dims[-1])
+        self.head = nn.Dense(
+            embed_dims[-1], num_classes) if num_classes > 0 else ops.Identity()
+
+    def cls_init_weights(self):
+        for _, cell in self.cells_and_names():
+            if isinstance(cell, nn.Dense):
+                cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
+                                                             cell.weight.shape,
+                                                             cell.weight.dtype))
+                if isinstance(cell, nn.Dense) and cell.bias is not None:
+                    cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
+                                                               cell.bias.shape,
+                                                               cell.bias.dtype))
+            elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm)):
+                cell.gamma.set_data(weight_init.initializer(weight_init.One(),
+                                                            cell.gamma.shape,
+                                                            cell.gamma.dtype))
+                cell.beta.set_data(weight_init.initializer(weight_init.Zero(),
+                                                           cell.beta.shape,
+                                                           cell.beta.dtype))
+
+    def get_classifier(self):
+        return self.head
+
+    def reset_classifier(self, num_classes, global_pool=''):
+        self.num_classes = num_classes
+        self.head = nn.Dense(
+            self.embed_dim, num_classes) if num_classes > 0 else ops.Identity()
+
+    def forward_embeddings(self, x):
+        x = self.patch_embed(x)
+        return x
+
+    def forward_tokens(self, x):
+        for _, block in enumerate(self.network):
+            x = block(x)
+        return x
+
+    def construct(self, x):
+        x = self.forward_embeddings(x)
+        x = self.forward_tokens(x)
+        x = self.norm(x)
+        cls_out = self.head(ops.Squeeze()(
+            (ops.AdaptiveAvgPool2D(output_size=1)(x))))
+        return cls_out
+
+
+def hire_mlp_tiny(pretrained=False, **kwargs):
+    layers = [2, 2, 4, 2]
+    mlp_ratios = [4, 4, 4, 4]
+    embed_dims = [64, 128, 320, 512]
+    pixel = [4, 3, 3, 2]
+    step_stride = [2, 2, 3, 2]
+    step_dilation = [2, 2, 1, 1]
+    step_pad_mode = 'c'
+    pixel_pad_mode = 'c'
+    model = HireMLPNet(
+        layers, embed_dims=embed_dims, patch_size=7, mlp_ratios=mlp_ratios, pixel=pixel,
+        step_stride=step_stride, step_dilation=step_dilation,
+        step_pad_mode=step_pad_mode, pixel_pad_mode=pixel_pad_mode, **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+
+def hire_mlp_small(pretrained=False, **kwargs):
+    layers = [3, 4, 10, 3]
+    mlp_ratios = [4, 4, 4, 4]
+    embed_dims = [64, 128, 320, 512]
+    pixel = [4, 3, 3, 2]
+    step_stride = [2, 2, 3, 2]
+    step_dilation = [2, 2, 1, 1]
+    step_pad_mode = 'c'
+    pixel_pad_mode = 'c'
+    model = HireMLPNet(
+        layers, embed_dims=embed_dims, patch_size=7, mlp_ratios=mlp_ratios, pixel=pixel,
+        step_stride=step_stride, step_dilation=step_dilation,
+        step_pad_mode=step_pad_mode, pixel_pad_mode=pixel_pad_mode, **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+
+def hire_mlp_base(pretrained=False, **kwargs):
+    layers = [4, 6, 24, 3]
+    mlp_ratios = [4, 4, 4, 4]
+    embed_dims = [64, 128, 320, 512]
+    pixel = [4, 3, 3, 2]
+    step_stride = [2, 2, 3, 2]
+    step_dilation = [2, 2, 1, 1]
+    step_pad_mode = 'c'
+    pixel_pad_mode = 'c'
+    model = HireMLPNet(
+        layers, embed_dims=embed_dims, patch_size=7, mlp_ratios=mlp_ratios, pixel=pixel,
+        step_stride=step_stride, step_dilation=step_dilation,
+        step_pad_mode=step_pad_mode, pixel_pad_mode=pixel_pad_mode, **kwargs)
+    model.default_cfg = _cfg()
+    return model
+
+
+def hire_mlp_large(pretrained=False, **kwargs):
+    layers = [4, 6, 24, 3]
+    mlp_ratios = [4, 4, 4, 4]
+    embed_dims = [96, 192, 384, 768]
+    pixel = [4, 3, 3, 2]
+    step_stride = [2, 2, 3, 2]
+    step_dilation = [2, 2, 1, 1]
+    step_pad_mode = 'c'
+    pixel_pad_mode = 'c'
+    model = HireMLPNet(
+        layers, embed_dims=embed_dims, patch_size=7, mlp_ratios=mlp_ratios, pixel=pixel,
+        step_stride=step_stride, step_dilation=step_dilation,
+        step_pad_mode=step_pad_mode, pixel_pad_mode=pixel_pad_mode, **kwargs)
+    model.default_cfg = _cfg()
+    return model
diff --git a/research/cv/wave_mlp/eval.py b/research/cv/wave_mlp/eval.py
index bcccd15ad9232929465f654c16a66082187f7a45..34050655f88ae93b6abc3ff46fbceb27cbe55bb2 100644
--- a/research/cv/wave_mlp/eval.py
+++ b/research/cv/wave_mlp/eval.py
@@ -36,7 +36,7 @@ if __name__ == '__main__':
         context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
                             device_id=device_id, save_graphs=False)
     elif args_opt.platform == "GPU":
-        context.set_context(mode=context.PYNATIVE_MODE,
+        context.set_context(mode=context.GRAPH_MODE,
                             device_target="GPU", save_graphs=False)
     else:
         raise ValueError("Unsupported platform.")