diff --git a/oneflow/python/nn/modules/pixelshuffle.py b/oneflow/python/nn/modules/pixelshuffle.py index bf01b0404807afacd4de4cf529c125e7ed602e3c..9bc0108b106691a84668c4424f3ccd2bd2070e7b 100644 --- a/oneflow/python/nn/modules/pixelshuffle.py +++ b/oneflow/python/nn/modules/pixelshuffle.py @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ +from typing import Optional + from oneflow.python.framework.tensor import Tensor from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.nn.module import Module @@ -20,13 +22,13 @@ from oneflow.python.nn.module import Module @oneflow_export("nn.PixelShuffle") @experimental_api -class PixelShuffle(Module): - r"""The interface is consistent with PyTorch. - The documentation is referenced from: +class PixelShufflev2(Module): + r""" + Part of the documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#torch.nn.PixelShuffle - Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` - to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor. + Rearranges elements in a tensor of shape :math:`(*, C \times r_h \times r_w, H, W)` + to a tensor of shape :math:`(*, C, H \times r_h, W \times r_w)`, where r_h and r_w are upscale factors. This is useful for implementing efficient sub-pixel convolution with a stride of :math:`1/r`. @@ -36,21 +38,33 @@ class PixelShuffle(Module): by Shi et. al (2016) for more details. Args: - upscale_factor (int): factor to increase spatial resolution by + upscale_factor (int, optional): factor to increase spatial resolution by, only use when factors of height and width spatial are the same. + + h_upscale_factor (int, optional): factor to increase height spatial resolution by, only one of h_upscale_factor and upscale_factor can be used. + w_upscale_factor (int, optional): factor to increase width spatial resolution by, only one of w_upscale_factor and upscale_factor can be used. Shape: - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where - .. math:: - C_{out} = C_{in} \div \text{upscale_factor}^2 + if use upscale_factor: .. math:: + C_{out} = C_{in} \div \text{h_upscale_factor}^2 + H_{out} = H_{in} \times \text{upscale_factor} - .. math:: W_{out} = W_{in} \times \text{upscale_factor} + if use h_upscale_factor and w_upscale_factor: + + .. math:: + C_{out} = C_{in} \div \text{h_upscale_factor} \div \text{w_upscale_factor} + + H_{out} = H_{in} \times \text{h_upscale_factor} + + W_{out} = W_{in} \times \text{w_upscale_factor} + For example: .. code-block:: python @@ -65,44 +79,77 @@ class PixelShuffle(Module): >>> print(y.size()) flow.Size([3, 1, 10, 10]) - >>> m = flow.nn.PixelShuffle(upscale_factor=3) - >>> x = flow.Tensor(np.random.randn(1, 18, 2, 2)) + >>> m = flow.nn.PixelShuffle(h_upscale_factor=3, w_upscale_factor=4) + >>> x = flow.Tensor(np.random.randn(1, 24, 2, 2)) >>> y = m(x) >>> print(y.size()) - flow.Size([1, 2, 6, 6]) + flow.Size([1, 2, 6, 8]) .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network: https://arxiv.org/abs/1609.05158 """ - def __init__(self, upscale_factor: int) -> None: + def __init__( + self, + upscale_factor: Optional[int] = None, + h_upscale_factor: Optional[int] = None, + w_upscale_factor: Optional[int] = None, + ) -> None: super().__init__() - assert upscale_factor > 0, "The scale factor must larger than zero" - self.upscale_factor = upscale_factor + + if upscale_factor is None: + assert ( + h_upscale_factor is not None and w_upscale_factor is not None + ), "h_upscale_factor and w_upscale_factor should be None if use upscale_factor" + else: + assert ( + h_upscale_factor is None and w_upscale_factor is None + ), "upscale_factor should be None if use h_upscale_factor and w_upscale_factor" + h_upscale_factor = upscale_factor + w_upscale_factor = upscale_factor + + assert ( + h_upscale_factor > 0 and w_upscale_factor > 0 + ), "The scale factor of height and width must larger than zero" + self.h_upscale_factor = h_upscale_factor + self.w_upscale_factor = w_upscale_factor def forward(self, input: Tensor) -> Tensor: assert len(input.shape) == 4, "Only Accept 4D Tensor" _batch, _channel, _height, _width = input.shape assert ( - _channel % (self.upscale_factor ** 2) == 0 - ), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)" + _channel % (self.h_upscale_factor * self.w_upscale_factor) == 0 + ), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or (h_upscale_factor * w_upscale_factor)" - _new_c = int(_channel / (self.upscale_factor ** 2)) + _new_c = int(_channel / (self.h_upscale_factor * self.w_upscale_factor)) out = input.reshape( - [_batch, _new_c, self.upscale_factor ** 2, _height, _width,] + [ + _batch, + _new_c, + self.h_upscale_factor * self.w_upscale_factor, + _height, + _width, + ] ) out = out.reshape( - [_batch, _new_c, self.upscale_factor, self.upscale_factor, _height, _width,] + [ + _batch, + _new_c, + self.h_upscale_factor, + self.w_upscale_factor, + _height, + _width, + ] ) out = out.permute(0, 1, 4, 2, 5, 3) out = out.reshape( [ _batch, _new_c, - _height * self.upscale_factor, - _width * self.upscale_factor, + _height * self.h_upscale_factor, + _width * self.w_upscale_factor, ] ) diff --git a/oneflow/python/test/modules/test_pixel_shuffle.py b/oneflow/python/test/modules/test_pixel_shuffle.py index f21c10fbc09e1f6c3f2b7fbf1e51ed1cad3952a2..a2e586abe561526c0b05d1e4f80cd5993ebe9743 100644 --- a/oneflow/python/test/modules/test_pixel_shuffle.py +++ b/oneflow/python/test/modules/test_pixel_shuffle.py @@ -22,45 +22,49 @@ import oneflow.experimental as flow from test_util import GenArgList -def _np_pixel_shuffle(input, factor): +def _np_pixel_shuffle(input, h_factor, w_factor): _batch, _channel, _height, _width = input.shape assert ( - _channel % (factor ** 2) == 0 - ), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)" - _new_c = int(_channel / (factor ** 2)) + _channel % (h_factor * w_factor) == 0 + ), "The channels of input tensor must be divisible by (h_upscale_factor * w_upscale_factor)" + _new_c = int(_channel / (h_factor * w_factor)) - out = np.reshape(input, [_batch, _new_c, factor ** 2, _height, _width]) - out = np.reshape(out, [_batch, _new_c, factor, factor, _height, _width]) + out = np.reshape(input, [_batch, _new_c, h_factor * w_factor, _height, _width]) + out = np.reshape(out, [_batch, _new_c, h_factor, w_factor, _height, _width]) out = np.transpose(out, [0, 1, 4, 2, 5, 3]) - out = np.reshape(out, [_batch, _new_c, _height * factor, _width * factor]) + out = np.reshape(out, [_batch, _new_c, _height * h_factor, _width * w_factor]) return out -def _np_pixel_shuffle_grad(input, factor): +def _np_pixel_shuffle_grad(input, h_factor, w_factor): _batch, _new_channel, _height_mul_factor, _width_mul_factor = input.shape - _channel = _new_channel * (factor ** 2) - _height = _height_mul_factor // factor - _width = _width_mul_factor // factor + _channel = _new_channel * (h_factor * w_factor) + _height = _height_mul_factor // h_factor + _width = _width_mul_factor // w_factor out = np.ones(shape=(_batch, _channel, _height, _width)) return out -def _test_pixel_shuffle_impl(test_case, device, shape, upscale_factor): +def _test_pixel_shuffle_impl( + test_case, device, shape, h_upscale_factor, w_upscale_factor +): x = np.random.randn(*shape) input = flow.Tensor( x, dtype=flow.float32, requires_grad=True, device=flow.device(device) ) - m = flow.nn.PixelShuffle(upscale_factor) + m = flow.nn.PixelShuffle( + h_upscale_factor=h_upscale_factor, w_upscale_factor=w_upscale_factor + ) m = m.to(device) of_out = m(input) - np_out = _np_pixel_shuffle(x, upscale_factor) + np_out = _np_pixel_shuffle(x, h_upscale_factor, w_upscale_factor) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) of_out = of_out.sum() of_out.backward() - np_grad = _np_pixel_shuffle_grad(np_out, upscale_factor) + np_grad = _np_pixel_shuffle_grad(np_out, h_upscale_factor, w_upscale_factor) test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5)) @@ -77,12 +81,14 @@ class TestPixelShuffleModule(flow.unittest.TestCase): arg_dict["device"] = ["cpu", "cuda"] arg_dict["shape"] = [(2, 144, 5, 5), (11, 144, 1, 1)] - arg_dict["upscale_factor"] = [2, 3, 4] + arg_dict["h_upscale_factor"] = [2, 3, 4] + arg_dict["w_upscale_factor"] = [2, 3, 4] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) arg_dict["shape"] = [(8, 25, 18, 18), (1, 25, 2, 2)] - arg_dict["upscale_factor"] = [5] + arg_dict["h_upscale_factor"] = [5] + arg_dict["w_upscale_factor"] = [5] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:])