diff --git a/oneflow/python/nn/modules/pixelshuffle.py b/oneflow/python/nn/modules/pixelshuffle.py
index bf01b0404807afacd4de4cf529c125e7ed602e3c..9bc0108b106691a84668c4424f3ccd2bd2070e7b 100644
--- a/oneflow/python/nn/modules/pixelshuffle.py
+++ b/oneflow/python/nn/modules/pixelshuffle.py
@@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 """
+from typing import Optional
+
 from oneflow.python.framework.tensor import Tensor
 from oneflow.python.oneflow_export import oneflow_export, experimental_api
 from oneflow.python.nn.module import Module
@@ -20,13 +22,13 @@ from oneflow.python.nn.module import Module
 
 @oneflow_export("nn.PixelShuffle")
 @experimental_api
-class PixelShuffle(Module):
-    r"""The interface is consistent with PyTorch.
-    The documentation is referenced from:
+class PixelShufflev2(Module):
+    r"""
+    Part of the documentation is referenced from:
     https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#torch.nn.PixelShuffle
 
-    Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
-    to a tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is an upscale factor.
+    Rearranges elements in a tensor of shape :math:`(*, C \times r_h \times r_w, H, W)`
+    to a tensor of shape :math:`(*, C, H \times r_h, W \times r_w)`, where r_h and r_w are upscale factors.
 
     This is useful for implementing efficient sub-pixel convolution
     with a stride of :math:`1/r`.
@@ -36,21 +38,33 @@ class PixelShuffle(Module):
     by Shi et. al (2016) for more details.
 
     Args:
-        upscale_factor (int): factor to increase spatial resolution by
+        upscale_factor (int, optional): factor to increase spatial resolution by, only use when factors of height and width spatial are the same.
+
+        h_upscale_factor (int, optional): factor to increase height spatial resolution by, only one of h_upscale_factor and upscale_factor can be used.
+        w_upscale_factor (int, optional): factor to increase width spatial resolution by, only one of w_upscale_factor and upscale_factor can be used.
 
     Shape:
         - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions
         - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where
 
-    .. math::
-        C_{out} = C_{in} \div \text{upscale_factor}^2
+    if use upscale_factor:
 
     .. math::
+        C_{out} = C_{in} \div \text{h_upscale_factor}^2
+
         H_{out} = H_{in} \times \text{upscale_factor}
 
-    .. math::
         W_{out} = W_{in} \times \text{upscale_factor}
 
+    if use h_upscale_factor and w_upscale_factor:
+
+    .. math::
+        C_{out} = C_{in} \div \text{h_upscale_factor} \div \text{w_upscale_factor}
+
+        H_{out} = H_{in} \times \text{h_upscale_factor}
+
+        W_{out} = W_{in} \times \text{w_upscale_factor}
+
     For example:
 
     .. code-block:: python
@@ -65,44 +79,77 @@ class PixelShuffle(Module):
         >>> print(y.size())
         flow.Size([3, 1, 10, 10])
 
-        >>> m = flow.nn.PixelShuffle(upscale_factor=3)
-        >>> x = flow.Tensor(np.random.randn(1, 18, 2, 2))
+        >>> m = flow.nn.PixelShuffle(h_upscale_factor=3, w_upscale_factor=4)
+        >>> x = flow.Tensor(np.random.randn(1, 24, 2, 2))
         >>> y = m(x)
         >>> print(y.size())
-        flow.Size([1, 2, 6, 6])
+        flow.Size([1, 2, 6, 8])
 
     .. _Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network:
         https://arxiv.org/abs/1609.05158
     """
 
-    def __init__(self, upscale_factor: int) -> None:
+    def __init__(
+        self,
+        upscale_factor: Optional[int] = None,
+        h_upscale_factor: Optional[int] = None,
+        w_upscale_factor: Optional[int] = None,
+    ) -> None:
         super().__init__()
-        assert upscale_factor > 0, "The scale factor must larger than zero"
-        self.upscale_factor = upscale_factor
+
+        if upscale_factor is None:
+            assert (
+                h_upscale_factor is not None and w_upscale_factor is not None
+            ), "h_upscale_factor and w_upscale_factor should be None if use upscale_factor"
+        else:
+            assert (
+                h_upscale_factor is None and w_upscale_factor is None
+            ), "upscale_factor should be None if use h_upscale_factor and w_upscale_factor"
+            h_upscale_factor = upscale_factor
+            w_upscale_factor = upscale_factor
+
+        assert (
+            h_upscale_factor > 0 and w_upscale_factor > 0
+        ), "The scale factor of height and width must larger than zero"
+        self.h_upscale_factor = h_upscale_factor
+        self.w_upscale_factor = w_upscale_factor
 
     def forward(self, input: Tensor) -> Tensor:
         assert len(input.shape) == 4, "Only Accept 4D Tensor"
 
         _batch, _channel, _height, _width = input.shape
         assert (
-            _channel % (self.upscale_factor ** 2) == 0
-        ), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)"
+            _channel % (self.h_upscale_factor * self.w_upscale_factor) == 0
+        ), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor) or (h_upscale_factor * w_upscale_factor)"
 
-        _new_c = int(_channel / (self.upscale_factor ** 2))
+        _new_c = int(_channel / (self.h_upscale_factor * self.w_upscale_factor))
 
         out = input.reshape(
-            [_batch, _new_c, self.upscale_factor ** 2, _height, _width,]
+            [
+                _batch,
+                _new_c,
+                self.h_upscale_factor * self.w_upscale_factor,
+                _height,
+                _width,
+            ]
         )
         out = out.reshape(
-            [_batch, _new_c, self.upscale_factor, self.upscale_factor, _height, _width,]
+            [
+                _batch,
+                _new_c,
+                self.h_upscale_factor,
+                self.w_upscale_factor,
+                _height,
+                _width,
+            ]
         )
         out = out.permute(0, 1, 4, 2, 5, 3)
         out = out.reshape(
             [
                 _batch,
                 _new_c,
-                _height * self.upscale_factor,
-                _width * self.upscale_factor,
+                _height * self.h_upscale_factor,
+                _width * self.w_upscale_factor,
             ]
         )
 
diff --git a/oneflow/python/test/modules/test_pixel_shuffle.py b/oneflow/python/test/modules/test_pixel_shuffle.py
index f21c10fbc09e1f6c3f2b7fbf1e51ed1cad3952a2..a2e586abe561526c0b05d1e4f80cd5993ebe9743 100644
--- a/oneflow/python/test/modules/test_pixel_shuffle.py
+++ b/oneflow/python/test/modules/test_pixel_shuffle.py
@@ -22,45 +22,49 @@ import oneflow.experimental as flow
 from test_util import GenArgList
 
 
-def _np_pixel_shuffle(input, factor):
+def _np_pixel_shuffle(input, h_factor, w_factor):
     _batch, _channel, _height, _width = input.shape
     assert (
-        _channel % (factor ** 2) == 0
-    ), "The channels of input tensor must be divisible by (upscale_factor * upscale_factor)"
-    _new_c = int(_channel / (factor ** 2))
+        _channel % (h_factor * w_factor) == 0
+    ), "The channels of input tensor must be divisible by (h_upscale_factor * w_upscale_factor)"
+    _new_c = int(_channel / (h_factor * w_factor))
 
-    out = np.reshape(input, [_batch, _new_c, factor ** 2, _height, _width])
-    out = np.reshape(out, [_batch, _new_c, factor, factor, _height, _width])
+    out = np.reshape(input, [_batch, _new_c, h_factor * w_factor, _height, _width])
+    out = np.reshape(out, [_batch, _new_c, h_factor, w_factor, _height, _width])
     out = np.transpose(out, [0, 1, 4, 2, 5, 3])
-    out = np.reshape(out, [_batch, _new_c, _height * factor, _width * factor])
+    out = np.reshape(out, [_batch, _new_c, _height * h_factor, _width * w_factor])
     return out
 
 
-def _np_pixel_shuffle_grad(input, factor):
+def _np_pixel_shuffle_grad(input, h_factor, w_factor):
     _batch, _new_channel, _height_mul_factor, _width_mul_factor = input.shape
-    _channel = _new_channel * (factor ** 2)
-    _height = _height_mul_factor // factor
-    _width = _width_mul_factor // factor
+    _channel = _new_channel * (h_factor * w_factor)
+    _height = _height_mul_factor // h_factor
+    _width = _width_mul_factor // w_factor
 
     out = np.ones(shape=(_batch, _channel, _height, _width))
     return out
 
 
-def _test_pixel_shuffle_impl(test_case, device, shape, upscale_factor):
+def _test_pixel_shuffle_impl(
+    test_case, device, shape, h_upscale_factor, w_upscale_factor
+):
     x = np.random.randn(*shape)
     input = flow.Tensor(
         x, dtype=flow.float32, requires_grad=True, device=flow.device(device)
     )
 
-    m = flow.nn.PixelShuffle(upscale_factor)
+    m = flow.nn.PixelShuffle(
+        h_upscale_factor=h_upscale_factor, w_upscale_factor=w_upscale_factor
+    )
     m = m.to(device)
     of_out = m(input)
-    np_out = _np_pixel_shuffle(x, upscale_factor)
+    np_out = _np_pixel_shuffle(x, h_upscale_factor, w_upscale_factor)
     test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
 
     of_out = of_out.sum()
     of_out.backward()
-    np_grad = _np_pixel_shuffle_grad(np_out, upscale_factor)
+    np_grad = _np_pixel_shuffle_grad(np_out, h_upscale_factor, w_upscale_factor)
     test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
 
 
@@ -77,12 +81,14 @@ class TestPixelShuffleModule(flow.unittest.TestCase):
         arg_dict["device"] = ["cpu", "cuda"]
 
         arg_dict["shape"] = [(2, 144, 5, 5), (11, 144, 1, 1)]
-        arg_dict["upscale_factor"] = [2, 3, 4]
+        arg_dict["h_upscale_factor"] = [2, 3, 4]
+        arg_dict["w_upscale_factor"] = [2, 3, 4]
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
         arg_dict["shape"] = [(8, 25, 18, 18), (1, 25, 2, 2)]
-        arg_dict["upscale_factor"] = [5]
+        arg_dict["h_upscale_factor"] = [5]
+        arg_dict["w_upscale_factor"] = [5]
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])