From b0dd35e5e23a5e40016c679563a2bded99d35d1e Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Wed, 14 Jul 2021 22:43:01 +0800 Subject: [PATCH] Fix default value not set bug (#5483) * add argmax test * fix ci error * fix docstring warning * fix tensor greater and less bug * fix conflict * add test_flow_xxx_against_pytorch func * fix default_value_not_set_bug * fix bug * auto format by CI Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../test/modules/automated_test_util.py | 26 +++++++++++++++---- .../python/test/modules/test_constantpad2d.py | 18 +++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/oneflow/python/test/modules/automated_test_util.py b/oneflow/python/test/modules/automated_test_util.py index 0459756ab..7359cf3f7 100644 --- a/oneflow/python/test/modules/automated_test_util.py +++ b/oneflow/python/test/modules/automated_test_util.py @@ -174,13 +174,29 @@ def test_against_pytorch( else: pytorch_call = eval(f"torch.{pytorch_callable_name}") + Spec = namedtuple( + "spec", + "args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations", + ) + if has_full_args_spec(pytorch_call): - spec = inspect.getfullargspec(pytorch_call) - else: - Spec = namedtuple( - "spec", - "args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations", + tmp_spec = inspect.getfullargspec(pytorch_call) + new_defaults = tmp_spec.defaults + if new_defaults is None: + new_defaults = [] + new_kwonlydefaults = tmp_spec.kwonlydefaults + if new_kwonlydefaults is None: + new_kwonlydefaults = [] + spec = Spec( + tmp_spec.args, + tmp_spec.varargs, + tmp_spec.varkw, + new_defaults, + tmp_spec.kwonlyargs, + new_kwonlydefaults, + tmp_spec.annotations, ) + else: args = list(extra_annotations.keys()) + list(extra_defaults.keys()) spec = Spec(args, None, None, [], [], {}, {}) diff --git a/oneflow/python/test/modules/test_constantpad2d.py b/oneflow/python/test/modules/test_constantpad2d.py index a29e7c96a..4f351ab8a 100644 --- a/oneflow/python/test/modules/test_constantpad2d.py +++ b/oneflow/python/test/modules/test_constantpad2d.py @@ -19,6 +19,7 @@ from collections import OrderedDict import numpy as np import oneflow.experimental as flow +from automated_test_util import * from test_util import ( GenArgList, FlattenArray, @@ -111,6 +112,23 @@ class TestConstantPad2dModule(flow.unittest.TestCase): for arg in GenArgList(arg_dict): _test_ConstantPad2d(test_case, *arg) + def test_with_random_data(test_case): + for device in ["cpu", "cuda"]: + spatial_size = np.random.randint(10, 20) + test_module_against_pytorch( + test_case, + "nn.ConstantPad2d", + extra_annotations={"padding": int, "value": float}, + extra_generators={ + "input": random_tensor( + ndim=4, dim2=spatial_size, dim3=spatial_size + ), + "padding": random(0, 6), + "value": random(0, 6), + }, + device=device, + ) + if __name__ == "__main__": unittest.main() -- GitLab