Skip to content
Snippets Groups Projects
Unverified Commit c511ad28 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3202 opt snn code

Merge pull request !3202 from 周莉莉/snn
parents 71b5dcab ac42b1e3
No related branches found
No related tags found
No related merge requests found
......@@ -20,11 +20,11 @@ from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.dataset import create_dataset_cifar10
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore import context
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms
from mindspore.ops import operations as P
def modelarts_process():
......@@ -35,16 +35,10 @@ def snn_model_build():
build snn model for lenet and resnet50
"""
if config.net_name == "resnet50":
if config.mode_name == 'GRAPH':
from src.snn_resnet import snn_resnet50_graph as snn_resnet50
else:
from src.snn_resnet import snn_resnet50_pynative as snn_resnet50
from src.snn_resnet import snn_resnet50
net = snn_resnet50(class_num=config.class_num)
elif config.net_name == "lenet":
if config.mode_name == 'GRAPH':
from src.snn_lenet import snn_lenet_graph as snn_lenet
else:
from src.snn_lenet import snn_lenet_pynative as snn_lenet
from src.snn_lenet import snn_lenet
net = snn_lenet(num_class=config.class_num)
else:
raise ValueError(f'config.model: {config.model_name} is not supported')
......@@ -57,8 +51,6 @@ def eval_net():
eval net
"""
print('eval with config: ', config)
correct = 0.0
total = 0.0
if config.mode_name == 'GRAPH':
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
else:
......@@ -68,23 +60,12 @@ def eval_net():
if ds_eval.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
network_eval = snn_model_build()
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
model = Model(network_eval, net_loss, metrics={"Accuracy": Accuracy()})
param_dict = load_checkpoint(config.ckpt_path)
load_param_into_net(network_eval, param_dict)
network_eval.set_train(False)
print("============== Starting Testing ==============", flush=True)
for _, data in enumerate(ds_eval.create_dict_iterator()):
image = data['image']
label = data['label']
outspikes = network_eval(image)
predicted = ops.Argmax(output_type=ms.int32)(outspikes)
total += label.shape[0]
cast = P.Cast()
correct += cast((predicted == label), ms.float32).sum().asnumpy().item()
if config.mode_name == 'PYNATIVE':
network_eval.reset_net()
accuracy = 100 * correct / total
print('Accuracy of the network is: %.4f %%' % accuracy, flush=True)
acc = model.eval(ds_eval)
print("============== {} ==============".format(acc))
if __name__ == "__main__":
eval_net()
......@@ -37,7 +37,7 @@ class relusigmoid(nn.Cell):
# must be a tuple
return (grad_x,)
class IFNode_GRAPH(nn.Cell):
class IFNode(nn.Cell):
"""
integrate and fire cell for GRAPH mode, it will output spike value
"""
......@@ -48,43 +48,10 @@ class IFNode_GRAPH(nn.Cell):
self.surrogate_function = surrogate_function
def construct(self, x, v):
""" neuronal_charge: v need to do add"""
"""neuronal_charge: v need to do add"""
v = v + x
if self.fire:
spike = self.surrogate_function(v - self.v_threshold) * self.v_threshold
v -= spike
return spike, v
return v, v
class IFNode_PYNATIVE(nn.Cell):
"""
integrate and fire cell for PYNATIVE mode, it will output spike value
"""
def __init__(self, v_threshold=1.0, v_reset=0.0, fire=True, surrogate_function=relusigmoid()):
super().__init__()
self.v_threshold = v_threshold
if v_reset is None:
self.v_reset = 0.0
else:
self.v_reset = v_reset
self.v = self.v_reset
self.fire = fire
self.surrogate_function = surrogate_function
def construct(self, x):
""" neuronal_charge: self.v need to do add"""
self.v = self.v + x
# neuronal_fire
if self.fire:
spike = self.surrogate_function(self.v - self.v_threshold) * self.v_threshold
self.v -= spike
return spike
return self.v
def reset(self):
"""each batch should reset the accumulated value of the net such as self.v"""
if self.v_reset is None:
self.v = 0.0
else:
self.v = self.v_reset
......@@ -16,7 +16,7 @@
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from src.ifnode import IFNode_GRAPH, IFNode_PYNATIVE
from src.ifnode import IFNode
import numpy as np
......@@ -44,127 +44,87 @@ def init_dense_bias(inC, outC):
return Tensor(weight)
class snn_lenet_graph(nn.Cell):
class Conv2d_Block(nn.Cell):
"""
snn backbone for lenet with graph mode
block: conv2d + ifnode
"""
def __init__(self, num_class=10, num_channel=3):
super(snn_lenet_graph, self).__init__()
self.T = 100
self.conv1 = nn.Conv2d(num_channel, 16, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(num_channel, 16, 3), bias_init=init_bias(num_channel, 16, 3))
self.ifnode1 = IFNode_GRAPH()
self.conv2 = nn.Conv2d(16, 16, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(16, 16, 3), bias_init=init_bias(16, 16, 3))
self.ifnode2 = IFNode_GRAPH()
self.conv3 = nn.Conv2d(16, 32, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(16, 32, 3), bias_init=init_bias(16, 32, 3))
self.ifnode3 = IFNode_GRAPH()
self.conv4 = nn.Conv2d(32, 32, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(32, 32, 3), bias_init=init_bias(32, 32, 3))
self.ifnode4 = IFNode_GRAPH()
self.conv5 = nn.Conv2d(32, 64, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(32, 64, 3), bias_init=init_bias(32, 64, 3))
self.ifnode5 = IFNode_GRAPH()
self.conv6 = nn.Conv2d(64, 64, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(64, 64, 3), bias_init=init_bias(64, 64, 3))
self.ifnode6 = IFNode_GRAPH()
self.fc1 = nn.Dense(64 * 4 * 4, 32, weight_init=init_dense_weight(64 * 4 * 4, 32),
bias_init=init_dense_bias(64 * 4 * 4, 32))
self.ifnode7 = IFNode_GRAPH()
self.fc2 = nn.Dense(32, num_class, weight_init=init_dense_weight(32, num_class),
bias_init=init_dense_bias(32, num_class))
self.ifnode8 = IFNode_GRAPH(fire=False)
def __init__(self, in_channels, out_channels, weight_init, bias_init, kernel_size=3, stride=1,
pad_mode='pad', padding=1, has_bias=True):
super(Conv2d_Block, self).__init__()
self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, pad_mode=pad_mode, padding=padding, has_bias=has_bias,
weight_init=weight_init, bias_init=bias_init)
self.ifnode = IFNode()
def construct(self, x_in):
"""forward the snn-lenet block"""
x = x_in
v1 = v2 = v3 = v4 = v5 = v6 = v7 = v8 = 0.0
for _ in range(self.T):
x = self.conv1(x_in)
x, v1 = self.ifnode1(x, v1)
x = self.conv2(x)
x, v2 = self.ifnode2(x, v2)
x = self.conv3(x)
x, v3 = self.ifnode3(x, v3)
x = self.conv4(x)
x, v4 = self.ifnode4(x, v4)
x = self.conv5(x)
x, v5 = self.ifnode5(x, v5)
x = self.conv6(x)
x, v6 = self.ifnode6(x, v6)
x = P.Reshape()(x, (-1, 64 * 4 * 4))
x = self.fc1(x)
x, v7 = self.ifnode7(x, v7)
x = self.fc2(x)
x, v8 = self.ifnode8(x, v8)
return x / self.T
x, v1 = x_in
out = self.conv2d(x)
out, v1 = self.ifnode(out, v1)
return (out, v1)
class snn_lenet_pynative(nn.Cell):
class Dense_Block(nn.Cell):
"""
snn backbone for lenet with pynative mode
block: dense + ifnode
"""
def __init__(self, in_channels, out_channels, weight_init, bias_init):
super(Dense_Block, self).__init__()
self.dense = nn.Dense(in_channels=in_channels, out_channels=out_channels,
weight_init=weight_init, bias_init=bias_init)
self.ifnode = IFNode()
def construct(self, x_in):
x, v1 = x_in
out = self.dense(x)
out, v1 = self.ifnode(out, v1)
return out, v1
class snn_lenet(nn.Cell):
"""
snn backbone for lenet with graph mode
"""
def __init__(self, num_class=10, num_channel=3):
super(snn_lenet_pynative, self).__init__()
super(snn_lenet, self).__init__()
self.T = 100
self.conv1 = nn.SequentialCell([nn.Conv2d(num_channel, 16, 3, stride=1, pad_mode='pad', padding=1,
has_bias=True, weight_init=init_weight(num_channel, 16, 3),
bias_init=init_bias(num_channel, 16, 3)),
IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
self.conv2 = nn.SequentialCell([nn.Conv2d(16, 16, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(16, 16, 3), bias_init=init_bias(16, 16, 3)),
IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
self.conv1 = Conv2d_Block(in_channels=num_channel, out_channels=16,
weight_init=init_weight(num_channel, 16, 3), bias_init=init_bias(num_channel, 16, 3))
self.conv3 = nn.SequentialCell([nn.Conv2d(16, 32, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(16, 32, 3), bias_init=init_bias(16, 32, 3)),
IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
self.conv2 = Conv2d_Block(in_channels=16, out_channels=16, stride=2,
weight_init=init_weight(16, 16, 3), bias_init=init_bias(16, 16, 3))
self.conv4 = nn.SequentialCell([nn.Conv2d(32, 32, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(32, 32, 3), bias_init=init_bias(32, 32, 3)),
IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
self.conv3 = Conv2d_Block(in_channels=16, out_channels=32,
weight_init=init_weight(16, 32, 3), bias_init=init_bias(16, 32, 3))
self.conv5 = nn.SequentialCell([nn.Conv2d(32, 64, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(32, 64, 3), bias_init=init_bias(32, 64, 3)),
IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
self.conv4 = Conv2d_Block(in_channels=32, out_channels=32, stride=2,
weight_init=init_weight(32, 32, 3), bias_init=init_bias(32, 32, 3))
self.conv6 = nn.SequentialCell([nn.Conv2d(64, 64, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
weight_init=init_weight(64, 64, 3), bias_init=init_bias(64, 64, 3)),
IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
self.conv5 = Conv2d_Block(in_channels=32, out_channels=64,
weight_init=init_weight(32, 64, 3), bias_init=init_bias(32, 64, 3))
self.fc1 = nn.SequentialCell([nn.Dense(64 * 4 * 4, 32,
weight_init=init_dense_weight(64 * 4 * 4, 32),
bias_init=init_dense_bias(64 * 4 * 4, 32)),
IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
self.conv6 = Conv2d_Block(in_channels=64, out_channels=64, stride=2,
weight_init=init_weight(64, 64, 3), bias_init=init_bias(64, 64, 3))
self.fc2 = nn.Dense(32, num_class, weight_init=init_dense_weight(32, num_class),
bias_init=init_dense_bias(32, num_class))
self.dense1 = Dense_Block(in_channels=64 * 4 * 4, out_channels=32,
weight_init=init_dense_weight(64 * 4 * 4, 32),
bias_init=init_dense_bias(64 * 4 * 4, 32))
self.outlayer = IFNode_PYNATIVE(v_threshold=1.0, v_reset=None, fire=False)
self.fc = nn.Dense(32, num_class, weight_init=init_dense_weight(32, num_class),
bias_init=init_dense_bias(32, num_class))
self.end_ifnode = IFNode(fire=False)
def construct(self, x_in):
"""forward the snn-lenet block"""
x = x_in
v1 = v2 = v3 = v4 = v5 = v6 = v7 = v8 = 0.0
for _ in range(self.T):
x = self.conv1(x_in)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
x, v1 = self.conv1((x_in, v1))
x, v2 = self.conv2((x, v2))
x, v3 = self.conv3((x, v3))
x, v4 = self.conv4((x, v4))
x, v5 = self.conv5((x, v5))
x, v6 = self.conv6((x, v6))
x = P.Reshape()(x, (-1, 64 * 4 * 4))
x = self.fc1(x)
x = self.fc2(x)
x = self.outlayer(x)
x, v7 = self.dense1((x, v7))
x = self.fc(x)
x, v8 = self.end_ifnode(x, v8)
return x / self.T
def reset_net(self):
"""each batch should reset the accumulated value of the net such as self.v"""
for item in self.cells():
if isinstance(type(item), type(nn.SequentialCell())):
if hasattr(item[-1], 'reset'):
item[-1].reset()
else:
if hasattr(item, 'reset'):
item.reset()
......@@ -15,9 +15,10 @@
"""ResNet_SNN."""
import math
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.ops as ops
from mindspore.common.initializer import HeNormal, HeUniform
from src.ifnode import IFNode_GRAPH, IFNode_PYNATIVE
from src.ifnode import IFNode
def _conv3x3(in_channel, out_channel, stride=1):
......@@ -47,7 +48,7 @@ def _fc(in_channel, out_channel):
bias_init=0)
class ResidualBlock_GRAPH(nn.Cell):
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
......@@ -60,21 +61,21 @@ class ResidualBlock_GRAPH(nn.Cell):
Tensor, output tensor.
Examples:
>>> ResidualBlock_GRAPH(3, 256, stride=2)
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self, in_channel, out_channel, stride=1):
super(ResidualBlock_GRAPH, self).__init__()
super(ResidualBlock, self).__init__()
self.stride = stride
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1)
self.bn1 = _bn(channel)
self.ifnode1 = IFNode_GRAPH()
self.ifnode1 = IFNode()
self.conv2 = _conv3x3(channel, channel, stride=stride)
self.bn2 = _bn(channel)
self.ifnode2 = IFNode_GRAPH()
self.ifnode2 = IFNode()
self.conv3 = _conv1x1(channel, out_channel, stride=1)
self.bn3 = _bn(out_channel)
......@@ -87,10 +88,10 @@ class ResidualBlock_GRAPH(nn.Cell):
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
self.ifnode3 = IFNode_GRAPH()
self.ifnode3 = IFNode()
def construct(self, x_in):
"""ResidualBlock with graph mode"""
"""ResidualBlock"""
x, v1, v2, v3 = x_in
identity = x
......@@ -113,7 +114,7 @@ class ResidualBlock_GRAPH(nn.Cell):
return (out, v1, v2, v3)
class ResNet_SNN_GRAPH(nn.Cell):
class ResNet_SNN(nn.Cell):
"""
ResNet architecture.
......@@ -129,16 +130,16 @@ class ResNet_SNN_GRAPH(nn.Cell):
Tensor, output tensor.
Examples:
>>> ResNet_SNN_GRAPH(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
>>> ResNet_SNN(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self, block, layer_nums, in_channels, out_channels, strides, num_classes):
super(ResNet_SNN_GRAPH, self).__init__()
super(ResNet_SNN, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
......@@ -146,79 +147,48 @@ class ResNet_SNN_GRAPH(nn.Cell):
self.T = 5
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.ifnode1 = IFNode_GRAPH()
self.ifnode1 = IFNode()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer_nums = layer_nums
# layer_nums:[3, 4, 6, 3]
self.layer1_1 = self._make_layer_test1(block, in_channel=in_channels[0],
out_channel=out_channels[0], stride=strides[0])
self.layer1_2 = self._make_layer_test2(block, out_channel=out_channels[0],)
self.layer1_3 = self._make_layer_test2(block, out_channel=out_channels[0],)
self.layer2_1 = self._make_layer_test1(block, in_channel=in_channels[1],
out_channel=out_channels[1], stride=strides[1])
self.layer2_2 = self._make_layer_test2(block, out_channel=out_channels[1])
self.layer2_3 = self._make_layer_test2(block, out_channel=out_channels[1])
self.layer2_4 = self._make_layer_test2(block, out_channel=out_channels[1])
self.layer3_1 = self._make_layer_test1(block, in_channel=in_channels[2],
out_channel=out_channels[2], stride=strides[2])
self.layer3_2 = self._make_layer_test2(block, out_channel=out_channels[2])
self.layer3_3 = self._make_layer_test2(block, out_channel=out_channels[2])
self.layer3_4 = self._make_layer_test2(block, out_channel=out_channels[2])
self.layer3_5 = self._make_layer_test2(block, out_channel=out_channels[2])
self.layer3_6 = self._make_layer_test2(block, out_channel=out_channels[2])
self.layer4_1 = self._make_layer_test1(block, in_channel=in_channels[3],
out_channel=out_channels[3], stride=strides[3])
self.layer4_2 = self._make_layer_test2(block, out_channel=out_channels[3])
self.layer4_3 = self._make_layer_test2(block, out_channel=out_channels[3])
self.layer1 = self.make_layer(block, layer_nums[0], in_channel=in_channels[0],
out_channel=out_channels[0], stride=strides[0])
self.layer2 = self.make_layer(block, layer_nums[1], in_channel=in_channels[1],
out_channel=out_channels[1], stride=strides[1])
self.layer3 = self.make_layer(block, layer_nums[2], in_channel=in_channels[2],
out_channel=out_channels[2], stride=strides[2])
self.layer4 = self.make_layer(block, layer_nums[3], in_channel=in_channels[3],
out_channel=out_channels[3], stride=strides[3])
self.mean = ops.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes)
self.end_ifnode = IFNode_GRAPH(fire=False)
self.end_ifnode = IFNode(fire=False)
self.layers = nn.CellList([self.layer1, self.layer2, self.layer3, self.layer4])
def _make_layer_test1(self, block, in_channel, out_channel, stride):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
"""
def make_layer(self, block, layer_num, in_channel, out_channel, stride):
layers = []
resnet_block = block(in_channel, out_channel, stride=stride)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def _make_layer_test2(self, block, out_channel):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
out_channel (int): Output channel.
Returns:
SequentialCell, the output layer.
"""
layers = []
resnet_block = block(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.SequentialCell(layers)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.layer.CellList(layers)
def construct(self, x_in):
"""ResNet SNN block with graph mode"""
out = x_in
v1 = v_end = 0.0
# layer_nums:[3, 4, 6, 3]
v1_1_1 = v1_1_2 = v1_1_3 = v1_2_1 = v1_2_2 = v1_2_3 = v1_3_1 = v1_3_2 = v1_3_3 = 0.0
v2_1_1 = v2_1_2 = v2_1_3 = v2_2_1 = v2_2_2 = v2_2_3 = v2_3_1 = v2_3_2 = v2_3_3 = v2_4_1 = v2_4_2 = v2_4_3 = 0.0
v3_1_1 = v3_1_2 = v3_1_3 = v3_2_1 = v3_2_2 = v3_2_3 = v3_3_1 = v3_3_2 = v3_3_3 = 0.0
v3_4_1 = v3_4_2 = v3_4_3 = v3_5_1 = v3_5_2 = v3_5_3 = v3_6_1 = v3_6_2 = v3_6_3 = 0.0
v4_1_1 = v4_1_2 = v4_1_3 = v4_2_1 = v4_2_2 = v4_2_3 = v4_3_1 = v4_3_2 = v4_3_3 = 0.0
V = []
for layer_num in self.layer_nums:
for _ in range(layer_num):
V.append([Tensor(0.0), Tensor(0.0), Tensor(0.0)])
for _ in range(self.T):
x = self.conv1(x_in)
......@@ -226,25 +196,20 @@ class ResNet_SNN_GRAPH(nn.Cell):
x, v1 = self.ifnode1(x, v1)
c1 = self.maxpool(x)
c1_1, v1_1_1, v1_1_2, v1_1_3 = self.layer1_1((c1, v1_1_1, v1_1_2, v1_1_3))
c1_2, v1_2_1, v1_2_2, v1_2_3 = self.layer1_2((c1_1, v1_2_1, v1_2_2, v1_2_3))
c1_3, v1_3_1, v1_3_2, v1_3_3 = self.layer1_3((c1_2, v1_3_1, v1_3_2, v1_3_3))
c2_1, v2_1_1, v2_1_2, v2_1_3 = self.layer2_1((c1_3, v2_1_1, v2_1_2, v2_1_3))
c2_2, v2_2_1, v2_2_2, v2_2_3 = self.layer2_2((c2_1, v2_2_1, v2_2_2, v2_2_3))
c2_3, v2_3_1, v2_3_2, v2_3_3 = self.layer2_3((c2_2, v2_3_1, v2_3_2, v2_3_3))
c2_4, v2_4_1, v2_4_2, v2_4_3 = self.layer2_4((c2_3, v2_4_1, v2_4_2, v2_4_3))
c3_1, v3_1_1, v3_1_2, v3_1_3 = self.layer3_1((c2_4, v3_1_1, v3_1_2, v3_1_3))
c3_2, v3_2_1, v3_2_2, v3_2_3 = self.layer3_2((c3_1, v3_2_1, v3_2_2, v3_2_3))
c3_3, v3_3_1, v3_3_2, v3_3_3 = self.layer3_3((c3_2, v3_3_1, v3_3_2, v3_3_3))
c3_4, v3_4_1, v3_4_2, v3_4_3 = self.layer3_4((c3_3, v3_4_1, v3_4_2, v3_4_3))
c3_5, v3_5_1, v3_5_2, v3_5_3 = self.layer3_5((c3_4, v3_5_1, v3_5_2, v3_5_3))
c3_6, v3_6_1, v3_6_2, v3_6_3 = self.layer3_6((c3_5, v3_6_1, v3_6_2, v3_6_3))
c4_1, v4_1_1, v4_1_2, v4_1_3 = self.layer4_1((c3_6, v4_1_1, v4_1_2, v4_1_3))
c4_2, v4_2_1, v4_2_2, v4_2_3 = self.layer4_2((c4_1, v4_2_1, v4_2_2, v4_2_3))
c4_3, v4_3_1, v4_3_2, v4_3_3 = self.layer4_3((c4_2, v4_3_1, v4_3_2, v4_3_3))
out = self.mean(c4_3, (2, 3))
out = c1
index = 0
ifnode_count = 0
for row in self.layer_nums:
layers = self.layers[index]
for col in range(row):
block = layers[col]
out, V[ifnode_count + col][0], V[ifnode_count + col][1], V[ifnode_count + col][2] = \
block((out, V[ifnode_count + col][0], V[ifnode_count + col][1], V[ifnode_count + col][2]))
ifnode_count += self.layer_nums[index]
index += 1
out = self.mean(out, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
out, v_end = self.end_ifnode(out, v_end)
......@@ -252,201 +217,10 @@ class ResNet_SNN_GRAPH(nn.Cell):
return out / self.T
class ResidualBlock_PYNATIVE(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock_PYNATIVE(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1):
super(ResidualBlock_PYNATIVE, self).__init__()
self.stride = stride
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1)
self.bn1 = _bn(channel)
self.ifnode1 = IFNode_PYNATIVE()
self.conv2 = _conv3x3(channel, channel, stride=stride)
self.bn2 = _bn(channel)
self.ifnode2 = IFNode_PYNATIVE()
self.conv3 = _conv1x1(channel, out_channel, stride=1)
self.bn3 = _bn(out_channel)
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
self.ifnode3 = IFNode_PYNATIVE()
def construct(self, x):
"""ResidualBlock with pynative mode"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.ifnode1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.ifnode2(out)
out = self.conv3(out)
out = self.bn3(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = out + identity
out = self.ifnode3(out)
return out
class ResNet_SNN_PYNATIVE(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet_SNN_PYNATIVE(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes):
super(ResNet_SNN_PYNATIVE, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.T = 5
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.ifnode1 = IFNode_PYNATIVE()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block, layer_nums[0], in_channel=in_channels[0],
out_channel=out_channels[0], stride=strides[0])
self.layer2 = self._make_layer(block, layer_nums[1], in_channel=in_channels[1],
out_channel=out_channels[1], stride=strides[1])
self.layer3 = self._make_layer(block, layer_nums[2], in_channel=in_channels[2],
out_channel=out_channels[2], stride=strides[2])
self.layer4 = self._make_layer(block, layer_nums[3], in_channel=in_channels[3],
out_channel=out_channels[3], stride=strides[3])
self.mean = ops.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes)
self.end_ifnode = IFNode_PYNATIVE(fire=False)
def construct(self, x_in):
"""ResNet SNN block with pynative mode"""
out = x_in
for _ in range(self.T):
x = self.conv1(x_in)
x = self.bn1(x)
x = self.ifnode1(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
out = self.end_ifnode(out)
return out / self.T
def reset_net(self):
for item in self.cells():
if isinstance(type(item), type(nn.SequentialCell())):
if hasattr(item[-1], 'reset'):
item[-1].reset()
else:
if hasattr(item, 'reset'):
item.reset()
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def snn_resnet50_graph(class_num=10):
return ResNet_SNN_GRAPH(ResidualBlock_GRAPH,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
def snn_resnet50_pynative(class_num=10):
return ResNet_SNN_PYNATIVE(ResidualBlock_PYNATIVE,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
def snn_resnet50(class_num=10):
return ResNet_SNN(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
......@@ -150,17 +150,11 @@ class AverageMeter:
def snn_model_build():
"""build snn model for resnet50 and lenet"""
if config.net_name == "resnet50":
if config.mode_name == 'GRAPH':
from src.snn_resnet import snn_resnet50_graph as snn_resnet50
else:
from src.snn_resnet import snn_resnet50_pynative as snn_resnet50
from src.snn_resnet import snn_resnet50
net = snn_resnet50(class_num=config.class_num)
init_weight(net=net)
elif config.net_name == "lenet":
if config.mode_name == 'GRAPH':
from src.snn_lenet import snn_lenet_graph as snn_lenet
else:
from src.snn_lenet import snn_lenet_pynative as snn_lenet
from src.snn_lenet import snn_lenet
net = snn_lenet(num_class=config.class_num)
else:
raise ValueError(f'config.model: {config.model_name} is not supported')
......@@ -223,8 +217,6 @@ def train_net():
label = onehot(label, config.class_num, Tensor(1.0, ms.float32), Tensor(0.0, ms.float32))
loss = network_train(images, label)
loss_meter.update(loss.asnumpy())
if config.mode_name == 'PYNATIVE':
net.reset_net()
if config.save_checkpoint:
cb_params.cur_epoch_num = epoch_idx + 1
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment