diff --git a/research/cv/AdaBin/README.md b/research/cv/AdaBin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a67e61f59872f77e2449ffcad77a4102d4cdefe4 --- /dev/null +++ b/research/cv/AdaBin/README.md @@ -0,0 +1,122 @@ +# Contents + +- [AdaBin Description](#AdaBin-description) +- [Dataset](#dataset) +- [Features](#features) + - [Mixed Precision](#mixed-precision) +- [Environment Requirements](#environment-requirements) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Evaluation Process](#evaluation-process) +- [Model Description](#model-description) + - [Performance](#performance) + - [Training Performance](#evaluation-performance) + - [Inference Performance](#evaluation-performance) +- [Description of Random Situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) +- [Reference](#reference) + +# [AdaBin Description](#contents) + +This paper studies the Binary Neural Networks (BNNs) in which weights and activations are both binarized into 1-bit values, thus greatly reducing the memory usage and computational complexity. Since the modern deep neural networks are of sophisticated design with complex architecture for the accuracy reason, the diversity on distributions of weights and activations is very high. Therefore, the conventional sign function cannot be well used for effectively binarizing full-precision values in BNNs. To this end, we present a simple yet effective approach called AdaBin to adaptively obtain the optimal binary sets {b1, b2} (b1, b2 鈭� R) of weights and activations for each layer instead of a fixed set (i.e., {鈭�1, +1}). In this way, the proposed method can better fit different distributions and increase the representation ability of binarized features. In practice, we use the center position and distance of 1-bit values to define a new binary quantization function. For the weights, we propose an equalization method to align the symmetrical center of binary distribution to real-valued distribution, and minimize the Kullback-Leibler divergence of them. Meanwhile, we introduce a gradient-based optimization method to get these two parameters for activations, which are jointly trained in an end-to-end manner. Experimental results on benchmark models and datasets demonstrate that the proposed AdaBin is able to achieve state-of-the-art performance. For instance, we obtain a 66.4% Top-1 accuracy on the ImageNet using ResNet-18 architecture, and a 69.4 mAP on PASCAL VOC using SSD300. + +[Paper](https://arxiv.org/abs/2208.08084): Zhijun Tu, Xinghao chen, Pengju Ren, Yunhe Wang. AdaBin: Improving Binary Neural Networks with Adaptive Binary Sets. Accepted by ECCV 2022. + +# [Dataset](#contents) + +- Dataset used: [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) + - Dataset size: 60000 colorful images in 10 classes + - Train: 50000 images + - Test: 10000 images + - Data format: RGB images. + - Note: Data will be processed in src/dataset.py + +# [Features](#contents) + +## [Mixed Precision(Ascend)](#contents) + +The mixed precision training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. +For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching 鈥榬educe precision鈥�. + +# [Environment Requirements](#contents) + +- Hardware锛圓scend/GPU/CPU锛� + - Prepare hardware environment with Ascend銆丟PU or CPU processor. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/en/r0.5/index.html) + - [MindSpore API](https://www.mindspore.cn/api/en/0.1.0-alpha/index.html) + +# [Script description](#contents) + +## [Script and sample code](#contents) + +```text +鈹溾攢鈹€ AdaBin + 鈹溾攢鈹€ README.md # readme + 鈹溾攢鈹€ src + 鈹� 鈹溾攢鈹€loss.py # label smoothing cross-entropy loss + 鈹� 鈹溾攢鈹€dataset.py # creating dataset + 鈹� 鈹溾攢鈹€resnet.py # ResNet architecture + 鈹� 鈹溾攢鈹€binarylib.py # binary quantizer + 鈹溾攢鈹€ eval.py # evaluation script +``` + +## [Evaluation Process](#contents) + +### Usage + +After installing MindSpore via the official website, you can start evaluation as follows: + +### Launch + +```bash +# infer example +GPU: python eval.py --dataset_path path/to/cifar10 --platform GPU --checkpoint_path [CHECKPOINT_PATH] +``` + +checkpoint can be found at https://download.mindspore.cn/models/r1.8/adabin_ascend_v180_cifar10_research_cv_acc88.15.ckpt + +### Result + +```bash +result on cifar-10-verify-bin: +{'Validation-Loss': 1.5793, 'Top1-Acc': 0.9212, 'Top5-Acc': 0.9986} +result on complete cifar-10 test set: +{'Validation-Loss': 0.3264, 'Top1-Acc': 0.8815} +``` + +# [Model Description](#contents) + +## [Performance](#contents) + +### Evaluation Performance + +#### AdaBin on CIFAR-10 + +| Parameters | | +| -------------------------- | -------------------------------------- | +| Model Version | AdaBin | +| uploaded Date | 08/16/2022 (month/day/year) 锛� | +| Device | GPU | +| MindSpore Version | 1.8.0 | +| Dataset | CIFAR-10 | +| Input size | 32x32 | +| Validation Loss | 0.326 | +| Training Time (min) | 350 | +| Training Time per step (s) | 0.18 | +| Accuracy (Top1) | 88.151 | + +# [Description of Random Situation](#contents) + +In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py. + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/models). + +# [Reference](#reference) + +[FDA-BNN](https://gitee.com/mindspore/models.git) + diff --git a/research/cv/AdaBin/eval.py b/research/cv/AdaBin/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad7b066598c038956438c05a5ef0938f241ebe6 --- /dev/null +++ b/research/cv/AdaBin/eval.py @@ -0,0 +1,83 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Inference Interface""" +import sys +import argparse +import logging + +from mindspore import context +from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.loss import LabelSmoothingCrossEntropy +from src.dataset import create_dataset_cifar10 +from src.resnet import resnet20 + +from easydict import EasyDict as edict + +root = logging.getLogger() +root.setLevel(logging.DEBUG) + +parser = argparse.ArgumentParser(description='Evaluation') +parser.add_argument('--data_path', type=str, default='/data/', + metavar='DIR', help='path to dataset') +parser.add_argument('--num-classes', type=int, default=10, metavar='N', + help='number of label classes (default: 10)') +parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', + help='input batch size for training (default: 32)') +parser.add_argument('--smoothing', type=float, default=0.1, + help='label smoothing (default: 0.1)') +parser.add_argument('--platform', type=str, default='CPU', + help='platform to run') +parser.add_argument('--ckpt', type=str, default='./fdabnn.ckpt', + help='model checkpoint to load') +parser.add_argument('--image_size', type=int, default=32, + help='input image size') + +def main(): + """Main entrance for training""" + args = parser.parse_args() + print(sys.argv) + + context.set_context(mode=context.PYNATIVE_MODE, device_target=args.platform, save_graphs=False) + + net = resnet20() + cfg = edict({ + 'image_height': args.image_size, + 'image_width': args.image_size, + }) + cfg.batch_size = args.batch_size + val_data_url = args.data_path + val_dataset = create_dataset_cifar10(val_data_url, repeat_num=1, training=False, cifar_cfg=cfg) + loss = LabelSmoothingCrossEntropy(smooth_factor=args.smoothing, + num_classes=args.num_classes) + + loss.add_flags_recursive(fp32=True, fp16=False) + eval_metrics = {'Validation-Loss': Loss(), + 'Top1-Acc': Top1CategoricalAccuracy(), + 'Top5-Acc': Top5CategoricalAccuracy()} + ckpt = load_checkpoint(args.ckpt) + + load_param_into_net(net, ckpt) + + net.set_train(False) + + model = Model(net, loss, metrics=eval_metrics) + metrics = model.eval(val_dataset, dataset_sink_mode=False) + print(metrics) + +if __name__ == '__main__': + main() diff --git a/research/cv/AdaBin/mindpsore_hub_conf.py b/research/cv/AdaBin/mindpsore_hub_conf.py new file mode 100644 index 0000000000000000000000000000000000000000..3215e3df926f841c14d98d228f1bc860234c13f4 --- /dev/null +++ b/research/cv/AdaBin/mindpsore_hub_conf.py @@ -0,0 +1,24 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.resnet import resnet20 + + +def create_network(name, *args, **kwargs): + """ create network """ + if name == 'resnet20': + return resnet20(*args, **kwargs) + raise NotImplementedError("{name} is not implemented in the repo") + \ No newline at end of file diff --git a/research/cv/AdaBin/src/binarylib.py b/research/cv/AdaBin/src/binarylib.py new file mode 100644 index 0000000000000000000000000000000000000000..a274f1ffced4015f7d29a7546fd20e0995d055b7 --- /dev/null +++ b/research/cv/AdaBin/src/binarylib.py @@ -0,0 +1,104 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' binary library''' +import numpy as np +import mindspore.nn as nn +from mindspore import dtype as mstype +from mindspore.ops import operations as P +from mindspore import Tensor, Parameter, ops + +Signer = ops.Sign() + +class BinaryActivation(nn.Cell): + ''' + Binarize activation with adaptive binary set + ''' + def __init__(self): + super(BinaryActivation, self).__init__() + self.alpha_a = Parameter(Tensor(1., dtype=mstype.float32), name="alpha_a", requires_grad=True) + self.beta_a = Parameter(Tensor(0., dtype=mstype.float32), name="beta_a", requires_grad=True) + + self.hardtanh = nn.Hardtanh(min_val=-1.0, max_val=1.0) + + def construct(self, x): + x_norm = (x - self.beta_a) / self.alpha_a + # clip range + x_norm = self.hardtanh(x_norm) + x_bin = Signer(x_norm) + x_adabin = (x_bin + self.beta_a)*self.alpha_a + return x_adabin + +def BinaryWeight(weight): + ''' + Binarize activation with adaptive binary set + ''' + beta_w = weight.mean((1, 2, 3)).view(-1, 1, 1, 1) + alpha_w = weight.std((1, 2, 3)).view(-1, 1, 1, 1) + + w_norm = (weight - beta_w) / alpha_w + w_bin = Signer(w_norm) + w_adabin = w_bin * alpha_w + beta_w + return w_adabin + +class AdaBinConv2d(nn.Conv2d): + ''' + AdaBin Binary Neural Network + ''' + def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode="same", padding=0, dilation=1, \ + group=1, has_bias=False, weight_init="normal", bias_init="zeros", data_format="NCHW", \ + a_bit=1, w_bit=1): + super(AdaBinConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, \ + dilation, group, has_bias, weight_init, bias_init, data_format) + self.a_bit = a_bit + self.w_bit = w_bit + self.binary_a = BinaryActivation() + self.conv2d = P.Conv2D(out_channel=out_channels, + kernel_size=kernel_size, + mode=1, + pad_mode=pad_mode, + pad=padding, + stride=stride, + dilation=dilation, + group=group) + + def construct(self, x): + if self.a_bit == 1: + x = self.binary_a(x) + + if self.w_bit == 1: + weight = BinaryWeight(self.weight) + else: + weight = self.weight + + output = self.conv2d(x, weight) + + return output + +class Maxout(nn.Cell): + ''' + Nonlinear function + ''' + def __init__(self, channel, neg_init=0.25, pos_init=1.0): + super(Maxout, self).__init__() + self.neg_scale = Parameter(Tensor(neg_init*np.ones((1, channel, 1, 1)), \ + dtype=mstype.float32), name="neg_scale", requires_grad=True) + self.pos_scale = Parameter(Tensor(pos_init*np.ones((1, channel, 1, 1)), \ + dtype=mstype.float32), name="pos_scale", requires_grad=True) + self.relu = nn.ReLU() + + def construct(self, x): + ''' Maxout ''' + x = self.pos_scale*self.relu(x) - self.neg_scale*self.relu(-x) + return x diff --git a/research/cv/AdaBin/src/dataset.py b/research/cv/AdaBin/src/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9739b8047913a837c330b39b7c26d408679cbdb5 --- /dev/null +++ b/research/cv/AdaBin/src/dataset.py @@ -0,0 +1,200 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Data operations, will be used in train.py and eval.py""" +import math +import os + +import numpy as np +import mindspore.dataset.vision as vision +import mindspore.dataset.transforms as data_trans +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +from mindspore.communication.management import get_rank, get_group_size +from mindspore.dataset.vision import Inter + + +# values that should remain constant +DEFAULT_CROP_PCT = 0.875 +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +# data preprocess configs +SCALE = (0.08, 1.0) +RATIO = (3. / 4., 4. / 3.) + +ds.config.set_seed(1) + + +def split_imgs_and_labels(imgs, labels): + """split data into labels and images""" + ret_imgs = [] + ret_labels = [] + + for i, image in enumerate(imgs): + ret_imgs.append(image) + ret_labels.append(labels[i]) + return np.array(ret_imgs), np.array(ret_labels) + + +def create_dataset(batch_size, train_data_url='', workers=8, distributed=False, + input_size=224, color_jitter=0.4): + """Create ImageNet training dataset""" + if not os.path.exists(train_data_url): + raise ValueError('Path not exists') + decode_op = vision.Decode(True) + type_cast_op = data_trans.TypeCast(mstype.int32) + + random_resize_crop_bicubic = vision.RandomResizedCrop(size=(input_size, input_size), + scale=SCALE, ratio=RATIO, + interpolation=Inter.BICUBIC) + random_horizontal_flip_op = vision.RandomHorizontalFlip(0.5) + adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter) + random_color_jitter_op = vision.RandomColorAdjust(brightness=adjust_range, + contrast=adjust_range, + saturation=adjust_range) + to_tensor = vision.ToTensor() + normalize_op = vision.Normalize( + IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, is_hwc=False) + + # assemble all the transforms + image_ops = data_trans.Compose([decode_op, random_resize_crop_bicubic, + random_horizontal_flip_op, random_color_jitter_op, to_tensor, normalize_op]) + + rank_id = get_rank() if distributed else 0 + rank_size = get_group_size() if distributed else 1 + + dataset_train = ds.ImageFolderDataset(train_data_url, + num_parallel_workers=workers, + shuffle=True, + num_shards=rank_size, + shard_id=rank_id) + + dataset_train = dataset_train.map(input_columns=["image"], + operations=image_ops, + num_parallel_workers=workers) + + dataset_train = dataset_train.map(input_columns=["label"], + operations=type_cast_op, + num_parallel_workers=workers) + + # batch dealing + ds_train = dataset_train.batch(batch_size, + per_batch_map=split_imgs_and_labels, + input_columns=["image", "label"], + num_parallel_workers=2, + drop_remainder=True) + + ds_train = ds_train.repeat(1) + return ds_train + + +def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False, + input_size=224): + """Create ImageNet validation dataset""" + if not os.path.exists(val_data_url): + raise ValueError('Path not exists') + rank_id = get_rank() if distributed else 0 + rank_size = get_group_size() if distributed else 1 + dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers, + num_shards=rank_size, shard_id=rank_id) + scale_size = None + + if isinstance(input_size, tuple): + assert len(input_size) == 2 + if input_size[-1] == input_size[-2]: + scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT)) + else: + scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size]) + else: + scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT)) + + type_cast_op = data_trans.TypeCast(mstype.int32) + decode_op = vision.Decode(True) + resize_op = vision.Resize(size=scale_size, interpolation=Inter.BICUBIC) + center_crop = vision.CenterCrop(size=input_size) + to_tensor = vision.ToTensor() + normalize_op = vision.Normalize( + IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, is_hwc=False) + + image_ops = data_trans.Compose([decode_op, resize_op, center_crop, + to_tensor, normalize_op]) + + dataset = dataset.map(input_columns=["label"], operations=type_cast_op, + num_parallel_workers=workers) + dataset = dataset.map(input_columns=["image"], operations=image_ops, + num_parallel_workers=workers) + dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels, + input_columns=["image", "label"], + num_parallel_workers=2, + drop_remainder=True) + dataset = dataset.repeat(1) + return dataset + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = rank_id = None + + return rank_size, rank_id + + +def create_dataset_cifar10(data_home, repeat_num=1, training=True, cifar_cfg=None): + """Data operations.""" + data_dir = os.path.join(data_home, "cifar-10-batches-bin") + if not training: + data_dir = os.path.join(data_home, "cifar-10-verify-bin") + + rank_size, rank_id = _get_rank_info() + if training: + data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True) + else: + data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False) + + resize_height = cifar_cfg.image_height + resize_width = cifar_cfg.image_width + + # define map operations + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_horizontal_op = vision.RandomHorizontalFlip() + resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR + rescale_op = vision.Rescale(1.0 / 255.0, 0.0) + normalize_op = vision.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), is_hwc=True) + changeswap_op = vision.HWC2CHW() + type_cast_op = data_trans.TypeCast(mstype.int32) + + c_trans = [] + if training: + c_trans = [random_crop_op, random_horizontal_op] + c_trans += [resize_op, rescale_op, normalize_op, changeswap_op] + + # apply map operations on images + data_set = data_set.map(operations=type_cast_op, input_columns="label") + data_set = data_set.map(operations=c_trans, input_columns="image") + + # apply batch operations + data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True) + + # apply repeat operations + data_set = data_set.repeat(repeat_num) + + return data_set diff --git a/research/cv/AdaBin/src/loss.py b/research/cv/AdaBin/src/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a1221745eea0e3de1627dbc705c3f23f67ad0b50 --- /dev/null +++ b/research/cv/AdaBin/src/loss.py @@ -0,0 +1,45 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network.""" + +from mindspore.nn.loss.loss import LossBase +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class LabelSmoothingCrossEntropy(LossBase): + """cross-entropy with label smoothing""" + + def __init__(self, smooth_factor=0.1, num_classes=1000): + super(LabelSmoothingCrossEntropy, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / + (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + self.cast = P.Cast() + + def construct(self, logits, label): + """construct""" + label = self.cast(label, mstype.int32) + one_hot_label = self.onehot(label, F.shape( + logits)[1], self.on_value, self.off_value) + loss_logit = self.ce(logits, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + return loss_logit diff --git a/research/cv/AdaBin/src/resnet.py b/research/cv/AdaBin/src/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f574a7e3092cb5d675052d6e4028b676c8239f48 --- /dev/null +++ b/research/cv/AdaBin/src/resnet.py @@ -0,0 +1,143 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn as nn +from mindspore.ops import operations as P + +from src.binarylib import AdaBinConv2d, Maxout + +class LambdaLayer(nn.Cell): + def __init__(self, lambd): + super(LambdaLayer, self).__init__() + self.lambd = lambd + + def forward(self, x): + return self.lambd(x) + +class BasicBlock(nn.Cell): + """ + ResNet basic cell definition. + + Args: + None. + Returns: + Tensor, output tensor. + """ + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + + self.stride = stride + self.in_planes = in_planes + self.planes = planes + + self.conv1 = AdaBinConv2d(in_planes, planes, kernel_size=3, stride=stride, pad_mode="pad", padding=1) + self.bn1 = nn.BatchNorm2d(planes) + self.nonlinear1 = Maxout(planes) + + self.conv2 = AdaBinConv2d(planes, planes, kernel_size=3, stride=1, pad_mode="pad", padding=1) + self.bn2 = nn.BatchNorm2d(planes) + self.nonlinear2 = Maxout(planes) + + self.pad = nn.SequentialCell() + if stride != 1 or in_planes != planes: + self.pad = nn.Pad(((0, 0), (planes // 4, planes // 4), (0, 0), (0, 0))) + + def construct(self, x): + """ construct """ + + out = self.bn1(self.conv1(x)) + if self.stride != 1 or self.in_planes != self.planes: + x = x[:, :, ::2, ::2] + out += self.pad(x) + out = self.nonlinear1(out) + x1 = out + out = self.bn2(self.conv2(out)) + out += x1 + out = self.nonlinear2(out) + return out + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + cell (Cell): Cell for network. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + """ + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 16 + + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, pad_mode="pad", padding=1) + self.bn1 = nn.BatchNorm2d(16) + self.nonlinear1 = Maxout(16) + + self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) + + self.ap = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.bn2 = nn.BatchNorm1d(64) + + self.linear = nn.Dense(64, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for s in strides: + layers.append(block(self.in_planes, planes, s)) + self.in_planes = planes * block.expansion + + return nn.SequentialCell(*layers) + + def construct(self, x): + """construct""" + + out = self.nonlinear1(self.bn1(self.conv1(x))) + + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + + out = self.ap(out, (2, 3)) + out = self.flatten(out) + out = self.bn2(out) + out = self.linear(out) + + return out + +def resnet20(): + """ resnet20 """ + return ResNet(BasicBlock, [3, 3, 3])