diff --git a/research/cv/wave_mlp/README.md b/research/cv/wave_mlp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ceb2a53c8ee22f89e894ca1e5d8dfaa100b007cb --- /dev/null +++ b/research/cv/wave_mlp/README.md @@ -0,0 +1,111 @@ +# Contents + +- [Contents](#contents) + - [Wave-MLP Description](#aug-vit-description) + - [Model architecture](#model-architecture) + - [Dataset](#dataset) + - [Environment Requirements](#environment-requirements) + - [Script description](#script-description) + - [Script and sample code](#script-and-sample-code) + - [Eval process](#eval-process) + - [Usage](#usage) + - [Launch](#launch) + - [Result](#result) + - [Description of Random Situation](#description-of-random-situation) + - [ModelZoo Homepage](#modelzoo-homepage) + +## [Wave-MLP Description](#contents) + +To dynamically aggregate tokens, Wave-MLP proposes to represent each token as a wave function with two parts, amplitude and phase. Amplitude is the original feature and the phase term is a complex value changing according to the semantic contents of input images. + +[Paper](https://arxiv.org/pdf/2111.12294.pdf): Yehui Tang, Kai Han, Jianyuan Guo, Chang Xu, Yanxi Li, Chao Xu, Yunhe Wang. An Image Patch is a Wave: Phase-Aware Vision MLP. arxiv 2111.12294. + +## [Model architecture](#contents) + +A block of Wave-MLP is shown below: + + + +## [Dataset](#contents) + +Dataset used: [ImageNet2012] + +- Dataset size 224*224 colorful images in 1000 classes + - Train锛�1,281,167 images + - Test锛� 50,000 images +- Data format锛歫peg + - Note锛欴ata will be processed in dataset.py + +## [Environment Requirements](#contents) + +- Hardware(Ascend/GPU) + - Prepare hardware environment with Ascend or GPU. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below拢潞 + - [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html) + +## [Script description](#contents) + +### [Script and sample code](#contents) + +```bash +WalveMlp +鈹溾攢鈹€ eval.py # inference entry +鈹溾攢鈹€ fig +鈹� 鈹斺攢鈹€ wavemlp.png # the illustration of wave_mlp network +鈹溾攢鈹€ readme.md # Readme +鈹斺攢鈹€ src + 鈹溾攢鈹€ dataset.py # dataset loader + 鈹斺攢鈹€ wave_mlp.py # wave_mlp network +``` + +## [Eval process](#contents) + +### Usage + +After installing MindSpore via the official website, you can start evaluation as follows: + +### Launch + +```bash +# infer example + # python + GPU: python eval.py --dataset_path dataset --platform GPU --checkpoint_path [CHECKPOINT_PATH] + # shell + bash ./scripts/run_eval.sh [DATA_PATH] [PLATFORM] [CHECKPOINT_PATH] +``` + +> checkpoint can be downloaded at https://download.mindspore.cn/model_zoo/research/cv/wavemlp/. + +### Result + +```bash +result: {'acc': 0.807} ckpt= ./WaveMLP_T.ckpt +``` + +### Inference Performance + +#### WaveMlp infer on ImageNet2012 + +| Parameters | Ascend | +| ------------------- | --------------------------- | +| Model Version | WaveMlp | +| Resource | Ascend 910; OS Euler2.8 | +| Uploaded Date | 08/03/2022 (month/day/year) | +| MindSpore Version | 1.6.0 | +| Dataset | ImageNet2012 | +| batch_size | 1024 | +| outputs | probability | +| Accuracy | 1pc: 80.7% | +| Speed | 1pc: 11.72 s/step | +| Total time | 1pc: 562.96 s/step | + +## [Description of Random Situation](#contents) + +In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py. + +## [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/models). \ No newline at end of file diff --git a/research/cv/wave_mlp/eval.py b/research/cv/wave_mlp/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..bcccd15ad9232929465f654c16a66082187f7a45 --- /dev/null +++ b/research/cv/wave_mlp/eval.py @@ -0,0 +1,55 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +eval. +""" +import os +import argparse +from mindspore import context +from mindspore import nn +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.dataset import create_dataset +from src.wave_mlp import WaveMLP_T + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') +parser.add_argument('--platform', type=str, default='Ascend', help='run platform') +args_opt = parser.parse_args() + +if __name__ == '__main__': + if args_opt.platform == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", + device_id=device_id, save_graphs=False) + elif args_opt.platform == "GPU": + context.set_context(mode=context.PYNATIVE_MODE, + device_target="GPU", save_graphs=False) + else: + raise ValueError("Unsupported platform.") + net = WaveMLP_T() + if args_opt.checkpoint_path: + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=1024) + + model = Model(net, loss_fn=loss, metrics={'acc'}) + res = model.eval(dataset, dataset_sink_mode=False) + print("result:", res, "ckpt=", args_opt.checkpoint_path) diff --git a/research/cv/wave_mlp/fig/wavemlp.png b/research/cv/wave_mlp/fig/wavemlp.png new file mode 100644 index 0000000000000000000000000000000000000000..ae784e36f645a2ee6528865e2855642a71aa1cb2 Binary files /dev/null and b/research/cv/wave_mlp/fig/wavemlp.png differ diff --git a/research/cv/wave_mlp/scripts/run_eval.sh b/research/cv/wave_mlp/scripts/run_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..4261a3c5dc8df9ac555e0305fce3d217304112a1 --- /dev/null +++ b/research/cv/wave_mlp/scripts/run_eval.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# -lt 3 ] +then + echo "Usage: bash ./scripts/run_eval.sh [DATA_PATH] [PLATFORM] [CHECKPOINT_PATH]" +exit 1 +fi + +DATA_PATH=$1 +PLATFORM=$2 +CHECKPOINT_PATH=$3 + +rm -rf evaluation +mkdir ./evaluation +cd ./evaluation || exit +echo "start training for device id $DEVICE_ID" +env > env.log +python eval.py --dataset_path=$DATA_PATH --platform=$PLATFORM --checkpoint_path=$CHECKPOINT_PATH > eval.log 2>&1 & +cd ../ diff --git a/research/cv/wave_mlp/src/dataset.py b/research/cv/wave_mlp/src/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..89c44b54105c8b8a1f3d18ed50c5418f90a3be30 --- /dev/null +++ b/research/cv/wave_mlp/src/dataset.py @@ -0,0 +1,101 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Data operations, will be used in train.py and eval.py""" +import os +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.vision.py_transforms as pytrans +import mindspore.dataset.transforms.py_transforms as py_transforms + +from mindspore.dataset.transforms.py_transforms import Compose +import mindspore.dataset.vision.c_transforms as C + + +class ToNumpy(py_transforms.PyTensorOperation): + + def __init__(self, output_type=np.float32): + self.output_type = output_type + self.random = False + + def __call__(self, img): + """ + Call method. + + Args: + img (Union[PIL Image, numpy.ndarray]): PIL Image or numpy.ndarray to be type converted. + + Returns: + numpy.ndarray, converted numpy.ndarray with desired type. + """ + np_img = np.array(img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return np_img + + +def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=128): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + rank (int): The shard ID within num_shards (default=None). + group_size (int): Number of shards that the dataset should be divided into (default=None). + repeat_num(int): the repeat times of dataset. Default: 1. + + Returns: + dataset + """ + + if not do_train: + dataset_path = os.path.join(dataset_path, 'val') + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False, num_shards=1, shard_id=0) + else: + dataset_path = os.path.join(dataset_path, 'train') + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=1, shard_id=0) + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(224), + C.RandomHorizontalFlip(prob=0.5), + C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + ] + else: + trans = [ + pytrans.Decode(), + pytrans.Resize(235), + pytrans.CenterCrop(224) + ] + trans += [ + pytrans.ToTensor(), + pytrans.Normalize(mean=mean, std=std), + ] + trans = Compose(trans) + + type_cast_op = C2.TypeCast(mstype.int32) + ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) + ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True, num_parallel_workers=8) + ds = ds.repeat(repeat_num) + return ds diff --git a/research/cv/wave_mlp/src/wave_mlp.py b/research/cv/wave_mlp/src/wave_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..dddac1fa8e7e7c4dfe3473f2ab3020891f369c22 --- /dev/null +++ b/research/cv/wave_mlp/src/wave_mlp.py @@ -0,0 +1,354 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +from itertools import repeat +import collections.abc +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.ops import operations as P +import mindspore.common.initializer as weight_init + + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, 2)) + + +class DropPath(nn.Cell): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.keep_prob = 1 - drop_prob + self.rand = P.UniformReal(seed=0) # seed must be 0, if set to other value, it's not rand for multiple call + self.shape = P.Shape() + self.floor = P.Floor() + + def construct(self, x): + if self.training: + x_shape = self.shape(x) # B N C + random_tensor = self.rand((x_shape[0], 1, 1)) + random_tensor = random_tensor + self.keep_prob + random_tensor = self.floor(random_tensor) + x = x / self.keep_prob + x = x * random_tensor + return x + + +def _cfg(url='', crop_pct=.96): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': crop_pct, 'interpolation': 'bicubic', + 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'classifier': 'head' + } + + +default_cfgs = { + 'wave_T': _cfg(crop_pct=0.9), + 'wave_S': _cfg(crop_pct=0.9), + 'wave_M': _cfg(crop_pct=0.9), + 'wave_B': _cfg(crop_pct=0.875), +} + + +class Mlp(nn.Cell): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.act = act_layer() + self.drop = nn.Dropout(1. - drop) + self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, has_bias=True) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, has_bias=True) + + def construct(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class PATM(nn.Cell): + def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., mode='fc'): + super().__init__() + + self.fc_h = nn.Conv2d(dim, dim, 1, 1, has_bias=qkv_bias) + self.fc_w = nn.Conv2d(dim, dim, 1, 1, has_bias=qkv_bias) + self.fc_c = nn.Conv2d(dim, dim, 1, 1, has_bias=qkv_bias) + self.tfc_h = nn.Conv2d(2 * dim, dim, (1, 7), stride=1, padding=(0, 0, 7 // 2, 7 // 2), group=dim, + has_bias=False, pad_mode='pad') + self.tfc_w = nn.Conv2d(2 * dim, dim, (7, 1), stride=1, padding=(7 // 2, 7 // 2, 0, 0), group=dim, + has_bias=False, pad_mode='pad') + self.reweight = Mlp(dim, dim // 4, dim * 3) + self.proj = nn.Conv2d(dim, dim, 1, 1, has_bias=True) + self.proj_drop = nn.Dropout(1. - proj_drop) + self.mode = mode + + if mode == 'fc': + self.theta_h_conv = nn.SequentialCell(nn.Conv2d(dim, dim, 1, 1, has_bias=True), nn.BatchNorm2d(dim), + nn.ReLU()) + self.theta_w_conv = nn.SequentialCell(nn.Conv2d(dim, dim, 1, 1, has_bias=True), nn.BatchNorm2d(dim), + nn.ReLU()) + else: + self.theta_h_conv = nn.SequentialCell( + nn.Conv2d(dim, dim, 3, stride=1, padding=1, group=dim, has_bias=False), + nn.BatchNorm2d(dim), nn.ReLU()) + self.theta_w_conv = nn.SequentialCell( + nn.Conv2d(dim, dim, 3, stride=1, padding=1, group=dim, has_bias=False), + nn.BatchNorm2d(dim), nn.ReLU()) + + def construct(self, x): + + B, C, _, _ = x.shape + theta_h = self.theta_h_conv(x) + theta_w = self.theta_w_conv(x) + + x_h = self.fc_h(x) + x_w = self.fc_w(x) + x_h = ops.Concat(axis=1)((x_h * (ops.Cos()(theta_h)), x_h * (ops.Sin()(theta_h)))) + x_w = ops.Concat(axis=1)((x_w * (ops.Cos()(theta_w)), x_w * (ops.Sin()(theta_w)))) + h = self.tfc_h(x_h) + w = self.tfc_w(x_w) + c = self.fc_c(x) + a = ops.AdaptiveAvgPool2D(output_size=(1, 1))(h + w + c) + a = ops.ExpandDims()( + ops.ExpandDims()(ops.Softmax(axis=0)(ops.Transpose()(self.reweight(a).reshape(B, C, 3), (2, 0, 1))), -1), + -1) + x = h * a[0] + w * a[1] + c * a[2] + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class WaveBlock(nn.Cell): + + def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.BatchNorm2d, mode='fc'): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = PATM(dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop, mode=mode) + self.drop_path = DropPath(drop_path) if drop_path > 0. else ops.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + + def construct(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbedOverlapping(nn.Cell): + def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d, + groups=1, use_norm=True): + super(PatchEmbedOverlapping).__init__() + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + self.patch_size = patch_size + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(padding, padding, padding, padding), + group=groups, pad_mode='pad', has_bias=True) + self.norm = norm_layer(embed_dim) if use_norm else ops.Identity() + + def construct(self, x): + x = self.proj(x) + x = self.norm(x) + return x + + +class Downsample(nn.Cell): + def __init__(self, in_embed_dim, out_embed_dim, patch_size, norm_layer=nn.BatchNorm2d, use_norm=True): + super(Downsample).__init__() + assert patch_size == 2, patch_size + self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=1, pad_mode='pad', + has_bias=True) + self.norm = norm_layer(out_embed_dim) if use_norm else ops.Identity() + + def construct(self, x): + x = self.proj(x) + x = self.norm(x) + return x + + +def basic_blocks(dim, index, layers, mlp_ratio=3., qkv_bias=False, qk_scale=None, attn_drop=0., + drop_path_rate=0., norm_layer=nn.BatchNorm2d, mode='fc', **kwargs): + blocks = [] + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(WaveBlock(dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, drop_path=block_dpr, norm_layer=norm_layer, mode=mode)) + blocks = nn.SequentialCell(*blocks) + return blocks + + +class WaveNet(nn.Cell): + def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dims=None, transitions=None, mlp_ratios=None, + qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=nn.BatchNorm2d, fork_feat=False, mode='fc', ds_use_norm=True, args=None): + + super().__init__() + + if not fork_feat: + self.num_classes = num_classes + self.fork_feat = fork_feat + + self.patch_embed = PatchEmbedOverlapping(patch_size=7, stride=4, padding=2, in_chans=3, embed_dim=embed_dims[0], + norm_layer=norm_layer, use_norm=ds_use_norm) + + network = [] + for i in range(len(layers)): + stage = basic_blocks(embed_dims[i], i, layers, mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, + qk_scale=qk_scale, attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate, + norm_layer=norm_layer, mode=mode) + network.append(stage) + if i >= len(layers) - 1: + break + if transitions[i] or embed_dims[i] != embed_dims[i + 1]: + patch_size = 2 if transitions[i] else 1 + network.append(Downsample(embed_dims[i], embed_dims[i + 1], patch_size, norm_layer=norm_layer, + use_norm=ds_use_norm)) + + self.network = nn.SequentialCell(network) + + if self.fork_feat: + # add a norm layer for each output + self.out_indices = [0, 2, 4, 6] + for i_emb, i_layer in enumerate(self.out_indices): + if i_emb == 0 and os.environ.get('FORK_LAST3', None): + layer = ops.Identity() + else: + layer = norm_layer(embed_dims[i_emb]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + else: + self.norm = norm_layer(embed_dims[-1]) + self.head = nn.Dense(embed_dims[-1], num_classes) if num_classes > 0 else ops.Identity() + + def cls_init_weights(self): + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Dense): + cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02), + cell.weight.shape, + cell.weight.dtype)) + if isinstance(cell, nn.Dense) and cell.bias is not None: + cell.bias.set_data(weight_init.initializer(weight_init.Zero(), + cell.bias.shape, + cell.bias.dtype)) + elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm)): + cell.gamma.set_data(weight_init.initializer(weight_init.One(), + cell.gamma.shape, + cell.gamma.dtype)) + cell.beta.set_data(weight_init.initializer(weight_init.Zero(), + cell.beta.shape, + cell.beta.dtype)) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else ops.Identity() + + def forward_embeddings(self, x): + x = self.patch_embed(x) + return x + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if self.fork_feat and idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(x) + outs.append(x_out) + if self.fork_feat: + return outs + return x + + def construct(self, x): + x = self.forward_embeddings(x) + x = self.forward_tokens(x) + if self.fork_feat: + return x + x = self.norm(x) + cls_out = self.head(ops.Squeeze()((ops.AdaptiveAvgPool2D(output_size=1)(x)))) + return cls_out + + +def GroupNorm(dim): + return nn.GroupNorm(1, dim) + + +def WaveMLP_T_dw(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [2, 2, 4, 2] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [64, 128, 320, 512] + model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, mode='depthwise', **kwargs) + model.default_cfg = default_cfgs['wave_T'] + return model + + +def WaveMLP_T(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [2, 2, 4, 2] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [64, 128, 320, 512] + model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, **kwargs) + model.default_cfg = default_cfgs['wave_T'] + return model + + +def WaveMLP_S(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [2, 3, 10, 3] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [64, 128, 320, 512] + model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, norm_layer=GroupNorm, **kwargs) + model.default_cfg = default_cfgs['wave_S'] + return model + + +def WaveMLP_M(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [3, 4, 18, 3] + mlp_ratios = [8, 8, 4, 4] + embed_dims = [64, 128, 320, 512] + model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, norm_layer=GroupNorm, ds_use_norm=False, **kwargs) + model.default_cfg = default_cfgs['wave_M'] + return model + + +def WaveMLP_B(pretrained=False, **kwargs): + transitions = [True, True, True, True] + layers = [2, 2, 18, 2] + mlp_ratios = [4, 4, 4, 4] + embed_dims = [96, 192, 384, 768] + model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions, + mlp_ratios=mlp_ratios, norm_layer=GroupNorm, ds_use_norm=False, **kwargs) + model.default_cfg = default_cfgs['wave_B'] + return model