diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst
index d447cbf024498adf78b52208ed51bd2056f98afe..95d7dd1cadfec4d6c154ec794721dc143ccbcb84 100644
--- a/docs/source/experimental.rst
+++ b/docs/source/experimental.rst
@@ -211,3 +211,4 @@ Experimental features
 .. autofunction:: oneflow.experimental.topk
 .. autofunction:: oneflow.experimental.Tensor.topk
 .. autofunction:: oneflow.experimental.nn.GroupNorm
+.. autofunction:: oneflow.experimental.nn.ZeroPad2d
diff --git a/oneflow/python/nn/modules/zeropad2d.py b/oneflow/python/nn/modules/zeropad2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..19b862a491ca2ff1bc1a10b7f3328b702b8819c2
--- /dev/null
+++ b/oneflow/python/nn/modules/zeropad2d.py
@@ -0,0 +1,128 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from __future__ import absolute_import
+
+from typing import Union
+
+import oneflow as flow
+from oneflow.python.oneflow_export import oneflow_export, experimental_api
+from oneflow.python.nn.module import Module
+
+
+@oneflow_export("nn.ZeroPad2d")
+@experimental_api
+class ZeroPad2d(Module):
+    r"""The interface is consistent with PyTorch.
+    The documentation is referenced from:
+    https://pytorch.org/docs/stable/generated/torch.nn.ZeroPad2d.html
+
+    Pads the input tensor boundaries with zero. User can set the amount of padding by setting the parameter `paddings`.
+
+    Args:
+        padding (Union[int, tuple]):  the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\mathrm{padding_{left}}`, :math:`\mathrm{padding_{right}}`, :math:`\mathrm{padding_{top}}`, :math:`\mathrm{padding_{bottom}}`)
+
+    Shape:
+        - Input: :math:`(N, C, H_{in}, W_{in})`
+        - Output: :math:`(N, C, H_{out}, W_{out})` where
+
+            :math:`H_{out} = H_{in} + \mathrm{padding_{top}} + \mathrm{padding_{bottom}}`
+
+            :math:`W_{out} = W_{in} + \mathrm{padding_{left}} + \mathrm{padding_{right}}`
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import oneflow.experimental as flow
+        >>> import numpy as np
+        >>> flow.enable_eager_execution()
+        >>> zeropad_layer_int = flow.nn.ZeroPad2d(2)
+        >>> zeropad_layer_tuple = flow.nn.ZeroPad2d((1,2,2,0))
+        >>> input = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))
+        >>> output_int = zeropad_layer_int(input)
+        >>> output_int.shape
+        flow.Size([1, 2, 7, 7])
+        >>> output_int
+        tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  0.,  0.,  1.,  2.,  0.,  0.],
+                  [ 0.,  0.,  3.,  4.,  5.,  0.,  0.],
+                  [ 0.,  0.,  6.,  7.,  8.,  0.,  0.],
+                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.]],
+        <BLANKLINE>
+                 [[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  0.,  9., 10., 11.,  0.,  0.],
+                  [ 0.,  0., 12., 13., 14.,  0.,  0.],
+                  [ 0.,  0., 15., 16., 17.,  0.,  0.],
+                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  0.,  0.,  0.,  0.,  0.,  0.]]]], dtype=oneflow.float32)
+        >>> output_tuple = zeropad_layer_tuple(input)
+        >>> output_tuple
+        tensor([[[[ 0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  0.,  1.,  2.,  0.,  0.],
+                  [ 0.,  3.,  4.,  5.,  0.,  0.],
+                  [ 0.,  6.,  7.,  8.,  0.,  0.]],
+        <BLANKLINE>
+                 [[ 0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  0.,  0.,  0.,  0.,  0.],
+                  [ 0.,  9., 10., 11.,  0.,  0.],
+                  [ 0., 12., 13., 14.,  0.,  0.],
+                  [ 0., 15., 16., 17.,  0.,  0.]]]], dtype=oneflow.float32)
+    """
+
+    def __init__(self, padding: Union[int, tuple]):
+        super().__init__()
+        if isinstance(padding, tuple):
+            assert len(padding) == 4, ValueError("Length of padding must be 4")
+            boundary = [padding[0], padding[1], padding[2], padding[3]]
+        elif isinstance(padding, int):
+            boundary = [padding, padding, padding, padding]
+        else:
+            raise ValueError("padding must be int  or tuple!")
+
+        self.padding = boundary
+        self.value = 0.0000
+
+    def forward(self, x):
+        _, _, h, w = x.shape
+
+        if x.dtype in [flow.float32, flow.float16, flow.float64]:
+            floating_value = float(self.value)
+            integral_value = int(0)
+        else:
+            floating_value = float(0)
+            integral_value = int(self.value)
+        self._op = (
+            flow.builtin_op("constant_pad2d")
+            .Input("x")
+            .Output("y")
+            .Attr("padding", self.padding)
+            .Attr("floating_value", floating_value)
+            .Attr("integral_value", integral_value)
+            .Build()
+        )
+
+        res = self._op(x)[0]
+        return res
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod(raise_on_error=True)
diff --git a/oneflow/python/test/modules/test_zeropad2d.py b/oneflow/python/test/modules/test_zeropad2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6067430a62479a2e02c2efae5dec18f53258323c
--- /dev/null
+++ b/oneflow/python/test/modules/test_zeropad2d.py
@@ -0,0 +1,116 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+
+import oneflow.experimental as flow
+from test_util import (
+    GenArgList,
+    FlattenArray,
+    Array2Numpy,
+    Index2Coordinate,
+)
+
+
+def _np_zero_pad2d_grad(src, dest, padding):
+    c_idx, h_idx, w_idx = 1, 2, 3
+    pad_left = padding[0]
+    pad_right = padding[1]
+    pad_top = padding[2]
+    pad_bottom = padding[3]
+    dx_height, dx_width = dest.shape[h_idx], dest.shape[w_idx]
+    dy_height, dy_width = src.shape[h_idx], src.shape[w_idx]
+
+    numpy_src = np.ones(src.shape, np.int32)
+    numpy_dest = np.zeros(dest.shape, np.int32)
+    array_src = FlattenArray(numpy_src)
+    array_dest = FlattenArray(numpy_dest)
+
+    src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx]
+    dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx]
+    elements_num = src.shape[0] * src_num
+    for iter_n in range(elements_num):
+        coords = Index2Coordinate(iter_n, src.shape)
+        n, c, i, j = coords[0], coords[c_idx], coords[h_idx], coords[w_idx]
+        ip_x = ip_y = 0
+        if (
+            j >= pad_left
+            and j < (dx_width + pad_left)
+            and i >= pad_top
+            and i < (dx_height + pad_top)
+        ):
+            ip_x = j - pad_left
+            ip_y = i - pad_top
+            src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j
+            dest_index = (
+                n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x
+            )
+            array_dest[dest_index] += array_src[src_index]
+    numpy_dest = Array2Numpy(array_dest, dest.shape)
+    return numpy_dest
+
+
+def _test_ZeroPad2d(test_case, shape, padding, value, device):
+    np_input = np.random.random(shape)
+    of_input = flow.Tensor(
+        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
+    )
+
+    if isinstance(padding, int):
+        np_boundary = ((0, 0), (0, 0), (padding, padding), (padding, padding))
+
+    elif isinstance(padding, (tuple, int)) and len(padding) == 4:
+        np_boundary = (
+            (0, 0),
+            (0, 0),
+            (padding[2], padding[3]),
+            (padding[0], padding[1]),
+        )
+    else:
+        raise ValueError("padding must be in  or tuple!")
+
+    layer = flow.nn.ZeroPad2d(padding=padding)
+    of_out = layer(of_input)
+    np_out = np.pad(np_input, np_boundary, mode="constant", constant_values=value)
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+    of_out = of_out.sum()
+    of_out.backward()
+
+    np_out_grad = _np_zero_pad2d_grad(np_out, np_input, layer.padding)
+    test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_out_grad, 1e-5, 1e-5))
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestZeroPad2dModule(flow.unittest.TestCase):
+    def test_ConstantPad2d(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["shape"] = [(1, 2, 3, 4), (8, 3, 4, 4)]
+        arg_dict["padding"] = [(2), (1, 1, 2, 2)]
+        arg_dict["value"] = [0.0]
+        arg_dict["device"] = ["cpu", "cuda"]
+
+        for arg in GenArgList(arg_dict):
+            _test_ZeroPad2d(test_case, *arg)
+
+
+if __name__ == "__main__":
+    unittest.main()