From 1ed758bf49f5495808d2af67a49f3f99db94f132 Mon Sep 17 00:00:00 2001
From: deepr <hexiangdong2020@outlook.com>
Date: Mon, 19 Jul 2021 22:52:43 +0800
Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E8=B0=83=E8=AF=95DeepLab=20V?=
 =?UTF-8?q?3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 main.py                       |   4 +-
 nets/__init__.py              |   0
 nets/deeplab_v3/__init__.py   |   0
 nets/deeplab_v3/deeplab_v3.py | 219 ++++++++++++++++++++++++++++++++++
 nets/net_factory.py           |  18 +++
 5 files changed, 240 insertions(+), 1 deletion(-)
 create mode 100644 nets/__init__.py
 create mode 100644 nets/deeplab_v3/__init__.py
 create mode 100644 nets/deeplab_v3/deeplab_v3.py
 create mode 100644 nets/net_factory.py

diff --git a/main.py b/main.py
index f13b0a5..e6b7edc 100644
--- a/main.py
+++ b/main.py
@@ -11,12 +11,14 @@ from mindspore.train.callback import TimeMonitor, LossMonitor
 from mindspore import Model
 
 from unet_medical.unet_model import UNetMedical
+from nets.deeplab_v3 import deeplab_v3
 from dataset import GetDatasetGenerator
 from loss import SoftmaxCrossEntropyLoss
 
 context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False,
                     device_target='Ascend', device_id=7)
-net = UNetMedical(n_channels=3, n_classes=6)
+# net = UNetMedical(n_channels=3, n_classes=6)
+net = deeplab_v3.DeepLabV3(phase='train', num_classes=6, output_stride=16, freeze_bn=False)
 
 train_dataset_generator = GetDatasetGenerator('./datasets', 'train')
 train_dataset = ds.GeneratorDataset(train_dataset_generator, ["data", "label"], shuffle=True)
diff --git a/nets/__init__.py b/nets/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nets/deeplab_v3/__init__.py b/nets/deeplab_v3/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nets/deeplab_v3/deeplab_v3.py b/nets/deeplab_v3/deeplab_v3.py
new file mode 100644
index 0000000..e41de7d
--- /dev/null
+++ b/nets/deeplab_v3/deeplab_v3.py
@@ -0,0 +1,219 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, weight_init='xavier_uniform')
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1, padding=1):
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, pad_mode='pad', padding=padding,
+                     dilation=dilation, weight_init='xavier_uniform')
+
+
+class Resnet(nn.Cell):
+    def __init__(self, block, block_num, output_stride, use_batch_statistics=True):
+        super(Resnet, self).__init__()
+        self.inplanes = 64
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, pad_mode='pad', padding=3,
+                               weight_init='xavier_uniform')
+        self.bn1 = nn.BatchNorm2d(self.inplanes, use_batch_statistics=use_batch_statistics)
+        self.relu = nn.ReLU()
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
+        self.layer1 = self._make_layer(block, 64, block_num[0], use_batch_statistics=use_batch_statistics)
+        self.layer2 = self._make_layer(block, 128, block_num[1], stride=2, use_batch_statistics=use_batch_statistics)
+
+        if output_stride == 16:
+            self.layer3 = self._make_layer(block, 256, block_num[2], stride=2,
+                                           use_batch_statistics=use_batch_statistics)
+            self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4],
+                                           use_batch_statistics=use_batch_statistics)
+        elif output_stride == 8:
+            self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2,
+                                           use_batch_statistics=use_batch_statistics)
+            self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4],
+                                           use_batch_statistics=use_batch_statistics)
+
+    def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None, use_batch_statistics=True):
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.SequentialCell([
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                nn.BatchNorm2d(planes * block.expansion, use_batch_statistics=use_batch_statistics)
+            ])
+
+        if grids is None:
+            grids = [1] * blocks
+
+        layers = [
+            block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0],
+                  use_batch_statistics=use_batch_statistics)
+        ]
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(
+                block(self.inplanes, planes, dilation=base_dilation * grids[i],
+                      use_batch_statistics=use_batch_statistics))
+
+        return nn.SequentialCell(layers)
+
+    def construct(self, x):
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.maxpool(out)
+
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = self.layer3(out)
+        out = self.layer4(out)
+        return out
+
+
+class Bottleneck(nn.Cell):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, use_batch_statistics=True):
+        super(Bottleneck, self).__init__()
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bn1 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics)
+
+        self.conv2 = conv3x3(planes, planes, stride, dilation, dilation)
+        self.bn2 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics)
+
+        self.conv3 = conv1x1(planes, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion, use_batch_statistics=use_batch_statistics)
+
+        self.relu = nn.ReLU()
+        self.downsample = downsample
+
+        self.add = P.Add()
+
+    def construct(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out = self.add(out, identity)
+        out = self.relu(out)
+        return out
+
+
+class ASPP(nn.Cell):
+    def __init__(self, atrous_rates, phase='train', in_channels=2048, num_classes=21,
+                 use_batch_statistics=True):
+        super(ASPP, self).__init__()
+        self.phase = phase
+        out_channels = 256
+        self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0], use_batch_statistics=use_batch_statistics)
+        self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1], use_batch_statistics=use_batch_statistics)
+        self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2], use_batch_statistics=use_batch_statistics)
+        self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3], use_batch_statistics=use_batch_statistics)
+        self.aspp_pooling = ASPPPooling(in_channels, out_channels, use_batch_statistics=use_batch_statistics)
+        self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1,
+                               weight_init='xavier_uniform')
+        self.bn1 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
+        self.relu = nn.ReLU()
+        self.conv2 = nn.Conv2d(out_channels, num_classes, kernel_size=1, weight_init='xavier_uniform', has_bias=True)
+        self.concat = P.Concat(axis=1)
+        self.drop = nn.Dropout(0.3)
+
+    def construct(self, x):
+        x1 = self.aspp1(x)
+        x2 = self.aspp2(x)
+        x3 = self.aspp3(x)
+        x4 = self.aspp4(x)
+        x5 = self.aspp_pooling(x)
+
+        x = self.concat((x1, x2))
+        x = self.concat((x, x3))
+        x = self.concat((x, x4))
+        x = self.concat((x, x5))
+
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        if self.phase == 'train':
+            x = self.drop(x)
+        x = self.conv2(x)
+        return x
+
+
+class ASPPPooling(nn.Cell):
+    def __init__(self, in_channels, out_channels, use_batch_statistics=True):
+        super(ASPPPooling, self).__init__()
+        self.conv = nn.SequentialCell([
+            nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init='xavier_uniform'),
+            nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics),
+            nn.ReLU()
+        ])
+        self.shape = P.Shape()
+
+    def construct(self, x):
+        size = self.shape(x)
+        out = nn.AvgPool2d(size[2])(x)
+        out = self.conv(out)
+        out = P.ResizeNearestNeighbor((size[2], size[3]), True)(out)
+        return out
+
+
+class ASPPConv(nn.Cell):
+    def __init__(self, in_channels, out_channels, atrous_rate=1, use_batch_statistics=True):
+        super(ASPPConv, self).__init__()
+        if atrous_rate == 1:
+            conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False, weight_init='xavier_uniform')
+        else:
+            conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, pad_mode='pad', padding=atrous_rate,
+                             dilation=atrous_rate, weight_init='xavier_uniform')
+        bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
+        relu = nn.ReLU()
+        self.aspp_conv = nn.SequentialCell([conv, bn, relu])
+
+    def construct(self, x):
+        out = self.aspp_conv(x)
+        return out
+
+
+class DeepLabV3(nn.Cell):
+    def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False):
+        super(DeepLabV3, self).__init__()
+        use_batch_statistics = not freeze_bn
+        self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride,
+                             use_batch_statistics=use_batch_statistics)
+        self.aspp = ASPP([1, 6, 12, 18], phase, 2048, num_classes,
+                         use_batch_statistics=use_batch_statistics)
+        self.shape = P.Shape()
+
+    def construct(self, x):
+        size = self.shape(x)
+        out = self.resnet(x)
+        out = self.aspp(out)
+        out = P.ResizeBilinear((size[2], size[3]), True)(out)
+        return out
diff --git a/nets/net_factory.py b/nets/net_factory.py
new file mode 100644
index 0000000..1f5c50a
--- /dev/null
+++ b/nets/net_factory.py
@@ -0,0 +1,18 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from src.nets.deeplab_v3 import deeplab_v3
+nets_map = {'deeplab_v3_s8': deeplab_v3.DeepLabV3,
+            'deeplab_v3_s16': deeplab_v3.DeepLabV3}
-- 
GitLab