diff --git a/oneflow/python/nn/modules/flatten.py b/oneflow/python/nn/modules/flatten.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3df133d758470064679cefe517180d2ee70582f
--- /dev/null
+++ b/oneflow/python/nn/modules/flatten.py
@@ -0,0 +1,81 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.framework.tensor import register_tensor_op
+
+
+@oneflow_export("nn.Flatten")
+class Flatten(Module):
+ """Flattens a contiguous range of dims into a tensor. For use with: nn.Sequential.
+
+ Args:
+ start_dim: first dim to flatten (default = 1).
+ end_dim: last dim to flatten (default = -1).
+
+
+ For example:
+
+ .. code-block:: python
+
+ import oneflow as flow
+ input = flow.Tensor(32, 1, 5, 5)
+ m = flow.nn.Flatten()
+ output = m(input)
+ output.size()
+ # out flow.Size([32, 25])
+
+ """
+
+ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
+ super().__init__()
+ self.op_ = (
+ flow.builtin_op("flatten")
+ .Input("in")
+ .Output("out")
+ .Attr("start_dim", start_dim)
+ .Attr("end_dim", end_dim)
+ .Build()
+ )
+
+ def forward(self, input):
+ return self.op_(input)[0]
+
+
+@oneflow_export("tmp.flatten")
+@register_tensor_op("flatten")
+def _flow_flatten(input, start_dim: int = 0, end_dim: int = -1):
+ """Flattens a contiguous range of dims into a tensor.
+
+ Args:
+ start_dim: first dim to flatten (default = 0).
+ end_dim: last dim to flatten (default = -1).
+
+
+ For example:
+
+ .. code-block:: python
+
+ import oneflow as flow
+ input = flow.Tensor(32, 1, 5, 5)
+ output = input.flatten(start_dim=1)
+ # output = flow.tmp.flatten(input, start_dim=1)
+ output.size()
+ # out flow.Size([32, 25])
+
+ """
+ return Flatten(start_dim=start_dim, end_dim=end_dim)(input)
diff --git a/oneflow/python/test/modules/test_flatten.py b/oneflow/python/test/modules/test_flatten.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d03be645048bfe9f68113bdc02b7cead1f25957
--- /dev/null
+++ b/oneflow/python/test/modules/test_flatten.py
@@ -0,0 +1,53 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+import numpy as np
+
+import oneflow as flow
+
+
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestFlattenModule(flow.unittest.TestCase):
+ def test_flatten(test_case):
+ m = flow.nn.Flatten()
+ x = flow.Tensor(32, 2, 5, 5)
+ flow.nn.init.uniform_(x)
+ y = m(x)
+ test_case.assertTrue(y.shape == flow.Size((32, 50)))
+ test_case.assertTrue(np.array_equal(y.numpy().flatten(), x.numpy().flatten()))
+
+ y2 = flow.tmp.flatten(x, start_dim=2)
+ test_case.assertTrue(y2.shape == flow.Size((32, 2, 25)))
+ test_case.assertTrue(np.array_equal(y2.numpy().flatten(), x.numpy().flatten()))
+
+ y3 = x.flatten(start_dim=1)
+ test_case.assertTrue(y3.shape == flow.Size((32, 50)))
+ test_case.assertTrue(np.array_equal(y3.numpy().flatten(), x.numpy().flatten()))
+
+ y4 = x.flatten(start_dim=1, end_dim=2)
+ test_case.assertTrue(y4.shape == flow.Size((32, 10, 5)))
+ test_case.assertTrue(np.array_equal(y4.numpy().flatten(), x.numpy().flatten()))
+
+ y5 = flow.tmp.flatten(x)
+ test_case.assertTrue(y5.shape == flow.Size((1600,)))
+ test_case.assertTrue(np.array_equal(y5.numpy().flatten(), x.numpy().flatten()))
+
+
+if __name__ == "__main__":
+ unittest.main()