diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst
index 7159922ea16f9b777746dc69c37b7caf70bca1e8..ea1ebf7f616872b23926e9635017f192ce593b2a 100644
--- a/docs/source/experimental.rst
+++ b/docs/source/experimental.rst
@@ -197,5 +197,6 @@ Experimental features
 .. autofunction:: oneflow.experimental.Tensor.ceil
 .. autofunction:: oneflow.experimental.expm1
 .. autofunction:: oneflow.experimental.Tensor.expm1
+.. autofunction:: oneflow.experimental.meshgrid
 .. autofunction:: oneflow.experimental.topk
 .. autofunction:: oneflow.experimental.Tensor.topk
diff --git a/oneflow/python/nn/modules/meshgrid.py b/oneflow/python/nn/modules/meshgrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..437a4ce768a506d85e06cc71d909000a38273bb6
--- /dev/null
+++ b/oneflow/python/nn/modules/meshgrid.py
@@ -0,0 +1,97 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export, experimental_api
+
+
+class MeshGrid(Module):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, inputs):
+        size = len(inputs)
+        assert size > 0, f"meshgrid expects a non-empty TensorList"
+        shape = list()
+        for i in range(size):
+            assert inputs[i].dim() <= 1, f(
+                "Expected scalar or 1D tensor in the tensor list but got: ", inputs[i]
+            )
+            if inputs[i].dim() == 0:
+                shape.append(1)
+            else:
+                shape.append(inputs[i].shape[0])
+        for i in range(size - 1):
+            assert (
+                inputs[i].dtype == inputs[i + 1].dtype
+                and inputs[i].device == inputs[i + 1].device
+            ), f"meshgrid expects all tensors to have the same dtype and device"
+        outputs = []
+        for i in range(size):
+            view_shape = [1] * size
+            view_shape[i] = -1
+            # TODO(BBuf) change reshape to view
+            outputs.append(inputs[i].reshape(view_shape).expand(*shape))
+        return outputs
+
+
+@oneflow_export("meshgrid")
+@experimental_api
+def meshgrid_op(*inputs):
+    r"""The interface is consistent with PyTorch.
+    The documentation is referenced from:
+    https://pytorch.org/docs/stable/_modules/torch/functional.html#meshgrid
+    
+    Take :math:`N` tensors, each of which can be either scalar or 1-dimensional
+    vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by
+    expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.
+
+    Args:
+        tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
+            treated as tensors of size :math:`(1,)` automatically
+
+    Returns:
+        seq (sequence of Tensors): If the input has :math:`k` tensors of size
+        :math:`(N_1,), (N_2,), \ldots , (N_k,)`, then the output would also have :math:`k` tensors,
+        where all tensors are of size :math:`(N_1, N_2, \ldots , N_k)`.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import numpy as np
+        >>> import oneflow.experimental as flow
+        >>> flow.enable_eager_execution()
+
+        >>> input1 = flow.Tensor(np.array([1, 2, 3]), dtype=flow.float32)
+        >>> input2 = flow.Tensor(np.array([4, 5, 6]), dtype=flow.float32)
+        >>> of_x, of_y = flow.meshgrid(input1, input2)
+        >>> of_x
+        tensor([[1., 1., 1.],
+                [2., 2., 2.],
+                [3., 3., 3.]], dtype=oneflow.float32)
+        >>> of_y
+        tensor([[4., 5., 6.],
+                [4., 5., 6.],
+                [4., 5., 6.]], dtype=oneflow.float32)
+    """
+    return MeshGrid()(inputs)
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod(raise_on_error=True)
diff --git a/oneflow/python/test/modules/test_meshgrid.py b/oneflow/python/test/modules/test_meshgrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3673725e189971ee48b90a0884261c5c77f719e
--- /dev/null
+++ b/oneflow/python/test/modules/test_meshgrid.py
@@ -0,0 +1,87 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+
+import oneflow.experimental as flow
+from test_util import GenArgList
+
+
+def _test_meshgrid_forawd(test_case, device):
+    input1 = flow.Tensor(
+        np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device),
+    )
+    input2 = flow.Tensor(
+        np.array([4, 5, 6]), dtype=flow.float32, device=flow.device(device),
+    )
+    np_x, np_y = np.meshgrid(input1.numpy(), input2.numpy(), indexing="ij")
+    of_x, of_y = flow.meshgrid(input1, input2)
+
+    test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 1e-4, 1e-4))
+    test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 1e-4, 1e-4))
+
+
+def _test_meshgrid_forawd_scalr(test_case, device):
+    input1 = flow.Tensor(np.array(1.0), dtype=flow.float32, device=flow.device(device),)
+    input2 = flow.Tensor(np.array(2.0), dtype=flow.float32, device=flow.device(device),)
+    np_x, np_y = np.meshgrid(input1.numpy(), input2.numpy(), indexing="ij")
+    of_x, of_y = flow.meshgrid(input1, input2)
+
+    test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 1e-4, 1e-4))
+    test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 1e-4, 1e-4))
+
+
+def _test_meshgrid_forawd_3tensor(test_case, device):
+    input1 = flow.Tensor(
+        np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device),
+    )
+    input2 = flow.Tensor(
+        np.array([4, 5, 6]), dtype=flow.float32, device=flow.device(device),
+    )
+    input3 = flow.Tensor(
+        np.array([7, 8, 9]), dtype=flow.float32, device=flow.device(device),
+    )
+    np_x, np_y, np_z = np.meshgrid(
+        input1.numpy(), input2.numpy(), input3.numpy(), indexing="ij"
+    )
+    of_x, of_y, of_z = flow.meshgrid(input1, input2, input3)
+
+    test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 1e-4, 1e-4))
+    test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 1e-4, 1e-4))
+    test_case.assertTrue(np.allclose(of_z.numpy(), np_z, 1e-4, 1e-4))
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestGreater(flow.unittest.TestCase):
+    def test_greter(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_meshgrid_forawd,
+            _test_meshgrid_forawd_scalr,
+            _test_meshgrid_forawd_3tensor,
+        ]
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
+
+if __name__ == "__main__":
+    unittest.main()