diff --git a/community/cv/snn/eval.py b/community/cv/snn/eval.py
index b343c10d65c456b8a708e7ead8cdeb74b6f6cdd6..83defedf46482207b7f3ef1c79b555b0773626ba 100644
--- a/community/cv/snn/eval.py
+++ b/community/cv/snn/eval.py
@@ -20,11 +20,11 @@ from src.model_utils.config import config
 from src.model_utils.moxing_adapter import moxing_wrapper
 from src.dataset import create_dataset_cifar10
 
-import mindspore.ops as ops
+import mindspore.nn as nn
 from mindspore import context
+from mindspore.train import Model
+from mindspore.nn.metrics import Accuracy
 from mindspore.train.serialization import load_checkpoint, load_param_into_net
-import mindspore as ms
-from mindspore.ops import operations as P
 
 
 def modelarts_process():
@@ -35,16 +35,10 @@ def snn_model_build():
     build snn model for lenet and resnet50
     """
     if config.net_name == "resnet50":
-        if config.mode_name == 'GRAPH':
-            from src.snn_resnet import snn_resnet50_graph as snn_resnet50
-        else:
-            from src.snn_resnet import snn_resnet50_pynative as snn_resnet50
+        from src.snn_resnet import snn_resnet50
         net = snn_resnet50(class_num=config.class_num)
     elif config.net_name == "lenet":
-        if config.mode_name == 'GRAPH':
-            from src.snn_lenet import snn_lenet_graph as snn_lenet
-        else:
-            from src.snn_lenet import snn_lenet_pynative as snn_lenet
+        from src.snn_lenet import snn_lenet
         net = snn_lenet(num_class=config.class_num)
     else:
         raise ValueError(f'config.model: {config.model_name} is not supported')
@@ -57,8 +51,6 @@ def eval_net():
     eval net
     """
     print('eval with config: ', config)
-    correct = 0.0
-    total = 0.0
     if config.mode_name == 'GRAPH':
         context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
     else:
@@ -68,23 +60,12 @@ def eval_net():
     if ds_eval.get_dataset_size() == 0:
         raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
     network_eval = snn_model_build()
+    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
+    model = Model(network_eval, net_loss, metrics={"Accuracy": Accuracy()})
     param_dict = load_checkpoint(config.ckpt_path)
     load_param_into_net(network_eval, param_dict)
-    network_eval.set_train(False)
-    print("============== Starting Testing ==============", flush=True)
-    for _, data in enumerate(ds_eval.create_dict_iterator()):
-        image = data['image']
-        label = data['label']
-        outspikes = network_eval(image)
-        predicted = ops.Argmax(output_type=ms.int32)(outspikes)
-        total += label.shape[0]
-        cast = P.Cast()
-        correct += cast((predicted == label), ms.float32).sum().asnumpy().item()
-        if config.mode_name == 'PYNATIVE':
-            network_eval.reset_net()
-
-    accuracy = 100 * correct / total
-    print('Accuracy of the network is: %.4f %%' % accuracy, flush=True)
+    acc = model.eval(ds_eval)
+    print("============== {} ==============".format(acc))
 
 if __name__ == "__main__":
     eval_net()
diff --git a/community/cv/snn/src/ifnode.py b/community/cv/snn/src/ifnode.py
index 7d763360bd40dd3754844a64584dd89deeda2c9a..369903e62b10010759ce12864e42a5957ae3ac05 100644
--- a/community/cv/snn/src/ifnode.py
+++ b/community/cv/snn/src/ifnode.py
@@ -37,7 +37,7 @@ class relusigmoid(nn.Cell):
         # must be a tuple
         return (grad_x,)
 
-class IFNode_GRAPH(nn.Cell):
+class IFNode(nn.Cell):
     """
     integrate and fire cell for GRAPH mode, it will output spike value
     """
@@ -48,43 +48,10 @@ class IFNode_GRAPH(nn.Cell):
         self.surrogate_function = surrogate_function
 
     def construct(self, x, v):
-        """ neuronal_charge: v need to do add"""
+        """neuronal_charge: v need to do add"""
         v = v + x
         if self.fire:
             spike = self.surrogate_function(v - self.v_threshold) * self.v_threshold
             v -= spike
             return spike, v
         return v, v
-
-
-class IFNode_PYNATIVE(nn.Cell):
-    """
-    integrate and fire cell for PYNATIVE mode, it will output spike value
-    """
-    def __init__(self, v_threshold=1.0, v_reset=0.0, fire=True, surrogate_function=relusigmoid()):
-        super().__init__()
-        self.v_threshold = v_threshold
-        if v_reset is None:
-            self.v_reset = 0.0
-        else:
-            self.v_reset = v_reset
-        self.v = self.v_reset
-        self.fire = fire
-        self.surrogate_function = surrogate_function
-
-    def construct(self, x):
-        """ neuronal_charge: self.v need to do add"""
-        self.v = self.v + x
-        # neuronal_fire
-        if self.fire:
-            spike = self.surrogate_function(self.v - self.v_threshold) * self.v_threshold
-            self.v -= spike
-            return spike
-        return self.v
-
-    def reset(self):
-        """each batch should reset the accumulated value of the net such as self.v"""
-        if self.v_reset is None:
-            self.v = 0.0
-        else:
-            self.v = self.v_reset
diff --git a/community/cv/snn/src/snn_lenet.py b/community/cv/snn/src/snn_lenet.py
index 25fc609b18d7693d9cecd17298e99c395a0d65cc..aedbcfcbf2f2be0b95a48305765f59332a2e8b07 100644
--- a/community/cv/snn/src/snn_lenet.py
+++ b/community/cv/snn/src/snn_lenet.py
@@ -16,7 +16,7 @@
 import mindspore.nn as nn
 from mindspore.ops import operations as P
 from mindspore import Tensor
-from src.ifnode import IFNode_GRAPH, IFNode_PYNATIVE
+from src.ifnode import IFNode
 import numpy as np
 
 
@@ -44,127 +44,87 @@ def init_dense_bias(inC, outC):
     return Tensor(weight)
 
 
-class snn_lenet_graph(nn.Cell):
+class Conv2d_Block(nn.Cell):
     """
-    snn backbone for lenet with graph mode
+    block: conv2d + ifnode
     """
-    def __init__(self, num_class=10, num_channel=3):
-        super(snn_lenet_graph, self).__init__()
-        self.T = 100
-        self.conv1 = nn.Conv2d(num_channel, 16, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
-                               weight_init=init_weight(num_channel, 16, 3), bias_init=init_bias(num_channel, 16, 3))
-        self.ifnode1 = IFNode_GRAPH()
-        self.conv2 = nn.Conv2d(16, 16, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
-                               weight_init=init_weight(16, 16, 3), bias_init=init_bias(16, 16, 3))
-        self.ifnode2 = IFNode_GRAPH()
-        self.conv3 = nn.Conv2d(16, 32, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
-                               weight_init=init_weight(16, 32, 3), bias_init=init_bias(16, 32, 3))
-        self.ifnode3 = IFNode_GRAPH()
-        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
-                               weight_init=init_weight(32, 32, 3), bias_init=init_bias(32, 32, 3))
-        self.ifnode4 = IFNode_GRAPH()
-        self.conv5 = nn.Conv2d(32, 64, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
-                               weight_init=init_weight(32, 64, 3), bias_init=init_bias(32, 64, 3))
-        self.ifnode5 = IFNode_GRAPH()
-        self.conv6 = nn.Conv2d(64, 64, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
-                               weight_init=init_weight(64, 64, 3), bias_init=init_bias(64, 64, 3))
-        self.ifnode6 = IFNode_GRAPH()
-        self.fc1 = nn.Dense(64 * 4 * 4, 32, weight_init=init_dense_weight(64 * 4 * 4, 32),
-                            bias_init=init_dense_bias(64 * 4 * 4, 32))
-        self.ifnode7 = IFNode_GRAPH()
-        self.fc2 = nn.Dense(32, num_class, weight_init=init_dense_weight(32, num_class),
-                            bias_init=init_dense_bias(32, num_class))
-        self.ifnode8 = IFNode_GRAPH(fire=False)
+    def __init__(self, in_channels, out_channels, weight_init, bias_init, kernel_size=3, stride=1,
+                 pad_mode='pad', padding=1, has_bias=True):
+        super(Conv2d_Block, self).__init__()
+        self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
+                                stride=stride, pad_mode=pad_mode, padding=padding, has_bias=has_bias,
+                                weight_init=weight_init, bias_init=bias_init)
+        self.ifnode = IFNode()
 
     def construct(self, x_in):
-        """forward the snn-lenet block"""
-        x = x_in
-        v1 = v2 = v3 = v4 = v5 = v6 = v7 = v8 = 0.0
-        for _ in range(self.T):
-            x = self.conv1(x_in)
-            x, v1 = self.ifnode1(x, v1)
-            x = self.conv2(x)
-            x, v2 = self.ifnode2(x, v2)
-            x = self.conv3(x)
-            x, v3 = self.ifnode3(x, v3)
-            x = self.conv4(x)
-            x, v4 = self.ifnode4(x, v4)
-            x = self.conv5(x)
-            x, v5 = self.ifnode5(x, v5)
-            x = self.conv6(x)
-            x, v6 = self.ifnode6(x, v6)
-            x = P.Reshape()(x, (-1, 64 * 4 * 4))
-            x = self.fc1(x)
-            x, v7 = self.ifnode7(x, v7)
-            x = self.fc2(x)
-            x, v8 = self.ifnode8(x, v8)
-        return x / self.T
+        x, v1 = x_in
+        out = self.conv2d(x)
+        out, v1 = self.ifnode(out, v1)
+        return (out, v1)
 
 
-class snn_lenet_pynative(nn.Cell):
+class Dense_Block(nn.Cell):
     """
-    snn backbone for lenet with pynative mode
+    block: dense + ifnode
+    """
+    def __init__(self, in_channels, out_channels, weight_init, bias_init):
+        super(Dense_Block, self).__init__()
+        self.dense = nn.Dense(in_channels=in_channels, out_channels=out_channels,
+                              weight_init=weight_init, bias_init=bias_init)
+        self.ifnode = IFNode()
+
+    def construct(self, x_in):
+        x, v1 = x_in
+        out = self.dense(x)
+        out, v1 = self.ifnode(out, v1)
+        return out, v1
+
+class snn_lenet(nn.Cell):
+    """
+    snn backbone for lenet with graph mode
     """
     def __init__(self, num_class=10, num_channel=3):
-        super(snn_lenet_pynative, self).__init__()
+        super(snn_lenet, self).__init__()
         self.T = 100
-        self.conv1 = nn.SequentialCell([nn.Conv2d(num_channel, 16, 3, stride=1, pad_mode='pad', padding=1,
-                                                  has_bias=True, weight_init=init_weight(num_channel, 16, 3),
-                                                  bias_init=init_bias(num_channel, 16, 3)),
-                                        IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
-
-        self.conv2 = nn.SequentialCell([nn.Conv2d(16, 16, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
-                                                  weight_init=init_weight(16, 16, 3), bias_init=init_bias(16, 16, 3)),
-                                        IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
+        self.conv1 = Conv2d_Block(in_channels=num_channel, out_channels=16,
+                                  weight_init=init_weight(num_channel, 16, 3), bias_init=init_bias(num_channel, 16, 3))
 
-        self.conv3 = nn.SequentialCell([nn.Conv2d(16, 32, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
-                                                  weight_init=init_weight(16, 32, 3), bias_init=init_bias(16, 32, 3)),
-                                        IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
+        self.conv2 = Conv2d_Block(in_channels=16, out_channels=16, stride=2,
+                                  weight_init=init_weight(16, 16, 3), bias_init=init_bias(16, 16, 3))
 
-        self.conv4 = nn.SequentialCell([nn.Conv2d(32, 32, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
-                                                  weight_init=init_weight(32, 32, 3), bias_init=init_bias(32, 32, 3)),
-                                        IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
+        self.conv3 = Conv2d_Block(in_channels=16, out_channels=32,
+                                  weight_init=init_weight(16, 32, 3), bias_init=init_bias(16, 32, 3))
 
-        self.conv5 = nn.SequentialCell([nn.Conv2d(32, 64, 3, stride=1, pad_mode='pad', padding=1, has_bias=True,
-                                                  weight_init=init_weight(32, 64, 3), bias_init=init_bias(32, 64, 3)),
-                                        IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
+        self.conv4 = Conv2d_Block(in_channels=32, out_channels=32, stride=2,
+                                  weight_init=init_weight(32, 32, 3), bias_init=init_bias(32, 32, 3))
 
-        self.conv6 = nn.SequentialCell([nn.Conv2d(64, 64, 3, stride=2, pad_mode='pad', padding=1, has_bias=True,
-                                                  weight_init=init_weight(64, 64, 3), bias_init=init_bias(64, 64, 3)),
-                                        IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
+        self.conv5 = Conv2d_Block(in_channels=32, out_channels=64,
+                                  weight_init=init_weight(32, 64, 3), bias_init=init_bias(32, 64, 3))
 
-        self.fc1 = nn.SequentialCell([nn.Dense(64 * 4 * 4, 32,
-                                               weight_init=init_dense_weight(64 * 4 * 4, 32),
-                                               bias_init=init_dense_bias(64 * 4 * 4, 32)),
-                                      IFNode_PYNATIVE(v_threshold=1.0, v_reset=None)])
+        self.conv6 = Conv2d_Block(in_channels=64, out_channels=64, stride=2,
+                                  weight_init=init_weight(64, 64, 3), bias_init=init_bias(64, 64, 3))
 
-        self.fc2 = nn.Dense(32, num_class, weight_init=init_dense_weight(32, num_class),
-                            bias_init=init_dense_bias(32, num_class))
+        self.dense1 = Dense_Block(in_channels=64 * 4 * 4, out_channels=32,
+                                  weight_init=init_dense_weight(64 * 4 * 4, 32),
+                                  bias_init=init_dense_bias(64 * 4 * 4, 32))
 
-        self.outlayer = IFNode_PYNATIVE(v_threshold=1.0, v_reset=None, fire=False)
+        self.fc = nn.Dense(32, num_class, weight_init=init_dense_weight(32, num_class),
+                           bias_init=init_dense_bias(32, num_class))
+        self.end_ifnode = IFNode(fire=False)
 
     def construct(self, x_in):
         """forward the snn-lenet block"""
         x = x_in
+        v1 = v2 = v3 = v4 = v5 = v6 = v7 = v8 = 0.0
         for _ in range(self.T):
-            x = self.conv1(x_in)
-            x = self.conv2(x)
-            x = self.conv3(x)
-            x = self.conv4(x)
-            x = self.conv5(x)
-            x = self.conv6(x)
+            x, v1 = self.conv1((x_in, v1))
+            x, v2 = self.conv2((x, v2))
+            x, v3 = self.conv3((x, v3))
+            x, v4 = self.conv4((x, v4))
+            x, v5 = self.conv5((x, v5))
+            x, v6 = self.conv6((x, v6))
             x = P.Reshape()(x, (-1, 64 * 4 * 4))
-            x = self.fc1(x)
-            x = self.fc2(x)
-            x = self.outlayer(x)
+            x, v7 = self.dense1((x, v7))
+            x = self.fc(x)
+            x, v8 = self.end_ifnode(x, v8)
         return x / self.T
-
-    def reset_net(self):
-        """each batch should reset the accumulated value of the net such as self.v"""
-        for item in self.cells():
-            if isinstance(type(item), type(nn.SequentialCell())):
-                if hasattr(item[-1], 'reset'):
-                    item[-1].reset()
-            else:
-                if hasattr(item, 'reset'):
-                    item.reset()
diff --git a/community/cv/snn/src/snn_resnet.py b/community/cv/snn/src/snn_resnet.py
index 4836401d2360390c26f8c387dbb76d2d39a168fc..cab8eb739d371e2830b49d8f31668067781600d3 100644
--- a/community/cv/snn/src/snn_resnet.py
+++ b/community/cv/snn/src/snn_resnet.py
@@ -15,9 +15,10 @@
 """ResNet_SNN."""
 import math
 import mindspore.nn as nn
+from mindspore import Tensor
 import mindspore.ops as ops
 from mindspore.common.initializer import HeNormal, HeUniform
-from src.ifnode import IFNode_GRAPH, IFNode_PYNATIVE
+from src.ifnode import IFNode
 
 
 def _conv3x3(in_channel, out_channel, stride=1):
@@ -47,7 +48,7 @@ def _fc(in_channel, out_channel):
                     bias_init=0)
 
 
-class ResidualBlock_GRAPH(nn.Cell):
+class ResidualBlock(nn.Cell):
     """
     ResNet V1 residual block definition.
 
@@ -60,21 +61,21 @@ class ResidualBlock_GRAPH(nn.Cell):
         Tensor, output tensor.
 
     Examples:
-        >>> ResidualBlock_GRAPH(3, 256, stride=2)
+        >>> ResidualBlock(3, 256, stride=2)
     """
     expansion = 4
 
     def __init__(self, in_channel, out_channel, stride=1):
-        super(ResidualBlock_GRAPH, self).__init__()
+        super(ResidualBlock, self).__init__()
         self.stride = stride
         channel = out_channel // self.expansion
         self.conv1 = _conv1x1(in_channel, channel, stride=1)
         self.bn1 = _bn(channel)
-        self.ifnode1 = IFNode_GRAPH()
+        self.ifnode1 = IFNode()
 
         self.conv2 = _conv3x3(channel, channel, stride=stride)
         self.bn2 = _bn(channel)
-        self.ifnode2 = IFNode_GRAPH()
+        self.ifnode2 = IFNode()
 
         self.conv3 = _conv1x1(channel, out_channel, stride=1)
         self.bn3 = _bn(out_channel)
@@ -87,10 +88,10 @@ class ResidualBlock_GRAPH(nn.Cell):
         if self.down_sample:
             self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
 
-        self.ifnode3 = IFNode_GRAPH()
+        self.ifnode3 = IFNode()
 
     def construct(self, x_in):
-        """ResidualBlock with graph mode"""
+        """ResidualBlock"""
         x, v1, v2, v3 = x_in
         identity = x
 
@@ -113,7 +114,7 @@ class ResidualBlock_GRAPH(nn.Cell):
         return (out, v1, v2, v3)
 
 
-class ResNet_SNN_GRAPH(nn.Cell):
+class ResNet_SNN(nn.Cell):
     """
     ResNet architecture.
 
@@ -129,16 +130,16 @@ class ResNet_SNN_GRAPH(nn.Cell):
         Tensor, output tensor.
 
     Examples:
-        >>> ResNet_SNN_GRAPH(ResidualBlock,
-        >>>                  [3, 4, 6, 3],
-        >>>                  [64, 256, 512, 1024],
-        >>>                  [256, 512, 1024, 2048],
-        >>>                  [1, 2, 2, 2],
-        >>>                  10)
+        >>> ResNet_SNN(ResidualBlock,
+        >>>            [3, 4, 6, 3],
+        >>>            [64, 256, 512, 1024],
+        >>>            [256, 512, 1024, 2048],
+        >>>            [1, 2, 2, 2],
+        >>>            10)
     """
 
     def __init__(self, block, layer_nums, in_channels, out_channels, strides, num_classes):
-        super(ResNet_SNN_GRAPH, self).__init__()
+        super(ResNet_SNN, self).__init__()
 
         if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
             raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
@@ -146,79 +147,48 @@ class ResNet_SNN_GRAPH(nn.Cell):
         self.T = 5
         self.conv1 = _conv7x7(3, 64, stride=2)
         self.bn1 = _bn(64)
-        self.ifnode1 = IFNode_GRAPH()
+        self.ifnode1 = IFNode()
 
         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
+        self.layer_nums = layer_nums
         # layer_nums:[3, 4, 6, 3]
-        self.layer1_1 = self._make_layer_test1(block, in_channel=in_channels[0],
-                                               out_channel=out_channels[0], stride=strides[0])
-        self.layer1_2 = self._make_layer_test2(block, out_channel=out_channels[0],)
-        self.layer1_3 = self._make_layer_test2(block, out_channel=out_channels[0],)
-        self.layer2_1 = self._make_layer_test1(block, in_channel=in_channels[1],
-                                               out_channel=out_channels[1], stride=strides[1])
-        self.layer2_2 = self._make_layer_test2(block, out_channel=out_channels[1])
-        self.layer2_3 = self._make_layer_test2(block, out_channel=out_channels[1])
-        self.layer2_4 = self._make_layer_test2(block, out_channel=out_channels[1])
-        self.layer3_1 = self._make_layer_test1(block, in_channel=in_channels[2],
-                                               out_channel=out_channels[2], stride=strides[2])
-        self.layer3_2 = self._make_layer_test2(block, out_channel=out_channels[2])
-        self.layer3_3 = self._make_layer_test2(block, out_channel=out_channels[2])
-        self.layer3_4 = self._make_layer_test2(block, out_channel=out_channels[2])
-        self.layer3_5 = self._make_layer_test2(block, out_channel=out_channels[2])
-        self.layer3_6 = self._make_layer_test2(block, out_channel=out_channels[2])
-        self.layer4_1 = self._make_layer_test1(block, in_channel=in_channels[3],
-                                               out_channel=out_channels[3], stride=strides[3])
-        self.layer4_2 = self._make_layer_test2(block, out_channel=out_channels[3])
-        self.layer4_3 = self._make_layer_test2(block, out_channel=out_channels[3])
+        self.layer1 = self.make_layer(block, layer_nums[0], in_channel=in_channels[0],
+                                      out_channel=out_channels[0], stride=strides[0])
+        self.layer2 = self.make_layer(block, layer_nums[1], in_channel=in_channels[1],
+                                      out_channel=out_channels[1], stride=strides[1])
+        self.layer3 = self.make_layer(block, layer_nums[2], in_channel=in_channels[2],
+                                      out_channel=out_channels[2], stride=strides[2])
+        self.layer4 = self.make_layer(block, layer_nums[3], in_channel=in_channels[3],
+                                      out_channel=out_channels[3], stride=strides[3])
 
         self.mean = ops.ReduceMean(keep_dims=True)
         self.flatten = nn.Flatten()
         self.end_point = _fc(out_channels[3], num_classes)
-        self.end_ifnode = IFNode_GRAPH(fire=False)
+        self.end_ifnode = IFNode(fire=False)
+        self.layers = nn.CellList([self.layer1, self.layer2, self.layer3, self.layer4])
 
 
-    def _make_layer_test1(self, block, in_channel, out_channel, stride):
-        """
-        Make stage network of ResNet.
-
-        Args:
-            block (Cell): Resnet block.
-            in_channel (int): Input channel.
-            out_channel (int): Output channel.
-            stride (int): Stride size for the first convolutional layer.
-        Returns:
-            SequentialCell, the output layer.
-        """
+    def make_layer(self, block, layer_num, in_channel, out_channel, stride):
         layers = []
+
         resnet_block = block(in_channel, out_channel, stride=stride)
         layers.append(resnet_block)
-        return nn.SequentialCell(layers)
-
-    def _make_layer_test2(self, block, out_channel):
-        """
-            Make stage network of ResNet.
-
-            Args:
-                block (Cell): Resnet block.
-                out_channel (int): Output channel.
-            Returns:
-                SequentialCell, the output layer.
-            """
-        layers = []
-        resnet_block = block(out_channel, out_channel, stride=1)
-        layers.append(resnet_block)
-        return nn.SequentialCell(layers)
+        for _ in range(1, layer_num):
+            resnet_block = block(out_channel, out_channel, stride=1)
+            layers.append(resnet_block)
+
+        return nn.layer.CellList(layers)
+
 
     def construct(self, x_in):
         """ResNet SNN block with graph mode"""
         out = x_in
         v1 = v_end = 0.0
-        # layer_nums:[3, 4, 6, 3]
-        v1_1_1 = v1_1_2 = v1_1_3 = v1_2_1 = v1_2_2 = v1_2_3 = v1_3_1 = v1_3_2 = v1_3_3 = 0.0
-        v2_1_1 = v2_1_2 = v2_1_3 = v2_2_1 = v2_2_2 = v2_2_3 = v2_3_1 = v2_3_2 = v2_3_3 = v2_4_1 = v2_4_2 = v2_4_3 = 0.0
-        v3_1_1 = v3_1_2 = v3_1_3 = v3_2_1 = v3_2_2 = v3_2_3 = v3_3_1 = v3_3_2 = v3_3_3 = 0.0
-        v3_4_1 = v3_4_2 = v3_4_3 = v3_5_1 = v3_5_2 = v3_5_3 = v3_6_1 = v3_6_2 = v3_6_3 = 0.0
-        v4_1_1 = v4_1_2 = v4_1_3 = v4_2_1 = v4_2_2 = v4_2_3 = v4_3_1 = v4_3_2 = v4_3_3 = 0.0
+
+        V = []
+        for layer_num in self.layer_nums:
+            for _ in range(layer_num):
+                V.append([Tensor(0.0), Tensor(0.0), Tensor(0.0)])
 
         for _ in range(self.T):
             x = self.conv1(x_in)
@@ -226,25 +196,20 @@ class ResNet_SNN_GRAPH(nn.Cell):
             x, v1 = self.ifnode1(x, v1)
 
             c1 = self.maxpool(x)
-
-            c1_1, v1_1_1, v1_1_2, v1_1_3 = self.layer1_1((c1, v1_1_1, v1_1_2, v1_1_3))
-            c1_2, v1_2_1, v1_2_2, v1_2_3 = self.layer1_2((c1_1, v1_2_1, v1_2_2, v1_2_3))
-            c1_3, v1_3_1, v1_3_2, v1_3_3 = self.layer1_3((c1_2, v1_3_1, v1_3_2, v1_3_3))
-            c2_1, v2_1_1, v2_1_2, v2_1_3 = self.layer2_1((c1_3, v2_1_1, v2_1_2, v2_1_3))
-            c2_2, v2_2_1, v2_2_2, v2_2_3 = self.layer2_2((c2_1, v2_2_1, v2_2_2, v2_2_3))
-            c2_3, v2_3_1, v2_3_2, v2_3_3 = self.layer2_3((c2_2, v2_3_1, v2_3_2, v2_3_3))
-            c2_4, v2_4_1, v2_4_2, v2_4_3 = self.layer2_4((c2_3, v2_4_1, v2_4_2, v2_4_3))
-            c3_1, v3_1_1, v3_1_2, v3_1_3 = self.layer3_1((c2_4, v3_1_1, v3_1_2, v3_1_3))
-            c3_2, v3_2_1, v3_2_2, v3_2_3 = self.layer3_2((c3_1, v3_2_1, v3_2_2, v3_2_3))
-            c3_3, v3_3_1, v3_3_2, v3_3_3 = self.layer3_3((c3_2, v3_3_1, v3_3_2, v3_3_3))
-            c3_4, v3_4_1, v3_4_2, v3_4_3 = self.layer3_4((c3_3, v3_4_1, v3_4_2, v3_4_3))
-            c3_5, v3_5_1, v3_5_2, v3_5_3 = self.layer3_5((c3_4, v3_5_1, v3_5_2, v3_5_3))
-            c3_6, v3_6_1, v3_6_2, v3_6_3 = self.layer3_6((c3_5, v3_6_1, v3_6_2, v3_6_3))
-            c4_1, v4_1_1, v4_1_2, v4_1_3 = self.layer4_1((c3_6, v4_1_1, v4_1_2, v4_1_3))
-            c4_2, v4_2_1, v4_2_2, v4_2_3 = self.layer4_2((c4_1, v4_2_1, v4_2_2, v4_2_3))
-            c4_3, v4_3_1, v4_3_2, v4_3_3 = self.layer4_3((c4_2, v4_3_1, v4_3_2, v4_3_3))
-
-            out = self.mean(c4_3, (2, 3))
+            out = c1
+
+            index = 0
+            ifnode_count = 0
+            for row in self.layer_nums:
+                layers = self.layers[index]
+                for col in range(row):
+                    block = layers[col]
+                    out, V[ifnode_count + col][0], V[ifnode_count + col][1], V[ifnode_count + col][2] = \
+                        block((out, V[ifnode_count + col][0], V[ifnode_count + col][1], V[ifnode_count + col][2]))
+                ifnode_count += self.layer_nums[index]
+                index += 1
+
+            out = self.mean(out, (2, 3))
             out = self.flatten(out)
             out = self.end_point(out)
             out, v_end = self.end_ifnode(out, v_end)
@@ -252,201 +217,10 @@ class ResNet_SNN_GRAPH(nn.Cell):
         return out / self.T
 
 
-class ResidualBlock_PYNATIVE(nn.Cell):
-    """
-    ResNet V1 residual block definition.
-
-    Args:
-        in_channel (int): Input channel.
-        out_channel (int): Output channel.
-        stride (int): Stride size for the first convolutional layer. Default: 1.
-
-    Returns:
-        Tensor, output tensor.
-
-    Examples:
-        >>> ResidualBlock_PYNATIVE(3, 256, stride=2)
-    """
-    expansion = 4
-
-    def __init__(self,
-                 in_channel,
-                 out_channel,
-                 stride=1):
-        super(ResidualBlock_PYNATIVE, self).__init__()
-        self.stride = stride
-        channel = out_channel // self.expansion
-        self.conv1 = _conv1x1(in_channel, channel, stride=1)
-        self.bn1 = _bn(channel)
-        self.ifnode1 = IFNode_PYNATIVE()
-
-        self.conv2 = _conv3x3(channel, channel, stride=stride)
-        self.bn2 = _bn(channel)
-        self.ifnode2 = IFNode_PYNATIVE()
-
-        self.conv3 = _conv1x1(channel, out_channel, stride=1)
-        self.bn3 = _bn(out_channel)
-
-        self.down_sample = False
-        if stride != 1 or in_channel != out_channel:
-            self.down_sample = True
-        self.down_sample_layer = None
-
-        if self.down_sample:
-            self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), _bn(out_channel)])
-
-        self.ifnode3 = IFNode_PYNATIVE()
-
-    def construct(self, x):
-        """ResidualBlock with pynative mode"""
-        identity = x
-
-        out = self.conv1(x)
-        out = self.bn1(out)
-        out = self.ifnode1(out)
-
-        out = self.conv2(out)
-        out = self.bn2(out)
-        out = self.ifnode2(out)
-
-        out = self.conv3(out)
-        out = self.bn3(out)
-
-        if self.down_sample:
-            identity = self.down_sample_layer(identity)
-        out = out + identity
-        out = self.ifnode3(out)
-
-        return out
-
-
-class ResNet_SNN_PYNATIVE(nn.Cell):
-    """
-    ResNet architecture.
-
-    Args:
-        block (Cell): Block for network.
-        layer_nums (list): Numbers of block in different layers.
-        in_channels (list): Input channel in each layer.
-        out_channels (list): Output channel in each layer.
-        strides (list):  Stride size in each layer.
-        num_classes (int): The number of classes that the training images are belonging to.
-
-    Returns:
-        Tensor, output tensor.
-
-    Examples:
-        >>> ResNet_SNN_PYNATIVE(ResidualBlock,
-        >>>                  [3, 4, 6, 3],
-        >>>                  [64, 256, 512, 1024],
-        >>>                  [256, 512, 1024, 2048],
-        >>>                  [1, 2, 2, 2],
-        >>>                  10)
-    """
-
-    def __init__(self,
-                 block,
-                 layer_nums,
-                 in_channels,
-                 out_channels,
-                 strides,
-                 num_classes):
-        super(ResNet_SNN_PYNATIVE, self).__init__()
-
-        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
-            raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
-
-        self.T = 5
-        self.conv1 = _conv7x7(3, 64, stride=2)
-        self.bn1 = _bn(64)
-        self.ifnode1 = IFNode_PYNATIVE()
-
-        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
-
-        self.layer1 = self._make_layer(block, layer_nums[0], in_channel=in_channels[0],
-                                       out_channel=out_channels[0], stride=strides[0])
-        self.layer2 = self._make_layer(block, layer_nums[1], in_channel=in_channels[1],
-                                       out_channel=out_channels[1], stride=strides[1])
-        self.layer3 = self._make_layer(block, layer_nums[2], in_channel=in_channels[2],
-                                       out_channel=out_channels[2], stride=strides[2])
-        self.layer4 = self._make_layer(block, layer_nums[3], in_channel=in_channels[3],
-                                       out_channel=out_channels[3], stride=strides[3])
-
-        self.mean = ops.ReduceMean(keep_dims=True)
-        self.flatten = nn.Flatten()
-        self.end_point = _fc(out_channels[3], num_classes)
-        self.end_ifnode = IFNode_PYNATIVE(fire=False)
-
-    def construct(self, x_in):
-        """ResNet SNN block with pynative mode"""
-        out = x_in
-        for _ in range(self.T):
-            x = self.conv1(x_in)
-            x = self.bn1(x)
-            x = self.ifnode1(x)
-
-            c1 = self.maxpool(x)
-
-            c2 = self.layer1(c1)
-            c3 = self.layer2(c2)
-            c4 = self.layer3(c3)
-            c5 = self.layer4(c4)
-
-            out = self.mean(c5, (2, 3))
-            out = self.flatten(out)
-            out = self.end_point(out)
-            out = self.end_ifnode(out)
-
-        return out / self.T
-
-    def reset_net(self):
-        for item in self.cells():
-            if isinstance(type(item), type(nn.SequentialCell())):
-                if hasattr(item[-1], 'reset'):
-                    item[-1].reset()
-            else:
-                if hasattr(item, 'reset'):
-                    item.reset()
-
-    def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
-        """
-        Make stage network of ResNet.
-
-        Args:
-            block (Cell): Resnet block.
-            layer_num (int): Layer number.
-            in_channel (int): Input channel.
-            out_channel (int): Output channel.
-            stride (int): Stride size for the first convolutional layer.
-        Returns:
-            SequentialCell, the output layer.
-
-        Examples:
-            >>> _make_layer(ResidualBlock, 3, 128, 256, 2)
-        """
-        layers = []
-
-        resnet_block = block(in_channel, out_channel, stride=stride)
-        layers.append(resnet_block)
-        for _ in range(1, layer_num):
-            resnet_block = block(out_channel, out_channel, stride=1)
-            layers.append(resnet_block)
-
-        return nn.SequentialCell(layers)
-
-def snn_resnet50_graph(class_num=10):
-    return ResNet_SNN_GRAPH(ResidualBlock_GRAPH,
-                            [3, 4, 6, 3],
-                            [64, 256, 512, 1024],
-                            [256, 512, 1024, 2048],
-                            [1, 2, 2, 2],
-                            class_num)
-
-
-def snn_resnet50_pynative(class_num=10):
-    return ResNet_SNN_PYNATIVE(ResidualBlock_PYNATIVE,
-                               [3, 4, 6, 3],
-                               [64, 256, 512, 1024],
-                               [256, 512, 1024, 2048],
-                               [1, 2, 2, 2],
-                               class_num)
+def snn_resnet50(class_num=10):
+    return ResNet_SNN(ResidualBlock,
+                      [3, 4, 6, 3],
+                      [64, 256, 512, 1024],
+                      [256, 512, 1024, 2048],
+                      [1, 2, 2, 2],
+                      class_num)
diff --git a/community/cv/snn/train.py b/community/cv/snn/train.py
index 3e9548756bc603825709cc36c881cb8258ddb3c7..ea111112458f5963c29759c548379e80d3467238 100644
--- a/community/cv/snn/train.py
+++ b/community/cv/snn/train.py
@@ -150,17 +150,11 @@ class AverageMeter:
 def snn_model_build():
     """build snn model for resnet50 and lenet"""
     if config.net_name == "resnet50":
-        if config.mode_name == 'GRAPH':
-            from src.snn_resnet import snn_resnet50_graph as snn_resnet50
-        else:
-            from src.snn_resnet import snn_resnet50_pynative as snn_resnet50
+        from src.snn_resnet import snn_resnet50
         net = snn_resnet50(class_num=config.class_num)
         init_weight(net=net)
     elif config.net_name == "lenet":
-        if config.mode_name == 'GRAPH':
-            from src.snn_lenet import snn_lenet_graph as snn_lenet
-        else:
-            from src.snn_lenet import snn_lenet_pynative as snn_lenet
+        from src.snn_lenet import snn_lenet
         net = snn_lenet(num_class=config.class_num)
     else:
         raise ValueError(f'config.model: {config.model_name} is not supported')
@@ -223,8 +217,6 @@ def train_net():
                 label = onehot(label, config.class_num, Tensor(1.0, ms.float32), Tensor(0.0, ms.float32))
             loss = network_train(images, label)
             loss_meter.update(loss.asnumpy())
-            if config.mode_name == 'PYNATIVE':
-                net.reset_net()
 
             if config.save_checkpoint:
                 cb_params.cur_epoch_num = epoch_idx + 1