diff --git a/research/cv/ICNet/src/models/icnet.py b/research/cv/ICNet/src/models/icnet.py index 9973816cac0784cf1d0712fa1805105d9f0c6839..5b589b1eb09944b562949fbe3a8be85327e5b8e9 100644 --- a/research/cv/ICNet/src/models/icnet.py +++ b/research/cv/ICNet/src/models/icnet.py @@ -24,7 +24,7 @@ __all__ = ['ICNet'] class ICNet(nn.Cell): """Image Cascade Network""" - def __init__(self, nclass=19, backbone='resnet50', pretrained_path='', istraining=True): + def __init__(self, nclass=19, backbone='resnet50', pretrained_path='', istraining=False): super(ICNet, self).__init__() self.conv_sub1 = nn.SequentialCell( _ConvBNReLU(3, 32, 3, 2), @@ -34,7 +34,7 @@ class ICNet(nn.Cell): self.istraining = istraining self.ppm = PyramidPoolingModule() - self.backbone = SegBaseModel(root=pretrained_path) + self.backbone = SegBaseModel(root=pretrained_path, istraining=istraining) self.head = _ICHead(nclass) @@ -228,11 +228,11 @@ class CascadeFeatureFusion24(nn.Cell): class SegBaseModel(nn.Cell): """Base Model for Semantic Segmentation""" - def __init__(self, nclass=19, backbone='resnet50', root=''): + def __init__(self, nclass=19, backbone='resnet50', root='', istraining=False): super(SegBaseModel, self).__init__() self.nclass = nclass if backbone == 'resnet50': - self.pretrained = get_resnet50v1b(ckpt_root=root) + self.pretrained = get_resnet50v1b(ckpt_root=root, istraining=istraining) def construct(self, x): """forwarding pre-trained network""" diff --git a/research/cv/ICNet/src/models/icnet_dc.py b/research/cv/ICNet/src/models/icnet_dc.py index 0ecd90ec82f7b7fbd92191a742648fb368a22ade..57d7bcc2136e5d6f442907f63cdeb4ae3cf6c9f2 100644 --- a/research/cv/ICNet/src/models/icnet_dc.py +++ b/research/cv/ICNet/src/models/icnet_dc.py @@ -34,7 +34,7 @@ class ICNetdc(nn.Cell): self.istraining = istraining self.ppm = PyramidPoolingModule() - self.backbone = SegBaseModel(root=pretrained_path) + self.backbone = SegBaseModel(root=pretrained_path, istraining=istraining) self.head = _ICHead(nclass, norm_layer=norm_layer) @@ -230,11 +230,11 @@ class CascadeFeatureFusion24(nn.Cell): class SegBaseModel(nn.Cell): """Base Model for Semantic Segmentation""" - def __init__(self, nclass=19, backbone='resnet50', root=""): + def __init__(self, nclass=19, backbone='resnet50', root="", istraining=False): super(SegBaseModel, self).__init__() self.nclass = nclass if backbone == 'resnet50': - self.pretrained = get_resnet50v1b(ckpt_root=root) + self.pretrained = get_resnet50v1b(ckpt_root=root, istraining=istraining) def construct(self, x): """forwarding pre-trained network""" diff --git a/research/cv/ICNet/src/models/resnet50_v1.py b/research/cv/ICNet/src/models/resnet50_v1.py index b2a934fe75475f1b77cd1ad771777360b01b9153..c091ec5fa5361bea3f33ccbebd42f3c7d7bb9553 100644 --- a/research/cv/ICNet/src/models/resnet50_v1.py +++ b/research/cv/ICNet/src/models/resnet50_v1.py @@ -258,7 +258,7 @@ class Resnet50v1b(nn.Cell): return out -def get_resnet50v1b(class_num=1001, ckpt_root='', pretrained=True): +def get_resnet50v1b(class_num=1001, ckpt_root='', pretrained=True, istraining=False): """ Get SE-ResNet50 neural network. Default : GE Theta+ version (best) @@ -278,7 +278,7 @@ def get_resnet50v1b(class_num=1001, ckpt_root='', pretrained=True): out_channels=[256, 512, 1024, 2048], strides=[1, 2, 2, 2], num_classes=class_num) - + model.set_train(istraining) if pretrained and model.training: pretrained_ckpt = ckpt_root param_dict = load_checkpoint(pretrained_ckpt)