Skip to content
Snippets Groups Projects
Unverified Commit 915a9884 authored by YongtaoShi's avatar YongtaoShi Committed by GitHub
Browse files

add pixelshufflev2 module (#5383)


* add pixelshufflev2 module

* update pixel shuffle module to v2

* handle pixelshufflev1

* handle pixelshufflev1

* auto format by CI

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
parent a30a6cee
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional
from oneflow.python.framework.tensor import Tensor
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
......@@ -20,13 +22,13 @@ from oneflow.python.nn.module import Module
@oneflow_export("nn.PixelShuffle")
@experimental_api
class PixelShuffle(Module):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
class PixelShufflev2(Module):
r"""
Part of the documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#torch.nn.PixelShuffle
Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor.
Rearranges elements in a tensor of shape :math:`(*, C \times r_h \times r_w, H, W)`
to a tensor of shape :math:`(*, C, H \times r_h, W \times r_w)`, where r_h and r_w are upscale factors.
This is useful for implementing efficient sub-pixel convolution
with a stride of :math:`1/r`.
......@@ -36,21 +38,33 @@ class PixelShuffle(Module):
by Shi et. al (2016) for more details.
Args:
upscale_factor (int): factor to increase spatial resolution by
upscale_factor (int, optional): factor to increase spatial resolution by, only use when factors of height and width spatial are the same.
h_upscale_factor (int, optional): factor to increase height spatial resolution by, only one of h_upscale_factor and upscale_factor can be used.
w_upscale_factor (int, optional): factor to increase width spatial resolution by, only one of w_upscale_factor and upscale_factor can be used.
Shape:
- Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
- Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
.. math::
C_{out} = C_{in} \div \text{upscale_factor}^2
if use upscale_factor:
.. math::
C_{out} = C_{in} \div \text{h_upscale_factor}^2
H_{out} = H_{in} \times \text{upscale_factor}
.. math::
W_{out} = W_{in} \times \text{upscale_factor}
if use h_upscale_factor and w_upscale_factor:
.. math::
C_{out} = C_{in} \div \text{h_upscale_factor} \div \text{w_upscale_factor}
H_{out} = H_{in} \times \text{h_upscale_factor}
W_{out} = W_{in} \times \text{w_upscale_factor}
For example:
.. code-block:: python
......@@ -65,44 +79,77 @@ class PixelShuffle(Module):
>>> print(y.size())
flow.Size([3, 1, 10, 10])
>>> m = flow.nn.PixelShuffle(upscale_factor=3)
>>> x = flow.Tensor(np.random.randn(1, 18, 2, 2))
>>> m = flow.nn.PixelShuffle(h_upscale_factor=3, w_upscale_factor=4)
>>> x = flow.Tensor(np.random.randn(1, 24, 2, 2))
>>> y = m(x)
>>> print(y.size())
flow.Size([1, 2, 6, 6])
flow.Size([1, 2, 6, 8])
.. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
https://arxiv.org/abs/1609.05158
"""
def __init__(self, upscale_factor: int) -> None:
def __init__(
self,
upscale_factor: Optional[int] = None,
h_upscale_factor: Optional[int] = None,
w_upscale_factor: Optional[int] = None,
) -> None:
super().__init__()
assert upscale_factor > 0, "The scale factor must larger than zero"
self.upscale_factor = upscale_factor
if upscale_factor is None:
assert (
h_upscale_factor is not None and w_upscale_factor is not None
), "h_upscale_factor and w_upscale_factor should be None if use upscale_factor"
else:
assert (
h_upscale_factor is None and w_upscale_factor is None
), "upscale_factor should be None if use h_upscale_factor and w_upscale_factor"
h_upscale_factor = upscale_factor
w_upscale_factor = upscale_factor
assert (
h_upscale_factor > 0 and w_upscale_factor > 0
), "The scale factor of height and width must larger than zero"
self.h_upscale_factor = h_upscale_factor
self.w_upscale_factor = w_upscale_factor
def forward(self, input: Tensor) -> Tensor:
assert len(input.shape) == 4, "Only Accept 4D Tensor"
_batch, _channel, _height, _width = input.shape
assert (
_channel % (self.upscale_factor ** 2) == 0
), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)"
_channel % (self.h_upscale_factor * self.w_upscale_factor) == 0
), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or (h_upscale_factor * w_upscale_factor)"
_new_c = int(_channel / (self.upscale_factor ** 2))
_new_c = int(_channel / (self.h_upscale_factor * self.w_upscale_factor))
out = input.reshape(
[_batch, _new_c, self.upscale_factor ** 2, _height, _width,]
[
_batch,
_new_c,
self.h_upscale_factor * self.w_upscale_factor,
_height,
_width,
]
)
out = out.reshape(
[_batch, _new_c, self.upscale_factor, self.upscale_factor, _height, _width,]
[
_batch,
_new_c,
self.h_upscale_factor,
self.w_upscale_factor,
_height,
_width,
]
)
out = out.permute(0, 1, 4, 2, 5, 3)
out = out.reshape(
[
_batch,
_new_c,
_height * self.upscale_factor,
_width * self.upscale_factor,
_height * self.h_upscale_factor,
_width * self.w_upscale_factor,
]
)
......
......@@ -22,45 +22,49 @@ import oneflow.experimental as flow
from test_util import GenArgList
def _np_pixel_shuffle(input, factor):
def _np_pixel_shuffle(input, h_factor, w_factor):
_batch, _channel, _height, _width = input.shape
assert (
_channel % (factor ** 2) == 0
), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)"
_new_c = int(_channel / (factor ** 2))
_channel % (h_factor * w_factor) == 0
), "The channels of input tensor must be divisible by (h_upscale_factor * w_upscale_factor)"
_new_c = int(_channel / (h_factor * w_factor))
out = np.reshape(input, [_batch, _new_c, factor ** 2, _height, _width])
out = np.reshape(out, [_batch, _new_c, factor, factor, _height, _width])
out = np.reshape(input, [_batch, _new_c, h_factor * w_factor, _height, _width])
out = np.reshape(out, [_batch, _new_c, h_factor, w_factor, _height, _width])
out = np.transpose(out, [0, 1, 4, 2, 5, 3])
out = np.reshape(out, [_batch, _new_c, _height * factor, _width * factor])
out = np.reshape(out, [_batch, _new_c, _height * h_factor, _width * w_factor])
return out
def _np_pixel_shuffle_grad(input, factor):
def _np_pixel_shuffle_grad(input, h_factor, w_factor):
_batch, _new_channel, _height_mul_factor, _width_mul_factor = input.shape
_channel = _new_channel * (factor ** 2)
_height = _height_mul_factor // factor
_width = _width_mul_factor // factor
_channel = _new_channel * (h_factor * w_factor)
_height = _height_mul_factor // h_factor
_width = _width_mul_factor // w_factor
out = np.ones(shape=(_batch, _channel, _height, _width))
return out
def _test_pixel_shuffle_impl(test_case, device, shape, upscale_factor):
def _test_pixel_shuffle_impl(
test_case, device, shape, h_upscale_factor, w_upscale_factor
):
x = np.random.randn(*shape)
input = flow.Tensor(
x, dtype=flow.float32, requires_grad=True, device=flow.device(device)
)
m = flow.nn.PixelShuffle(upscale_factor)
m = flow.nn.PixelShuffle(
h_upscale_factor=h_upscale_factor, w_upscale_factor=w_upscale_factor
)
m = m.to(device)
of_out = m(input)
np_out = _np_pixel_shuffle(x, upscale_factor)
np_out = _np_pixel_shuffle(x, h_upscale_factor, w_upscale_factor)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
of_out = of_out.sum()
of_out.backward()
np_grad = _np_pixel_shuffle_grad(np_out, upscale_factor)
np_grad = _np_pixel_shuffle_grad(np_out, h_upscale_factor, w_upscale_factor)
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
......@@ -77,12 +81,14 @@ class TestPixelShuffleModule(flow.unittest.TestCase):
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 144, 5, 5), (11, 144, 1, 1)]
arg_dict["upscale_factor"] = [2, 3, 4]
arg_dict["h_upscale_factor"] = [2, 3, 4]
arg_dict["w_upscale_factor"] = [2, 3, 4]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
arg_dict["shape"] = [(8, 25, 18, 18), (1, 25, 2, 2)]
arg_dict["upscale_factor"] = [5]
arg_dict["h_upscale_factor"] = [5]
arg_dict["w_upscale_factor"] = [5]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment