diff --git a/oneflow/python/nn/modules/activation.py b/oneflow/python/nn/modules/activation.py
index c03cf767e20cc8d184be6e327af0158119d6d254..2331b1a8428b2d92a3d9d6fee4bf5a41280830bb 100644
--- a/oneflow/python/nn/modules/activation.py
+++ b/oneflow/python/nn/modules/activation.py
@@ -18,17 +18,25 @@ import oneflow._oneflow_internal
 from oneflow.python.nn.module import Module
 from oneflow.python.oneflow_export import oneflow_export
 from oneflow.python.framework.tensor import register_tensor_op
+from typing import Optional
 
 
-@oneflow_export("nn.Sigmoid")
-class Sigmoid(Module):
-    def __init__(self):
-        super().__init__()
-        self._op = flow.builtin_op("sigmoid").Input("in").Output("out").Build()
+def _softmax_need_transpose(x, axis):
+    assert type(axis) is int
+    dim_num = len(x.shape)
+    assert dim_num >= 2
+    if axis < 0:
+        axis += dim_num
+    assert axis >= 0
+    assert axis < dim_num
 
-    def forward(self, x):
-        res = self._op(x)[0]
-        return res
+    need_transpose = False
+    permute = list(range(dim_num))
+    if axis != dim_num - 1:
+        need_transpose = True
+        permute[axis] = permute[-1]
+        permute[-1] = axis
+    return need_transpose, permute
 
 
 @oneflow_export("nn.ReLU")
@@ -199,7 +207,7 @@ def gelu_op(x):
 
     Args:
         x (oneflow.Tensor): Input Tensor
-
+ 
     Returns:
         oneflow.Tensor: A Tensor.
 
@@ -216,8 +224,221 @@ def gelu_op(x):
         gelu = flow.nn.GELU()
         
         out = gelu(input)
-
         # out [-0.15426877, 0., 0.34573123]
-
     """
     return GELU()(x)
+
+
+@oneflow_export("nn.Sigmoid")
+class Sigmoid(Module):
+    r"""Applies the element-wise function:
+
+    .. math::
+        \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
+
+    Shape:
+        - Input: :math:`(N, *)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(N, *)`, same shape as the input
+
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+        import numpy as np
+
+        x = flow.Tensor(
+            np.array(
+                [
+                    [0.81733328, 0.43621480, 0.10351428],
+                    [-1.15555191, -0.67776406, 0.27372134],
+                ]
+            )
+        )
+        m = flow.nn.Sigmoid() # or y = flow.sigmoid(x)
+        y = m(x)
+        # [[0.69366997, 0.60735673, 0.52585548],
+        # [0.23947647, 0.33676055, 0.56800622]]
+
+    """
+
+    def __init__(self):
+        super().__init__()
+        self._op = flow.builtin_op("sigmoid").Input("in").Output("out").Build()
+
+    def forward(self, x):
+        return self._op(x)[0]
+
+
+@oneflow_export("sigmoid")
+@register_tensor_op("sigmoid")
+def sigmoid_op(x):
+    r"""Applies the element-wise function:
+
+    .. math::
+        \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
+
+    Shape:
+        - Input: :math:`(N, *)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(N, *)`, same shape as the input
+
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+        import numpy as np
+
+        x = flow.Tensor(
+            np.array(
+                [
+                    [0.81733328, 0.43621480, 0.10351428],
+                    [-1.15555191, -0.67776406, 0.27372134],
+                ]
+            )
+        )
+        y = x.sigmoid()
+        # [[0.69366997, 0.60735673, 0.52585548],
+        # [0.23947647, 0.33676055, 0.56800622]]
+
+    """
+    return Sigmoid()(x)
+
+
+@oneflow_export("nn.Softmax")
+@oneflow_export("softmax")
+class Softmax(Module):
+    def __init__(self, dim: Optional[int] = None):
+        super().__init__()
+        self.axis = -1 if dim is None else dim
+        self._op = flow.builtin_op("softmax").Input("in").Output("out").Build()
+
+    def forward(self, x):
+        need_transpose, permute = _softmax_need_transpose(x, self.axis)
+        if need_transpose:
+            x = x.transpose(perm=permute)
+
+        res = self._op(x)[0]
+        if need_transpose:
+            res = res.transpose(perm=permute)
+        return res
+
+
+@oneflow_export("softmax")
+def softmax_op(tensor, dim=None):
+    r"""Applies the Softmax function to an n-dimensional input Tensor
+    rescaling them so that the elements of the n-dimensional output Tensor
+    lie in the range [0,1] and sum to 1.
+
+    Softmax is defined as:
+
+    .. math::
+        \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
+
+    When the input Tensor is a sparse tensor then the unspecifed
+    values are treated as ``-inf``.
+
+    Shape:
+        - Input: :math:`(*)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(*)`, same shape as the input
+
+    Returns:
+        a Tensor of the same dimension and shape as the input with
+        values in the range [0, 1]
+
+    Args:
+        dim (int): A dimension along which Softmax will be computed (so every slice
+            along dim will sum to 1).
+
+    For example: 
+
+    .. code-block:: python 
+
+        import oneflow as flow
+        import numpy as np
+
+        m = flow.nn.Softmax(dim = 2)
+        x = flow.Tensor(
+            np.array(
+                [[[[-0.46716809,  0.40112534,  0.61984003],
+                [-1.31244969, -0.42528763,  1.47953856]]],
+
+                [[[ 1.02978742, -0.49383053,  1.88214159],
+                [ 1.35351622, -1.46251285, -1.40751374]]]]
+            )
+        )
+        y = m(x)
+        # [[[[0.6995764  0.6955959  0.29740235]
+        # [0.3004236  0.30440408 0.7025977 ]]]
+
+        # [[[0.4197673  0.7248568  0.96407217]
+        # [0.58023274 0.27514324 0.03592779]]]]
+    """
+    return Softmax(dim)(tensor)
+
+
+@oneflow_export("nn.LogSoftmax")
+class LogSoftmax(Module):
+    r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
+    input Tensor.
+    The LogSoftmax formulation can be simplified as:
+
+    .. math::
+        \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
+
+    Args:
+        dim (int): A dimension along which LogSoftmax will be computed.
+
+    Shape:
+        - Input: :math:`(N, *)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(N, *)`, same shape as the input
+
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+        import numpy as np
+
+        m = flow.nn.LogSoftmax(dim=1)
+        x = flow.Tensor(
+            np.array(
+                [[ 0.4296, -1.1957,  2.5463],
+                [ 1.2552, -1.5747,  0.6923]]
+            )
+        )
+        y = m(x)
+        # [[-2.251349   -3.8766491  -0.13464898]
+        # [-0.48770458 -3.3176045  -1.0506046 ]]
+    """
+
+    def __init__(
+        self, dim: Optional[int] = 1,
+    ):
+        super().__init__()
+        self.dim = dim
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+        if not hasattr(self, "dim"):
+            self.dim = None
+
+    def forward(self, x):
+        need_transpose, permute = _softmax_need_transpose(x, self.dim)
+        if need_transpose:
+            x = x.transpose(perm=permute)
+
+        x = flow.softmax(x)
+        res = flow.log(x)
+
+        if need_transpose:
+            res = res.transpose(perm=permute)
+
+        return res
+
+    def extra_repr(self):
+        return "dim={dim}".format(dim=self.dim)
diff --git a/oneflow/python/nn/modules/arange.py b/oneflow/python/nn/modules/arange.py
new file mode 100644
index 0000000000000000000000000000000000000000..36a31acc21750f63482c5a3daa88e1e509881b06
--- /dev/null
+++ b/oneflow/python/nn/modules/arange.py
@@ -0,0 +1,70 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.framework.tensor import register_tensor_op
+
+
+class Arange(Module):
+    def __init__(self, start, end, step=1) -> None:
+        super().__init__()
+        self.start = 0 if start is None else start
+        self.end = 1 if end is None else end
+        self.step = step
+        self.dtype = flow.int64  # "Only support dtype: `flow.int64` for now!"
+        assert self.end > self.start, "end should be larger than start"
+        assert self.step <= self.end - self.start, "step is ilegal"
+        assert type(self.start) == int, "Params `start`'s type should be int"
+        assert type(self.end) == int, "Params `end`'s type should be int"
+        assert type(self.step) == int, "Params `step`'s type should be int"
+        # TODO: zhaoluyang Put dtype attr in forward() after bug fixed
+        self._op_arange = (
+            flow.builtin_op("range").Output("out").Attr("dtype", self.dtype).Build()
+        )
+
+    def forward(self):
+        return self._op_arange(start=self.start, delta=self.step, limit=self.end)[0]
+
+
+@oneflow_export("arange")
+def arange_op(start=1, end=1, step=1):
+    r"""
+    Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1`
+    with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is
+    the gap between two values in the tensor.
+
+    .. math::
+    \text{out}_{i+1} = \text{out}_i + \text{step}.
+
+    Args:
+    start (float): the starting value for the set of points. Default: ``0``.
+    end (float): the ending value for the set of points
+    step (float): the gap between each pair of adjacent points. Default: ``1``.
+
+    Keyword args:
+    dtype: If `dtype` is not given, the `dtype` is inferred to be the default dtype.
+
+    For example: 
+
+    .. code-block:: python 
+
+        import oneflow as flow
+        y = flow.arange(0, 5)
+        # [0, 1, 2, 3, 4]
+
+    """
+    return Arange(start, end, step)()
diff --git a/oneflow/python/nn/modules/eq.py b/oneflow/python/nn/modules/eq.py
new file mode 100644
index 0000000000000000000000000000000000000000..3181a552c47d07e8855b1b754d107b3115aefca5
--- /dev/null
+++ b/oneflow/python/nn/modules/eq.py
@@ -0,0 +1,75 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.framework.tensor import register_tensor_op
+
+
+class Eq(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.eq_op = (
+            flow.builtin_op("broadcast_equal").Input("x").Input("y").Output("z").Build()
+        )
+
+    def forward(self, input, other):
+        if isinstance(other, flow.Tensor):
+            for i in range(len(input.size())):
+                assert (
+                    input.shape[i] >= other.shape[i]
+                ), "The second tensor's shape should broadcastable with the first argument."
+        elif isinstance(other, int) or isinstance(other, float):
+            raise NotImplementedError(
+                "Unsupport data type, int or float data type are not support yet!"
+            )
+        else:
+            raise NotImplementedError(
+                "Unsupport data type, The second argument can be a tensor whose shape is broadcastable with the first argument."
+            )
+
+        return self.eq_op(input, other)[0]
+
+
+@oneflow_export("eq", "equal")
+@register_tensor_op("eq")
+def eq_op(input, other):
+    r"""
+    Computes element-wise equality.
+    The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
+
+    Args:
+    input (Tensor): the tensor to compare
+    other (Tensor): the tensor to compare
+
+    Returns:
+    A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere
+
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+        import numpy as np
+
+        input = flow.Tensor(np.array([2, 3, 4, 5]), dtype=flow.float32)
+        other = flow.Tensor(np.array([2, 3, 4, 1]), dtype=flow.float32)
+
+        y = flow.eq(input, other)
+        # [1 1 1 0]
+
+    """
+    return Eq()(input, other)
diff --git a/oneflow/python/nn/modules/masked_fill.py b/oneflow/python/nn/modules/masked_fill.py
new file mode 100644
index 0000000000000000000000000000000000000000..57cf929bc30a51a1853bbefc73dad2faf0dc0b07
--- /dev/null
+++ b/oneflow/python/nn/modules/masked_fill.py
@@ -0,0 +1,84 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.framework.tensor import register_tensor_op
+
+
+class MaskedFill(Module):
+    def __init__(self, value) -> None:
+        super().__init__()
+        self.value = value
+        self._where_op = (
+            flow.builtin_op("where")
+            .Input("condition")
+            .Input("x")
+            .Input("y")
+            .Output("out")
+            .Build()
+        )
+
+    def forward(self, input, mask):
+        in_shape = tuple(input.shape)
+        value_like_x = flow.Tensor(*in_shape)
+        value_like_x.fill_(self.value)
+        return self._where_op(mask, value_like_x, input)[0]
+
+
+@oneflow_export("tmp.masked_fill")
+@register_tensor_op("masked_fill")
+def masked_fill_op(tensor, mask, value):
+    r"""
+    Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is True.
+    The shape of :attr:`mask` must be broadcastable with the shape of the underlying tensor.
+
+    Args:
+        mask (BoolTensor) 鈥� the boolean mask
+        value (float) 鈥� the value to fill in with
+
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+        import numpy as np
+
+        in_arr = np.array(
+            [[[-0.13169311,  0.97277078,  1.23305363,  1.56752789],
+            [-1.51954275,  1.87629473, -0.53301206,  0.53006478],
+            [-1.38244183, -2.63448052,  1.30845795, -0.67144869]],
+
+            [[ 0.41502161,  0.14452418,  0.38968   , -1.76905653],
+            [ 0.34675095, -0.7050969 , -0.7647731 , -0.73233418],
+            [-1.90089858,  0.01262963,  0.74693893,  0.57132389]]]
+        )
+
+        fill_value = 8.7654321 # random value e.g. -1e9 3.1415
+        input = flow.Tensor(in_arr, dtype=flow.float32)
+        mask = flow.Tensor((in_arr > 0).astype(np.int8), dtype=flow.int)
+
+        output = input.masked_fill(mask, fill_value)
+        #  [[[-0.13169311  8.765432    8.765432    8.765432  ]
+        #   [-1.5195427   8.765432   -0.53301203  8.765432  ]
+        #   [-1.3824419  -2.6344805   8.765432   -0.6714487 ]]
+
+        #  [[ 8.765432    8.765432    8.765432   -1.7690566 ]
+        #   [ 8.765432   -0.7050969  -0.7647731  -0.7323342 ]
+        #   [-1.9008986   8.765432    8.765432    8.765432  ]]]
+
+    """
+    return MaskedFill(value)(tensor, mask)
diff --git a/oneflow/python/nn/modules/unsqueeze.py b/oneflow/python/nn/modules/unsqueeze.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b452a9a2215b68ed9ce2fdf7b933834d917db2e
--- /dev/null
+++ b/oneflow/python/nn/modules/unsqueeze.py
@@ -0,0 +1,65 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.framework.tensor import register_tensor_op
+
+
+class Unsqueeze(Module):
+    def __init__(self, dim: int = 0) -> None:
+        super().__init__()
+        self.dim = dim
+        self._op = flow.builtin_op("expand_dims").Input("in").Output("out").Build()
+
+    def forward(self, input):
+        assert (
+            -(1 + input.ndimension()) <= self.dim <= input.ndimension()
+        ), "dim should within the range [-input.ndimension() - 1, input.ndimension() + 1)"
+
+        if self.dim < 0:
+            self.dim = 1 + input.ndimension() + self.dim
+        return self._op(input, axis=self.dim)[0]
+
+
+@oneflow_export("unsqueeze")
+@register_tensor_op("unsqueeze")
+def unsqueeze_op(input, dim):
+    r"""Returns a new tensor with a dimension of size one inserted at the
+    specified position.
+
+    The returned tensor shares the same underlying data with this tensor.
+
+    A :attr:`dim` value within the range `[-input.ndimension() - 1, input.ndimension() + 1)`
+    can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze`
+    applied at :attr:`dim` = ``dim + input.ndimension() + 1``.
+
+    Args:
+        input (Tensor) 鈥� the input tensor.
+        dim (int): the index at which to insert the singleton dimension
+
+    For example: 
+
+    .. code-block:: python 
+
+        import numpy as np
+        import oneflow as flow
+
+        x = flow.Tensor(np.random.rand(2, 3, 4))
+        y = x.unsqueeze(2)
+
+    """
+    return Unsqueeze(dim)(input)
diff --git a/oneflow/python/test/modules/test_activation.py b/oneflow/python/test/modules/test_activation.py
index 0708bff7a129c864ebe2382f8b9b583a7bfdd3f3..d6c5794bc0b344089121e86eeeaaf44a51392835 100644
--- a/oneflow/python/test/modules/test_activation.py
+++ b/oneflow/python/test/modules/test_activation.py
@@ -119,5 +119,121 @@ class TestGeLU(flow.unittest.TestCase):
         test_case.assertTrue(np.allclose(y.numpy(), z, rtol=1e-4, atol=1e-4))
 
 
+def numpy_sigmoid(x):
+    return 1.0 / (1 + np.exp(-x))
+
+
+def numpy_softmax(x, axis):
+    x = x - x.max(axis=axis, keepdims=True)
+    y = np.exp(x)
+    return y / y.sum(axis=axis, keepdims=True)
+
+
+def numpy_logsoftmax(x, dim):
+    e_x = np.exp(x - np.max(x, axis=dim, keepdims=True))
+    return np.log(e_x / e_x.sum(axis=dim, keepdims=True))
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestSigmoidModule(flow.unittest.TestCase):
+    def test_sigmoid(test_case):
+        m = flow.nn.Sigmoid()
+        input_arr = np.random.randn(2, 3, 4, 5)
+        x = flow.Tensor(input_arr)
+
+        y = m(x)
+        y2 = flow.sigmoid(x)
+        y3 = x.sigmoid()
+        output = numpy_sigmoid(input_arr)
+
+        test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
+        test_case.assertTrue(np.allclose(y2.numpy(), output, rtol=1e-05))
+        test_case.assertTrue(np.allclose(y3.numpy(), output, rtol=1e-05))
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestSoftmaxModule(flow.unittest.TestCase):
+    def test_softmax(test_case):
+        axis = 0
+        m = flow.nn.Softmax(dim=axis)
+        arr = np.random.randn(2, 3, 4, 5)
+        x = flow.Tensor(arr)
+        y = m(x)
+        output = numpy_softmax(arr, axis)
+        test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
+
+    def test_softmax_dim_1(test_case):
+        axis = 1
+        m = flow.nn.Softmax(dim=axis)
+        arr = np.random.randn(9, 7, 8, 16)
+        x = flow.Tensor(arr)
+        y = m(x)
+        output = numpy_softmax(arr, axis)
+        test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
+
+    def test_softmax_dim_2(test_case):
+        axis = 2
+        m = flow.nn.Softmax(dim=axis)
+        arr = np.random.randn(2, 5, 6, 3)
+        x = flow.Tensor(arr)
+        y = m(x)
+        output = numpy_softmax(arr, axis)
+        test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
+
+    def test_softmax_dim_3(test_case):
+        axis = 3
+        m = flow.nn.Softmax(dim=axis)
+        arr = np.random.randn(1, 3, 4, 7)
+        x = flow.Tensor(arr)
+        y = m(x)
+        output = numpy_softmax(arr, axis)
+        test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
+
+        axis2 = -1
+        m2 = flow.nn.Softmax(dim=axis)
+        y2 = m(x)
+        output2 = numpy_softmax(arr, axis)
+        test_case.assertTrue(np.allclose(y2.numpy(), output2, rtol=1e-05))
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestLogSoftmaxModule(flow.unittest.TestCase):
+    def test_logsoftmax(test_case):
+        dim = 1
+        m = flow.nn.LogSoftmax(dim)
+        input_arr = np.random.randn(4, 7)
+        x = flow.Tensor(input_arr)
+        y = m(x)
+        output = numpy_logsoftmax(input_arr, dim)
+        test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
+
+    def test_logsoftmax_dim_2(test_case):
+        dim = 2
+        m = flow.nn.LogSoftmax(dim)
+        input_arr = np.random.randn(3, 4, 5)
+        x = flow.Tensor(input_arr)
+        y = m(x)
+        output = numpy_logsoftmax(input_arr, dim)
+        test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
+
+    def test_logsoftmax_dim_3(test_case):
+        dim = 3
+        m = flow.nn.LogSoftmax(dim)
+        input_arr = np.random.randn(8, 9, 7, 3)
+        x = flow.Tensor(input_arr)
+        y = m(x)
+        output = numpy_logsoftmax(input_arr, dim)
+        test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/oneflow/python/test/modules/test_arange.py b/oneflow/python/test/modules/test_arange.py
new file mode 100644
index 0000000000000000000000000000000000000000..5df044862775fc00f4ad1da238af8bdd01cd0129
--- /dev/null
+++ b/oneflow/python/test/modules/test_arange.py
@@ -0,0 +1,47 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+
+import numpy as np
+import oneflow as flow
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestArange(flow.unittest.TestCase):
+    def test_arange(test_case):
+        np_out = np.arange(5)
+        of_out = flow.arange(0, end=5)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
+
+        np_out2 = np.arange(0, 20, 2)
+        of_out2 = flow.arange(0, 20, step=2)
+        test_case.assertTrue(np.allclose(of_out2.numpy(), np_out2))
+
+    def test_arange_v2(test_case):
+        np_out = np.arange(20)
+        of_out = flow.arange(start=0, end=20)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
+
+        np_out2 = np.arange(0, 100, 3)
+        of_out2 = flow.arange(start=0, end=100, step=3)
+        test_case.assertTrue(np.allclose(of_out2.numpy(), np_out2))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_eq.py b/oneflow/python/test/modules/test_eq.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7f4355cb558d29b24a0aea828609cf9318755b4
--- /dev/null
+++ b/oneflow/python/test/modules/test_eq.py
@@ -0,0 +1,51 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+
+import numpy as np
+import oneflow as flow
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestEq(flow.unittest.TestCase):
+    def test_eq(test_case):
+        arr1 = np.array([2, 3, 4, 5,])
+        arr2 = np.array([2, 3, 4, 1])
+        input = flow.Tensor(arr1, dtype=flow.float32)
+        other = flow.Tensor(arr2, dtype=flow.float32)
+
+        of_out = flow.eq(input, other)
+        of_out2 = flow.equal(input, other)
+        np_out = np.equal(arr1, arr2)
+        test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+        test_case.assertTrue(np.array_equal(of_out2.numpy(), np_out))
+
+    def test_eq_tensor_function(test_case):
+        arr1 = np.random.randint(1, 10, size=(2, 3, 4, 5))
+        arr2 = np.random.randint(1, 10, size=(2, 3, 4, 5))
+        input = flow.Tensor(arr1, dtype=flow.float32)
+        other = flow.Tensor(arr2, dtype=flow.float32)
+
+        of_out = input.eq(other)
+        np_out = np.equal(arr1, arr2)
+        test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_masked_fill.py b/oneflow/python/test/modules/test_masked_fill.py
new file mode 100644
index 0000000000000000000000000000000000000000..56cbe7814e77d8febc8f36f2e8cc4445765375ad
--- /dev/null
+++ b/oneflow/python/test/modules/test_masked_fill.py
@@ -0,0 +1,72 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+
+import numpy as np
+import oneflow as flow
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestMaskedFill(flow.unittest.TestCase):
+    def test_masked_fill(test_case):
+        input_arr = np.array(
+            [
+                [
+                    [-0.13169311, 0.97277078, 1.23305363, 1.56752789],
+                    [-1.51954275, 1.87629473, -0.53301206, 0.53006478],
+                    [-1.38244183, -2.63448052, 1.30845795, -0.67144869],
+                ],
+                [
+                    [0.41502161, 0.14452418, 0.38968, -1.76905653],
+                    [0.34675095, -0.7050969, -0.7647731, -0.73233418],
+                    [-1.90089858, 0.01262963, 0.74693893, 0.57132389],
+                ],
+            ]
+        )
+
+        output = np.array(
+            [
+                [
+                    [-0.1316931, 8.7654321, 8.7654321, 8.7654321],
+                    [-1.5195428, 8.7654321, -0.5330121, 8.7654321],
+                    [-1.3824418, -2.6344805, 8.7654321, -0.6714487],
+                ],
+                [
+                    [8.7654321, 8.7654321, 8.7654321, -1.7690565],
+                    [8.7654321, -0.7050969, -0.7647731, -0.7323342],
+                    [-1.9008986, 8.7654321, 8.7654321, 8.7654321],
+                ],
+            ]
+        )
+
+        fill_value = 8.7654321  # random value e.g. -1e9 3.14
+
+        input = flow.Tensor(input_arr, dtype=flow.float32)
+        mask = flow.Tensor((input_arr > 0).astype(np.int8), dtype=flow.int)
+        of_out = input.masked_fill(mask, value=fill_value)
+        test_case.assertTrue(np.allclose(of_out.numpy(), output))
+
+        input2 = flow.Tensor(input_arr, dtype=flow.float32)
+        mask2 = flow.Tensor((input_arr > 0).astype(np.int8), dtype=flow.int)
+        of_out2 = flow.tmp.masked_fill(input2, mask, value=fill_value)
+        test_case.assertTrue(np.allclose(of_out2.numpy(), output))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_unsqueeze.py b/oneflow/python/test/modules/test_unsqueeze.py
new file mode 100644
index 0000000000000000000000000000000000000000..14f050a6fba6f996a82d75cc58a93f3e33b87094
--- /dev/null
+++ b/oneflow/python/test/modules/test_unsqueeze.py
@@ -0,0 +1,51 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+
+import numpy as np
+import oneflow as flow
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestUnsqueeze(flow.unittest.TestCase):
+    def test_unsqueeze(test_case):
+        np_arr = np.random.rand(2, 6, 9, 3)
+        x = flow.Tensor(np_arr)
+        y = flow.unsqueeze(x, dim=1)
+        output = np.expand_dims(np_arr, axis=1)
+        test_case.assertTrue(np.allclose(output, y.numpy(), rtol=1e-05))
+
+    def test_unsqueeze_tensor_function(test_case):
+        np_arr = np.random.rand(2, 3, 4)
+        x = flow.Tensor(np_arr)
+        y = x.unsqueeze(dim=2)
+        output = np.expand_dims(np_arr, axis=2)
+        test_case.assertTrue(np.allclose(output, y.numpy(), rtol=1e-05))
+
+    def test_unsqueeze_different_dim(test_case):
+        np_arr = np.random.rand(4, 5, 6, 7)
+        x = flow.Tensor(np_arr)
+        for axis in range(-5, 5):
+            y = flow.unsqueeze(x, dim=axis)
+            output = np.expand_dims(np_arr, axis=axis)
+            test_case.assertTrue(np.allclose(output, y.numpy(), rtol=1e-05))
+
+
+if __name__ == "__main__":
+    unittest.main()