From dbf34a1dd8cc959732aa9f37d609318c157b2569 Mon Sep 17 00:00:00 2001 From: ZZK <42901638+MARD1NO@users.noreply.github.com> Date: Mon, 2 Aug 2021 14:18:52 +0800 Subject: [PATCH] Add conv3d Module (#5327) * add conv3d module * add simple test case * use conv base class to write conv3d * still test error * add torch style conv3d unit test * fix format * add assert * unittest still error * auto format by CI * fix format and autotest * remove dir * remove useless file * add extra expr * auto format by CI * fix import * fix doc Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org> --- oneflow/core/functional/functional_api.yaml | 6 + oneflow/core/functional/impl/nn_functor.cpp | 9 + python/oneflow/nn/__init__.py | 2 +- python/oneflow/nn/modules/conv.py | 208 +++++++++++++++++++- python/oneflow/test/modules/test_conv3d.py | 46 +++++ 5 files changed, 268 insertions(+), 3 deletions(-) create mode 100644 python/oneflow/test/modules/test_conv3d.py diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index d721a6512..a3318bf39 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -303,6 +303,12 @@ "Tensor Conv2d(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, Int32List padding, Int32List dilation, Int32 groups=1)" bind_python: True + +- name: "conv3d" + signature: + "Tensor Conv3d(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, + Int32List padding, Int32List dilation, Int32 groups=1)" + bind_python: True - name: "conv_data_grad" signature: diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 29ba46df2..48e05f1ef 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -107,6 +107,14 @@ class Conv2dFunctor : public ConvBaseFunctor { } }; +class Conv3dFunctor : public ConvBaseFunctor { + public: + Conv3dFunctor() : ConvBaseFunctor(/*num_spatial_dims_=*/3) { + conv_op_ = + CHECK_JUST(one::OpBuilder("conv3d").Input("in").Input("weight").Output("out").Build()); + } +}; + class MatMulBaseFunctor { public: MatMulBaseFunctor() = default; @@ -508,6 +516,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::BiasAddFunctor>("BiasAdd"); m.add_functor<impl::Conv1dFunctor>("Conv1d"); m.add_functor<impl::Conv2dFunctor>("Conv2d"); + m.add_functor<impl::Conv3dFunctor>("Conv3d"); m.add_functor<impl::MatMulFunctor>("MatMul"); m.add_functor<impl::BatchMatMulFunctor>("BatchMatMul"); m.add_functor<impl::BroadcastMatMulFunctor>("BroadcastMatMul"); diff --git a/python/oneflow/nn/__init__.py b/python/oneflow/nn/__init__.py index afbd5067a..c39f5e80d 100644 --- a/python/oneflow/nn/__init__.py +++ b/python/oneflow/nn/__init__.py @@ -49,7 +49,7 @@ from oneflow.nn.modules.container import ( ParameterList, Sequential, ) -from oneflow.nn.modules.conv import Conv1d, Conv2d +from oneflow.nn.modules.conv import Conv1d, Conv2d, Conv3d from oneflow.nn.modules.dataset import ( COCOReader, CoinFlip, diff --git a/python/oneflow/nn/modules/conv.py b/python/oneflow/nn/modules/conv.py index 7a24f11ed..0f4f0a9ed 100644 --- a/python/oneflow/nn/modules/conv.py +++ b/python/oneflow/nn/modules/conv.py @@ -17,9 +17,9 @@ import math import oneflow as flow from oneflow.nn import init -from oneflow.nn.common_types import _size_1_t, _size_2_t +from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t from oneflow.nn.module import Module -from oneflow.nn.modules.utils import _pair, _single +from oneflow.nn.modules.utils import _single, _pair, _triple def slice(x, begin, size): @@ -483,6 +483,210 @@ class Conv2d(Module): return s.format(**self.__dict__) +class Conv3d(Module): + r"""The interface is consistent with PyTorch. + The documentation is referenced from: https://pytorch.org/docs/master/generated/torch.nn.Conv3d.html#conv3d + + Applies a 3D convolution over an input signal composed of several input + planes. + In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` + and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as: + .. math:: + out(N_i, C_{out_j}) = bias(C_{out_j}) + + \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k) + where :math:`\star` is the valid 3D `cross-correlation`_ operator + This module supports :ref:`TensorFloat32<tf32_on_ampere>`. + * :attr:`stride` controls the stride for the cross-correlation. + * :attr:`padding` controls the amount of implicit zero-paddings on both + sides for :attr:`padding` number of points for each dimension. + * :attr:`dilation` controls the spacing between the kernel points; also known as the 脿 trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + * :attr:`groups` controls the connections between inputs and outputs. + :attr:`in_channels` and :attr:`out_channels` must both be divisible by + :attr:`groups`. For example, + * At groups=1, all inputs are convolved to all outputs. + * At groups=2, the operation becomes equivalent to having two conv + layers side by side, each seeing half the input channels, + and producing half the output channels, and both subsequently + concatenated. + * At groups= :attr:`in_channels`, each input channel is convolved with + its own set of filters, of size + :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`. + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + Note: + Depending of the size of your kernel, several (of the last) + columns of the input might be lost, because it is a valid `cross-correlation`_, + and not a full `cross-correlation`_. + It is up to the user to add proper padding. + Note: + When `groups == in_channels` and `out_channels == K * in_channels`, + where `K` is a positive integer, this operation is also termed in + literature as depthwise convolution. + In other words, for an input of size :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`, + a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments + :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`. + Note: + In some circumstances when using the CUDA backend with CuDNN, this operator + may select a nondeterministic algorithm to increase performance. If this is + undesirable, you can try to make the operation deterministic (potentially at + a performance cost) by setting ``torch.backends.cudnn.deterministic = + True``. + Please see the notes on :doc:`/notes/randomness` for background. + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + Shape: + - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] + \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, + then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + For example: + .. code-block:: python + >>> import numpy as np + >>> import oneflow as flow + >>> import oneflow.nn as nn + + >>> arr = np.random.randn(1, 2, 5, 5, 5) + >>> input = flow.Tensor(arr) + >>> m = nn.Conv3d(2, 4, kernel_size=3, stride=1) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", # TODO: refine this type + ): + super().__init__() + + assert padding_mode == "zeros" + self.kernel_size = _triple(kernel_size) + self.stride = _triple(stride) + self.padding = _triple(padding) + self.dilation = _triple(dilation) + self.groups = groups + assert in_channels % groups == 0, "in_channels must be divisible by groups" + assert out_channels % groups == 0, "out_channels must be divisible by groups" + self.in_channels = in_channels + self.out_channels = out_channels + self.weight = flow.nn.Parameter( + flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + self.out_channel_groups = out_channels // groups + self.bias = None + if bias: + self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self) -> None: + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, x): + if x.shape[1] != self.in_channels: + raise ValueError("The input channels should be equal to self.in_channels") + if x.device.type == "cpu" and self.groups > 1: + in_channel_axis = 1 + in_split_list = ConvUtil.split( + x, axis=in_channel_axis, split_num=self.groups + ) + out_list = [] + for i in range(len(in_split_list)): + out_list.append( + flow.F.conv3d( + in_split_list[i], + self.weight[ + i + * self.out_channel_groups : (i + 1) + * self.out_channel_groups, + :, + :, + :, + ], + self.bias[ + i + * self.out_channel_groups : (i + 1) + * self.out_channel_groups + ] + if self.bias + else None, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=1, + ) + ) + res = flow.experimental.cat(out_list, dim=in_channel_axis) + else: + res = flow.F.conv3d( + x, + self.weight, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + return res + + def extra_repr(self): + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is None: + s += ", bias=False" + if self.padding_mode != "zeros": + s += ", padding_mode={padding_mode}" + return s.format(**self.__dict__) + + if __name__ == "__main__": import doctest diff --git a/python/oneflow/test/modules/test_conv3d.py b/python/oneflow/test/modules/test_conv3d.py new file mode 100644 index 000000000..9d2414cde --- /dev/null +++ b/python/oneflow/test/modules/test_conv3d.py @@ -0,0 +1,46 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import oneflow as flow +import oneflow.unittest +from automated_test_util import * + + +@flow.unittest.skip_unless_1n1d() +class TestConv3DModule(flow.unittest.TestCase): + @autotest(n=2) + def test_conv3d_with_random_data(test_case): + channels = random(1, 6) + m = torch.nn.Conv3d( + in_channels=channels, + out_channels=random(1, 6), + kernel_size=random(1, 3), + stride=random() | nothing(), + padding=random(1, 3).to(int) | nothing(), + dilation=random(1, 5) | nothing(), + groups=random(1, 5) | nothing(), + padding_mode=constant("zeros") | nothing(), + ) + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(ndim=5, dim0=2, dim1=channels).to(device) + y = m(x) + return y + + +if __name__ == "__main__": + unittest.main() -- GitLab