diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst
index 833bf90a72d51a3021fe2512e69b33e73086ccd5..d447cbf024498adf78b52208ed51bd2056f98afe 100644
--- a/docs/source/experimental.rst
+++ b/docs/source/experimental.rst
@@ -162,6 +162,8 @@ Experimental features
 .. autofunction:: oneflow.experimental.nn.MaxPool3d
 .. autofunction:: oneflow.experimental.repeat
 .. autofunction:: oneflow.experimental.Tensor.repeat
+.. autofunction:: oneflow.experimental.tile
+.. autofunction:: oneflow.experimental.Tensor.tile
 .. autofunction:: oneflow.experimental.reshape
 .. autofunction:: oneflow.experimental.Tensor.reshape
 .. autofunction:: oneflow.experimental.squeeze
diff --git a/oneflow/python/nn/modules/tile.py b/oneflow/python/nn/modules/tile.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e583c44272eb3f4498ad6b1a9e613fb6a05d378
--- /dev/null
+++ b/oneflow/python/nn/modules/tile.py
@@ -0,0 +1,95 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from typing import Union
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export, experimental_api
+from oneflow.python.framework.tensor import Tensor, register_tensor_op
+
+
+class Tile(Module):
+    def __init__(self, reps: tuple) -> None:
+        super().__init__()
+        self.reps = reps
+
+    def forward(self, input: Tensor) -> Tensor:
+        reps = self.reps
+        for s in self.reps:
+            assert s > 0
+        input_shape = input.shape
+        diff = len(input_shape) - len(reps)
+        if diff > 0:
+            shape = [1 for _ in range(diff)]
+            shape.extend([i for i in reps])
+            reps = tuple(shape)
+        return input.repeat(reps)
+
+
+@oneflow_export("tile")
+@register_tensor_op("tile")
+@experimental_api
+def tile_op(x, reps):
+    r"""The interface is consistent with PyTorch.
+    The documentation is referenced from:
+    https://pytorch.org/docs/stable/generated/torch.tile.html
+
+    Constructs a tensor by repeating the elements of ``input``.  The ``reps`` argument specifies the number
+    of repetitions in each dimension.
+
+    If ``reps`` specifies fewer dimensions than ``input`` has, then ones are prepended to ``reps`` until
+    all dimensions are specified.  For example, if ``input`` has shape (8, 6, 4, 2) and ``reps`` is (2, 2),
+    then ``reps`` is treated as (1, 1, 2, 2).
+
+    Analogously, if ``input`` has fewer dimensions than ``reps`` specifies, then ``input`` is treated as
+    if it were unsqueezed at dimension zero until it has as many dimensions as ``reps`` specifies.
+    For example, if ``input`` has shape (4, 2) and ``reps`` is (3, 3, 2, 2), then ``input`` is treated as
+    if it had the shape (1, 1, 4, 2).
+
+    .. note::
+        This function is similar to NumPy’s tile function.
+
+    Args:
+        input (oneflow.Tensor): the tensor whose elements to repeat.
+        reps (tuple): the number of repetitions per dimension.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import oneflow.experimental as flow
+        >>> import numpy as np
+        >>> flow.enable_eager_execution()
+
+        >>> x = np.array([1, 2]).astype(np.int32)
+        >>> input = flow.Tensor(x, dtype=flow.int32)
+        >>> out = input.tile(reps=(2,))
+        >>> out
+        tensor([1, 2, 1, 2], dtype=oneflow.int32)
+
+        >>> x = np.random.randn(5, 2, 1)
+        >>> input = flow.Tensor(x)
+        >>> out = input.tile(reps=(3, 4))
+        >>> out.size()
+        flow.Size([5, 6, 4])
+
+    """
+    return Tile(reps=reps)(x)
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod(raise_on_error=True)
diff --git a/oneflow/python/test/modules/test_tile.py b/oneflow/python/test/modules/test_tile.py
new file mode 100644
index 0000000000000000000000000000000000000000..12ff77fbdf23bec2391fd56282f5f0f962d6cc7c
--- /dev/null
+++ b/oneflow/python/test/modules/test_tile.py
@@ -0,0 +1,181 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+
+import oneflow.experimental as flow
+from test_util import GenArgList
+
+
+def np_tile(x, sizes):
+    return np.tile(x, sizes)
+
+
+def np_tile_grad(x, sizes):
+    times = np.array(sizes).prod()
+    return np.ones(shape=x.shape) * times
+
+
+def _test_tile_less_dim_a(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(2, 4, 1, 3), dtype=flow.float32, device=flow.device(device)
+    )
+    sizes = (2,)
+    np_out = np_tile(input.numpy(), sizes)
+    of_out = input.tile(reps=sizes)
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+
+
+def _test_tile_less_dim_b(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(3, 2, 5), dtype=flow.float32, device=flow.device(device)
+    )
+    sizes = (3, 4)
+    np_out = np_tile(input.numpy(), sizes)
+    of_out = input.tile(reps=sizes)
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+
+
+def _test_tile_less_dim_c(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(4, 3, 2, 5, 3), dtype=flow.float32, device=flow.device(device)
+    )
+    sizes = (2, 3, 4, 4)
+    np_out = np_tile(input.numpy(), sizes)
+    of_out = input.tile(reps=sizes)
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+
+
+def _test_tile_same_dim(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(1, 2, 5, 3), dtype=flow.float32, device=flow.device(device)
+    )
+    sizes = (4, 2, 3, 19)
+    of_out = input.tile(reps=sizes)
+    np_out = np_tile(input.numpy(), sizes)
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+
+
+def _test_tile_same_dim_int(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(1, 2, 5, 3), dtype=flow.int32, device=flow.device(device)
+    )
+    size_tensor = flow.Tensor(np.random.randn(4, 2, 3, 19))
+    sizes = size_tensor.size()
+    of_out = input.tile(reps=sizes)
+    np_out = np_tile(input.numpy(), sizes)
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out.astype(np.int32)))
+
+
+def _test_tile_same_dim_int8(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(1, 2, 5, 3), dtype=flow.int8, device=flow.device(device)
+    )
+    size_tensor = flow.Tensor(np.random.randn(4, 2, 3, 19))
+    sizes = size_tensor.size()
+    of_out = input.tile(reps=sizes)
+    np_out = np_tile(input.numpy(), sizes)
+    test_case.assertTrue(np.array_equal(of_out.numpy(), np_out.astype(np.int32)))
+
+
+def _test_tile_less_dim_a_backward(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(2, 4, 1, 3),
+        dtype=flow.float32,
+        device=flow.device(device),
+        requires_grad=True,
+    )
+    sizes = (2,)
+    of_out = input.tile(reps=sizes)
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = np_tile_grad(input.numpy(), sizes)
+    test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad))
+
+
+def _test_tile_less_dim_b_backward(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(3, 2, 5),
+        dtype=flow.float32,
+        device=flow.device(device),
+        requires_grad=True,
+    )
+    sizes = (3, 4)
+    of_out = input.tile(reps=sizes)
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = np_tile_grad(input.numpy(), sizes)
+    test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad))
+
+
+def _test_tile_less_dim_c_backward(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(4, 3, 2, 5, 3),
+        dtype=flow.float32,
+        device=flow.device(device),
+        requires_grad=True,
+    )
+    sizes = (2, 3, 4, 4)
+    of_out = input.tile(reps=sizes)
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = np_tile_grad(input.numpy(), sizes)
+    test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad))
+
+
+def _test_tile_same_dim_backward(test_case, device):
+    input = flow.Tensor(
+        np.random.randn(1, 2, 5, 3),
+        dtype=flow.float32,
+        device=flow.device(device),
+        requires_grad=True,
+    )
+    sizes = (1, 2, 3, 1)
+    of_out = input.tile(reps=sizes)
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = np_tile_grad(input.numpy(), sizes)
+    test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad))
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestTile(flow.unittest.TestCase):
+    def test_tile(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_tile_less_dim_a,
+            _test_tile_less_dim_b,
+            _test_tile_less_dim_c,
+            _test_tile_same_dim,
+            _test_tile_same_dim_int,
+            _test_tile_same_dim_int8,
+            _test_tile_less_dim_a_backward,
+            _test_tile_less_dim_b_backward,
+            _test_tile_less_dim_c_backward,
+            _test_tile_same_dim_backward,
+        ]
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
+
+if __name__ == "__main__":
+    unittest.main()