Skip to content
Snippets Groups Projects
Unverified Commit 41786ad7 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add flow xxx and tensor xxx autotest (#5386)


* add argmax test

* fix ci error

* fix docstring warning

* fix tensor greater and less bug

* fix conflict

* add test_flow_xxx_against_pytorch func

* add test_flow_tensor_xxx_against_pytorch func

* add test

* fix first comments

* code reuse

* fix comments

* fix comments

* fix conv test bug

* fix conv test bug

* fix comments

* fix comment

* fix comment

* fix comment

* fix comment

* fix bug

* fix comment

* fix comment

* fix comments

* fix comments

* fix comments

* fix comments

* fix comments

* code format

* refine autotest

* fix comments

* fix comments

* fix comment

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent b05dd645
No related branches found
No related tags found
No related merge requests found
......@@ -13,10 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import inspect
import typing # This unused import is needed
from typing import Dict, Optional, Tuple, Any, Union
from collections import namedtuple
import random as random_util
import os
......@@ -24,6 +24,10 @@ import oneflow.experimental as flow
import torch
import numpy as np
TEST_MODULE = 0
TEST_FLOW = 1
TEST_TENSOR = 2
rng = np.random.default_rng()
default_generators = {}
......@@ -100,13 +104,14 @@ def random(low, high):
if annotation.__origin__ is Union:
x = random_util.choice(annotation.__args__)
return generator(x)
if annotation.__origin__ is Tuple:
if annotation.__origin__ is Tuple or annotation.__origin__ is tuple:
t = [generator(x) for x in annotation.__args__]
return zip(*t)
else:
raise NotImplementedError(
f"Not implemented annotation {annotation} in random, type(annotation.__origin__) is {type(annotation.__origin__)}"
)
if annotation == int:
val = int(rng.integers(low, high))
elif annotation == float:
......@@ -127,18 +132,20 @@ def constant(val):
return generator
def test_module_against_pytorch(
def test_against_pytorch(
test_case,
module_class_name,
callable_name,
extra_annotations: Optional[Dict[str, Any]] = None,
extra_generators: Optional[Dict[str, Any]] = None,
extra_defaults: Optional[Dict[str, Any]] = None,
device: str = "cuda",
training: bool = True,
backward: bool = True,
rtol=1e-4,
atol=1e-5,
n=20,
pytorch_module_class_name=None,
pytorch_callable_name=None,
api_flag: int = TEST_MODULE,
):
assert device in ["cuda", "cpu"]
if not training:
......@@ -147,18 +154,44 @@ def test_module_against_pytorch(
extra_annotations = {}
if extra_generators is None:
extra_generators = {}
if pytorch_module_class_name is None:
pytorch_module_class_name = module_class_name
if extra_defaults is None:
extra_defaults = {}
if pytorch_callable_name is None:
pytorch_callable_name = callable_name
verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None
torch_module_class = eval(f"torch.{pytorch_module_class_name}")
spec = inspect.getfullargspec(torch_module_class)
def has_full_args_spec(callable):
try:
spec = inspect.getfullargspec(callable)
return True
except Exception:
return False
if api_flag == TEST_TENSOR:
pytorch_tensor = torch.Tensor(1)
pytorch_call = eval(f"pytorch_tensor.{pytorch_callable_name}")
else:
pytorch_call = eval(f"torch.{pytorch_callable_name}")
if has_full_args_spec(pytorch_call):
spec = inspect.getfullargspec(pytorch_call)
else:
Spec = namedtuple(
"spec",
"args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations",
)
args = list(extra_annotations.keys()) + list(extra_defaults.keys())
spec = Spec(args, None, None, [], [], {}, {})
annotations = spec.annotations
annotations.update(extra_annotations)
if "return" in annotations:
del annotations["return"]
args = (set(spec.args) | set(spec.kwonlyargs)) - {"self"}
assert args == set(
annotations.keys()
), f"args = {args}, annotations = {annotations.keys()}"
......@@ -201,14 +234,30 @@ def test_module_against_pytorch(
torch_input_original.to(device),
)
try:
torch_module = torch_module_class(**torch_attr_dict)
torch_module = torch_module.to(device)
torch_module.train(training)
torch_res = torch_module(torch_input)
if api_flag == TEST_MODULE:
torch_call = pytorch_call(**torch_attr_dict)
torch_call = torch_call.to(device)
torch_call.train(training)
torch_res = torch_call(torch_input)
state_dict = torch_call.state_dict()
state_dict = {
k: v.detach().cpu().numpy() for k, v in state_dict.items()
}
elif api_flag == TEST_FLOW:
torch_xxx_func = eval(f"torch.{pytorch_callable_name}")
torch_res = torch_xxx_func(torch_input, **torch_attr_dict)
else:
torch_tensor_xxx_func = eval(f"torch_input.{pytorch_callable_name}")
torch_res = torch_tensor_xxx_func(**torch_attr_dict)
loss = torch_res.sum()
loss.backward()
state_dict = torch_module.state_dict()
state_dict = {k: v.detach().cpu().numpy() for k, v in state_dict.items()}
if api_flag == TEST_MODULE:
state_dict = torch_call.state_dict()
state_dict = {
k: v.detach().cpu().numpy() for k, v in state_dict.items()
}
except Exception as e:
if verbose:
print(f"PyTorch error: {e}")
......@@ -216,12 +265,20 @@ def test_module_against_pytorch(
# so just skip when PyTorch raises an exception
continue
flow_module_class = eval(f"flow.{module_class_name}")
flow_module = flow_module_class(**flow_attr_dict)
flow_module = flow_module.to(device)
flow_module.train(training)
flow_module.load_state_dict(state_dict)
flow_res = flow_module(flow_input)
if api_flag == TEST_MODULE:
flow_call_class = eval(f"flow.{callable_name}")
flow_call = flow_call_class(**flow_attr_dict)
flow_call = flow_call.to(device)
flow_call.train(training)
flow_call.load_state_dict(state_dict)
flow_res = flow_call(flow_input)
elif api_flag == TEST_FLOW:
flow_xxx_func = eval(f"flow.{callable_name}")
flow_res = flow_xxx_func(flow_input, **flow_attr_dict)
else:
flow_tensor_xxx_func = eval(f"flow_input.{callable_name}")
flow_res = flow_tensor_xxx_func(**flow_attr_dict)
loss = flow_res.sum()
loss.backward()
......@@ -239,17 +296,113 @@ def test_module_against_pytorch(
allclose_or_fail(flow_res, torch_res)
allclose_or_fail(flow_input_original.grad, torch_input_original.grad)
flow_parameters = dict(flow_module.named_parameters())
for name, torch_param in torch_module.named_parameters():
flow_param = flow_parameters[name]
allclose_or_fail(flow_param.grad, torch_param.grad)
if api_flag == TEST_MODULE:
flow_parameters = dict(flow_call.named_parameters())
for name, torch_param in torch_call.named_parameters():
flow_param = flow_parameters[name]
allclose_or_fail(flow_param.grad, torch_param.grad)
n -= 1
def test_module_against_pytorch(
test_case,
callable_name,
extra_annotations: Optional[Dict[str, Any]] = None,
extra_generators: Optional[Dict[str, Any]] = None,
extra_defaults: Optional[Dict[str, Any]] = None,
device: str = "cuda",
training: bool = True,
backward: bool = True,
rtol=1e-4,
atol=1e-5,
n=20,
pytorch_callable_name=None,
):
return test_against_pytorch(
test_case=test_case,
callable_name=callable_name,
extra_annotations=extra_annotations,
extra_generators=extra_generators,
extra_defaults=extra_defaults,
device=device,
training=training,
backward=backward,
rtol=rtol,
atol=atol,
n=n,
pytorch_callable_name=pytorch_callable_name,
api_flag=TEST_MODULE,
)
def test_flow_against_pytorch(
test_case,
callable_name,
extra_annotations: Optional[Dict[str, Any]] = None,
extra_generators: Optional[Dict[str, Any]] = None,
extra_defaults: Optional[Dict[str, Any]] = None,
device: str = "cuda",
training: bool = True,
backward: bool = True,
rtol=1e-4,
atol=1e-5,
n=20,
pytorch_callable_name=None,
):
return test_against_pytorch(
test_case=test_case,
callable_name=callable_name,
extra_annotations=extra_annotations,
extra_generators=extra_generators,
extra_defaults=extra_defaults,
device=device,
training=training,
backward=backward,
rtol=rtol,
atol=atol,
n=n,
pytorch_callable_name=pytorch_callable_name,
api_flag=TEST_FLOW,
)
def test_tensor_against_pytorch(
test_case,
callable_name,
extra_annotations: Optional[Dict[str, Any]] = None,
extra_generators: Optional[Dict[str, Any]] = None,
extra_defaults: Optional[Dict[str, Any]] = None,
device: str = "cuda",
training: bool = True,
backward: bool = True,
rtol=1e-4,
atol=1e-5,
n=20,
pytorch_callable_name=None,
):
return test_against_pytorch(
test_case=test_case,
callable_name=callable_name,
extra_annotations=extra_annotations,
extra_generators=extra_generators,
extra_defaults=extra_defaults,
device=device,
training=training,
backward=backward,
rtol=rtol,
atol=atol,
n=n,
pytorch_callable_name=pytorch_callable_name,
api_flag=TEST_TENSOR,
)
__all__ = [
"random_tensor",
"random",
"choose",
"constant",
"test_module_against_pytorch",
"test_flow_against_pytorch",
"test_tensor_against_pytorch",
]
......@@ -22,6 +22,7 @@ import oneflow.experimental as flow
from test_util import GenArgList
from automated_test_util import *
def _test_abs_forward(test_case, device):
......@@ -75,6 +76,18 @@ class TestAbs(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
def test_flow_abs_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_flow_against_pytorch(
test_case, "abs", device=device,
)
def test_flow_tensor_abs_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_tensor_against_pytorch(
test_case, "abs", device=device,
)
if __name__ == "__main__":
unittest.main()
......@@ -20,6 +20,7 @@ import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList
from automated_test_util import *
def _test_squeeze(test_case, device):
......@@ -101,6 +102,26 @@ class TestSqueeze(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
def test_flow_squeeze_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_flow_against_pytorch(
test_case,
"squeeze",
extra_annotations={"dim": int,},
extra_generators={"dim": random(0, 6)},
device=device,
)
def test_flow_tensor_squeeze_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_tensor_against_pytorch(
test_case,
"squeeze",
extra_annotations={"dim": int},
extra_generators={"dim": random(0, 6)},
device=device,
)
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment