Skip to content
Snippets Groups Projects
Unverified Commit 1af27df2 authored by YongtaoShi's avatar YongtaoShi Committed by GitHub
Browse files

add pixel_shuffle module (#5135)


* add pixel_shuffle module

* add pixel_shuffle testcase

* simplify implement

* simplify implement

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent a9a87f13
No related branches found
No related tags found
No related merge requests found
......@@ -86,6 +86,7 @@ Experimental features
.. autofunction:: oneflow.experimental.lt
.. autofunction:: oneflow.experimental.Tensor.lt
.. autofunction:: oneflow.experimental.nn.Identity
.. autofunction:: oneflow.experimental.nn.PixelShuffle
.. autofunction:: oneflow.experimental.nn.Linear
.. autofunction:: oneflow.experimental.nn.CrossEntropyLoss
.. autofunction:: oneflow.experimental.nn.NLLLoss
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from oneflow.python.framework.tensor import Tensor
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
@oneflow_export("nn.PixelShuffle")
@experimental_api
class PixelShuffle(Module):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#torch.nn.PixelShuffle
Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor.
This is useful for implementing efficient sub-pixel convolution
with a stride of :math:`1/r`.
See the paper:
`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_
by Shi et. al (2016) for more details.
Args:
upscale_factor (int): factor to increase spatial resolution by
Shape:
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
.. math::
C_{out} = C_{in} \div \text{upscale_factor}^2
.. math::
H_{out} = H_{in} \times \text{upscale_factor}
.. math::
W_{out} = W_{in} \times \text{upscale_factor}
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> m = flow.nn.PixelShuffle(upscale_factor=2)
>>> x = flow.Tensor(np.random.randn(3, 4, 5, 5))
>>> y = m(x)
>>> print(y.size())
flow.Size([3, 1, 10, 10])
>>> m = flow.nn.PixelShuffle(upscale_factor=3)
>>> x = flow.Tensor(np.random.randn(1, 18, 2, 2))
>>> y = m(x)
>>> print(y.size())
flow.Size([1, 2, 6, 6])
.. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
https://arxiv.org/abs/1609.05158
"""
def __init__(self, upscale_factor: int) -> None:
super().__init__()
assert upscale_factor > 0, "The scale factor must larger than zero"
self.upscale_factor = upscale_factor
def forward(self, input: Tensor) -> Tensor:
assert len(input.shape) == 4, "Only Accept 4D Tensor"
_batch, _channel, _height, _width = input.shape
assert (
_channel % (self.upscale_factor ** 2) == 0
), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)"
_new_c = int(_channel / (self.upscale_factor ** 2))
out = input.reshape(
[_batch, _new_c, self.upscale_factor ** 2, _height, _width,]
)
out = out.reshape(
[_batch, _new_c, self.upscale_factor, self.upscale_factor, _height, _width,]
)
out = out.permute(0, 1, 4, 2, 5, 3)
out = out.reshape(
[
_batch,
_new_c,
_height * self.upscale_factor,
_width * self.upscale_factor,
]
)
return out
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
......@@ -4289,6 +4289,7 @@ def triplet_margin_loss(
@oneflow_export("nn.PixelShuffle")
@stable_api
def pixel_shuffle(
input: oneflow._oneflow_internal.BlobDesc,
upscale_factor: int,
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList
def _np_pixel_shuffle(input, factor):
_batch, _channel, _height, _width = input.shape
assert (
_channel % (factor ** 2) == 0
), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)"
_new_c = int(_channel / (factor ** 2))
out = np.reshape(input, [_batch, _new_c, factor ** 2, _height, _width])
out = np.reshape(out, [_batch, _new_c, factor, factor, _height, _width])
out = np.transpose(out, [0, 1, 4, 2, 5, 3])
out = np.reshape(out, [_batch, _new_c, _height * factor, _width * factor])
return out
def _np_pixel_shuffle_grad(input, factor):
_batch, _new_channel, _height_mul_factor, _width_mul_factor = input.shape
_channel = _new_channel * (factor ** 2)
_height = _height_mul_factor // factor
_width = _width_mul_factor // factor
out = np.ones(shape=(_batch, _channel, _height, _width))
return out
def _test_pixel_shuffle_impl(test_case, device, shape, upscale_factor):
x = np.random.randn(*shape)
input = flow.Tensor(
x, dtype=flow.float32, requires_grad=True, device=flow.device(device)
)
m = flow.nn.PixelShuffle(upscale_factor)
m = m.to(device)
of_out = m(input)
np_out = _np_pixel_shuffle(x, upscale_factor)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
of_out = of_out.sum()
of_out.backward()
np_grad = _np_pixel_shuffle_grad(np_out, upscale_factor)
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestPixelShuffleModule(flow.unittest.TestCase):
def test_pixel_shuffle(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_pixel_shuffle_impl,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 144, 5, 5), (11, 144, 1, 1)]
arg_dict["upscale_factor"] = [2, 3, 4]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
arg_dict["shape"] = [(8, 25, 18, 18), (1, 25, 2, 2)]
arg_dict["upscale_factor"] = [5]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment