Skip to content
Snippets Groups Projects
Unverified Commit 097cd4a3 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2205 add wave_mlp to models

Merge pull request !2205 from 185******25/master
parents 9deabd70 92f8dc50
No related branches found
No related tags found
No related merge requests found
# Contents
- [Contents](#contents)
- [Wave-MLP Description](#aug-vit-description)
- [Model architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script description](#script-description)
- [Script and sample code](#script-and-sample-code)
- [Eval process](#eval-process)
- [Usage](#usage)
- [Launch](#launch)
- [Result](#result)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
## [Wave-MLP Description](#contents)
To dynamically aggregate tokens, Wave-MLP proposes to represent each token as a wave function with two parts, amplitude and phase. Amplitude is the original feature and the phase term is a complex value changing according to the semantic contents of input images.
[Paper](https://arxiv.org/pdf/2111.12294.pdf): Yehui Tang, Kai Han, Jianyuan Guo, Chang Xu, Yanxi Li, Chao Xu, Yunhe Wang. An Image Patch is a Wave: Phase-Aware Vision MLP. arxiv 2111.12294.
## [Model architecture](#contents)
A block of Wave-MLP is shown below:
![image-wavemlp](./fig/wavemlp.png)
## [Dataset](#contents)
Dataset used: [ImageNet2012]
- Dataset size 224*224 colorful images in 1000 classes
- Train:1,281,167 images
- Test: 50,000 images
- Data format:jpeg
- Note:Data will be processed in dataset.py
## [Environment Requirements](#contents)
- Hardware(Ascend/GPU)
- Prepare hardware environment with Ascend or GPU.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below£º
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
## [Script description](#contents)
### [Script and sample code](#contents)
```bash
WalveMlp
├── eval.py # inference entry
├── fig
│ └── wavemlp.png # the illustration of wave_mlp network
├── readme.md # Readme
└── src
├── dataset.py # dataset loader
└── wave_mlp.py # wave_mlp network
```
## [Eval process](#contents)
### Usage
After installing MindSpore via the official website, you can start evaluation as follows:
### Launch
```bash
# infer example
# python
GPU: python eval.py --dataset_path dataset --platform GPU --checkpoint_path [CHECKPOINT_PATH]
# shell
bash ./scripts/run_eval.sh [DATA_PATH] [PLATFORM] [CHECKPOINT_PATH]
```
> checkpoint can be downloaded at https://download.mindspore.cn/model_zoo/research/cv/wavemlp/.
### Result
```bash
result: {'acc': 0.807} ckpt= ./WaveMLP_T.ckpt
```
### Inference Performance
#### WaveMlp infer on ImageNet2012
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | WaveMlp |
| Resource | Ascend 910; OS Euler2.8 |
| Uploaded Date | 08/03/2022 (month/day/year) |
| MindSpore Version | 1.6.0 |
| Dataset | ImageNet2012 |
| batch_size | 1024 |
| outputs | probability |
| Accuracy | 1pc: 80.7% |
| Speed | 1pc: 11.72 s/step |
| Total time | 1pc: 562.96 s/step |
## [Description of Random Situation](#contents)
In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py.
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/models).
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval.
"""
import os
import argparse
from mindspore import context
from mindspore import nn
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset
from src.wave_mlp import WaveMLP_T
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--platform', type=str, default='Ascend', help='run platform')
args_opt = parser.parse_args()
if __name__ == '__main__':
if args_opt.platform == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False)
elif args_opt.platform == "GPU":
context.set_context(mode=context.PYNATIVE_MODE,
device_target="GPU", save_graphs=False)
else:
raise ValueError("Unsupported platform.")
net = WaveMLP_T()
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=1024)
model = Model(net, loss_fn=loss, metrics={'acc'})
res = model.eval(dataset, dataset_sink_mode=False)
print("result:", res, "ckpt=", args_opt.checkpoint_path)
research/cv/wave_mlp/fig/wavemlp.png

70.6 KiB

#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 3 ]
then
echo "Usage: bash ./scripts/run_eval.sh [DATA_PATH] [PLATFORM] [CHECKPOINT_PATH]"
exit 1
fi
DATA_PATH=$1
PLATFORM=$2
CHECKPOINT_PATH=$3
rm -rf evaluation
mkdir ./evaluation
cd ./evaluation || exit
echo "start training for device id $DEVICE_ID"
env > env.log
python eval.py --dataset_path=$DATA_PATH --platform=$PLATFORM --checkpoint_path=$CHECKPOINT_PATH > eval.log 2>&1 &
cd ../
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations, will be used in train.py and eval.py"""
import os
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.py_transforms as pytrans
import mindspore.dataset.transforms.py_transforms as py_transforms
from mindspore.dataset.transforms.py_transforms import Compose
import mindspore.dataset.vision.c_transforms as C
class ToNumpy(py_transforms.PyTensorOperation):
def __init__(self, output_type=np.float32):
self.output_type = output_type
self.random = False
def __call__(self, img):
"""
Call method.
Args:
img (Union[PIL Image, numpy.ndarray]): PIL Image or numpy.ndarray to be type converted.
Returns:
numpy.ndarray, converted numpy.ndarray with desired type.
"""
np_img = np.array(img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=128):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided into (default=None).
repeat_num(int): the repeat times of dataset. Default: 1.
Returns:
dataset
"""
if not do_train:
dataset_path = os.path.join(dataset_path, 'val')
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=False, num_shards=1, shard_id=0)
else:
dataset_path = os.path.join(dataset_path, 'train')
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=1, shard_id=0)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(224),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
]
else:
trans = [
pytrans.Decode(),
pytrans.Resize(235),
pytrans.CenterCrop(224)
]
trans += [
pytrans.ToTensor(),
pytrans.Normalize(mean=mean, std=std),
]
trans = Compose(trans)
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True, num_parallel_workers=8)
ds = ds.repeat(repeat_num)
return ds
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from itertools import repeat
import collections.abc
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.ops import operations as P
import mindspore.common.initializer as weight_init
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, 2))
class DropPath(nn.Cell):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.keep_prob = 1 - drop_prob
self.rand = P.UniformReal(seed=0) # seed must be 0, if set to other value, it's not rand for multiple call
self.shape = P.Shape()
self.floor = P.Floor()
def construct(self, x):
if self.training:
x_shape = self.shape(x) # B N C
random_tensor = self.rand((x_shape[0], 1, 1))
random_tensor = random_tensor + self.keep_prob
random_tensor = self.floor(random_tensor)
x = x / self.keep_prob
x = x * random_tensor
return x
def _cfg(url='', crop_pct=.96):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': crop_pct, 'interpolation': 'bicubic',
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'classifier': 'head'
}
default_cfgs = {
'wave_T': _cfg(crop_pct=0.9),
'wave_S': _cfg(crop_pct=0.9),
'wave_M': _cfg(crop_pct=0.9),
'wave_B': _cfg(crop_pct=0.875),
}
class Mlp(nn.Cell):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.act = act_layer()
self.drop = nn.Dropout(1. - drop)
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, has_bias=True)
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, has_bias=True)
def construct(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class PATM(nn.Cell):
def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., mode='fc'):
super().__init__()
self.fc_h = nn.Conv2d(dim, dim, 1, 1, has_bias=qkv_bias)
self.fc_w = nn.Conv2d(dim, dim, 1, 1, has_bias=qkv_bias)
self.fc_c = nn.Conv2d(dim, dim, 1, 1, has_bias=qkv_bias)
self.tfc_h = nn.Conv2d(2 * dim, dim, (1, 7), stride=1, padding=(0, 0, 7 // 2, 7 // 2), group=dim,
has_bias=False, pad_mode='pad')
self.tfc_w = nn.Conv2d(2 * dim, dim, (7, 1), stride=1, padding=(7 // 2, 7 // 2, 0, 0), group=dim,
has_bias=False, pad_mode='pad')
self.reweight = Mlp(dim, dim // 4, dim * 3)
self.proj = nn.Conv2d(dim, dim, 1, 1, has_bias=True)
self.proj_drop = nn.Dropout(1. - proj_drop)
self.mode = mode
if mode == 'fc':
self.theta_h_conv = nn.SequentialCell(nn.Conv2d(dim, dim, 1, 1, has_bias=True), nn.BatchNorm2d(dim),
nn.ReLU())
self.theta_w_conv = nn.SequentialCell(nn.Conv2d(dim, dim, 1, 1, has_bias=True), nn.BatchNorm2d(dim),
nn.ReLU())
else:
self.theta_h_conv = nn.SequentialCell(
nn.Conv2d(dim, dim, 3, stride=1, padding=1, group=dim, has_bias=False),
nn.BatchNorm2d(dim), nn.ReLU())
self.theta_w_conv = nn.SequentialCell(
nn.Conv2d(dim, dim, 3, stride=1, padding=1, group=dim, has_bias=False),
nn.BatchNorm2d(dim), nn.ReLU())
def construct(self, x):
B, C, _, _ = x.shape
theta_h = self.theta_h_conv(x)
theta_w = self.theta_w_conv(x)
x_h = self.fc_h(x)
x_w = self.fc_w(x)
x_h = ops.Concat(axis=1)((x_h * (ops.Cos()(theta_h)), x_h * (ops.Sin()(theta_h))))
x_w = ops.Concat(axis=1)((x_w * (ops.Cos()(theta_w)), x_w * (ops.Sin()(theta_w))))
h = self.tfc_h(x_h)
w = self.tfc_w(x_w)
c = self.fc_c(x)
a = ops.AdaptiveAvgPool2D(output_size=(1, 1))(h + w + c)
a = ops.ExpandDims()(
ops.ExpandDims()(ops.Softmax(axis=0)(ops.Transpose()(self.reweight(a).reshape(B, C, 3), (2, 0, 1))), -1),
-1)
x = h * a[0] + w * a[1] + c * a[2]
x = self.proj(x)
x = self.proj_drop(x)
return x
class WaveBlock(nn.Cell):
def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.BatchNorm2d, mode='fc'):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = PATM(dim, qkv_bias=qkv_bias, qk_scale=None, attn_drop=attn_drop, mode=mode)
self.drop_path = DropPath(drop_path) if drop_path > 0. else ops.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
def construct(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbedOverlapping(nn.Cell):
def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d,
groups=1, use_norm=True):
super(PatchEmbedOverlapping).__init__()
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(padding, padding, padding, padding),
group=groups, pad_mode='pad', has_bias=True)
self.norm = norm_layer(embed_dim) if use_norm else ops.Identity()
def construct(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class Downsample(nn.Cell):
def __init__(self, in_embed_dim, out_embed_dim, patch_size, norm_layer=nn.BatchNorm2d, use_norm=True):
super(Downsample).__init__()
assert patch_size == 2, patch_size
self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=1, pad_mode='pad',
has_bias=True)
self.norm = norm_layer(out_embed_dim) if use_norm else ops.Identity()
def construct(self, x):
x = self.proj(x)
x = self.norm(x)
return x
def basic_blocks(dim, index, layers, mlp_ratio=3., qkv_bias=False, qk_scale=None, attn_drop=0.,
drop_path_rate=0., norm_layer=nn.BatchNorm2d, mode='fc', **kwargs):
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(WaveBlock(dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, drop_path=block_dpr, norm_layer=norm_layer, mode=mode))
blocks = nn.SequentialCell(*blocks)
return blocks
class WaveNet(nn.Cell):
def __init__(self, layers, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dims=None, transitions=None, mlp_ratios=None,
qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=nn.BatchNorm2d, fork_feat=False, mode='fc', ds_use_norm=True, args=None):
super().__init__()
if not fork_feat:
self.num_classes = num_classes
self.fork_feat = fork_feat
self.patch_embed = PatchEmbedOverlapping(patch_size=7, stride=4, padding=2, in_chans=3, embed_dim=embed_dims[0],
norm_layer=norm_layer, use_norm=ds_use_norm)
network = []
for i in range(len(layers)):
stage = basic_blocks(embed_dims[i], i, layers, mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
qk_scale=qk_scale, attn_drop=attn_drop_rate, drop_path_rate=drop_path_rate,
norm_layer=norm_layer, mode=mode)
network.append(stage)
if i >= len(layers) - 1:
break
if transitions[i] or embed_dims[i] != embed_dims[i + 1]:
patch_size = 2 if transitions[i] else 1
network.append(Downsample(embed_dims[i], embed_dims[i + 1], patch_size, norm_layer=norm_layer,
use_norm=ds_use_norm))
self.network = nn.SequentialCell(network)
if self.fork_feat:
# add a norm layer for each output
self.out_indices = [0, 2, 4, 6]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
layer = ops.Identity()
else:
layer = norm_layer(embed_dims[i_emb])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
else:
self.norm = norm_layer(embed_dims[-1])
self.head = nn.Dense(embed_dims[-1], num_classes) if num_classes > 0 else ops.Identity()
def cls_init_weights(self):
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Dense):
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, nn.Dense) and cell.bias is not None:
cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm)):
cell.gamma.set_data(weight_init.initializer(weight_init.One(),
cell.gamma.shape,
cell.gamma.dtype))
cell.beta.set_data(weight_init.initializer(weight_init.Zero(),
cell.beta.shape,
cell.beta.dtype))
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else ops.Identity()
def forward_embeddings(self, x):
x = self.patch_embed(x)
return x
def forward_tokens(self, x):
outs = []
for idx, block in enumerate(self.network):
x = block(x)
if self.fork_feat and idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
x_out = norm_layer(x)
outs.append(x_out)
if self.fork_feat:
return outs
return x
def construct(self, x):
x = self.forward_embeddings(x)
x = self.forward_tokens(x)
if self.fork_feat:
return x
x = self.norm(x)
cls_out = self.head(ops.Squeeze()((ops.AdaptiveAvgPool2D(output_size=1)(x))))
return cls_out
def GroupNorm(dim):
return nn.GroupNorm(1, dim)
def WaveMLP_T_dw(pretrained=False, **kwargs):
transitions = [True, True, True, True]
layers = [2, 2, 4, 2]
mlp_ratios = [4, 4, 4, 4]
embed_dims = [64, 128, 320, 512]
model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
mlp_ratios=mlp_ratios, mode='depthwise', **kwargs)
model.default_cfg = default_cfgs['wave_T']
return model
def WaveMLP_T(pretrained=False, **kwargs):
transitions = [True, True, True, True]
layers = [2, 2, 4, 2]
mlp_ratios = [4, 4, 4, 4]
embed_dims = [64, 128, 320, 512]
model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
mlp_ratios=mlp_ratios, **kwargs)
model.default_cfg = default_cfgs['wave_T']
return model
def WaveMLP_S(pretrained=False, **kwargs):
transitions = [True, True, True, True]
layers = [2, 3, 10, 3]
mlp_ratios = [4, 4, 4, 4]
embed_dims = [64, 128, 320, 512]
model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
mlp_ratios=mlp_ratios, norm_layer=GroupNorm, **kwargs)
model.default_cfg = default_cfgs['wave_S']
return model
def WaveMLP_M(pretrained=False, **kwargs):
transitions = [True, True, True, True]
layers = [3, 4, 18, 3]
mlp_ratios = [8, 8, 4, 4]
embed_dims = [64, 128, 320, 512]
model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
mlp_ratios=mlp_ratios, norm_layer=GroupNorm, ds_use_norm=False, **kwargs)
model.default_cfg = default_cfgs['wave_M']
return model
def WaveMLP_B(pretrained=False, **kwargs):
transitions = [True, True, True, True]
layers = [2, 2, 18, 2]
mlp_ratios = [4, 4, 4, 4]
embed_dims = [96, 192, 384, 768]
model = WaveNet(layers, embed_dims=embed_dims, patch_size=7, transitions=transitions,
mlp_ratios=mlp_ratios, norm_layer=GroupNorm, ds_use_norm=False, **kwargs)
model.default_cfg = default_cfgs['wave_B']
return model
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment