diff --git a/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp b/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..85e19231f78cfeb0cb4569025842b4a2c8730a23
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp
@@ -0,0 +1,72 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/functional/functional.h"
+
+namespace oneflow {
+namespace one {
+
+struct ClipByScalarMaxInterpState : public OpExprInterpState {
+  bool requires_grad;
+  functional::Scalar max;
+};
+
+class ClipByScalarMax : public OpExprGradFunction<ClipByScalarMaxInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override {
+    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
+    CHECK_NOTNULL_OR_RETURN(fw_op_expr);
+    base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Capture(ClipByScalarMaxInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    ctx->SaveTensorForBackward(inputs.at(0));
+
+    ComposedAttrMap composed_attrs(attrs, base_attrs_);
+    if (IsFloatingDataType(inputs.at(0)->dtype())) {
+      ctx->max = functional::Scalar(JUST(composed_attrs.GetAttr<double>("floating_max")));
+    } else if (IsIntegralDataType(inputs.at(0)->dtype())) {
+      ctx->max = functional::Scalar(JUST(composed_attrs.GetAttr<int64_t>("integral_max")));
+    } else {
+      UNIMPLEMENTED_THEN_RETURN() << "Data type is not floating or integral type.";
+    }
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const ClipByScalarMaxInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    in_grads->resize(1);
+    if (ctx->requires_grad) {
+      const auto& x = ctx->SavedTensors().at(0);
+      in_grads->at(0) = JUST(functional::ClipByScalarMaxGrad(out_grads.at(0), x, ctx->max));
+    }
+    return Maybe<void>::Ok();
+  }
+
+ private:
+  AttrMap base_attrs_;
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("clip_by_scalar_max", ClipByScalarMax);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/python/oneflow/nn/modules/activation.py b/python/oneflow/nn/modules/activation.py
index 8c3abb6e5f5aa516a1fb6aaecb0c5fe08d04819a..406d0187e6e40ba0a5bc52f58b851e39eb5344ea 100644
--- a/python/oneflow/nn/modules/activation.py
+++ b/python/oneflow/nn/modules/activation.py
@@ -199,7 +199,7 @@ class Tanh(Module):
         out = \\frac{e^x-e^{-x}}{e^x+e^{-x}}
 
     Args:
-        x (oneflow.Tensor): A Tensor
+        input (oneflow.Tensor): A Tensor
 
     Returns:
         oneflow.Tensor: The result Tensor
@@ -223,12 +223,12 @@ class Tanh(Module):
     def __init__(self):
         super().__init__()
 
-    def forward(self, x):
-        return flow.F.tanh(x)
+    def forward(self, input):
+        return flow.F.tanh(input)
 
 
 @register_tensor_op("tanh")
-def tanh_op(x):
+def tanh_op(input):
     """This operator computes the hyperbolic tangent value of Tensor.
 
     The equation is:
@@ -258,7 +258,7 @@ def tanh_op(x):
         tensor([-0.7616,  0.    ,  0.7616], dtype=oneflow.float32)
 
     """
-    return Tanh()(x)
+    return Tanh()(input)
 
 
 class ELU(Module):
diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py
index d684591b6b1820fab5a9871da10e332ee4c244c5..a54d81c5fb798331968aa561d46077b9da9401f4 100644
--- a/python/oneflow/test/modules/test_activation.py
+++ b/python/oneflow/test/modules/test_activation.py
@@ -60,7 +60,7 @@ class TestReLUModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             _test_relu_impl(test_case, *arg)
 
-    @autotest
+    @autotest()
     def test_relu_module_with_random_data(test_case):
         m = torch.nn.ReLU()
         m.train(random())
@@ -101,7 +101,7 @@ class TestReLU6Module(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             _test_relu6_impl(test_case, *arg)
 
-    @autotest
+    @autotest()
     def test_relu6_module_with_random_data(test_case):
         m = torch.nn.ReLU6()
         m.train(random())
@@ -153,7 +153,7 @@ class TestTanh(flow.unittest.TestCase):
             _test_tanh_nn_impl(test_case, *arg)
             _test_tanh_function_impl(test_case, *arg)
 
-    @autotest
+    @autotest()
     def test_tanh_module_with_random_data(test_case):
         m = torch.nn.Tanh()
         m.train(random())
@@ -163,11 +163,11 @@ class TestTanh(flow.unittest.TestCase):
         y = m(x)
         return y
 
-    @autotest
+    @autotest()
     def test_flow_tanh_with_random_data(test_case):
         device = random_device()
         x = random_pytorch_tensor().to(device)
-        y = flow.tanh(x)
+        y = torch.tanh(x)
         return y
 
 
@@ -199,7 +199,7 @@ class TestELUModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             _test_elu_function_impl(test_case, *arg)
 
-    @autotest
+    @autotest()
     def test_elu_module_with_random_data(test_case):
         m = torch.nn.ELU(alpha=random() | nothing())
         m.train(random())
@@ -678,7 +678,8 @@ class TestSoftplusModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
-    @autotest
+    @unittest.skip("pytorch softplus backward has bug")
+    @autotest()
     def test_softplus_module_with_random_data(test_case):
         m = torch.nn.Softplus(beta=random() | nothing(), threshold=random() | nothing())
         m.train(random())
@@ -782,7 +783,7 @@ class TestLeakyReLUModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             _test_leakyrelu_impl(test_case, *arg)
 
-    @autotest
+    @autotest()
     def test_leakyrelu_module_with_random_data(test_case):
         m = torch.nn.LeakyReLU(negative_slope=random() | nothing())
         m.train(random())
diff --git a/python/oneflow/test/modules/test_addmm.py b/python/oneflow/test/modules/test_addmm.py
index 8dd938c29567139d19312047a1bac5afe17e0633..787b914090611d9612a9084c86f92915b3aff33f 100644
--- a/python/oneflow/test/modules/test_addmm.py
+++ b/python/oneflow/test/modules/test_addmm.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_addmm(test_case, shape, alpha, beta, device):
@@ -65,6 +66,36 @@ class TestAddmm(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest()
+    def test_addmm_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor(ndim=2, dim0=2, dim1=3).to(device)
+        mat1 = random_pytorch_tensor(ndim=2, dim0=2, dim1=4).to(device)
+        mat2 = random_pytorch_tensor(ndim=2, dim0=4, dim1=3).to(device)
+        y = torch.addmm(
+            input,
+            mat1,
+            mat2,
+            beta=random().to(float) | nothing(),
+            alpha=random().to(float) | nothing(),
+        )
+        return y
+
+    @autotest()
+    def test_addmm_broadcast_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor(ndim=2, dim0=1, dim1=1).to(device)
+        mat1 = random_pytorch_tensor(ndim=2, dim0=2, dim1=4).to(device)
+        mat2 = random_pytorch_tensor(ndim=2, dim0=4, dim1=3).to(device)
+        y = torch.addmm(
+            input,
+            mat1,
+            mat2,
+            beta=random().to(float) | nothing(),
+            alpha=random().to(float) | nothing(),
+        )
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_batchnorm.py b/python/oneflow/test/modules/test_batchnorm.py
index ffe4c767ea53e4f414689fab15a355dfb6893ad3..193fb36454d3e4b0c8a6caada514cbd868a84733 100644
--- a/python/oneflow/test/modules/test_batchnorm.py
+++ b/python/oneflow/test/modules/test_batchnorm.py
@@ -516,6 +516,17 @@ class TestBatchNorm(flow.unittest.TestCase):
                     n=10,
                 )
 
+    @autotest(n=1, auto_backward=False)
+    def test_batchnorm3d_module_with_random_data(test_case):
+        channel = random().to(int)
+        m = torch.nn.BatchNorm2d(num_features=channel, track_running_stats=False)
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor(ndim=4, dim1=channel, requires_grad=False).to(device)
+        y = m(x)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_ceil.py b/python/oneflow/test/modules/test_ceil.py
index 6eb12d7a32be4b65546b8bd0a3e2c472cbacdb5f..002251d666225c6838edab796332a94a3896d97e 100644
--- a/python/oneflow/test/modules/test_ceil.py
+++ b/python/oneflow/test/modules/test_ceil.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_ceil_impl(test_case, device, shape):
@@ -46,6 +47,13 @@ class TestCeilModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest()
+    def test_ceil_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = torch.ceil(input)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_clamp.py b/python/oneflow/test/modules/test_clamp.py
index 4ef548bfd688db283eb8d4856a8a5ab201691a6f..640408dd1d68370a146f0f9e64c070df14fac650 100644
--- a/python/oneflow/test/modules/test_clamp.py
+++ b/python/oneflow/test/modules/test_clamp.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_clamp(test_case, shape, device):
@@ -106,6 +107,52 @@ class TestClampModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest()
+    def test_clamp_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = torch.clamp(input, min=random().to(float), max=random().to(float))
+        return y
+
+    @autotest()
+    def test_clamp_min_none_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = torch.clamp(input, min=random().to(float), max=random().to(float))
+        return y
+
+    @autotest()
+    def test_clamp_max_none_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = torch.clamp(
+            input, min=random().to(float), max=random().to(float) | nothing()
+        )
+        return y
+
+    @autotest()
+    def test_clip_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = torch.clip(input, min=random().to(float), max=random().to(float))
+        return y
+
+    @autotest()
+    def test_clip_min_none_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = torch.clip(input, min=random().to(float), max=random().to(float))
+        return y
+
+    @autotest()
+    def test_clip_max_none_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = torch.clip(
+            input, min=random().to(float), max=random().to(float) | nothing()
+        )
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/modules/test_expm1.py b/python/oneflow/test/modules/test_expm1.py
index 9084a370b62d39c1ade318bc7cf775d57ebe2b05..454b5daf6297318353285796f1f26e082949af58 100644
--- a/python/oneflow/test/modules/test_expm1.py
+++ b/python/oneflow/test/modules/test_expm1.py
@@ -22,6 +22,7 @@ from test_util import GenArgList
 
 import oneflow as flow
 import oneflow.unittest
+from automated_test_util import *
 
 
 def _test_expm1_impl(test_case, device, shape):
@@ -46,6 +47,13 @@ class TestExpm1Module(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest()
+    def test_expm1_flow_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = torch.expm1(input)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/oneflow/test/tensor/test_tensor.py b/python/oneflow/test/tensor/test_tensor.py
index b5bc832f3e36b7bb4bcbe0af4950bcaccb5a7c10..072cd2a4ea9c6c8f37694cc5dd5437701c311301 100644
--- a/python/oneflow/test/tensor/test_tensor.py
+++ b/python/oneflow/test/tensor/test_tensor.py
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
 limitations under the License.
 """
 
-import random
 import unittest
 from collections import OrderedDict
 
@@ -714,13 +713,89 @@ class TestTensor(flow.unittest.TestCase):
             np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True)
         )
 
-    def test_tensor_addmm_(test_case):
-        input = flow.Tensor(np.random.randn(2, 6), dtype=flow.float32)
-        mat1 = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32)
-        mat2 = flow.Tensor(np.random.randn(3, 6), dtype=flow.float32)
-        of_out = input.addmm(mat1, mat2, alpha=1, beta=2)
-        np_out = np.add(2 * input.numpy(), 1 * np.matmul(mat1.numpy(), mat2.numpy()))
-        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05))
+    @autotest()
+    def test_addmm_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor(ndim=2, dim0=2, dim1=3).to(device)
+        mat1 = random_pytorch_tensor(ndim=2, dim0=2, dim1=4).to(device)
+        mat2 = random_pytorch_tensor(ndim=2, dim0=4, dim1=3).to(device)
+        y = input.addmm(
+            mat1,
+            mat2,
+            beta=random().to(float) | nothing(),
+            alpha=random().to(float) | nothing(),
+        )
+        return y
+
+    @autotest()
+    def test_addmm_broadcast_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor(ndim=2, dim0=1, dim1=1).to(device)
+        mat1 = random_pytorch_tensor(ndim=2, dim0=2, dim1=4).to(device)
+        mat2 = random_pytorch_tensor(ndim=2, dim0=4, dim1=3).to(device)
+        y = input.addmm(
+            mat1,
+            mat2,
+            beta=random().to(float) | nothing(),
+            alpha=random().to(float) | nothing(),
+        )
+        return y
+
+    @autotest()
+    def test_clamp_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = input.clamp(min=random().to(float), max=random().to(float))
+        return y
+
+    @autotest()
+    def test_clamp_minnone_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = input.clamp(min=random().to(float) | nothing(), max=random().to(float))
+        return y
+
+    @autotest()
+    def test_clamp_maxnone_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = input.clamp(min=random().to(float), max=random().to(float) | nothing())
+        return y
+
+    @autotest()
+    def test_clip_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = input.clip(min=random().to(float), max=random().to(float))
+        return y
+
+    @autotest()
+    def test_clip_minnone_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = input.clip(min=random().to(float) | nothing(), max=random().to(float))
+        return y
+
+    @autotest()
+    def test_clip_maxnone_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = input.clip(min=random().to(float), max=random().to(float) | nothing())
+        return y
+
+    @autotest()
+    def test_ceil_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = input.ceil()
+        return y
+
+    @autotest()
+    def test_expm1_tensor_with_random_data(test_case):
+        device = random_device()
+        input = random_pytorch_tensor().to(device)
+        y = input.expm1()
+        return y
 
     def test_norm_tensor_function(test_case):
         input = flow.Tensor(
@@ -818,7 +893,7 @@ class TestTensor(flow.unittest.TestCase):
     )
     def test_tensor_fmod(test_case):
         x = flow.Tensor(np.random.uniform(-100, 100, (5, 5)), requires_grad=True)
-        y = random.uniform(-10, 10)
+        y = np.random.uniform(-10, 10)
         of_out = x.fmod(y)
         np_out = np.sign(x.numpy()) * np.abs(np.fmod(x.numpy(), y))
         test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
@@ -834,7 +909,7 @@ class TestTensor(flow.unittest.TestCase):
     )
     def test_magic_fmod(test_case):
         x = flow.Tensor(np.random.uniform(-100, 100, (5, 5)), requires_grad=True)
-        y = random.uniform(-10, 10)
+        y = np.random.uniform(-10, 10)
         of_out = x % y
         np_out = np.sign(x.numpy()) * np.abs(np.fmod(x.numpy(), y))
         test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
@@ -844,32 +919,6 @@ class TestTensor(flow.unittest.TestCase):
             np.allclose(x.grad.numpy(), np.ones((5, 5)), 0.0001, 0.0001)
         )
 
-    @unittest.skipIf(
-        not flow.unittest.env.eager_execution_enabled(),
-        "numpy doesn't work in lazy mode",
-    )
-    def test_tensor_ceil(test_case):
-        x = flow.Tensor(np.random.randn(2, 3), requires_grad=True)
-        of_out = x.ceil()
-        np_out = np.ceil(x.numpy())
-        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
-        of_out = of_out.sum()
-        of_out.backward()
-        test_case.assertTrue(
-            np.allclose(x.grad.numpy(), np.zeros((2, 3)), 0.0001, 0.0001)
-        )
-
-    def test_tensor_expm1(test_case):
-        x = flow.Tensor(np.random.randn(2, 3), requires_grad=True)
-        of_out = x.expm1()
-        np_out = np.expm1(x.numpy())
-        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001))
-        of_out = of_out.sum()
-        of_out.backward()
-        test_case.assertTrue(
-            np.allclose(x.grad.numpy(), np.exp(x.numpy()), 0.0001, 0.0001)
-        )
-
     def test_tensor_mish(test_case):
         def np_mish(x):
             f = 1 + np.exp(x)