Skip to content
Snippets Groups Projects
Unverified Commit a9a87f13 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add convtranspose module (#5101)


* add convtranspose module

* add deconv2d module

* fix convtranspose2d build bug

* fix convtranspose module backward bug

* add deconv samples

* refine convtranspose2d docs

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 0eef39fb
No related branches found
No related tags found
No related merge requests found
......@@ -62,6 +62,7 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.ModuleList
.. autofunction:: oneflow.experimental.nn.ModuleDict
.. autofunction:: oneflow.experimental.nn.Conv2d
.. autofunction:: oneflow.experimental.nn.ConvTranspose2d
.. autofunction:: oneflow.experimental.nn.Dropout
.. autofunction:: oneflow.experimental.eq
.. autofunction:: oneflow.experimental.to
......
......@@ -62,9 +62,9 @@ Maybe<void> DeConvolutionNd::Init(const OpExpr& op) {
int32_t ndims = kernel_size_->size();
CHECK_EQ_OR_RETURN(ndims, strides_->size());
CHECK_EQ_OR_RETURN(ndims, dilation_rate_->size());
int32_t filters = JUST(op_trait_->GetAttr<int32_t>("filters"));
// int32_t filters = JUST(op_trait_->GetAttr<int32_t>("filters"));
activation_grad_op_ =
JUST(op_expr_helper::ConvNdOp(filters, *kernel_size_, *strides_, *padding_before_,
JUST(op_expr_helper::ConvNdOp(/*filters=1*/ 1, *kernel_size_, *strides_, *padding_before_,
*dilation_rate_, /*groups=*/1, *data_format_));
weight_grad_op_ = JUST(op_expr_helper::ConvNdFilterGradOp(
*kernel_size_, *strides_, *padding_before_, *dilation_rate_, /*groups=*/1, *data_format_));
......@@ -89,8 +89,11 @@ Maybe<void> DeConvolutionNd::Apply(const DeConvolutionNdInterpState* ctx,
in_grads->resize(2);
if (ctx->activation_requires_grad) {
const auto& weight = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*activation_grad_op_,
{out_grads.at(0), weight}, /*attrs=*/{}));
MutableAttrMap attrs;
const int32_t filters = weight->shape()->At(0);
JUST(attrs.SetAttr<int32_t>("filters", filters));
in_grads->at(0) = JUST(
OpInterpUtil::Dispatch<Tensor>(*activation_grad_op_, {out_grads.at(0), weight}, attrs));
}
if (ctx->weight_requires_grad) {
int idx = ctx->activation_requires_grad;
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.modules.utils import _pair
from oneflow.python.nn.common_types import _size_2_t
from oneflow.python.nn import init
@oneflow_export("nn.ConvTranspose2d")
@experimental_api
class ConvTranspose2d(Module):
r"""
Applies a 2D transposed convolution operator over an input image composed of several input planes.
This module can be seen as the gradient of Conv2d with respect to its input.
It is also known as a fractionally-strided convolution or
a deconvolution (although it is not an actual deconvolution operation).
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding
will be added to both sides of each dimension in the input. Default: 0
output_padding (int or tuple, optional): Additional size added to one side
of each dimension in the output shape. Default: 0
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
.. math::
H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0]
\times (\text{kernel_size}[0] - 1) + \text{output_padding}[0] + 1
.. math::
W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1]
\times (\text{kernel_size}[1] - 1) + \text{output_padding}[1] + 1
Attributes:
weight (Tensor): the learnable weights of the module of shape
:math:`(\text{in_channels}, \frac{\text{out_channels}}{\text{groups}},`
:math:`\text{kernel_size[0]}, \text{kernel_size[1]})`.
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel_size}[i]}`
bias (Tensor): the learnable bias of the module of shape (out_channels)
If :attr:`bias` is ``True``, then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel_size}[i]}`
Examples::
>>> import numpy as np
>>> import oneflow.experimental as flow
>>> import oneflow.experimental.nn as nn
>>> flow.enable_eager_execution()
>>> m = nn.ConvTranspose2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> m = m.to("cuda")
>>> input = flow.Tensor(np.random.randn(20, 16, 50, 100), device=flow.device("cuda"))
>>> output = m(input)
>>> output.size()
flow.Size([20, 33, 93, 100])
.. _cross-correlation:
https://en.wikipedia.org/wiki/Cross-correlation
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: _size_2_t = 0,
output_padding: _size_2_t = 0,
groups: int = 1,
bias: bool = True,
dilation: int = 1,
padding_mode: str = "zeros",
) -> None:
super().__init__()
assert padding_mode == "zeros"
assert groups == 1, f"not support group convtranspose2d now!"
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
output_padding = _pair(output_padding)
dilation = _pair(dilation)
self.groups = groups
self.weight = flow.nn.Parameter(
flow.Tensor(in_channels, out_channels // groups, *kernel_size)
)
self.bias = None
self._bias_add_op = None
if bias:
self.bias = flow.nn.Parameter(flow.Tensor(out_channels // groups))
self._bias_add_op = (
flow.builtin_op("bias_add")
.Input("a")
.Input("b")
.Output("out")
.Attr("axis", 1)
.Build()
)
self._op = (
flow.builtin_op("deconv2d")
.Input("in")
.Input("weight")
.Attr("filters", out_channels)
.Attr("padding_before", padding)
.Attr("data_format", "channels_first")
.Attr("kernel_size", kernel_size)
.Attr("strides", stride)
.Attr("dilation_rate", dilation)
.Attr("output_padding", output_padding)
.Attr("groups", groups)
.Output("out")
.Build()
)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, x):
res = self._op(x, self.weight)[0]
if self._bias_add_op is not None:
res = self._bias_add_op(res, self.bias)[0]
return res
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
import oneflow.experimental.nn as nn
from test_util import GenArgList
def _test_deconv_bias_false(test_case, device):
np_arr = np.array(
[
[
[
[0.2735021114349365, -1.3842310905456543],
[1.058540940284729, -0.03388553857803345],
]
]
]
)
input = flow.Tensor(
np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True
)
weight = np.array(
[
[
[
[0.06456436216831207, -0.10852358490228653, -0.21638715267181396],
[-0.2279110550880432, 0.1476770043373108, 0.19457484781742096],
[0.05026858672499657, 0.10818571597337723, 0.02056501805782318],
],
[
[0.205095112323761, 0.1488947868347168, -0.2344113141298294],
[0.1684819906949997, -0.21986986696720123, 0.1082606166601181],
[-0.1528974026441574, 0.17120417952537537, 0.01954500749707222],
],
]
]
)
m = nn.ConvTranspose2d(1, 2, 3, stride=1, bias=False)
m.weight = flow.nn.Parameter(flow.Tensor(weight))
m = m.to(device)
output = m(input)
np_out = np.array(
[
[
[
[
0.01765848882496357,
-0.1190534234046936,
0.09103937447071075,
0.2995298206806183,
],
[
0.006009865552186966,
0.2388070970773697,
-0.37657976150512695,
-0.26200416684150696,
],
[
-0.22750461101531982,
0.12405071407556534,
0.056831881403923035,
-0.035060010850429535,
],
[
0.053211357444524765,
0.11281562596559525,
0.0181029811501503,
-0.0006968567031435668,
],
],
[
[
0.05609394609928131,
-0.24317599833011627,
-0.27021679282188416,
0.32447943091392517,
],
[
0.26318174600601196,
-0.14269141852855682,
0.08078087121248245,
-0.14191456139087677,
],
[
0.13652732968330383,
0.020019691437482834,
-0.10959184169769287,
-0.03072327747941017,
],
[
-0.16184815764427185,
0.1864076405763626,
0.014887845143675804,
-0.0006622931105084717,
],
],
]
]
)
test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6))
output = output.sum()
output.backward()
np_grad = [
[
[
[0.24731683731079102, 0.24731683731079102],
[0.24731683731079102, 0.24731683731079102],
]
]
]
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6))
def _test_deconv_bias_true(test_case, device):
np_arr = np.array(
[
[
[
[0.2735021114349365, -1.3842310905456543],
[1.058540940284729, -0.03388553857803345],
]
]
]
)
input = flow.Tensor(
np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True
)
weight = np.array(
[
[
[
[0.06456436216831207, -0.10852358490228653, -0.21638715267181396],
[-0.2279110550880432, 0.1476770043373108, 0.19457484781742096],
[0.05026858672499657, 0.10818571597337723, 0.02056501805782318],
],
[
[0.205095112323761, 0.1488947868347168, -0.2344113141298294],
[0.1684819906949997, -0.21986986696720123, 0.1082606166601181],
[-0.1528974026441574, 0.17120417952537537, 0.01954500749707222],
],
]
]
)
bias = np.array([0.06456436216831207, -0.10852358490228653])
m = nn.ConvTranspose2d(1, 2, 3, stride=1)
m.weight = flow.nn.Parameter(flow.Tensor(weight))
m.bias = flow.nn.Parameter(flow.Tensor(bias))
m = m.to(device)
output = m(input)
np_out = [
[
[
[
0.0822228491306305,
-0.05448906123638153,
0.15560373663902283,
0.36409419775009155,
],
[
0.07057422399520874,
0.30337145924568176,
-0.3120154142379761,
-0.19743980467319489,
],
[
-0.16294024884700775,
0.188615083694458,
0.12139624357223511,
0.029504351317882538,
],
[
0.11777572333812714,
0.17737999558448792,
0.08266734331846237,
0.06386750191450119,
],
],
[
[
-0.05242963880300522,
-0.3516995906829834,
-0.3787403702735901,
0.21595585346221924,
],
[
0.15465816855430603,
-0.25121501088142395,
-0.027742713689804077,
-0.2504381537437439,
],
[
0.028003744781017303,
-0.088503897190094,
-0.2181154191493988,
-0.139246866106987,
],
[
-0.2703717350959778,
0.07788405567407608,
-0.09363573789596558,
-0.10918587446212769,
],
],
]
]
test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6))
output = output.sum()
output.backward()
np_grad = [
[
[
[0.24731683731079102, 0.24731683731079102],
[0.24731683731079102, 0.24731683731079102],
]
]
]
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestLess(flow.unittest.TestCase):
def test_less(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_deconv_bias_false,
_test_deconv_bias_true,
]
arg_dict["device"] = ["cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment