From dbf34a1dd8cc959732aa9f37d609318c157b2569 Mon Sep 17 00:00:00 2001
From: ZZK <42901638+MARD1NO@users.noreply.github.com>
Date: Mon, 2 Aug 2021 14:18:52 +0800
Subject: [PATCH] Add conv3d Module (#5327)

* add conv3d module

* add simple test case

* use conv base class to write conv3d

* still test error

* add torch style conv3d unit test

* fix format

* add assert

* unittest still error

* auto format by CI

* fix format and autotest

* remove dir

* remove useless file

* add extra expr

* auto format by CI

* fix import

* fix doc

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
---
 oneflow/core/functional/functional_api.yaml |   6 +
 oneflow/core/functional/impl/nn_functor.cpp |   9 +
 python/oneflow/nn/__init__.py               |   2 +-
 python/oneflow/nn/modules/conv.py           | 208 +++++++++++++++++++-
 python/oneflow/test/modules/test_conv3d.py  |  46 +++++
 5 files changed, 268 insertions(+), 3 deletions(-)
 create mode 100644 python/oneflow/test/modules/test_conv3d.py

diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index d721a6512..a3318bf39 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -303,6 +303,12 @@
     "Tensor Conv2d(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, 
                    Int32List padding, Int32List dilation, Int32 groups=1)"
   bind_python: True
+
+- name: "conv3d"
+  signature:
+    "Tensor Conv3d(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, 
+                   Int32List padding, Int32List dilation, Int32 groups=1)"
+  bind_python: True
   
 - name: "conv_data_grad"
   signature:
diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp
index 29ba46df2..48e05f1ef 100644
--- a/oneflow/core/functional/impl/nn_functor.cpp
+++ b/oneflow/core/functional/impl/nn_functor.cpp
@@ -107,6 +107,14 @@ class Conv2dFunctor : public ConvBaseFunctor {
   }
 };
 
+class Conv3dFunctor : public ConvBaseFunctor {
+ public:
+  Conv3dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/3) {
+    conv_op_ =
+        CHECK_JUST(one::OpBuilder("conv3d").Input("in").Input("weight").Output("out").Build());
+  }
+};
+
 class MatMulBaseFunctor {
  public:
   MatMulBaseFunctor() = default;
@@ -508,6 +516,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::BiasAddFunctor>("BiasAdd");
   m.add_functor<impl::Conv1dFunctor>("Conv1d");
   m.add_functor<impl::Conv2dFunctor>("Conv2d");
+  m.add_functor<impl::Conv3dFunctor>("Conv3d");
   m.add_functor<impl::MatMulFunctor>("MatMul");
   m.add_functor<impl::BatchMatMulFunctor>("BatchMatMul");
   m.add_functor<impl::BroadcastMatMulFunctor>("BroadcastMatMul");
diff --git a/python/oneflow/nn/__init__.py b/python/oneflow/nn/__init__.py
index afbd5067a..c39f5e80d 100644
--- a/python/oneflow/nn/__init__.py
+++ b/python/oneflow/nn/__init__.py
@@ -49,7 +49,7 @@ from oneflow.nn.modules.container import (
     ParameterList,
     Sequential,
 )
-from oneflow.nn.modules.conv import Conv1d, Conv2d
+from oneflow.nn.modules.conv import Conv1d, Conv2d, Conv3d
 from oneflow.nn.modules.dataset import (
     COCOReader,
     CoinFlip,
diff --git a/python/oneflow/nn/modules/conv.py b/python/oneflow/nn/modules/conv.py
index 7a24f11ed..0f4f0a9ed 100644
--- a/python/oneflow/nn/modules/conv.py
+++ b/python/oneflow/nn/modules/conv.py
@@ -17,9 +17,9 @@ import math
 
 import oneflow as flow
 from oneflow.nn import init
-from oneflow.nn.common_types import _size_1_t, _size_2_t
+from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t
 from oneflow.nn.module import Module
-from oneflow.nn.modules.utils import _pair, _single
+from oneflow.nn.modules.utils import _single, _pair, _triple
 
 
 def slice(x, begin, size):
@@ -483,6 +483,210 @@ class Conv2d(Module):
         return s.format(**self.__dict__)
 
 
+class Conv3d(Module):
+    r"""The interface is consistent with PyTorch.    
+    The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.Conv3d.html#conv3d
+    
+    Applies a 3D convolution over an input signal composed of several input
+    planes.
+    In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)`
+    and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as:
+    .. math::
+        out(N_i, C_{out_j}) = bias(C_{out_j}) +
+                                \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k)
+    where :math:`\star` is the valid 3D `cross-correlation`_ operator
+    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
+    * :attr:`stride` controls the stride for the cross-correlation.
+    * :attr:`padding` controls the amount of implicit zero-paddings on both
+      sides for :attr:`padding` number of points for each dimension.
+    * :attr:`dilation` controls the spacing between the kernel points; also known as the 脿 trous algorithm.
+      It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
+    * :attr:`groups` controls the connections between inputs and outputs.
+      :attr:`in_channels` and :attr:`out_channels` must both be divisible by
+      :attr:`groups`. For example,
+        * At groups=1, all inputs are convolved to all outputs.
+        * At groups=2, the operation becomes equivalent to having two conv
+          layers side by side, each seeing half the input channels,
+          and producing half the output channels, and both subsequently
+          concatenated.
+        * At groups= :attr:`in_channels`, each input channel is convolved with
+          its own set of filters, of size
+          :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.
+    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
+        - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
+        - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
+          the second `int` for the height dimension and the third `int` for the width dimension
+    Note:
+         Depending of the size of your kernel, several (of the last)
+         columns of the input might be lost, because it is a valid `cross-correlation`_,
+         and not a full `cross-correlation`_.
+         It is up to the user to add proper padding.
+    Note:
+        When `groups == in_channels` and `out_channels == K * in_channels`,
+        where `K` is a positive integer, this operation is also termed in
+        literature as depthwise convolution.
+        In other words, for an input of size :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`,
+        a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
+        :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
+    Note:
+        In some circumstances when using the CUDA backend with CuDNN, this operator
+        may select a nondeterministic algorithm to increase performance. If this is
+        undesirable, you can try to make the operation deterministic (potentially at
+        a performance cost) by setting ``torch.backends.cudnn.deterministic =
+        True``.
+        Please see the notes on :doc:`/notes/randomness` for background.
+    Args:
+        in_channels (int): Number of channels in the input image
+        out_channels (int): Number of channels produced by the convolution
+        kernel_size (int or tuple): Size of the convolving kernel
+        stride (int or tuple, optional): Stride of the convolution. Default: 1
+        padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
+        padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
+        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
+        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
+        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
+    Shape:
+        - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
+        - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
+          .. math::
+              D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
+                    \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
+          .. math::
+              H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
+                    \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
+          .. math::
+              W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2]
+                    \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
+    Attributes:
+        weight (Tensor): the learnable weights of the module of shape
+                         :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
+                         :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`.
+                         The values of these weights are sampled from
+                         :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+                         :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
+        bias (Tensor):   the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
+                         then the values of these weights are
+                         sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
+                         :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}`
+    For example: 
+    .. code-block:: python
+        >>> import numpy as np
+        >>> import oneflow as flow
+        >>> import oneflow.nn as nn
+
+        >>> arr = np.random.randn(1, 2, 5, 5, 5)
+        >>> input = flow.Tensor(arr)
+        >>> m = nn.Conv3d(2, 4, kernel_size=3, stride=1)
+        >>> output = m(input)
+        
+    .. _cross-correlation:
+        https://en.wikipedia.org/wiki/Cross-correlation
+    .. _link:
+        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: _size_3_t,
+        stride: _size_3_t = 1,
+        padding: _size_3_t = 0,
+        dilation: _size_3_t = 1,
+        groups: int = 1,
+        bias: bool = True,
+        padding_mode: str = "zeros",  # TODO: refine this type
+    ):
+        super().__init__()
+
+        assert padding_mode == "zeros"
+        self.kernel_size = _triple(kernel_size)
+        self.stride = _triple(stride)
+        self.padding = _triple(padding)
+        self.dilation = _triple(dilation)
+        self.groups = groups
+        assert in_channels % groups == 0, "in_channels must be divisible by groups"
+        assert out_channels % groups == 0, "out_channels must be divisible by groups"
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.weight = flow.nn.Parameter(
+            flow.Tensor(out_channels, in_channels // groups, *self.kernel_size)
+        )
+        self.out_channel_groups = out_channels // groups
+        self.bias = None
+        if bias:
+            self.bias = flow.nn.Parameter(flow.Tensor(out_channels))
+        self.reset_parameters()
+
+    def reset_parameters(self) -> None:
+        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+        if self.bias is not None:
+            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
+            bound = 1 / math.sqrt(fan_in)
+            init.uniform_(self.bias, -bound, bound)
+
+    def forward(self, x):
+        if x.shape[1] != self.in_channels:
+            raise ValueError("The input channels should be equal to self.in_channels")
+        if x.device.type == "cpu" and self.groups > 1:
+            in_channel_axis = 1
+            in_split_list = ConvUtil.split(
+                x, axis=in_channel_axis, split_num=self.groups
+            )
+            out_list = []
+            for i in range(len(in_split_list)):
+                out_list.append(
+                    flow.F.conv3d(
+                        in_split_list[i],
+                        self.weight[
+                            i
+                            * self.out_channel_groups : (i + 1)
+                            * self.out_channel_groups,
+                            :,
+                            :,
+                            :,
+                        ],
+                        self.bias[
+                            i
+                            * self.out_channel_groups : (i + 1)
+                            * self.out_channel_groups
+                        ]
+                        if self.bias
+                        else None,
+                        stride=self.stride,
+                        padding=self.padding,
+                        dilation=self.dilation,
+                        groups=1,
+                    )
+                )
+            res = flow.experimental.cat(out_list, dim=in_channel_axis)
+        else:
+            res = flow.F.conv3d(
+                x,
+                self.weight,
+                self.bias,
+                stride=self.stride,
+                padding=self.padding,
+                dilation=self.dilation,
+                groups=self.groups,
+            )
+        return res
+
+    def extra_repr(self):
+        s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}"
+        if self.padding != (0,) * len(self.padding):
+            s += ", padding={padding}"
+        if self.dilation != (1,) * len(self.dilation):
+            s += ", dilation={dilation}"
+        if self.groups != 1:
+            s += ", groups={groups}"
+        if self.bias is None:
+            s += ", bias=False"
+        if self.padding_mode != "zeros":
+            s += ", padding_mode={padding_mode}"
+        return s.format(**self.__dict__)
+
+
 if __name__ == "__main__":
     import doctest
 
diff --git a/python/oneflow/test/modules/test_conv3d.py b/python/oneflow/test/modules/test_conv3d.py
new file mode 100644
index 000000000..9d2414cde
--- /dev/null
+++ b/python/oneflow/test/modules/test_conv3d.py
@@ -0,0 +1,46 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+import oneflow as flow
+import oneflow.unittest
+from automated_test_util import *
+
+
+@flow.unittest.skip_unless_1n1d()
+class TestConv3DModule(flow.unittest.TestCase):
+    @autotest(n=2)
+    def test_conv3d_with_random_data(test_case):
+        channels = random(1, 6)
+        m = torch.nn.Conv3d(
+            in_channels=channels,
+            out_channels=random(1, 6),
+            kernel_size=random(1, 3),
+            stride=random() | nothing(),
+            padding=random(1, 3).to(int) | nothing(),
+            dilation=random(1, 5) | nothing(),
+            groups=random(1, 5) | nothing(),
+            padding_mode=constant("zeros") | nothing(),
+        )
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor(ndim=5, dim0=2, dim1=channels).to(device)
+        y = m(x)
+        return y
+
+
+if __name__ == "__main__":
+    unittest.main()
-- 
GitLab