diff --git a/community/cv/snn/eval.py b/community/cv/snn/eval.py index b343c10d65c456b8a708e7ead8cdeb74b6f6cdd6..83defedf46482207b7f3ef1c79b555b0773626ba 100644 --- a/community/cv/snn/eval.py +++ b/community/cv/snn/eval.py @@ -20,11 +20,11 @@ from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper from src.dataset import create_dataset_cifar10 -import mindspore.ops as ops +import mindspore.nn as nn from mindspore import context +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy from mindspore.train.serialization import load_checkpoint, load_param_into_net -import mindspore as ms -from mindspore.ops import operations as P def modelarts_process(): @@ -35,16 +35,10 @@ def snn_model_build(): build snn model for lenet and resnet50 """ if config.net_name == "resnet50": - if config.mode_name == 'GRAPH': - from src.snn_resnet import snn_resnet50_graph as snn_resnet50 - else: - from src.snn_resnet import snn_resnet50_pynative as snn_resnet50 + from src.snn_resnet import snn_resnet50 net = snn_resnet50(class_num=config.class_num) elif config.net_name == "lenet": - if config.mode_name == 'GRAPH': - from src.snn_lenet import snn_lenet_graph as snn_lenet - else: - from src.snn_lenet import snn_lenet_pynative as snn_lenet + from src.snn_lenet import snn_lenet net = snn_lenet(num_class=config.class_num) else: raise ValueError(f'config.model: {config.model_name} is not supported') @@ -57,8 +51,6 @@ def eval_net(): eval net """ print('eval with config: ', config) - correct = 0.0 - total = 0.0 if config.mode_name == 'GRAPH': context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) else: @@ -68,23 +60,12 @@ def eval_net(): if ds_eval.get_dataset_size() == 0: raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") network_eval = snn_model_build() + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + model = Model(network_eval, net_loss, metrics={"Accuracy": Accuracy()}) param_dict = load_checkpoint(config.ckpt_path) load_param_into_net(network_eval, param_dict) - network_eval.set_train(False) - print("============== Starting Testing ==============", flush=True) - for _, data in enumerate(ds_eval.create_dict_iterator()): - image = data['image'] - label = data['label'] - outspikes = network_eval(image) - predicted = ops.Argmax(output_type=ms.int32)(outspikes) - total += label.shape[0] - cast = P.Cast() - correct += cast((predicted == label), ms.float32).sum().asnumpy().item() - if config.mode_name == 'PYNATIVE': - network_eval.reset_net() - - accuracy = 100 * correct / total - print('Accuracy of the network is: %.4f %%' % accuracy, flush=True) + acc = model.eval(ds_eval) + print("============== {} ==============".format(acc)) if __name__ == "__main__": eval_net() diff --git a/community/cv/snn/src/ifnode.py b/community/cv/snn/src/ifnode.py index 7d763360bd40dd3754844a64584dd89deeda2c9a..369903e62b10010759ce12864e42a5957ae3ac05 100644 --- a/community/cv/snn/src/ifnode.py +++ b/community/cv/snn/src/ifnode.py @@ -37,7 +37,7 @@ class relusigmoid(nn.Cell): # must be a tuple return (grad_x,) -class IFNode_GRAPH(nn.Cell): +class IFNode(nn.Cell): """ integrate and fire cell for GRAPH mode, it will output spike value """ @@ -48,43 +48,10 @@ class IFNode_GRAPH(nn.Cell): self.surrogate_function = surrogate_function def construct(self, x, v): - """ neuronal_charge: v need to do add""" + """neuronal_charge: v need to do add""" v = v + x if self.fire: spike = self.surrogate_function(v - self.v_threshold) * self.v_threshold v -= spike return spike, v return v, v - - -class IFNode_PYNATIVE(nn.Cell): - """ - integrate and fire cell for PYNATIVE mode, it will output spike value - """ - def __init__(self, v_threshold=1.0, v_reset=0.0, fire=True, surrogate_function=relusigmoid()): - super().__init__() - self.v_threshold = v_threshold - if v_reset is None: - self.v_reset = 0.0 - else: - self.v_reset = v_reset - self.v = self.v_reset - self.fire = fire - self.surrogate_function = surrogate_function - - def construct(self, x): - """ neuronal_charge: self.v need to do add""" - self.v = self.v + x - # neuronal_fire - if self.fire: - spike = self.surrogate_function(self.v - self.v_threshold) * self.v_threshold - self.v -= spike - return spike - return self.v - - def reset(self): - """each batch should reset the accumulated value of the net such as self.v""" - if self.v_reset is None: - self.v = 0.0 - else: - self.v = self.v_reset diff --git a/community/cv/snn/src/snn_lenet.py b/community/cv/snn/src/snn_lenet.py index 25fc609b18d7693d9cecd17298e99c395a0d65cc..aedbcfcbf2f2be0b95a48305765f59332a2e8b07 100644 --- a/community/cv/snn/src/snn_lenet.py +++ b/community/cv/snn/src/snn_lenet.py @@ -16,7 +16,7 @@ import mindspore.nn as nn from mindspore.ops import operations as P from mindspore import Tensor -from src.ifnode import IFNode_GRAPH, IFNode_PYNATIVE +from src.ifnode import IFNode import numpy as np @@ -44,127 +44,87 @@ def init_dense_bias(inC, outC): return Tensor(weight) -class snn_lenet_graph(nn.Cell): +class Conv2d_Block(nn.Cell): """ - snn backbone for lenet with graph mode + block: conv2d + ifnode """ - def __init__(self, num_class=10, num_channel=3): - super(snn_lenet_graph, self).__init__() - self.T = 100 - self.conv1 = nn.Conv2d(num_channel, 16, 3, stride=1, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(num_channel, 16, 3), bias_init=init_bias(num_channel, 16, 3)) - self.ifnode1 = IFNode_GRAPH() - self.conv2 = nn.Conv2d(16, 16, 3, stride=2, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(16, 16, 3), bias_init=init_bias(16, 16, 3)) - self.ifnode2 = IFNode_GRAPH() - self.conv3 = nn.Conv2d(16, 32, 3, stride=1, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(16, 32, 3), bias_init=init_bias(16, 32, 3)) - self.ifnode3 = IFNode_GRAPH() - self.conv4 = nn.Conv2d(32, 32, 3, stride=2, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(32, 32, 3), bias_init=init_bias(32, 32, 3)) - self.ifnode4 = IFNode_GRAPH() - self.conv5 = nn.Conv2d(32, 64, 3, stride=1, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(32, 64, 3), bias_init=init_bias(32, 64, 3)) - self.ifnode5 = IFNode_GRAPH() - self.conv6 = nn.Conv2d(64, 64, 3, stride=2, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(64, 64, 3), bias_init=init_bias(64, 64, 3)) - self.ifnode6 = IFNode_GRAPH() - self.fc1 = nn.Dense(64 * 4 * 4, 32, weight_init=init_dense_weight(64 * 4 * 4, 32), - bias_init=init_dense_bias(64 * 4 * 4, 32)) - self.ifnode7 = IFNode_GRAPH() - self.fc2 = nn.Dense(32, num_class, weight_init=init_dense_weight(32, num_class), - bias_init=init_dense_bias(32, num_class)) - self.ifnode8 = IFNode_GRAPH(fire=False) + def __init__(self, in_channels, out_channels, weight_init, bias_init, kernel_size=3, stride=1, + pad_mode='pad', padding=1, has_bias=True): + super(Conv2d_Block, self).__init__() + self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, pad_mode=pad_mode, padding=padding, has_bias=has_bias, + weight_init=weight_init, bias_init=bias_init) + self.ifnode = IFNode() def construct(self, x_in): - """forward the snn-lenet block""" - x = x_in - v1 = v2 = v3 = v4 = v5 = v6 = v7 = v8 = 0.0 - for _ in range(self.T): - x = self.conv1(x_in) - x, v1 = self.ifnode1(x, v1) - x = self.conv2(x) - x, v2 = self.ifnode2(x, v2) - x = self.conv3(x) - x, v3 = self.ifnode3(x, v3) - x = self.conv4(x) - x, v4 = self.ifnode4(x, v4) - x = self.conv5(x) - x, v5 = self.ifnode5(x, v5) - x = self.conv6(x) - x, v6 = self.ifnode6(x, v6) - x = P.Reshape()(x, (-1, 64 * 4 * 4)) - x = self.fc1(x) - x, v7 = self.ifnode7(x, v7) - x = self.fc2(x) - x, v8 = self.ifnode8(x, v8) - return x / self.T + x, v1 = x_in + out = self.conv2d(x) + out, v1 = self.ifnode(out, v1) + return (out, v1) -class snn_lenet_pynative(nn.Cell): +class Dense_Block(nn.Cell): """ - snn backbone for lenet with pynative mode + block: dense + ifnode + """ + def __init__(self, in_channels, out_channels, weight_init, bias_init): + super(Dense_Block, self).__init__() + self.dense = nn.Dense(in_channels=in_channels, out_channels=out_channels, + weight_init=weight_init, bias_init=bias_init) + self.ifnode = IFNode() + + def construct(self, x_in): + x, v1 = x_in + out = self.dense(x) + out, v1 = self.ifnode(out, v1) + return out, v1 + +class snn_lenet(nn.Cell): + """ + snn backbone for lenet with graph mode """ def __init__(self, num_class=10, num_channel=3): - super(snn_lenet_pynative, self).__init__() + super(snn_lenet, self).__init__() self.T = 100 - self.conv1 = nn.SequentialCell([nn.Conv2d(num_channel, 16, 3, stride=1, pad_mode='pad', padding=1, - has_bias=True, weight_init=init_weight(num_channel, 16, 3), - bias_init=init_bias(num_channel, 16, 3)), - IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)]) - - self.conv2 = nn.SequentialCell([nn.Conv2d(16, 16, 3, stride=2, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(16, 16, 3), bias_init=init_bias(16, 16, 3)), - IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)]) + self.conv1 = Conv2d_Block(in_channels=num_channel, out_channels=16, + weight_init=init_weight(num_channel, 16, 3), bias_init=init_bias(num_channel, 16, 3)) - self.conv3 = nn.SequentialCell([nn.Conv2d(16, 32, 3, stride=1, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(16, 32, 3), bias_init=init_bias(16, 32, 3)), - IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)]) + self.conv2 = Conv2d_Block(in_channels=16, out_channels=16, stride=2, + weight_init=init_weight(16, 16, 3), bias_init=init_bias(16, 16, 3)) - self.conv4 = nn.SequentialCell([nn.Conv2d(32, 32, 3, stride=2, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(32, 32, 3), bias_init=init_bias(32, 32, 3)), - IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)]) + self.conv3 = Conv2d_Block(in_channels=16, out_channels=32, + weight_init=init_weight(16, 32, 3), bias_init=init_bias(16, 32, 3)) - self.conv5 = nn.SequentialCell([nn.Conv2d(32, 64, 3, stride=1, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(32, 64, 3), bias_init=init_bias(32, 64, 3)), - IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)]) + self.conv4 = Conv2d_Block(in_channels=32, out_channels=32, stride=2, + weight_init=init_weight(32, 32, 3), bias_init=init_bias(32, 32, 3)) - self.conv6 = nn.SequentialCell([nn.Conv2d(64, 64, 3, stride=2, pad_mode='pad', padding=1, has_bias=True, - weight_init=init_weight(64, 64, 3), bias_init=init_bias(64, 64, 3)), - IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)]) + self.conv5 = Conv2d_Block(in_channels=32, out_channels=64, + weight_init=init_weight(32, 64, 3), bias_init=init_bias(32, 64, 3)) - self.fc1 = nn.SequentialCell([nn.Dense(64 * 4 * 4, 32, - weight_init=init_dense_weight(64 * 4 * 4, 32), - bias_init=init_dense_bias(64 * 4 * 4, 32)), - IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)]) + self.conv6 = Conv2d_Block(in_channels=64, out_channels=64, stride=2, + weight_init=init_weight(64, 64, 3), bias_init=init_bias(64, 64, 3)) - self.fc2 = nn.Dense(32, num_class, weight_init=init_dense_weight(32, num_class), - bias_init=init_dense_bias(32, num_class)) + self.dense1 = Dense_Block(in_channels=64 * 4 * 4, out_channels=32, + weight_init=init_dense_weight(64 * 4 * 4, 32), + bias_init=init_dense_bias(64 * 4 * 4, 32)) - self.outlayer = IFNode_PYNATIVE(v_threshold=1.0, v_reset=None, fire=False) + self.fc = nn.Dense(32, num_class, weight_init=init_dense_weight(32, num_class), + bias_init=init_dense_bias(32, num_class)) + self.end_ifnode = IFNode(fire=False) def construct(self, x_in): """forward the snn-lenet block""" x = x_in + v1 = v2 = v3 = v4 = v5 = v6 = v7 = v8 = 0.0 for _ in range(self.T): - x = self.conv1(x_in) - x = self.conv2(x) - x = self.conv3(x) - x = self.conv4(x) - x = self.conv5(x) - x = self.conv6(x) + x, v1 = self.conv1((x_in, v1)) + x, v2 = self.conv2((x, v2)) + x, v3 = self.conv3((x, v3)) + x, v4 = self.conv4((x, v4)) + x, v5 = self.conv5((x, v5)) + x, v6 = self.conv6((x, v6)) x = P.Reshape()(x, (-1, 64 * 4 * 4)) - x = self.fc1(x) - x = self.fc2(x) - x = self.outlayer(x) + x, v7 = self.dense1((x, v7)) + x = self.fc(x) + x, v8 = self.end_ifnode(x, v8) return x / self.T - - def reset_net(self): - """each batch should reset the accumulated value of the net such as self.v""" - for item in self.cells(): - if isinstance(type(item), type(nn.SequentialCell())): - if hasattr(item[-1], 'reset'): - item[-1].reset() - else: - if hasattr(item, 'reset'): - item.reset() diff --git a/community/cv/snn/src/snn_resnet.py b/community/cv/snn/src/snn_resnet.py index 4836401d2360390c26f8c387dbb76d2d39a168fc..cab8eb739d371e2830b49d8f31668067781600d3 100644 --- a/community/cv/snn/src/snn_resnet.py +++ b/community/cv/snn/src/snn_resnet.py @@ -15,9 +15,10 @@ """ResNet_SNN.""" import math import mindspore.nn as nn +from mindspore import Tensor import mindspore.ops as ops from mindspore.common.initializer import HeNormal, HeUniform -from src.ifnode import IFNode_GRAPH, IFNode_PYNATIVE +from src.ifnode import IFNode def _conv3x3(in_channel, out_channel, stride=1): @@ -47,7 +48,7 @@ def _fc(in_channel, out_channel): bias_init=0) -class ResidualBlock_GRAPH(nn.Cell): +class ResidualBlock(nn.Cell): """ ResNet V1 residual block definition. @@ -60,21 +61,21 @@ class ResidualBlock_GRAPH(nn.Cell): Tensor, output tensor. Examples: - >>> ResidualBlock_GRAPH(3, 256, stride=2) + >>> ResidualBlock(3, 256, stride=2) """ expansion = 4 def __init__(self, in_channel, out_channel, stride=1): - super(ResidualBlock_GRAPH, self).__init__() + super(ResidualBlock, self).__init__() self.stride = stride channel = out_channel // self.expansion self.conv1 = _conv1x1(in_channel, channel, stride=1) self.bn1 = _bn(channel) - self.ifnode1 = IFNode_GRAPH() + self.ifnode1 = IFNode() self.conv2 = _conv3x3(channel, channel, stride=stride) self.bn2 = _bn(channel) - self.ifnode2 = IFNode_GRAPH() + self.ifnode2 = IFNode() self.conv3 = _conv1x1(channel, out_channel, stride=1) self.bn3 = _bn(out_channel) @@ -87,10 +88,10 @@ class ResidualBlock_GRAPH(nn.Cell): if self.down_sample: self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)]) - self.ifnode3 = IFNode_GRAPH() + self.ifnode3 = IFNode() def construct(self, x_in): - """ResidualBlock with graph mode""" + """ResidualBlock""" x, v1, v2, v3 = x_in identity = x @@ -113,7 +114,7 @@ class ResidualBlock_GRAPH(nn.Cell): return (out, v1, v2, v3) -class ResNet_SNN_GRAPH(nn.Cell): +class ResNet_SNN(nn.Cell): """ ResNet architecture. @@ -129,16 +130,16 @@ class ResNet_SNN_GRAPH(nn.Cell): Tensor, output tensor. Examples: - >>> ResNet_SNN_GRAPH(ResidualBlock, - >>> [3, 4, 6, 3], - >>> [64, 256, 512, 1024], - >>> [256, 512, 1024, 2048], - >>> [1, 2, 2, 2], - >>> 10) + >>> ResNet_SNN(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) """ def __init__(self, block, layer_nums, in_channels, out_channels, strides, num_classes): - super(ResNet_SNN_GRAPH, self).__init__() + super(ResNet_SNN, self).__init__() if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") @@ -146,79 +147,48 @@ class ResNet_SNN_GRAPH(nn.Cell): self.T = 5 self.conv1 = _conv7x7(3, 64, stride=2) self.bn1 = _bn(64) - self.ifnode1 = IFNode_GRAPH() + self.ifnode1 = IFNode() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.layer_nums = layer_nums # layer_nums:[3, 4, 6, 3] - self.layer1_1 = self._make_layer_test1(block, in_channel=in_channels[0], - out_channel=out_channels[0], stride=strides[0]) - self.layer1_2 = self._make_layer_test2(block, out_channel=out_channels[0],) - self.layer1_3 = self._make_layer_test2(block, out_channel=out_channels[0],) - self.layer2_1 = self._make_layer_test1(block, in_channel=in_channels[1], - out_channel=out_channels[1], stride=strides[1]) - self.layer2_2 = self._make_layer_test2(block, out_channel=out_channels[1]) - self.layer2_3 = self._make_layer_test2(block, out_channel=out_channels[1]) - self.layer2_4 = self._make_layer_test2(block, out_channel=out_channels[1]) - self.layer3_1 = self._make_layer_test1(block, in_channel=in_channels[2], - out_channel=out_channels[2], stride=strides[2]) - self.layer3_2 = self._make_layer_test2(block, out_channel=out_channels[2]) - self.layer3_3 = self._make_layer_test2(block, out_channel=out_channels[2]) - self.layer3_4 = self._make_layer_test2(block, out_channel=out_channels[2]) - self.layer3_5 = self._make_layer_test2(block, out_channel=out_channels[2]) - self.layer3_6 = self._make_layer_test2(block, out_channel=out_channels[2]) - self.layer4_1 = self._make_layer_test1(block, in_channel=in_channels[3], - out_channel=out_channels[3], stride=strides[3]) - self.layer4_2 = self._make_layer_test2(block, out_channel=out_channels[3]) - self.layer4_3 = self._make_layer_test2(block, out_channel=out_channels[3]) + self.layer1 = self.make_layer(block, layer_nums[0], in_channel=in_channels[0], + out_channel=out_channels[0], stride=strides[0]) + self.layer2 = self.make_layer(block, layer_nums[1], in_channel=in_channels[1], + out_channel=out_channels[1], stride=strides[1]) + self.layer3 = self.make_layer(block, layer_nums[2], in_channel=in_channels[2], + out_channel=out_channels[2], stride=strides[2]) + self.layer4 = self.make_layer(block, layer_nums[3], in_channel=in_channels[3], + out_channel=out_channels[3], stride=strides[3]) self.mean = ops.ReduceMean(keep_dims=True) self.flatten = nn.Flatten() self.end_point = _fc(out_channels[3], num_classes) - self.end_ifnode = IFNode_GRAPH(fire=False) + self.end_ifnode = IFNode(fire=False) + self.layers = nn.CellList([self.layer1, self.layer2, self.layer3, self.layer4]) - def _make_layer_test1(self, block, in_channel, out_channel, stride): - """ - Make stage network of ResNet. - - Args: - block (Cell): Resnet block. - in_channel (int): Input channel. - out_channel (int): Output channel. - stride (int): Stride size for the first convolutional layer. - Returns: - SequentialCell, the output layer. - """ + def make_layer(self, block, layer_num, in_channel, out_channel, stride): layers = [] + resnet_block = block(in_channel, out_channel, stride=stride) layers.append(resnet_block) - return nn.SequentialCell(layers) - - def _make_layer_test2(self, block, out_channel): - """ - Make stage network of ResNet. - - Args: - block (Cell): Resnet block. - out_channel (int): Output channel. - Returns: - SequentialCell, the output layer. - """ - layers = [] - resnet_block = block(out_channel, out_channel, stride=1) - layers.append(resnet_block) - return nn.SequentialCell(layers) + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1) + layers.append(resnet_block) + + return nn.layer.CellList(layers) + def construct(self, x_in): """ResNet SNN block with graph mode""" out = x_in v1 = v_end = 0.0 - # layer_nums:[3, 4, 6, 3] - v1_1_1 = v1_1_2 = v1_1_3 = v1_2_1 = v1_2_2 = v1_2_3 = v1_3_1 = v1_3_2 = v1_3_3 = 0.0 - v2_1_1 = v2_1_2 = v2_1_3 = v2_2_1 = v2_2_2 = v2_2_3 = v2_3_1 = v2_3_2 = v2_3_3 = v2_4_1 = v2_4_2 = v2_4_3 = 0.0 - v3_1_1 = v3_1_2 = v3_1_3 = v3_2_1 = v3_2_2 = v3_2_3 = v3_3_1 = v3_3_2 = v3_3_3 = 0.0 - v3_4_1 = v3_4_2 = v3_4_3 = v3_5_1 = v3_5_2 = v3_5_3 = v3_6_1 = v3_6_2 = v3_6_3 = 0.0 - v4_1_1 = v4_1_2 = v4_1_3 = v4_2_1 = v4_2_2 = v4_2_3 = v4_3_1 = v4_3_2 = v4_3_3 = 0.0 + + V = [] + for layer_num in self.layer_nums: + for _ in range(layer_num): + V.append([Tensor(0.0), Tensor(0.0), Tensor(0.0)]) for _ in range(self.T): x = self.conv1(x_in) @@ -226,25 +196,20 @@ class ResNet_SNN_GRAPH(nn.Cell): x, v1 = self.ifnode1(x, v1) c1 = self.maxpool(x) - - c1_1, v1_1_1, v1_1_2, v1_1_3 = self.layer1_1((c1, v1_1_1, v1_1_2, v1_1_3)) - c1_2, v1_2_1, v1_2_2, v1_2_3 = self.layer1_2((c1_1, v1_2_1, v1_2_2, v1_2_3)) - c1_3, v1_3_1, v1_3_2, v1_3_3 = self.layer1_3((c1_2, v1_3_1, v1_3_2, v1_3_3)) - c2_1, v2_1_1, v2_1_2, v2_1_3 = self.layer2_1((c1_3, v2_1_1, v2_1_2, v2_1_3)) - c2_2, v2_2_1, v2_2_2, v2_2_3 = self.layer2_2((c2_1, v2_2_1, v2_2_2, v2_2_3)) - c2_3, v2_3_1, v2_3_2, v2_3_3 = self.layer2_3((c2_2, v2_3_1, v2_3_2, v2_3_3)) - c2_4, v2_4_1, v2_4_2, v2_4_3 = self.layer2_4((c2_3, v2_4_1, v2_4_2, v2_4_3)) - c3_1, v3_1_1, v3_1_2, v3_1_3 = self.layer3_1((c2_4, v3_1_1, v3_1_2, v3_1_3)) - c3_2, v3_2_1, v3_2_2, v3_2_3 = self.layer3_2((c3_1, v3_2_1, v3_2_2, v3_2_3)) - c3_3, v3_3_1, v3_3_2, v3_3_3 = self.layer3_3((c3_2, v3_3_1, v3_3_2, v3_3_3)) - c3_4, v3_4_1, v3_4_2, v3_4_3 = self.layer3_4((c3_3, v3_4_1, v3_4_2, v3_4_3)) - c3_5, v3_5_1, v3_5_2, v3_5_3 = self.layer3_5((c3_4, v3_5_1, v3_5_2, v3_5_3)) - c3_6, v3_6_1, v3_6_2, v3_6_3 = self.layer3_6((c3_5, v3_6_1, v3_6_2, v3_6_3)) - c4_1, v4_1_1, v4_1_2, v4_1_3 = self.layer4_1((c3_6, v4_1_1, v4_1_2, v4_1_3)) - c4_2, v4_2_1, v4_2_2, v4_2_3 = self.layer4_2((c4_1, v4_2_1, v4_2_2, v4_2_3)) - c4_3, v4_3_1, v4_3_2, v4_3_3 = self.layer4_3((c4_2, v4_3_1, v4_3_2, v4_3_3)) - - out = self.mean(c4_3, (2, 3)) + out = c1 + + index = 0 + ifnode_count = 0 + for row in self.layer_nums: + layers = self.layers[index] + for col in range(row): + block = layers[col] + out, V[ifnode_count + col][0], V[ifnode_count + col][1], V[ifnode_count + col][2] = \ + block((out, V[ifnode_count + col][0], V[ifnode_count + col][1], V[ifnode_count + col][2])) + ifnode_count += self.layer_nums[index] + index += 1 + + out = self.mean(out, (2, 3)) out = self.flatten(out) out = self.end_point(out) out, v_end = self.end_ifnode(out, v_end) @@ -252,201 +217,10 @@ class ResNet_SNN_GRAPH(nn.Cell): return out / self.T -class ResidualBlock_PYNATIVE(nn.Cell): - """ - ResNet V1 residual block definition. - - Args: - in_channel (int): Input channel. - out_channel (int): Output channel. - stride (int): Stride size for the first convolutional layer. Default: 1. - - Returns: - Tensor, output tensor. - - Examples: - >>> ResidualBlock_PYNATIVE(3, 256, stride=2) - """ - expansion = 4 - - def __init__(self, - in_channel, - out_channel, - stride=1): - super(ResidualBlock_PYNATIVE, self).__init__() - self.stride = stride - channel = out_channel // self.expansion - self.conv1 = _conv1x1(in_channel, channel, stride=1) - self.bn1 = _bn(channel) - self.ifnode1 = IFNode_PYNATIVE() - - self.conv2 = _conv3x3(channel, channel, stride=stride) - self.bn2 = _bn(channel) - self.ifnode2 = IFNode_PYNATIVE() - - self.conv3 = _conv1x1(channel, out_channel, stride=1) - self.bn3 = _bn(out_channel) - - self.down_sample = False - if stride != 1 or in_channel != out_channel: - self.down_sample = True - self.down_sample_layer = None - - if self.down_sample: - self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)]) - - self.ifnode3 = IFNode_PYNATIVE() - - def construct(self, x): - """ResidualBlock with pynative mode""" - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.ifnode1(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.ifnode2(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.down_sample: - identity = self.down_sample_layer(identity) - out = out + identity - out = self.ifnode3(out) - - return out - - -class ResNet_SNN_PYNATIVE(nn.Cell): - """ - ResNet architecture. - - Args: - block (Cell): Block for network. - layer_nums (list): Numbers of block in different layers. - in_channels (list): Input channel in each layer. - out_channels (list): Output channel in each layer. - strides (list): Stride size in each layer. - num_classes (int): The number of classes that the training images are belonging to. - - Returns: - Tensor, output tensor. - - Examples: - >>> ResNet_SNN_PYNATIVE(ResidualBlock, - >>> [3, 4, 6, 3], - >>> [64, 256, 512, 1024], - >>> [256, 512, 1024, 2048], - >>> [1, 2, 2, 2], - >>> 10) - """ - - def __init__(self, - block, - layer_nums, - in_channels, - out_channels, - strides, - num_classes): - super(ResNet_SNN_PYNATIVE, self).__init__() - - if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: - raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") - - self.T = 5 - self.conv1 = _conv7x7(3, 64, stride=2) - self.bn1 = _bn(64) - self.ifnode1 = IFNode_PYNATIVE() - - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") - - self.layer1 = self._make_layer(block, layer_nums[0], in_channel=in_channels[0], - out_channel=out_channels[0], stride=strides[0]) - self.layer2 = self._make_layer(block, layer_nums[1], in_channel=in_channels[1], - out_channel=out_channels[1], stride=strides[1]) - self.layer3 = self._make_layer(block, layer_nums[2], in_channel=in_channels[2], - out_channel=out_channels[2], stride=strides[2]) - self.layer4 = self._make_layer(block, layer_nums[3], in_channel=in_channels[3], - out_channel=out_channels[3], stride=strides[3]) - - self.mean = ops.ReduceMean(keep_dims=True) - self.flatten = nn.Flatten() - self.end_point = _fc(out_channels[3], num_classes) - self.end_ifnode = IFNode_PYNATIVE(fire=False) - - def construct(self, x_in): - """ResNet SNN block with pynative mode""" - out = x_in - for _ in range(self.T): - x = self.conv1(x_in) - x = self.bn1(x) - x = self.ifnode1(x) - - c1 = self.maxpool(x) - - c2 = self.layer1(c1) - c3 = self.layer2(c2) - c4 = self.layer3(c3) - c5 = self.layer4(c4) - - out = self.mean(c5, (2, 3)) - out = self.flatten(out) - out = self.end_point(out) - out = self.end_ifnode(out) - - return out / self.T - - def reset_net(self): - for item in self.cells(): - if isinstance(type(item), type(nn.SequentialCell())): - if hasattr(item[-1], 'reset'): - item[-1].reset() - else: - if hasattr(item, 'reset'): - item.reset() - - def _make_layer(self, block, layer_num, in_channel, out_channel, stride): - """ - Make stage network of ResNet. - - Args: - block (Cell): Resnet block. - layer_num (int): Layer number. - in_channel (int): Input channel. - out_channel (int): Output channel. - stride (int): Stride size for the first convolutional layer. - Returns: - SequentialCell, the output layer. - - Examples: - >>> _make_layer(ResidualBlock, 3, 128, 256, 2) - """ - layers = [] - - resnet_block = block(in_channel, out_channel, stride=stride) - layers.append(resnet_block) - for _ in range(1, layer_num): - resnet_block = block(out_channel, out_channel, stride=1) - layers.append(resnet_block) - - return nn.SequentialCell(layers) - -def snn_resnet50_graph(class_num=10): - return ResNet_SNN_GRAPH(ResidualBlock_GRAPH, - [3, 4, 6, 3], - [64, 256, 512, 1024], - [256, 512, 1024, 2048], - [1, 2, 2, 2], - class_num) - - -def snn_resnet50_pynative(class_num=10): - return ResNet_SNN_PYNATIVE(ResidualBlock_PYNATIVE, - [3, 4, 6, 3], - [64, 256, 512, 1024], - [256, 512, 1024, 2048], - [1, 2, 2, 2], - class_num) +def snn_resnet50(class_num=10): + return ResNet_SNN(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) diff --git a/community/cv/snn/train.py b/community/cv/snn/train.py index 3e9548756bc603825709cc36c881cb8258ddb3c7..ea111112458f5963c29759c548379e80d3467238 100644 --- a/community/cv/snn/train.py +++ b/community/cv/snn/train.py @@ -150,17 +150,11 @@ class AverageMeter: def snn_model_build(): """build snn model for resnet50 and lenet""" if config.net_name == "resnet50": - if config.mode_name == 'GRAPH': - from src.snn_resnet import snn_resnet50_graph as snn_resnet50 - else: - from src.snn_resnet import snn_resnet50_pynative as snn_resnet50 + from src.snn_resnet import snn_resnet50 net = snn_resnet50(class_num=config.class_num) init_weight(net=net) elif config.net_name == "lenet": - if config.mode_name == 'GRAPH': - from src.snn_lenet import snn_lenet_graph as snn_lenet - else: - from src.snn_lenet import snn_lenet_pynative as snn_lenet + from src.snn_lenet import snn_lenet net = snn_lenet(num_class=config.class_num) else: raise ValueError(f'config.model: {config.model_name} is not supported') @@ -223,8 +217,6 @@ def train_net(): label = onehot(label, config.class_num, Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)) loss = network_train(images, label) loss_meter.update(loss.asnumpy()) - if config.mode_name == 'PYNATIVE': - net.reset_net() if config.save_checkpoint: cb_params.cur_epoch_num = epoch_idx + 1