diff --git a/ci/test_multi_client/generic_test.sh b/ci/test_multi_client/generic_test.sh
index 1537bfc704ae3e236686a78a70f26b234ef738a5..941d727e6c6a61fdc12513788989218ad48a6911 100644
--- a/ci/test_multi_client/generic_test.sh
+++ b/ci/test_multi_client/generic_test.sh
@@ -6,6 +6,7 @@ export PYTHONUNBUFFERED=1
 src_dir=${ONEFLOW_SRC_DIR:-"$PWD"}
 test_dir=${ONEFLOW_TEST_DIR:-"$PWD/oneflow/python/test/modules"}
 test_tmp_dir=${ONEFLOW_TEST_TMP_DIR:-"./test_tmp_dir"}
+export ONEFLOW_TEST_UTILS_DIR=$src_dir/oneflow/python/test_utils
 
 
 rm -rf $test_tmp_dir
diff --git a/oneflow/python/nn/modules/matmul.py b/oneflow/python/nn/modules/matmul.py
index 252a90271db9c4151494524e0f7c0ab027b74229..61c443444f127b3aea5969026998dc9ba21b9b08 100644
--- a/oneflow/python/nn/modules/matmul.py
+++ b/oneflow/python/nn/modules/matmul.py
@@ -47,7 +47,7 @@ class MatMul(Module):
 @oneflow_export("matmul")
 @register_tensor_op("matmul")
 @experimental_api
-def matmul_op(a, b):
+def matmul_op(input, other):
     r"""This operator applies matrix multiplication to two Tensor.
 
     Args:
@@ -71,7 +71,7 @@ def matmul_op(a, b):
         flow.Size([2, 5])
 
     """
-    return MatMul()(a, b)
+    return MatMul()(input, other)
 
 
 if __name__ == "__main__":
diff --git a/oneflow/python/test/modules/automated_test_util.py b/oneflow/python/test/modules/automated_test_util.py
index 7359cf3f7e4735fe45df2597dddeaf2015993cb4..a52a0b0d7d783d00265babaa13b530fa81fa3ebe 100644
--- a/oneflow/python/test/modules/automated_test_util.py
+++ b/oneflow/python/test/modules/automated_test_util.py
@@ -13,416 +13,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 """
-import inspect
-import typing  # This unused import is needed
-from typing import Dict, Optional, Tuple, Any, Union
-from collections import namedtuple
-import random as random_util
 import os
+import sys
 
-import oneflow.experimental as flow
-import torch
-import numpy as np
+test_util_parent_dir = os.path.dirname(
+    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+)
 
-TEST_MODULE = 0
-TEST_FLOW = 1
-TEST_TENSOR = 2
+oneflow_test_utils_dir_from_env = os.getenv("ONEFLOW_TEST_UTILS_DIR")
+if oneflow_test_utils_dir_from_env:
+    from pathlib import Path
 
-rng = np.random.default_rng()
+    oneflow_test_utils_dir_from_env = Path(oneflow_test_utils_dir_from_env)
+    test_util_parent_dir = str(oneflow_test_utils_dir_from_env.parent.absolute())
 
-default_generators = {}
-
-
-def data_generator(annotation):
-    def register_data_generator(func):
-        default_generators[annotation] = func
-        return func
-
-    return register_data_generator
-
-
-@data_generator(bool)
-def _random_bool():
-    val = random_util.choice([True, False])
-    return val, val
-
-
-@data_generator(torch.Tensor)
-def _random_tensor():
-    return random_tensor()(None)
-
-
-def random_tensor(ndim=None, dim0=1, dim1=None, dim2=None, dim3=None, dim4=None):
-    assert ndim is None or 1 <= ndim <= 5
-    if ndim is None:
-        ndim = rng.integers(low=1, high=6)
-    shape = rng.integers(low=1, high=8, size=ndim)
-    if dim0 is not None:
-        shape[0] = dim0
-    if ndim >= 2 and dim1 is not None:
-        shape[1] = dim1
-    if ndim >= 3 and dim2 is not None:
-        shape[2] = dim2
-    if ndim >= 4 and dim3 is not None:
-        shape[3] = dim3
-    if ndim == 5 and dim4 is not None:
-        shape[4] = dim4
-
-    def generator(_):
-        np_arr = rng.random(shape)
-        return flow.Tensor(np_arr), torch.Tensor(np_arr)
-
-    return generator
-
-
-def choose(x):
-    def generator(_):
-        val = random_util.choice(x)
-        return val, val
-
-    return generator
-
-
-def random(low, high):
-    def generator(annotation):
-        if hasattr(annotation, "__origin__"):
-            # PyTorch _size_2_t and similar types are defined by type variables,
-            # leading to unexpected __args__ and __origin__
-            #
-            # _size_2_t = Union[T, Tuple[T, T]][int]
-            # _size_2_t.__origin__
-            # >> typing.Union[~T, typing.Tuple[~T, ~T]]
-            #
-            # So recreate a new annotation object by repr and eval
-            #
-            # _size_2_t
-            # >> typing.Union[int, typing.Tuple[int, int]]
-            # _size_2_t_new = eval(repr(annotation))
-            # _size_2_t_new.__origin__
-            # >> typing.Union
-            annotation = eval(repr(annotation))
-            if annotation.__origin__ is Union:
-                x = random_util.choice(annotation.__args__)
-                return generator(x)
-            if annotation.__origin__ is Tuple or annotation.__origin__ is tuple:
-                t = [generator(x) for x in annotation.__args__]
-                return zip(*t)
-            else:
-                raise NotImplementedError(
-                    f"Not implemented annotation {annotation} in random, type(annotation.__origin__) is {type(annotation.__origin__)}"
-                )
-
-        if annotation == int:
-            val = int(rng.integers(low, high))
-        elif annotation == float:
-            val = float(rng.random() * (high - low) + low)
-        else:
-            raise NotImplementedError(
-                f"Not implemented annotation {annotation} in random"
-            )
-        return val, val
-
-    return generator
-
-
-def constant(val):
-    def generator(_):
-        return val, val
-
-    return generator
-
-
-def test_against_pytorch(
-    test_case,
-    callable_name,
-    extra_annotations: Optional[Dict[str, Any]] = None,
-    extra_generators: Optional[Dict[str, Any]] = None,
-    extra_defaults: Optional[Dict[str, Any]] = None,
-    device: str = "cuda",
-    training: bool = True,
-    backward: bool = True,
-    rtol=1e-4,
-    atol=1e-5,
-    n=20,
-    pytorch_callable_name=None,
-    api_flag: int = TEST_MODULE,
-):
-    assert device in ["cuda", "cpu"]
-    if not training:
-        assert not backward
-    if extra_annotations is None:
-        extra_annotations = {}
-    if extra_generators is None:
-        extra_generators = {}
-    if extra_defaults is None:
-        extra_defaults = {}
-    if pytorch_callable_name is None:
-        pytorch_callable_name = callable_name
-
-    verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None
-
-    def has_full_args_spec(callable):
-        try:
-            spec = inspect.getfullargspec(callable)
-            return True
-        except Exception:
-            return False
-
-    if api_flag == TEST_TENSOR:
-        pytorch_tensor = torch.Tensor(1)
-        pytorch_call = eval(f"pytorch_tensor.{pytorch_callable_name}")
-    else:
-        pytorch_call = eval(f"torch.{pytorch_callable_name}")
-
-    Spec = namedtuple(
-        "spec",
-        "args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations",
-    )
-
-    if has_full_args_spec(pytorch_call):
-        tmp_spec = inspect.getfullargspec(pytorch_call)
-        new_defaults = tmp_spec.defaults
-        if new_defaults is None:
-            new_defaults = []
-        new_kwonlydefaults = tmp_spec.kwonlydefaults
-        if new_kwonlydefaults is None:
-            new_kwonlydefaults = []
-        spec = Spec(
-            tmp_spec.args,
-            tmp_spec.varargs,
-            tmp_spec.varkw,
-            new_defaults,
-            tmp_spec.kwonlyargs,
-            new_kwonlydefaults,
-            tmp_spec.annotations,
-        )
-    else:
-        args = list(extra_annotations.keys()) + list(extra_defaults.keys())
-        spec = Spec(args, None, None, [], [], {}, {})
-
-    annotations = spec.annotations
-    annotations.update(extra_annotations)
-
-    if "return" in annotations:
-        del annotations["return"]
-
-    args = (set(spec.args) | set(spec.kwonlyargs)) - {"self"}
-
-    assert args == set(
-        annotations.keys()
-    ), f"args = {args}, annotations = {annotations.keys()}"
-    annotations.update({"input": torch.Tensor})
-
-    def has_default(name):
-        if name in spec.args:
-            return (len(spec.args) - spec.args.index(name)) <= len(spec.defaults)
-        else:
-            assert name in spec.kwonlyargs
-            return (len(spec.kwonlyargs) - spec.kwonlyargs.index(name)) <= len(
-                spec.kwonlydefaults
-            )
-
-    def generate(name):
-        annotation = annotations[name]
-        if name in extra_generators:
-            return extra_generators[name](annotation)
-        return default_generators[annotation]()
-
-    while n > 0:
-        flow_attr_dict = {}
-        torch_attr_dict = {}
-        for name in args:
-            if has_default(name):
-                if rng.random() < 1 / 3:
-                    continue
-            flow_data, torch_data = generate(name)
-            if isinstance(torch_data, torch.Tensor):
-                torch_data = torch_data.to(device)
-            if isinstance(flow_data, flow.Tensor):
-                flow_data = flow_data.to(device)
-            flow_attr_dict[name] = flow_data
-            torch_attr_dict[name] = torch_data
-
-        if verbose:
-            print(f"attr = {torch_attr_dict}, device = {device}")
-
-        flow_input_original, torch_input_original = generate("input")
-        flow_input_original.requires_grad_(backward)
-        torch_input_original.requires_grad_(backward)
-        flow_input, torch_input = (
-            flow_input_original.to(device),
-            torch_input_original.to(device),
-        )
-        try:
-            if api_flag == TEST_MODULE:
-                torch_call = pytorch_call(**torch_attr_dict)
-                torch_call = torch_call.to(device)
-                torch_call.train(training)
-                torch_res = torch_call(torch_input)
-                state_dict = torch_call.state_dict()
-                state_dict = {
-                    k: v.detach().cpu().numpy() for k, v in state_dict.items()
-                }
-            elif api_flag == TEST_FLOW:
-                torch_xxx_func = eval(f"torch.{pytorch_callable_name}")
-                torch_res = torch_xxx_func(torch_input, **torch_attr_dict)
-
-            else:
-                torch_tensor_xxx_func = eval(f"torch_input.{pytorch_callable_name}")
-                torch_res = torch_tensor_xxx_func(**torch_attr_dict)
-
-            loss = torch_res.sum()
-            loss.backward()
-            if api_flag == TEST_MODULE:
-                state_dict = torch_call.state_dict()
-                state_dict = {
-                    k: v.detach().cpu().numpy() for k, v in state_dict.items()
-                }
-        except Exception as e:
-            if verbose:
-                print(f"PyTorch error: {e}")
-            # The random generated test data is not always valid,
-            # so just skip when PyTorch raises an exception
-            continue
-
-        if api_flag == TEST_MODULE:
-            flow_call_class = eval(f"flow.{callable_name}")
-            flow_call = flow_call_class(**flow_attr_dict)
-            flow_call = flow_call.to(device)
-            flow_call.train(training)
-            flow_call.load_state_dict(state_dict)
-            flow_res = flow_call(flow_input)
-        elif api_flag == TEST_FLOW:
-            flow_xxx_func = eval(f"flow.{callable_name}")
-            flow_res = flow_xxx_func(flow_input, **flow_attr_dict)
-        else:
-            flow_tensor_xxx_func = eval(f"flow_input.{callable_name}")
-            flow_res = flow_tensor_xxx_func(**flow_attr_dict)
-
-        loss = flow_res.sum()
-        loss.backward()
-
-        def allclose_or_fail(flow_tensor, torch_tensor):
-            is_allclose = np.allclose(
-                flow_tensor.numpy(),
-                torch_tensor.detach().cpu().numpy(),
-                rtol=rtol,
-                atol=atol,
-            )
-            test_case.assertTrue(
-                is_allclose,
-                f"flow_tensor = {flow_tensor},\ntorch_tensor = {torch_tensor},\nattr_dict = {torch_attr_dict},\nflow_input_tensor = {flow_input_original}",
-            )
-
-        allclose_or_fail(flow_res, torch_res)
-        allclose_or_fail(flow_input_original.grad, torch_input_original.grad)
-        if api_flag == TEST_MODULE:
-            flow_parameters = dict(flow_call.named_parameters())
-            for name, torch_param in torch_call.named_parameters():
-                flow_param = flow_parameters[name]
-                allclose_or_fail(flow_param.grad, torch_param.grad)
-        n -= 1
-
-
-def test_module_against_pytorch(
-    test_case,
-    callable_name,
-    extra_annotations: Optional[Dict[str, Any]] = None,
-    extra_generators: Optional[Dict[str, Any]] = None,
-    extra_defaults: Optional[Dict[str, Any]] = None,
-    device: str = "cuda",
-    training: bool = True,
-    backward: bool = True,
-    rtol=1e-4,
-    atol=1e-5,
-    n=20,
-    pytorch_callable_name=None,
-):
-    return test_against_pytorch(
-        test_case=test_case,
-        callable_name=callable_name,
-        extra_annotations=extra_annotations,
-        extra_generators=extra_generators,
-        extra_defaults=extra_defaults,
-        device=device,
-        training=training,
-        backward=backward,
-        rtol=rtol,
-        atol=atol,
-        n=n,
-        pytorch_callable_name=pytorch_callable_name,
-        api_flag=TEST_MODULE,
-    )
-
-
-def test_flow_against_pytorch(
-    test_case,
-    callable_name,
-    extra_annotations: Optional[Dict[str, Any]] = None,
-    extra_generators: Optional[Dict[str, Any]] = None,
-    extra_defaults: Optional[Dict[str, Any]] = None,
-    device: str = "cuda",
-    training: bool = True,
-    backward: bool = True,
-    rtol=1e-4,
-    atol=1e-5,
-    n=20,
-    pytorch_callable_name=None,
-):
-    return test_against_pytorch(
-        test_case=test_case,
-        callable_name=callable_name,
-        extra_annotations=extra_annotations,
-        extra_generators=extra_generators,
-        extra_defaults=extra_defaults,
-        device=device,
-        training=training,
-        backward=backward,
-        rtol=rtol,
-        atol=atol,
-        n=n,
-        pytorch_callable_name=pytorch_callable_name,
-        api_flag=TEST_FLOW,
-    )
-
-
-def test_tensor_against_pytorch(
-    test_case,
-    callable_name,
-    extra_annotations: Optional[Dict[str, Any]] = None,
-    extra_generators: Optional[Dict[str, Any]] = None,
-    extra_defaults: Optional[Dict[str, Any]] = None,
-    device: str = "cuda",
-    training: bool = True,
-    backward: bool = True,
-    rtol=1e-4,
-    atol=1e-5,
-    n=20,
-    pytorch_callable_name=None,
-):
-    return test_against_pytorch(
-        test_case=test_case,
-        callable_name=callable_name,
-        extra_annotations=extra_annotations,
-        extra_generators=extra_generators,
-        extra_defaults=extra_defaults,
-        device=device,
-        training=training,
-        backward=backward,
-        rtol=rtol,
-        atol=atol,
-        n=n,
-        pytorch_callable_name=pytorch_callable_name,
-        api_flag=TEST_TENSOR,
-    )
-
-
-__all__ = [
-    "random_tensor",
-    "random",
-    "choose",
-    "constant",
-    "test_module_against_pytorch",
-    "test_flow_against_pytorch",
-    "test_tensor_against_pytorch",
-]
+sys.path.append(test_util_parent_dir)
+from test_utils.automated_test_util import *
diff --git a/oneflow/python/test/modules/test_conv.py b/oneflow/python/test/modules/test_conv.py
index 1b179112780bb98c63499205f5f7d9bd8e5fb165..922aebfdfba8699321e99f0c5e0caf568e9020a8 100644
--- a/oneflow/python/test/modules/test_conv.py
+++ b/oneflow/python/test/modules/test_conv.py
@@ -1808,12 +1808,13 @@ class TestConv2d(flow.unittest.TestCase):
     @unittest.skip("need a more relaxed tolerance")
     def test_with_random_data(test_case):
         for device in ["cpu", "cuda"]:
+            channels = random(1, 6)
             test_module_against_pytorch(
                 test_case,
                 "nn.Conv2d",
                 extra_generators={
-                    "input": random_tensor(ndim=4, dim1=4),
-                    "in_channels": constant(4),
+                    "input": random_tensor(ndim=4, dim1=channels),
+                    "in_channels": channels,
                     "out_channels": random(1, 129),
                     "kernel_size": random(1, 4),
                     "stride": random(1, 4),
@@ -1825,6 +1826,30 @@ class TestConv2d(flow.unittest.TestCase):
                 device=device,
             )
 
+    @unittest.skip("need a more relaxed tolerance")
+    @autotest()
+    def test_against_pytorch(test_case):
+        channels = random(1, 6)
+        m = torch.nn.Conv2d(
+            channels,
+            random(1, 6),
+            random(1, 6),
+            stride=random(1, 3) | nothing(),
+            padding=random(1, 3) | nothing(),
+            dilation=random(1, 3) | nothing(),
+            groups=random(1, 3) | nothing(),
+            bias=random() | nothing(),
+            padding_mode=constant("zeros") | nothing(),
+        )
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor(
+            ndim=4, dim1=channels, dim2=random(1, 8), dim3=random(1, 8)
+        ).to(device)
+        y = m(x)
+        return y
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/oneflow/python/test/modules/test_flatten.py b/oneflow/python/test/modules/test_flatten.py
index 44cc76685c5079cb47356fda214198d085dee50e..0c59ae1d2cd9b9b3079c8d855fa3b98561d58605 100644
--- a/oneflow/python/test/modules/test_flatten.py
+++ b/oneflow/python/test/modules/test_flatten.py
@@ -67,12 +67,20 @@ class TestFlattenModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
-    def test_with_random_data(test_case):
-        test_module_against_pytorch(
-            test_case,
-            "nn.Flatten",
-            extra_generators={"start_dim": random(1, 6), "end_dim": random(1, 6),},
+    # Our flatten produces a new tensor if flatten is effectively a no-op,
+    # while pytorch's flatten returns the input tensor itself,
+    # leading to the inconsistency on the leaf-ness of x and thus the existence of x's grad
+    @autotest(auto_backward=False)
+    def test_against_pytorch(test_case):
+        m = torch.nn.Flatten(
+            start_dim=random(1, 6) | nothing(), end_dim=random(1, 6) | nothing()
         )
+        m.train(random())
+        device = random_device()
+        m.to(device)
+        x = random_pytorch_tensor().to(device)
+        y = m(x)
+        return y
 
 
 if __name__ == "__main__":
diff --git a/oneflow/python/test/modules/test_masked_fill.py b/oneflow/python/test/modules/test_masked_fill.py
index 56272068a02a5bced31bd9742a4d234549caff9f..040fa97e7d7d35c6f35764d4a3cdf39549e48325 100644
--- a/oneflow/python/test/modules/test_masked_fill.py
+++ b/oneflow/python/test/modules/test_masked_fill.py
@@ -23,6 +23,7 @@ from automated_test_util import *
 
 @flow.unittest.skip_unless_1n1d()
 class TestMaskedFill(flow.unittest.TestCase):
+    @unittest.skip("has bug now, need rewrite")
     def test_masked_fill_aginst_pytorch(test_case):
         import numpy as np
         import torch
diff --git a/oneflow/python/test/modules/test_matmul.py b/oneflow/python/test/modules/test_matmul.py
index 4ad3ab582958928a41f80678f521fd393cc81c9e..7ccd5e69e27ae50a417ad9781c9311f829629a93 100644
--- a/oneflow/python/test/modules/test_matmul.py
+++ b/oneflow/python/test/modules/test_matmul.py
@@ -17,9 +17,11 @@ from collections import OrderedDict
 
 import unittest
 import numpy as np
+import torch
 
 import oneflow.experimental as flow
 from test_util import GenArgList
+from automated_test_util import *
 
 
 def _test_matmul(test_case, device):
@@ -333,6 +335,14 @@ class TestModule(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    @autotest()
+    def test_flow_matmul_with_random_data(test_case):
+        k = random(1, 6)
+        x = random_pytorch_tensor(ndim=2, dim1=k)
+        y = random_pytorch_tensor(ndim=2, dim0=k)
+        z = torch.matmul(x, y)
+        return z
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/oneflow/python/test/tensor/automated_test_util.py b/oneflow/python/test/tensor/automated_test_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..22334c9d366c516599c3d6382501439e32aff4de
--- /dev/null
+++ b/oneflow/python/test/tensor/automated_test_util.py
@@ -0,0 +1,23 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import sys
+
+BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(BASE_DIR)
+
+
+from test_utils.automated_test_util import *
diff --git a/oneflow/python/test/tensor/test_tensor.py b/oneflow/python/test/tensor/test_tensor.py
index 8e9a0847875a7f6664d4be8fb3c627533edf503f..c3288b793a71da16e84bfc487375e7165f91dd0d 100644
--- a/oneflow/python/test/tensor/test_tensor.py
+++ b/oneflow/python/test/tensor/test_tensor.py
@@ -748,6 +748,13 @@ class TestTensor(flow.unittest.TestCase):
             np.allclose(of_input.grad.numpy(), np_out_grad, 1e-4, 1e-4, equal_nan=True)
         )
 
+    # TODO: find a way to import automated_test_util here to enable the following test
+    #
+    # @autotest()
+    # def test_tensor_tan(test_case):
+    #     x = random_pytorch_tensor().to(random_device())
+    #     return x.tan()
+
     def test_tensor_tan(test_case):
         np_input = np.random.random((2, 3)) - 0.5
         of_input = flow.Tensor(np_input, dtype=flow.float32, requires_grad=True)
diff --git a/oneflow/python/test_utils/__init__.py b/oneflow/python/test_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..569f9a06b653d037755601d1f6fc11cf92f9386c
--- /dev/null
+++ b/oneflow/python/test_utils/__init__.py
@@ -0,0 +1,16 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+# Since the name of test_util is already occupied, the name of this package is called test_utils.
diff --git a/oneflow/python/test_utils/automated_test_util/__init__.py b/oneflow/python/test_utils/automated_test_util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abf12a3daccf3e22b5adab11c6617dec088ea333
--- /dev/null
+++ b/oneflow/python/test_utils/automated_test_util/__init__.py
@@ -0,0 +1,17 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from .generators import *
+from .torch_flow_dual_object import *
diff --git a/oneflow/python/test_utils/automated_test_util/generators.py b/oneflow/python/test_utils/automated_test_util/generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dd25c1436db87d9e76e176842aeaa967bc6009a
--- /dev/null
+++ b/oneflow/python/test_utils/automated_test_util/generators.py
@@ -0,0 +1,635 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import inspect
+import typing  # This unused import is needed
+from typing import Dict, Optional, Tuple, Any, Union
+from collections import namedtuple
+import random as random_util
+import os
+
+import numpy as np
+import oneflow.experimental as flow
+import torch
+import numpy as np
+
+py_tuple = tuple
+
+TEST_MODULE = 0
+TEST_FLOW = 1
+TEST_TENSOR = 2
+
+rng = np.random.default_rng()
+
+annotation2default_generator = {}
+annotation2torch_to_flow_converter = {}
+
+
+def data_generator(annotation):
+    def register_data_generator(cls):
+        annotation2default_generator[annotation] = lambda: cls()
+        return cls
+
+    return register_data_generator
+
+
+def torch_to_flow_converter(annotation):
+    def register_flow_to_flow_converter(func):
+        annotation2torch_to_flow_converter[annotation] = func
+        return func
+
+    return register_flow_to_flow_converter
+
+
+@torch_to_flow_converter(torch.Tensor)
+def tensor_converter(torch_tensor):
+    return flow.tensor(torch_tensor.cpu().numpy())
+
+
+def convert_torch_object_to_flow(x):
+    for annotation, converter in annotation2torch_to_flow_converter.items():
+        if isinstance(x, annotation):
+            return converter(x)
+    return x
+
+
+def pack(x):
+    if isinstance(x, generator):
+        return x
+    return constant(x)
+
+
+class Nothing:
+    pass
+
+
+class generator:
+    def __init__(self, children):
+        self.children = children
+        self._value = None
+
+    def _init(self):
+        self._value = None
+        for x in self.children:
+            x._init()
+
+    def eval(self):
+        self._init()
+        return self.value()
+
+    def _calc_value(self):
+        raise NotImplementedError()
+
+    def value(self):
+        if self._value is None:
+            self._value = self._calc_value()
+        return self._value
+
+    def size(self):
+        return 1
+
+    def __or__(self, other):
+        other = pack(other)
+        return oneof(
+            self, other, possibility=self.size() / (self.size() + other.size())
+        )
+
+    def __ror__(self, other):
+        return self | other
+
+    def __add__(self, other):
+        return add(self, other)
+
+    def __radd__(self, other):
+        return self + other
+
+    def __sub__(self, other):
+        return self + neg(other)
+
+    def __rsub__(self, other):
+        return neg(self - other)
+
+    def to(self, annotation):
+        self._to(annotation)
+        for x in self.children:
+            x.to(annotation)
+        return self
+
+    def _to(self, annotation):
+        pass
+
+
+class add(generator):
+    def __init__(self, a, b):
+        self.a = pack(a)
+        self.b = pack(b)
+        super().__init__([self.a, self.b])
+
+    def _calc_value(self):
+        return self.a.value() + self.b.value()
+
+
+class neg(generator):
+    def __init__(self, a):
+        self.a = pack(a)
+        super().__init__([self.a])
+
+    def _calc_value(self):
+        return -self.a.value()
+
+
+class oneof(generator):
+    def __init__(self, *args, possibility=None):
+        self.args = list(map(pack, args))
+        super().__init__(self.args)
+        if isinstance(possibility, float):
+            assert len(args) == 2
+            possibility = [possibility, 1 - possibility]
+        if possibility is None:
+            possibility = [1 / len(args)] * len(args)
+        self.possibility = pack(possibility)
+
+    def _calc_value(self):
+        rand = rng.random()
+        sum = 0
+        for i, possibility in enumerate(self.possibility.value()):
+            sum += possibility
+            if sum > rand:
+                return self.args[i].value()
+        raise RuntimeError()
+
+    def size(self):
+        return sum([x.size() for x in self.args])
+
+
+class tuple(generator):
+    def __init__(self, *args):
+        self.args = list(map(pack, args))
+        super().__init__(self.args)
+
+    def _calc_value(self):
+        return py_tuple([x.value() for x in self.args])
+
+
+class constant(generator):
+    def __init__(self, x):
+        super().__init__([])
+        self.x = x
+
+    def _calc_value(self):
+        return self.x
+
+
+class nothing(generator):
+    def __init__(self):
+        super().__init__([])
+
+    def _calc_value(self):
+        return Nothing()
+
+
+class random(generator):
+    def __init__(self, low=1, high=6):
+        self.low = pack(low)
+        self.high = pack(high)
+        super().__init__([self.low, self.high])
+        self.annotation = None
+
+    def _to(self, annotation):
+        if self.annotation is not None:
+            return
+        if hasattr(annotation, "__origin__"):
+            # PyTorch _size_2_t and similar types are defined by type variables,
+            # leading to unexpected __args__ and __origin__
+            #
+            # >>> _size_2_t = Union[T, Tuple[T, T]][int]
+            # >>> _size_2_t.__origin__
+            # typing.Union[~T, typing.Tuple[~T, ~T]]
+            #
+            # So recreate a new annotation object by repr and eval
+            #
+            # >>> _size_2_t
+            # typing.Union[int, typing.Tuple[int, int]]
+            # >>> _size_2_t_new = eval(repr(annotation))
+            # >>> _size_2_t_new.__origin__
+            # typing.Union
+            annotation = eval(repr(annotation))
+        self.annotation = annotation
+
+    def _generate(self, annotation):
+        if hasattr(annotation, "__origin__"):
+            if annotation.__origin__ is Union:
+                x = random_util.choice(annotation.__args__)
+                return self._generate(x)
+            if annotation.__origin__ is Tuple or annotation.__origin__ is py_tuple:
+                return [self._generate(x) for x in annotation.__args__]
+            else:
+                raise NotImplementedError(
+                    f"Not implemented annotation {annotation} in random, type(annotation.__origin__) is {type(annotation.__origin__)}"
+                )
+
+        low, high = self.low.value(), self.high.value()
+
+        if annotation == int:
+            val = int(rng.integers(low, high))
+        elif annotation == float:
+            val = float(rng.random() * (high - low) + low)
+        elif annotation == bool:
+            val = random_util.choice([True, False])
+        else:
+            raise NotImplementedError(
+                f"Not implemented annotation {annotation} in random"
+            )
+        return val
+
+    def _calc_value(self):
+        return self._generate(self.annotation)
+
+
+def random_or_nothing(low, high):
+    return oneof(random(low, high), nothing(), possibility=2 / 3)
+
+
+@data_generator(torch.Tensor)
+class random_tensor(generator):
+    def __init__(self, ndim=None, dim0=1, dim1=None, dim2=None, dim3=None, dim4=None):
+        if ndim is None:
+            ndim = random(1, 6)
+        if dim0 is None:
+            dim0 = random(1, 8)
+        if dim1 is None:
+            dim1 = random(1, 8)
+        if dim2 is None:
+            dim2 = random(1, 8)
+        if dim3 is None:
+            dim3 = random(1, 8)
+        if dim4 is None:
+            dim4 = random(1, 8)
+        self.ndim = pack(ndim).to(int)
+        self.dim0 = pack(dim0).to(int)
+        self.dim1 = pack(dim1).to(int)
+        self.dim2 = pack(dim2).to(int)
+        self.dim3 = pack(dim3).to(int)
+        self.dim4 = pack(dim4).to(int)
+        super().__init__(
+            [self.ndim, self.dim0, self.dim1, self.dim2, self.dim3, self.dim4]
+        )
+
+    def _calc_value(self):
+        ndim = self.ndim.value()
+        dim0 = self.dim0.value()
+        dim1 = self.dim1.value()
+        dim2 = self.dim2.value()
+        dim3 = self.dim3.value()
+        dim4 = self.dim4.value()
+        shape = rng.integers(low=1, high=8, size=ndim)
+        if dim0 is not None:
+            shape[0] = dim0
+        if ndim >= 2:
+            shape[1] = dim1
+        if ndim >= 3:
+            shape[2] = dim2
+        if ndim >= 4:
+            shape[3] = dim3
+        if ndim == 5:
+            shape[4] = dim4
+        np_arr = rng.random(shape)
+        return torch.Tensor(np_arr)
+
+
+@data_generator(bool)
+def random_bool():
+    return random().to(bool)
+
+
+class random_device(generator):
+    def __init__(self):
+        super().__init__([])
+
+    def _calc_value(self):
+        return random_util.choice(["cuda", "cpu"])
+
+
+def test_against_pytorch(
+    test_case,
+    callable_name,
+    extra_annotations: Optional[Dict[str, Any]] = None,
+    extra_generators: Optional[Dict[str, Any]] = None,
+    extra_defaults: Optional[Dict[str, Any]] = None,
+    device: str = "cuda",
+    training: bool = True,
+    backward: bool = True,
+    rtol=1e-4,
+    atol=1e-5,
+    n=20,
+    pytorch_callable_name=None,
+    api_flag: int = TEST_MODULE,
+):
+    assert device in ["cuda", "cpu"]
+    if not training:
+        assert not backward
+    if extra_annotations is None:
+        extra_annotations = {}
+    if extra_generators is None:
+        extra_generators = {}
+    if extra_defaults is None:
+        extra_defaults = {}
+    if pytorch_callable_name is None:
+        pytorch_callable_name = callable_name
+
+    verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None
+
+    def has_full_args_spec(callable):
+        try:
+            inspect.getfullargspec(callable)
+            return True
+        except Exception:
+            return False
+
+    if api_flag == TEST_TENSOR:
+        pytorch_tensor = torch.Tensor(1)
+        pytorch_call = eval(f"pytorch_tensor.{pytorch_callable_name}")
+    else:
+        pytorch_call = eval(f"torch.{pytorch_callable_name}")
+
+    Spec = namedtuple(
+        "spec",
+        "args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations",
+    )
+
+    if has_full_args_spec(pytorch_call):
+        tmp_spec = inspect.getfullargspec(pytorch_call)
+        new_defaults = tmp_spec.defaults
+        if new_defaults is None:
+            new_defaults = []
+        new_kwonlydefaults = tmp_spec.kwonlydefaults
+        if new_kwonlydefaults is None:
+            new_kwonlydefaults = []
+        spec = Spec(
+            tmp_spec.args,
+            tmp_spec.varargs,
+            tmp_spec.varkw,
+            new_defaults,
+            tmp_spec.kwonlyargs,
+            new_kwonlydefaults,
+            tmp_spec.annotations,
+        )
+    else:
+        args = list(extra_annotations.keys()) + list(extra_defaults.keys())
+        spec = Spec(args, None, None, [], [], {}, {})
+
+    annotations = spec.annotations
+    annotations.update(extra_annotations)
+
+    if "return" in annotations:
+        del annotations["return"]
+
+    args = (set(spec.args) | set(spec.kwonlyargs)) - {"self"}
+
+    assert args == set(
+        annotations.keys()
+    ), f"args = {args}, annotations = {annotations.keys()}"
+
+    if "input" not in annotations:
+        annotations.update({"input": torch.Tensor})
+
+    def has_default(name):
+        if name in spec.args:
+            return (len(spec.args) - spec.args.index(name)) <= len(spec.defaults)
+        else:
+            assert name in spec.kwonlyargs
+            return (len(spec.kwonlyargs) - spec.kwonlyargs.index(name)) <= len(
+                spec.kwonlydefaults
+            )
+
+    def get_generator(name):
+        annotation = annotations[name]
+        if name in extra_generators:
+            generator = extra_generators[name]
+        else:
+            generator = annotation2default_generator[annotation]()
+        generator = generator.to(annotation)
+        return generator
+
+    while n > 0:
+        flow_attr_dict = {}
+        torch_attr_dict = {}
+
+        generator_tuple = tuple(
+            *([get_generator(name) for name in args] + [get_generator("input")])
+        )
+        values = generator_tuple.eval()
+
+        for i, name in enumerate(args):
+            torch_data = values[i]
+            if isinstance(torch_data, Nothing):
+                continue
+            flow_data = convert_torch_object_to_flow(torch_data)
+            if isinstance(torch_data, torch.Tensor):
+                torch_data = torch_data.to(device)
+            if isinstance(flow_data, flow.Tensor):
+                flow_data = flow_data.to(device)
+            flow_attr_dict[name] = flow_data
+            torch_attr_dict[name] = torch_data
+
+        if verbose:
+            print(f"attr = {torch_attr_dict}, device = {device}")
+
+        torch_input_original = values[-1]
+        flow_input_original = convert_torch_object_to_flow(torch_input_original)
+        flow_input_original.requires_grad_(backward)
+        torch_input_original.requires_grad_(backward)
+        flow_input, torch_input = (
+            flow_input_original.to(device),
+            torch_input_original.to(device),
+        )
+        try:
+            if api_flag == TEST_MODULE:
+                torch_call = pytorch_call(**torch_attr_dict)
+                torch_call = torch_call.to(device)
+                torch_call.train(training)
+                torch_res = torch_call(torch_input)
+                state_dict = torch_call.state_dict()
+                state_dict = {
+                    k: v.detach().cpu().numpy() for k, v in state_dict.items()
+                }
+            elif api_flag == TEST_FLOW:
+                torch_xxx_func = eval(f"torch.{pytorch_callable_name}")
+                torch_res = torch_xxx_func(torch_input, **torch_attr_dict)
+
+            else:
+                torch_tensor_xxx_func = eval(f"torch_input.{pytorch_callable_name}")
+                torch_res = torch_tensor_xxx_func(**torch_attr_dict)
+
+            loss = torch_res.sum()
+            loss.backward()
+            if api_flag == TEST_MODULE:
+                state_dict = torch_call.state_dict()
+                state_dict = {
+                    k: v.detach().cpu().numpy() for k, v in state_dict.items()
+                }
+        except Exception as e:
+            if verbose:
+                print(f"PyTorch error: {e}")
+            # The random generated test data is not always valid,
+            # so just skip when PyTorch raises an exception
+            continue
+
+        if api_flag == TEST_MODULE:
+            flow_call_class = eval(f"flow.{callable_name}")
+            flow_call = flow_call_class(**flow_attr_dict)
+            flow_call = flow_call.to(device)
+            flow_call.train(training)
+            flow_call.load_state_dict(state_dict)
+            flow_res = flow_call(flow_input)
+        elif api_flag == TEST_FLOW:
+            flow_xxx_func = eval(f"flow.{callable_name}")
+            flow_res = flow_xxx_func(flow_input, **flow_attr_dict)
+        else:
+            flow_tensor_xxx_func = eval(f"flow_input.{callable_name}")
+            flow_res = flow_tensor_xxx_func(**flow_attr_dict)
+
+        loss = flow_res.sum()
+        loss.backward()
+
+        def allclose_or_fail(flow_tensor, torch_tensor):
+            is_allclose = np.allclose(
+                flow_tensor.numpy(),
+                torch_tensor.detach().cpu().numpy(),
+                rtol=rtol,
+                atol=atol,
+            )
+            test_case.assertTrue(
+                is_allclose,
+                f"flow_tensor = {flow_tensor},\ntorch_tensor = {torch_tensor},\nattr_dict = {torch_attr_dict},\nflow_input_tensor = {flow_input_original}",
+            )
+
+        allclose_or_fail(flow_res, torch_res)
+        allclose_or_fail(flow_input_original.grad, torch_input_original.grad)
+        if api_flag == TEST_MODULE:
+            flow_parameters = dict(flow_call.named_parameters())
+            for name, torch_param in torch_call.named_parameters():
+                flow_param = flow_parameters[name]
+                allclose_or_fail(flow_param.grad, torch_param.grad)
+        if verbose:
+            print("test passed")
+        n -= 1
+
+
+def test_module_against_pytorch(
+    test_case,
+    callable_name,
+    extra_annotations: Optional[Dict[str, Any]] = None,
+    extra_generators: Optional[Dict[str, Any]] = None,
+    extra_defaults: Optional[Dict[str, Any]] = None,
+    device: str = "cuda",
+    training: bool = True,
+    backward: bool = True,
+    rtol=1e-4,
+    atol=1e-5,
+    n=20,
+    pytorch_callable_name=None,
+):
+    return test_against_pytorch(
+        test_case=test_case,
+        callable_name=callable_name,
+        extra_annotations=extra_annotations,
+        extra_generators=extra_generators,
+        extra_defaults=extra_defaults,
+        device=device,
+        training=training,
+        backward=backward,
+        rtol=rtol,
+        atol=atol,
+        n=n,
+        pytorch_callable_name=pytorch_callable_name,
+        api_flag=TEST_MODULE,
+    )
+
+
+def test_flow_against_pytorch(
+    test_case,
+    callable_name,
+    extra_annotations: Optional[Dict[str, Any]] = None,
+    extra_generators: Optional[Dict[str, Any]] = None,
+    extra_defaults: Optional[Dict[str, Any]] = None,
+    device: str = "cuda",
+    training: bool = True,
+    backward: bool = True,
+    rtol=1e-4,
+    atol=1e-5,
+    n=20,
+    pytorch_callable_name=None,
+):
+    return test_against_pytorch(
+        test_case=test_case,
+        callable_name=callable_name,
+        extra_annotations=extra_annotations,
+        extra_generators=extra_generators,
+        extra_defaults=extra_defaults,
+        device=device,
+        training=training,
+        backward=backward,
+        rtol=rtol,
+        atol=atol,
+        n=n,
+        pytorch_callable_name=pytorch_callable_name,
+        api_flag=TEST_FLOW,
+    )
+
+
+def test_tensor_against_pytorch(
+    test_case,
+    callable_name,
+    extra_annotations: Optional[Dict[str, Any]] = None,
+    extra_generators: Optional[Dict[str, Any]] = None,
+    extra_defaults: Optional[Dict[str, Any]] = None,
+    device: str = "cuda",
+    training: bool = True,
+    backward: bool = True,
+    rtol=1e-4,
+    atol=1e-5,
+    n=20,
+    pytorch_callable_name=None,
+):
+    return test_against_pytorch(
+        test_case=test_case,
+        callable_name=callable_name,
+        extra_annotations=extra_annotations,
+        extra_generators=extra_generators,
+        extra_defaults=extra_defaults,
+        device=device,
+        training=training,
+        backward=backward,
+        rtol=rtol,
+        atol=atol,
+        n=n,
+        pytorch_callable_name=pytorch_callable_name,
+        api_flag=TEST_TENSOR,
+    )
+
+
+__all__ = [
+    "random_tensor",
+    "random_bool",
+    "random_device",
+    "random",
+    "random_or_nothing",
+    "constant",
+    "nothing",
+    "test_module_against_pytorch",
+    "test_flow_against_pytorch",
+    "test_tensor_against_pytorch",
+]
diff --git a/oneflow/python/test_utils/automated_test_util/torch_flow_dual_object.py b/oneflow/python/test_utils/automated_test_util/torch_flow_dual_object.py
new file mode 100644
index 0000000000000000000000000000000000000000..37594b94d4e987215a4f46731e5de9ff645a1367
--- /dev/null
+++ b/oneflow/python/test_utils/automated_test_util/torch_flow_dual_object.py
@@ -0,0 +1,313 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import collections.abc
+import inspect
+import functools
+import os
+
+import torch as torch_original
+import oneflow as flow_stable
+import oneflow.experimental as flow
+import numpy as np
+from .generators import generator, random_tensor, Nothing
+
+
+postulate = [".rand", ".Tensor"]
+
+
+def torch_tensor_to_flow(x):
+    return flow.tensor(x.cpu().numpy())
+
+
+class PyTorchDoesNotSupportError(Exception):
+    def __init__(self, exc):
+        self.exc = exc
+
+    def __str__(self):
+        return repr(self)
+
+    def __repr__(self):
+        return f"PyTorch error: {str(self.exc)}"
+
+
+def get_args(callable, *args, **kwargs):
+    try:
+        spec = inspect.getfullargspec(callable)
+        spec_args = spec.args
+        if spec_args[0] == "self":
+            del spec_args[0]
+        for i, arg in enumerate(args):
+            arg_name = spec_args[i]
+            annotation = spec.annotations[arg_name]
+            if isinstance(arg, generator):
+                arg.to(annotation)
+        for arg_name, arg in kwargs.items():
+            annotation = spec.annotations[arg_name]
+            if isinstance(arg, generator):
+                arg.to(annotation)
+    except:
+        pass
+    pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs = [], {}, [], {}
+
+    def get_pytorch_value(x):
+        if isinstance(x, DualObject):
+            return x.pytorch
+        return x
+
+    def get_oneflow_value(x):
+        if isinstance(x, DualObject):
+            return x.oneflow
+        return x
+
+    def get_generator_value(x):
+        if isinstance(x, generator):
+            return x.value()
+        return x
+
+    for arg in args:
+        arg = get_generator_value(arg)
+        pytorch_args.append(get_pytorch_value(arg))
+        oneflow_args.append(get_oneflow_value(arg))
+    for key, value in kwargs.items():
+        value = get_generator_value(value)
+        if isinstance(value, Nothing):
+            continue
+        pytorch_kwargs[key] = get_pytorch_value(value)
+        oneflow_kwargs[key] = get_oneflow_value(value)
+    return pytorch_args, pytorch_kwargs, oneflow_args, oneflow_kwargs
+
+
+counter = 0
+
+
+def GetDualObject(name, pytorch, oneflow):
+    global counter
+    counter += 1
+
+    skipped_magic_methods = [
+        "__class__",
+        "__mro__",
+        "__new__",
+        "__init__",
+        "__getattr__",
+        "__setattr__",
+        "__getattribute__",
+        "__dict__",
+        "__weakref__",
+        "__builtins__",
+        "__qualname__",
+        "__name__",
+        "__str__",
+        "__repr__",
+    ]
+
+    pytorch_methods = dir(pytorch)
+    if hasattr(pytorch, "__call__") and "__call__" not in pytorch_methods:
+        pytorch_methods.append("__call__")
+
+    magic_methods_for_new_cls = {}
+
+    for method_name in pytorch_methods:
+        if method_name.startswith("__") and method_name not in skipped_magic_methods:
+            # init a new 'method_name' variable other than the one in for loop,
+            # avoid a pitfall:
+            # https://python.plainenglish.io/python-pitfalls-with-variable-capture-dcfc113f39b7
+            def get_dual_method(method_name):
+                # __call__ is special. We should not delegate the '__call__' of the torch wrapper of class 'nn.Conv2d'
+                # to 'nn.Conv2d.__call__', as 'nn.Conv2d.__call__' belongs to the object of type 'nn.Conv2d'
+                # (not the class itself)
+                if method_name == "__call__":
+
+                    def dual_method(self, *args, **kwargs):
+                        (
+                            pytorch_args,
+                            pytorch_kwargs,
+                            oneflow_args,
+                            oneflow_kwargs,
+                        ) = get_args(pytorch, *args, **kwargs)
+                        # use () instead of '__call__'
+                        try:
+                            pytorch_res = pytorch(*pytorch_args, **pytorch_kwargs)
+                        except Exception as e:
+                            raise PyTorchDoesNotSupportError(e)
+                        # only check if the method is a postulate when it is called
+                        if name in postulate:
+                            oneflow_res = torch_tensor_to_flow(pytorch_res)
+                        else:
+                            oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs)
+                        return GetDualObject("unused", pytorch_res, oneflow_res)
+
+                else:
+
+                    def dual_method(self, *args, **kwargs):
+                        pytorch_method = getattr(pytorch, method_name)
+                        oneflow_method = getattr(oneflow, method_name)
+                        (
+                            pytorch_args,
+                            pytorch_kwargs,
+                            oneflow_args,
+                            oneflow_kwargs,
+                        ) = get_args(pytorch_method, *args, **kwargs)
+                        try:
+                            pytorch_res = pytorch_method(
+                                *pytorch_args, **pytorch_kwargs
+                            )
+                        except Exception as e:
+                            raise PyTorchDoesNotSupportError(e)
+                        oneflow_res = oneflow_method(*oneflow_args, **oneflow_kwargs)
+                        return GetDualObject("unused", pytorch_res, oneflow_res)
+
+                return dual_method
+
+            magic_methods_for_new_cls[method_name] = get_dual_method(method_name)
+
+    Cls = type(f"{name}_{counter}", (DualObject,), magic_methods_for_new_cls)
+    return Cls(name, pytorch, oneflow)
+
+
+class DualObject:
+    def __init__(self, name, pytorch, oneflow):
+        self.name = name
+        self.pytorch = pytorch
+        self.oneflow = oneflow
+
+        if isinstance(pytorch, torch_original.nn.Module):
+            state_dict = pytorch.state_dict()
+            state_dict = {k: v.detach().cpu().numpy() for k, v in state_dict.items()}
+            oneflow.load_state_dict(state_dict)
+            dual_modules_to_test.append(self)
+
+        if isinstance(pytorch, torch_original.Tensor):
+            dual_objects_to_test.append(self)
+
+    def __repr__(self):
+        return f"PyTorch object:\n{self.pytorch}\n\nOneFlow object:\n{self.oneflow}"
+
+    def __getattr__(self, key):
+        pytorch_attr = getattr(self.pytorch, key)
+        oneflow_attr = getattr(self.oneflow, key)
+        new_name = f"{self.name}.{key}"
+
+        return GetDualObject(new_name, pytorch_attr, oneflow_attr)
+
+
+dual_modules_to_test = []
+dual_objects_to_test = []
+torch_type2checker = {}
+
+
+def equality_checker(torch_type, flow_type):
+    def deco(f):
+        torch_type2checker[(torch_type, flow_type)] = f
+        return f
+
+    return deco
+
+
+def check_equality(dual_object: DualObject):
+    checker = torch_type2checker.get(
+        (type(dual_object.pytorch), type(dual_object.oneflow)), None
+    )
+    if checker is None:
+        for key, value in torch_type2checker.items():
+            if isinstance(dual_object.pytorch, key[0]) and isinstance(
+                dual_object.oneflow, key[1]
+            ):
+                checker = value
+                break
+    assert checker is not None
+    return checker(dual_object.pytorch, dual_object.oneflow)
+
+
+@equality_checker(torch_original.Tensor, flow.Tensor)
+@equality_checker(torch_original.Tensor, flow_stable._oneflow_internal.Tensor)
+def check_tensor_equality(torch_tensor, flow_tensor):
+    # TODO: check dtype
+    if torch_tensor.grad is not None:
+        assert (
+            flow_tensor.grad is not None
+        ), "OneFlow tensor doesn't have grad while PyTorch tensor has one"
+        if not np.allclose(
+            torch_tensor.grad.detach().cpu().numpy(), flow_tensor.grad.numpy()
+        ):
+            return False
+    return np.allclose(torch_tensor.detach().cpu().numpy(), flow_tensor.numpy())
+
+
+def autotest(n=20, auto_backward=True, rtol=1e-4, atol=1e-5):
+    verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None
+
+    def deco(f):
+        @functools.wraps(f)
+        def new_f(test_case):
+            nonlocal n
+            while n > 0:
+                dual_modules_to_test.clear()
+                dual_objects_to_test.clear()
+                try:
+                    res = f(test_case)
+                except PyTorchDoesNotSupportError as e:
+                    if verbose:
+                        print(e)
+                    continue
+                # TODO: support types other than Tensor, like torch.Size/flow.Size
+                if res is not None:
+                    if not isinstance(res, collections.abc.Sequence):
+                        res = [res]
+                    for x in res:
+                        if auto_backward:
+                            if isinstance(x.pytorch, torch_original.Tensor):
+                                x.sum().backward()
+                        dual_objects_to_test.append(x)
+                for x in dual_modules_to_test:
+                    # x.state_dict().values() returns dual object with inconsistent values
+                    for key in x.pytorch.state_dict().keys():
+                        dual_objects_to_test.append(
+                            GetDualObject(
+                                "unused",
+                                x.pytorch.state_dict()[key],
+                                x.oneflow.state_dict()[key],
+                            )
+                        )
+                for x in dual_objects_to_test:
+                    test_case.assertTrue(check_equality(x))
+                if verbose:
+                    print("test passed")
+                n -= 1
+
+        return new_f
+
+    return deco
+
+
+def random_pytorch_tensor(
+    ndim=None, dim0=1, dim1=None, dim2=None, dim3=None, dim4=None, requires_grad=True
+):
+    if isinstance(requires_grad, generator):
+        requires_grad = requires_grad.value()
+    pytorch_tensor = (
+        random_tensor(ndim, dim0, dim1, dim2, dim3, dim4)
+        .value()
+        .requires_grad_(requires_grad)
+    )
+    flow_tensor = flow.tensor(pytorch_tensor.detach().cpu().numpy(), requires_grad=True)
+    return GetDualObject("unused", pytorch_tensor, flow_tensor)
+
+
+torch = GetDualObject("", torch_original, flow)
+
+
+__all__ = ["torch", "autotest", "random_pytorch_tensor"]
diff --git a/tools/check_src.py b/tools/check_src.py
index b7a9b396118914493c1dea3e8d738079dee2d277..7177f834429ec25a10ecdf3f3889d5fc928e0cb0 100644
--- a/tools/check_src.py
+++ b/tools/check_src.py
@@ -12,7 +12,10 @@ def check_unwanted_test_scripts(python_test_dir=None, allowed=None):
         os.path.relpath(os.path.join(python_test_dir, a), src_root) for a in allowed
     ]
     for (dirpath, dirnames, filenames) in os.walk(src_root):
-        if python_test_dir in dirpath and "__pycache__" not in dirpath:
+        if (
+            dirpath.startswith(os.path.abspath(python_test_dir) + os.sep)
+            and "__pycache__" not in dirpath
+        ):
             rel_to_python_test = os.path.relpath(dirpath, python_test_dir)
             rel_to_src_root = os.path.relpath(dirpath, src_root)
             print(f"checking: {rel_to_src_root}")
@@ -39,7 +42,7 @@ def check_unwanted_test_scripts(python_test_dir=None, allowed=None):
 
 check_unwanted_test_scripts(
     python_test_dir=os.path.join(src_root, "oneflow/python/test"),
-    allowed=["custom_ops", "dataloader", "graph", "models", "modules", "tensor",],
+    allowed=["custom_ops", "dataloader", "graph", "models", "modules", "tensor"],
 )
 
 check_unwanted_test_scripts(