Skip to content
Snippets Groups Projects
Unverified Commit b0dd35e5 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Fix default value not set bug (#5483)


* add argmax test

* fix ci error

* fix docstring warning

* fix tensor greater and less bug

* fix conflict

* add test_flow_xxx_against_pytorch func

* fix default_value_not_set_bug

* fix bug

* auto format by CI

Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent f15297bc
No related branches found
No related tags found
No related merge requests found
......@@ -174,13 +174,29 @@ def test_against_pytorch(
else:
pytorch_call = eval(f"torch.{pytorch_callable_name}")
Spec = namedtuple(
"spec",
"args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations",
)
if has_full_args_spec(pytorch_call):
spec = inspect.getfullargspec(pytorch_call)
else:
Spec = namedtuple(
"spec",
"args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations",
tmp_spec = inspect.getfullargspec(pytorch_call)
new_defaults = tmp_spec.defaults
if new_defaults is None:
new_defaults = []
new_kwonlydefaults = tmp_spec.kwonlydefaults
if new_kwonlydefaults is None:
new_kwonlydefaults = []
spec = Spec(
tmp_spec.args,
tmp_spec.varargs,
tmp_spec.varkw,
new_defaults,
tmp_spec.kwonlyargs,
new_kwonlydefaults,
tmp_spec.annotations,
)
else:
args = list(extra_annotations.keys()) + list(extra_defaults.keys())
spec = Spec(args, None, None, [], [], {}, {})
......
......@@ -19,6 +19,7 @@ from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from automated_test_util import *
from test_util import (
GenArgList,
FlattenArray,
......@@ -111,6 +112,23 @@ class TestConstantPad2dModule(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
_test_ConstantPad2d(test_case, *arg)
def test_with_random_data(test_case):
for device in ["cpu", "cuda"]:
spatial_size = np.random.randint(10, 20)
test_module_against_pytorch(
test_case,
"nn.ConstantPad2d",
extra_annotations={"padding": int, "value": float},
extra_generators={
"input": random_tensor(
ndim=4, dim2=spatial_size, dim3=spatial_size
),
"padding": random(0, 6),
"value": random(0, 6),
},
device=device,
)
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment