diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py
index c6ee560096b6b139bc70a75237e1c9bdd95de73c..ce02de4b7b147428bdd11c7db04fd9f60c854278 100644
--- a/oneflow/python/framework/tensor.py
+++ b/oneflow/python/framework/tensor.py
@@ -301,6 +301,33 @@ class Tensor:
     def __deepcopy__(self, memo):
         TODO()
 
+    def __mul__(self, other):
+        return self.mul(other)
+
+    def __rmul__(self, other):
+        return self.mul(other)
+
+    def __add__(self, other):
+        return self.add(other)
+
+    def __radd__(self, other):
+        return self.add(other)
+
+    def __sub__(self, other):
+        return self.sub(other)
+
+    def __rsub__(self, other):
+        return flow.sub(other, self)
+
+    def __truediv__(self, other):
+        return self.div(other)
+
+    def __rtruediv__(self, other):
+        return flow.div(other, self)
+
+    def __neg__(self):
+        return flow.mul(-1, self)
+
     def _determine_if_needed(self, determining_initializer=None):
         if not self.is_determined:
             self.determine(determining_initializer)
diff --git a/oneflow/python/nn/modules/loss.py b/oneflow/python/nn/modules/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..e471f0068a7009cf3824718c7ab05da0cd919dcc
--- /dev/null
+++ b/oneflow/python/nn/modules/loss.py
@@ -0,0 +1,88 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from typing import Optional
+
+import oneflow as flow
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.nn.module import Module
+
+
+@oneflow_export("nn.CrossEntropyLoss")
+class CrossEntropyLoss(Module):
+    r"""
+    Args:
+        reduction (string, optional): Specifies the reduction to apply to the output:
+            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
+            be applied, ``'mean'``: the weighted mean of the output is taken,
+            ``'sum'``: the output will be summed. Note: :attr:`size_average`
+            and :attr:`reduce` are in the process of being deprecated, and in
+            the meantime, specifying either of those two args will override
+            :attr:`reduction`. Default: ``'mean'``
+
+    For example:
+    .. code-block:: python
+
+        import oneflow as flow
+        input = flow.Tensor(
+            [[-0.1664078, -1.7256707, -0.14690138],
+                [-0.21474946, 0.53737473, 0.99684894],
+                [-1.135804, -0.50371903, 0.7645404]], dtype=flow.float32)
+        target = flow.Tensor(np.array([0, 1, 2]), dtype=flow.int32)
+        out = flow.nn.CrossEntropyLoss(reduction="none")(input, target)
+        # out: [0.80199665 1.1166505  0.35826027]
+        out_sum = flow.nn.CrossEntropyLoss(reduction="sum")(input, target)
+        # out_sum: [2.2769074]
+        out_mean = flow.nn.CrossEntropyLoss(reduction="mean")(input, target)
+        # out_mean: [0.7589692]
+
+    """
+
+    def __init__(
+        self,
+        weight=None,
+        ignore_index: Optional[int] = None,
+        reduction: Optional[str] = "mean",
+    ) -> None:
+        super().__init__()
+        if weight is not None:
+            raise ValueError("Argument weight is not supported yet")
+        if ignore_index is not None:
+            raise ValueError("Argument ignore_index is not supported yet")
+        assert reduction in [
+            "sum",
+            "none",
+            "mean",
+            None,
+        ], "only 'sum', 'mean' and None supported by now"
+
+        self.reduction = reduction
+        self._op = (
+            flow.builtin_op("sparse_softmax_cross_entropy")
+            .Input("prediction")
+            .Input("label")
+            .Output("prob")
+            .Output("out")
+            .Build()
+        )
+
+    def forward(self, input, target):
+        prob, out = self._op(input, target, depth=input.shape[len(input.shape) - 1])
+        if self.reduction == "mean":
+            return flow.mean(out)
+        elif self.reduction == "sum":
+            return flow.sum(out)
+        else:
+            return out
diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py
index 7384ea6900b02c831ee6b84cbf162b12d0458051..34cbdee759a445c511dc97a3771b2d641bf057ac 100644
--- a/oneflow/python/nn/modules/math_ops.py
+++ b/oneflow/python/nn/modules/math_ops.py
@@ -13,12 +13,534 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 """
+
+import collections
+from typing import Optional, Sequence, Union
+
 import oneflow as flow
-from oneflow.python.nn.module import Module
 from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.nn.module import Module
 from oneflow.python.framework.tensor import register_tensor_op
 
 
+def _check_axis(axis, shape):
+    # TODO(yaochi): refine this function when all related ops in `python/ops/math_ops.py` migrated
+    if axis is None:
+        axis = list(range(len(shape)))
+
+    if isinstance(axis, int):
+        axis = [axis]
+
+    assert isinstance(axis, (list, tuple)), "Invalid axis {}".format(axis)
+    for x in axis:
+        if x < 0:
+            x += len(shape)
+        assert x >= 0 and x < len(shape), "Invalid axis {}, len(shape): {}".format(
+            axis, len(shape)
+        )
+
+    return axis
+
+
+class Sum(Module):
+    def __init__(
+        self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False
+    ) -> None:
+        super().__init__()
+
+        self.axis = axis
+        self.keepdims = keepdims
+        self._op = (
+            flow.builtin_op("reduce_sum")
+            .Input("input_tensor")
+            .Output("output_tensor")
+            .Attr("keepdims", keepdims)
+            .Build()
+        )
+
+    def forward(self, input):
+        axis_checked = _check_axis(self.axis, input.shape)
+        if len(axis_checked) == 0:
+            return input
+        return self._op(input, axis=axis_checked)[0]
+
+
+@oneflow_export("sum")
+@register_tensor_op("sum")
+def _sum(input, dim=None, keepdims=False):
+    r"""Computes the sum of row of elements in a tensor in the given axis, if the axis is None, sum of all elements will be caculated.
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+        input = flow.Tensor(np.random.randn(4, 5, 6), dtype=flow.float32)
+        of_out = flow.sum(input, dim=(2,1))
+
+    """
+
+    return Sum(dim, keepdims)(input)
+
+
+class ScalarMul(Module):
+    def __init__(self, operand) -> None:
+        super().__init__()
+        self._op = flow.builtin_op("scalar_mul").Input("in").Output("out")
+        if isinstance(operand, int):
+            self._op = (
+                self._op.Attr("has_int_operand", True)
+                .Attr("has_float_operand", False)
+                .Attr("int_operand", operand)
+                .Attr("float_operand", 0.0)
+                .Build()
+            )
+        elif isinstance(operand, float):
+            self._op = (
+                self._op.Attr("has_int_operand", False)
+                .Attr("has_float_operand", True)
+                .Attr("int_operand", 0)
+                .Attr("float_operand", operand)
+                .Build()
+            )
+        else:
+            raise ValueError("operand type can only be int or float")
+
+    def forward(self, x):
+        return self._op(x)[0]
+
+
+class ScalarMulByTensor(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = (
+            flow.builtin_op("scalar_mul_by_tensor")
+            .Input("x")
+            .Input("scalar")
+            .Output("y")
+            .Build()
+        )
+
+    def forward(self, x, y):
+        return self._op(x, y)[0]
+
+
+class ElementwiseMul(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = (
+            flow.builtin_op("multiply").Input("x").Input("y").Output("out").Build()
+        )
+
+    def forward(self, x, y):
+        return self._op(x, y)[0]
+
+
+class BroadcastMul(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = (
+            flow.builtin_op("broadcast_mul").Input("x").Input("y").Output("z").Build()
+        )
+
+    def forward(self, x, y):
+        return self._op(x, y)[0]
+
+
+@oneflow_export("mul")
+@register_tensor_op("mul")
+def _mul(x, y):
+    r"""Computes the multiplication of x by y for each element, scalar and broadcast promotation are supported.
+    The formula is:
+    .. math::
+        out = x \times y
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+
+        # element-wise multiply
+        x = flow.Tensor(np.random.randn(2,3))
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.mul(x,y).numpy()
+        print(out.shape) # (2,3)
+
+        # scalar mutiply
+        x = 5
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.mul(x,y).numpy()
+        print(out.shape) # (2,3)
+
+        # broadcast mutiply
+        x = flow.Tensor(np.random.randn(1,1))
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.mul(x,y).numpy()
+        print(out.shape) # (2,3)
+
+    """
+
+    if isinstance(x, (int, float)):
+        return ScalarMul(x)(y)
+    elif isinstance(y, (int, float)):
+        return ScalarMul(y)(x)
+    elif x.shape == y.shape:
+        return ElementwiseMul()(x, y)
+    elif x.shape == (1,):
+        return ScalarMulByTensor()(y, x)
+    elif y.shape == (1,):
+        return ScalarMulByTensor()(x, y)
+    else:
+        return BroadcastMul()(x, y)
+
+
+class Mean(Module):
+    def __init__(
+        self,
+        axis: Optional[Union[collections.Sized, int]] = None,
+        keepdims: bool = False,
+    ) -> None:
+        super().__init__()
+        self.keepdims = keepdims
+        self.axis = axis
+        # TODO: add if input.is_dynamic branch like flow.math.reduce_mean
+        if axis is None:
+            self.axes = []
+        else:
+            self.axes = list(axis) if isinstance(axis, collections.Sized) else [axis]
+
+    def forward(self, input_tensor):
+        reduce_sum = flow.sum(input_tensor, dim=self.axis, keepdims=self.keepdims)
+        reduce_count = 1
+        if len(self.axes) == 0:
+            for dim in input_tensor.shape:
+                reduce_count *= dim
+        else:
+            for i in self.axes:
+                reduce_count *= input_tensor.shape[i]
+        return flow.mul(reduce_sum, 1.0 / reduce_count)
+
+
+@oneflow_export("mean")
+@register_tensor_op("mean")
+def _mean(input_tensor, dim=None, keepdim=False):
+    r"""Computes the mean of row of elements in a tensor in the given axis,
+    if the axis is None, mean of all elements will be caculated.
+
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+
+        input = flow.Tensor([[1, 2, 3], [4, 5, 6]])
+        out = flow.mean(input)
+        # out: [3.5]
+        print(out.numpy())
+
+        input = flow.Tensor([[1, 2, 3], [4, 5, 6]])
+        out = flow.mean(input, axis=0)
+        # out: [2.5 3.5 4.5]
+        print(out.numpy())
+
+        input = flow.Tensor([[1, 2, 3], [4, 5, 6]])
+        out = flow.mean(input, axis=1)
+        # out: [ 2. 5.]
+        print(out.numpy())
+
+    """
+
+    return Mean(axis=dim, keepdims=keepdim)(input_tensor)
+
+
+class ScalarSubByTensor(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = (
+            flow.builtin_op("scalar_sub_by_tensor")
+            .Input("x")
+            .Input("scalar")
+            .Output("y")
+            .Build()
+        )
+
+    def forward(self, x, y):
+        return self._op(x, y)[0]
+
+
+class BroadcastSub(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = (
+            flow.builtin_op("broadcast_sub").Input("x").Input("y").Output("z").Build()
+        )
+
+    def forward(self, x, y):
+        return self._op(x, y)[0]
+
+
+class ScalarAdd(Module):
+    def __init__(self, operand) -> None:
+        super().__init__()
+        self._op = flow.builtin_op("scalar_add").Input("in").Output("out")
+
+        if isinstance(operand, int):
+            self._op = (
+                self._op.Attr("has_int_operand", True)
+                .Attr("has_float_operand", False)
+                .Attr("int_operand", operand)
+                .Attr("float_operand", 0.0)
+                .Build()
+            )
+        elif isinstance(operand, float):
+            self._op = (
+                self._op.Attr("has_int_operand", False)
+                .Attr("has_float_operand", True)
+                .Attr("int_operand", 0)
+                .Attr("float_operand", operand)
+                .Build()
+            )
+        else:
+            raise ValueError("operand type can only be int or float")
+
+    def forward(self, x):
+        return self._op(x)[0]
+
+
+@oneflow_export("sub")
+@register_tensor_op("sub")
+def _sub(x, y):
+    r"""Computes the subtraction of x by y for each element, scalar and broadcast promotation are supported.
+    The formula is:
+    .. math::
+        out = x - y
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+
+        # element-wise subtract
+        x = flow.Tensor(np.random.randn(2,3))
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.sub(x,y).numpy()
+        print(out.shape) # (2,3)
+
+        # scalar subtract
+        x = 5
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.sub(x,y).numpy()
+        print(out.shape) # (2,3)
+
+        # broadcast subtract
+        x = flow.Tensor(np.random.randn(1,1))
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.sub(x,y).numpy()
+        print(out.shape) # (2,3)
+
+    """
+
+    if isinstance(x, (int, float)):
+        return ScalarAdd(x)(ScalarMul(-1)(y))
+    elif isinstance(y, (int, float)):
+        return ScalarAdd(-1 * y)(x)
+    elif x.shape == y.shape:
+        # TODO: add element-wise op
+        return BroadcastSub()(x, y)
+    elif y.shape == (1,):
+        return ScalarSubByTensor()(x, y)
+    else:
+        return BroadcastSub()(x, y)
+
+
+class BroadcastDiv(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = (
+            flow.builtin_op("broadcast_div").Input("x").Input("y").Output("z").Build()
+        )
+
+    def forward(self, x, y):
+        return self._op(x, y)[0]
+
+
+class ScalarDivByTensor(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = (
+            flow.builtin_op("scalar_div_by_tensor")
+            .Input("x")
+            .Input("scalar")
+            .Output("y")
+            .Build()
+        )
+
+    def forward(self, x, scalar):
+        return self._op(x, scalar)[0]
+
+
+@oneflow_export("div")
+@register_tensor_op("div")
+def _div(x, y):
+    r"""Computes the division of x by y for each element, scalar and broadcast promotation are supported.
+    The formula is:
+    .. math::
+        out = \frac{X}{Y}
+    Args:
+        x (Union[int, float, flow.Tensor]): X.
+        y (Union[int, float, flow.Tensor]): Y.
+    For example:
+    .. code-block:: python
+
+        import oneflow as flow
+
+        # element-wise divide
+        x = flow.Tensor(np.random.randn(2,3))
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.div(x,y).numpy()
+        print(out.shape) # (2,3)
+
+        # scalar divide
+        x = 5
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.div(x,y).numpy()
+        print(out.shape) # (2,3)
+
+        # broadcast divide
+        x = flow.Tensor(np.random.randn(1,1))
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.div(x,y).numpy()
+        print(out.shape) # (2,3)
+
+    """
+
+    if isinstance(x, (int, float)):
+        return ScalarMul(x)(flow.reciprocal(y))
+    elif isinstance(y, (int, float)):
+        if y == 0 or y == 0.0:
+            y = 0.0
+        else:
+            y = 1.0 / (float(y))
+        return ScalarMul(y)(x)
+    elif x.shape == y.shape:
+        return BroadcastDiv()(x, y)
+    elif y.shape == (1,):
+        return ScalarDivByTensor()(x, y)
+    else:
+        return BroadcastDiv()(x, y)
+
+
+class Reciprocal(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = flow.builtin_op("reciprocal_no_nan").Input("x").Output("y").Build()
+
+    def forward(self, x):
+        return self._op(x)[0]
+
+
+@oneflow_export("reciprocal")
+@register_tensor_op("reciprocal")
+def _reciprocal(x):
+    r"""Computes the safe reciprocal of x. If x is zero, the reciprocal will
+    be also set to zero.
+
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+        x = flow.Tensor(np.array([[1, 2, 3], [4, 5, 6]]))
+        out = flow.reciprocal()(x)
+        # out [[1.         0.5        0.33333334]
+               [0.25       0.2        0.16666667]]
+
+    """
+
+    return Reciprocal()(x)
+
+
+class ScalarAddByTensor(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = (
+            flow.builtin_op("scalar_add_by_tensor")
+            .Input("x")
+            .Input("scalar")
+            .Output("y")
+            .Build()
+        )
+
+    def forward(self, x, y):
+        return self._op(x, y)[0]
+
+
+class ElementwiseAdd(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = flow.builtin_op("add_n").Input("in", 2).Output("out").Build()
+
+    def forward(self, x, y):
+        return self._op(x, y)[0]
+
+
+class BroadcastAdd(Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self._op = (
+            flow.builtin_op("broadcast_add").Input("x").Input("y").Output("z").Build()
+        )
+
+    def forward(self, x, y):
+        return self._op(x, y)[0]
+
+
+@oneflow_export("add")
+@register_tensor_op("add")
+def _add(x, y):
+    r"""Computes the addition of x by y for each element, scalar and broadcast promotation are supported.
+    The formula is:
+    .. math::
+        out = x + y
+    For example:
+
+    .. code-block:: python
+
+        import oneflow as flow
+
+        # element-wise add
+        x = flow.Tensor(np.random.randn(2,3))
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.add(x,y).numpy()
+        print(out.shape) # (2,3)
+
+        # scalar add
+        x = 5
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.add(x,y).numpy()
+        print(out.shape) # (2,3)
+
+        # broadcast add
+        x = flow.Tensor(np.random.randn(1,1))
+        y = flow.Tensor(np.random.randn(2,3))
+        out = flow.add(x,y).numpy()
+        print(out.shape) # (2,3)
+
+    """
+
+    if isinstance(x, (int, float)):
+        return ScalarAdd(x)(y)
+    elif isinstance(y, (int, float)):
+        return ScalarAdd(y)(x)
+    elif x.shape == y.shape:
+        return ElementwiseAdd()(x, y)
+    elif x.shape == (1,):
+        return ScalarAddByTensor()(y, x)
+    elif y.shape == (1,):
+        return ScalarAddByTensor()(x, y)
+    else:
+        return BroadcastAdd()(x, y)
+
+
 class Sin(Module):
     def __init__(self) -> None:
         super().__init__()
@@ -33,25 +555,18 @@ class Sin(Module):
 def sin_op(tensor):
     r"""
     Returns a new tensor with the sine of the elements of :attr:`input`.
-
     .. math::
         \text{out}_{i} = \sin(\text{input}_{i})
-
     Args:
         input (Tensor) 鈥� the input tensor.
-
     For example:
-
     .. code-block:: python
-
         import oneflow as flow
         import numpy as np
-
         arr = np.array([-0.5461,  0.1347, -2.7266, -0.2746])
         input = flow.Tensor(arr, dtype=flow.float32)
         output = flow.sin(input)
         # [-0.51935846  0.13429303 -0.40318328 -0.27116194]
-
     """
     return Sin()(tensor)
 
@@ -70,20 +585,14 @@ class Cos(Module):
 def cos_op(tensor):
     r"""
     Returns a new tensor with the cosine  of the elements of :attr:`input`.
-
     .. math::
         \text{out}_{i} = \cos(\text{input}_{i})
-
     Args:
         input (Tensor) 鈥� the input tensor.
-
     For example:
-
     .. code-block:: python
-
         import oneflow as flow
         import numpy as np
-
         arr = np.array([1.4309,  1.2706, -0.8562,  0.9796])
         input = flow.Tensor(arr, dtype=flow.float32)
         output = flow.cos(input)
@@ -107,20 +616,14 @@ class Log(Module):
 def log_op(tensor):
     r"""
     Returns a new tensor with the natural logarithm of the elements of :attr:`input`.
-
     .. math::
         y_{i} = \log_{e} (x_{i})
-
     Args:
         input (Tensor) 鈥� the input tensor.
-
     For example:
-
     .. code-block:: python
-
         import oneflow as flow
         import numpy as np
-
         arr = np.random.randn(2, 3, 4, 5)
         input = flow.Tensor(arr, dtype=flow.float32)
         output = flow.log(input)
diff --git a/oneflow/python/ops/math_ops.py b/oneflow/python/ops/math_ops.py
index 5a1f7088ed5a63b4831d4945d0b7af13050bf71b..29c8792d06c540f83cd7c3c0b95832458e8bc036 100644
--- a/oneflow/python/ops/math_ops.py
+++ b/oneflow/python/ops/math_ops.py
@@ -198,8 +198,6 @@ def subtract(
     elif x.shape == y.shape:
         # TODO: add element-wise op
         return broadcast_sub(x, y, name)
-    elif x.shape == (1,):
-        return scalar_sub_by_tensor(y, x, name)
     elif y.shape == (1,):
         return scalar_sub_by_tensor(x, y, name)
     else:
@@ -315,8 +313,6 @@ def divide(
     elif x.shape == y.shape:
         # TODO: add element-wise op
         return broadcast_div(x, y, name)
-    elif x.shape == (1,):
-        return scalar_div_by_tensor(y, x, name)
     elif y.shape == (1,):
         return scalar_div_by_tensor(x, y, name)
     else:
diff --git a/oneflow/python/test/modules/test_add.py b/oneflow/python/test/modules/test_add.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b0eb52fbabca0818f4ea8c16787661f12815d06
--- /dev/null
+++ b/oneflow/python/test/modules/test_add.py
@@ -0,0 +1,59 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+import unittest
+import numpy as np
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestAddModule(flow.unittest.TestCase):
+    def test_add(test_case):
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.add(x, y)
+        np_out = np.add(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = 5
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.add(x, y)
+        np_out = np.add(x, y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = 5
+        of_out = flow.add(x, y)
+        np_out = np.add(x.numpy(), y)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = flow.Tensor(np.array([5]))
+        of_out = flow.add(x, y)
+        np_out = np.add(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(1, 1))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.add(x, y)
+        np_out = np.add(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_avgpool2d_module.py b/oneflow/python/test/modules/test_avgpool2d.py
similarity index 100%
rename from oneflow/python/test/modules/test_avgpool2d_module.py
rename to oneflow/python/test/modules/test_avgpool2d.py
diff --git a/oneflow/python/test/modules/test_crossentropyloss.py b/oneflow/python/test/modules/test_crossentropyloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..75f23358be6d63b5830da926383407fb1615a1a7
--- /dev/null
+++ b/oneflow/python/test/modules/test_crossentropyloss.py
@@ -0,0 +1,61 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+import unittest
+import numpy as np
+
+g_test_samples = [
+    {
+        "input": np.array(
+            [
+                [-0.6980871, 0.4765042, -1.969919, 0.28965086, -0.53548324],
+                [-0.26332688, 0.27541, 0.30080616, 0.09914763, 0.53522176],
+                [0.7332028, 0.38375184, -0.2831992, -0.9833142, 0.387824],
+            ]
+        ),
+        "target": np.array([3, 3, 4], dtype=np.int32),
+        "out": np.array([1.1380, 1.7332, 1.4287], dtype=np.float32),
+        "out_sum": np.array([4.2999], dtype=np.float32),
+        "out_mean": np.array([1.4333], dtype=np.float32),
+    }
+]
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestCrossEntropyLossModule(flow.unittest.TestCase):
+    def test_CrossEntropyLoss(test_case):
+        global g_test_samples
+        for sample in g_test_samples:
+            loss = flow.nn.CrossEntropyLoss(reduction=None)
+            input = flow.Tensor(sample["input"], dtype=flow.float32)
+            target = flow.Tensor(sample["target"], dtype=flow.int32)
+            of_out = loss(input, target)
+            assert np.allclose(of_out.numpy(), sample["out"], 1e-4, 1e-4)
+
+            loss_sum = flow.nn.CrossEntropyLoss(reduction="sum")
+            of_out_sum = loss_sum(input, target)
+            assert np.allclose(of_out_sum.numpy(), sample["out_sum"], 1e-4, 1e-4)
+
+            loss_mean = flow.nn.CrossEntropyLoss(reduction="mean")
+            of_out_mean = loss_mean(input, target)
+            assert np.allclose(of_out_mean.numpy(), sample["out_mean"], 1e-4, 1e-4)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_div.py b/oneflow/python/test/modules/test_div.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e6f530f9cb26b2dba3a7e330f0a6dc03d57450a
--- /dev/null
+++ b/oneflow/python/test/modules/test_div.py
@@ -0,0 +1,65 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+import unittest
+import numpy as np
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestDiv(flow.unittest.TestCase):
+    def test_div(test_case):
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.div(x, y)
+        np_out = np.divide(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = 5
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.div(x, y)
+        np_out = np.divide(x, y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = 5
+        of_out = flow.div(x, y)
+        np_out = np.divide(x.numpy(), y)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = flow.Tensor(np.random.randn(1, 1))
+        of_out = flow.div(x, y)
+        np_out = np.divide(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.array([5]))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.div(x, y)
+        np_out = np.divide(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = flow.Tensor(np.array([5]))
+        of_out = flow.div(x, y)
+        np_out = np.divide(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_mean.py b/oneflow/python/test/modules/test_mean.py
new file mode 100644
index 0000000000000000000000000000000000000000..542fb97bc7682fb8a204fa73455cf23465bbc417
--- /dev/null
+++ b/oneflow/python/test/modules/test_mean.py
@@ -0,0 +1,39 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+import unittest
+import numpy as np
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestMeanModule(flow.unittest.TestCase):
+    def test_mean(test_case):
+        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
+        of_out = flow.mean(input, dim=1)
+        np_out = np.mean(input.numpy(), axis=1)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
+        of_out = flow.mean(input, dim=0)
+        np_out = np.mean(input.numpy(), axis=0)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_mul.py b/oneflow/python/test/modules/test_mul.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7bba0d0843ecaa93e9f4d2404cf4b3dd696080f
--- /dev/null
+++ b/oneflow/python/test/modules/test_mul.py
@@ -0,0 +1,53 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+import unittest
+import numpy as np
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestMulModule(flow.unittest.TestCase):
+    def test_mul(test_case):
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.mul(x, y)
+        np_out = np.multiply(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = 5
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.mul(x, y)
+        np_out = np.multiply(x, y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = 5
+        of_out = flow.mul(x, y)
+        np_out = np.multiply(x.numpy(), y)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(1, 1))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.mul(x, y)
+        np_out = np.multiply(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_reciprocal.py b/oneflow/python/test/modules/test_reciprocal.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad64f316974ecc295c386b1c6c4c10340ba30d66
--- /dev/null
+++ b/oneflow/python/test/modules/test_reciprocal.py
@@ -0,0 +1,34 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+import unittest
+import numpy as np
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestReciprocalModule(flow.unittest.TestCase):
+    def test_reciprocal(test_case):
+        x = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.reciprocal(x)
+        np_out = np.reciprocal(x.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_sub.py b/oneflow/python/test/modules/test_sub.py
new file mode 100644
index 0000000000000000000000000000000000000000..15f720eb84bc5da4a47b11d39ae1423c33f4a8c8
--- /dev/null
+++ b/oneflow/python/test/modules/test_sub.py
@@ -0,0 +1,65 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+import unittest
+import numpy as np
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestSubModule(flow.unittest.TestCase):
+    def test_sub(test_case):
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.sub(x, y)
+        np_out = np.subtract(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = 5
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = flow.sub(x, y)
+        np_out = np.subtract(x, y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = 5
+        of_out = flow.sub(x, y)
+        np_out = np.subtract(x.numpy(), y)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        y = flow.Tensor(np.random.randn(1, 1))
+        of_out = flow.sub(x, y)
+        np_out = np.subtract(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.array([5]))
+        y = flow.Tensor(np.random.randn(1, 1))
+        of_out = flow.sub(x, y)
+        np_out = np.subtract(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(1, 1))
+        y = flow.Tensor(np.array([5]))
+        of_out = flow.sub(x, y)
+        np_out = np.subtract(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/modules/test_sum.py b/oneflow/python/test/modules/test_sum.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d49b384ebece9d2d9bb2ca3acda62c00073c617
--- /dev/null
+++ b/oneflow/python/test/modules/test_sum.py
@@ -0,0 +1,51 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+import unittest
+import numpy as np
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestSumModule(flow.unittest.TestCase):
+    def test_sum(test_case):
+        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
+        of_out = flow.sum(input, dim=0)
+        np_out = np.sum(input.numpy(), axis=0)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
+        of_out = flow.sum(input, dim=0)
+        np_out = np.sum(input.numpy(), axis=0)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
+        of_out = flow.sum(input, dim=1)
+        of_out2 = input.sum(dim=1)
+        np_out = np.sum(input.numpy(), axis=1)
+        test_case.assertTrue(np.allclose(of_out2.numpy(), of_out.numpy(), 1e-4, 1e-4))
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        input = flow.Tensor(np.random.randn(4, 5, 6), dtype=flow.float32)
+        of_out = flow.sum(input, dim=(2, 1))
+        np_out = np.sum(input.numpy(), axis=(2, 1))
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/tensor/test_tensor.py b/oneflow/python/test/tensor/test_tensor.py
index 7f1137af30658de6cbeba201d24e60a53aaf8de4..1974e0567e98bf2db9973f662ad282478f5734a8 100644
--- a/oneflow/python/test/tensor/test_tensor.py
+++ b/oneflow/python/test/tensor/test_tensor.py
@@ -94,6 +94,97 @@ class TestTensor(flow.unittest.TestCase):
         test_case.assertTrue(np.array_equal(y.numpy(), 5 * np.ones(y.shape)))
         test_case.assertTrue(np.array_equal(z.numpy(), 5 * np.ones(z.shape)))
 
+    def test_div(test_case):
+        x = flow.Tensor(np.random.randn(1, 1))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = x / y
+        np_out = np.divide(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        of_out = x / 3
+        np_out = np.divide(x.numpy(), 3)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        of_out = 3 / x
+        np_out = np.divide(3, x.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(1))
+        of_out = 3 / x
+        np_out = np.divide(3, x.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+    def test_mul(test_case):
+        x = flow.Tensor(np.random.randn(1, 1))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = x * y
+        np_out = np.multiply(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        of_out = x * 3
+        np_out = np.multiply(x.numpy(), 3)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        of_out = 3 * x
+        np_out = np.multiply(3, x.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+    def test_add_tensor_method(test_case):
+        x = flow.Tensor(np.random.randn(1, 1))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = x + y
+        np_out = np.add(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        of_out = x + 3
+        np_out = np.add(x.numpy(), 3)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        of_out = 3 + x
+        np_out = np.add(3, x.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+    def test_sub_tensor_method(test_case):
+        x = flow.Tensor(np.random.randn(1, 1))
+        y = flow.Tensor(np.random.randn(2, 3))
+        of_out = x - y
+        np_out = np.subtract(x.numpy(), y.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        of_out = x - 3
+        np_out = np.subtract(x.numpy(), 3)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+        x = flow.Tensor(np.random.randn(2, 3))
+        of_out = 3 - x
+        np_out = np.subtract(3, x.numpy())
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+    def test_sum(test_case):
+        input = flow.Tensor(np.random.randn(4, 5, 6), dtype=flow.float32)
+        of_out = input.sum(dim=(2, 1))
+        np_out = np.sum(input.numpy(), axis=(2, 1))
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+    def test_mean(test_case):
+        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
+        of_out = input.mean(dim=0)
+        np_out = np.mean(input.numpy(), axis=0)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
+    def test_neg(test_case):
+        input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
+        of_out = -input
+        np_out = -input.numpy()
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+
 
 if __name__ == "__main__":
     unittest.main()