Skip to content
Snippets Groups Projects
Unverified Commit 6e2e6693 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add autotest part2 (#5467)


* add argmax test

* fix ci error

* fix docstring warning

* fix tensor greater and less bug

* fix conflict

* add test_flow_xxx_against_pytorch func

* add asinh and arcsinh test

* add asinh and arcsinh test

* add sinh test

* add atan2 module test

* add softplus test, softplus module has bug

* comment softplus

* add elu module autotest

* refine

* auto format by CI

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
parent bdd1100d
No related branches found
No related tags found
No related merge requests found
......@@ -220,6 +220,10 @@ def test_against_pytorch(
if rng.random() < 1 / 3:
continue
flow_data, torch_data = generate(name)
if isinstance(torch_data, torch.Tensor):
torch_data = torch_data.to(device)
if isinstance(flow_data, flow.Tensor):
flow_data = flow_data.to(device)
flow_attr_dict[name] = flow_data
torch_attr_dict[name] = torch_data
......@@ -291,7 +295,7 @@ def test_against_pytorch(
)
test_case.assertTrue(
is_allclose,
f"flow_tensor = {flow_tensor},\ntorch_tensor = {torch_tensor},\nattr_dict = {torch_attr_dict}",
f"flow_tensor = {flow_tensor},\ntorch_tensor = {torch_tensor},\nattr_dict = {torch_attr_dict},\nflow_input_tensor = {flow_input_original}",
)
allclose_or_fail(flow_res, torch_res)
......
......@@ -214,6 +214,16 @@ class TestELUModule(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
_test_elu_function_impl(test_case, *arg)
def test_elu_module_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_module_against_pytorch(
test_case,
"nn.ELU",
extra_annotations={"alpha": float},
extra_generators={"alpha": random(0, 6)},
device=device,
)
def _np_gelu(x):
return 0.5 * x * (1 + special.erf(x / np.sqrt(2)))
......@@ -693,6 +703,17 @@ class TestSoftplusModule(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@unittest.skip("Pytorch Softplus has bug")
def test_softplus_module_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_module_against_pytorch(
test_case,
"nn.Softplus",
extra_annotations={"beta": int, "threshold": int},
extra_generators={"beta": random(3, 4), "threshold": random(1, 2)},
device=device,
)
def _test_hardswish_impl(test_case, shape, device):
m = flow.nn.Hardswish()
......
......@@ -20,6 +20,7 @@ import numpy as np
from test_util import GenArgList
import oneflow.experimental as flow
from automated_test_util import *
def _test_atan2_forward(test_case, shape, scalar, device):
......@@ -120,6 +121,19 @@ class TestAtan2(flow.unittest.TestCase):
for arg in GenArgList(arg_dict):
_test_atan2_backward(test_case, *arg)
def test_flow_atan2_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_flow_against_pytorch(
test_case,
"atan2",
extra_annotations={"other": flow.Tensor},
extra_generators={
"input": random_tensor(ndim=1, dim1=1),
"other": random_tensor(ndim=1, dim1=1),
},
device=device,
)
if __name__ == "__main__":
unittest.main()
......@@ -78,6 +78,47 @@ class TestVariance(flow.unittest.TestCase):
arg[0](test_case, *arg[1:])
def _test_sinh_impl(test_case, shape, device):
np_input = np.random.randn(*shape)
of_input = flow.Tensor(
np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
)
np_x_grad = np.cosh(np_input)
of_out = flow.sinh(of_input)
np_out = np.sinh(np_input)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
of_out = of_out.sum()
of_out.backward()
test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_x_grad, 1e-4, 1e-4))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class Testsinh(flow.unittest.TestCase):
def test_sinh(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(2, 3), (2, 4, 5, 6)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_sinh_impl(test_case, *arg)
def test_flow_sinh_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_flow_against_pytorch(
test_case, "sinh", device=device,
)
def test_flow_tensor_sinh_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_tensor_against_pytorch(
test_case, "sinh", device=device,
)
def _test_sin(test_case, shape, device):
input = flow.Tensor(np.random.randn(*shape), device=flow.device(device))
of_out = flow.sin(input)
......@@ -522,6 +563,30 @@ class TestAsinh(flow.unittest.TestCase):
_test_asinh(test_case, *arg)
_test_arcsinh(test_case, *arg)
def test_flow_asinh_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_flow_against_pytorch(
test_case, "asinh", device=device,
)
def test_flow_arcsinh_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_flow_against_pytorch(
test_case, "arcsinh", device=device,
)
def test_flow_tensor_asinh_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_tensor_against_pytorch(
test_case, "asinh", device=device,
)
def test_flow_tensor_arcsinh_with_random_data(test_case):
for device in ["cpu", "cuda"]:
test_tensor_against_pytorch(
test_case, "arcsinh", device=device,
)
def _topk_np(input, k, dim: int = None, largest: bool = True, _sorted: bool = True):
in_dims = input.shape
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
from test_util import GenArgList
import oneflow.experimental as flow
def _test_sinh_impl(test_case, shape, device):
np_input = np.random.randn(*shape)
of_input = flow.Tensor(
np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
)
np_x_grad = np.cosh(np_input)
of_out = flow.sinh(of_input)
np_out = np.sinh(np_input)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
of_out = of_out.sum()
of_out.backward()
test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_x_grad, 1e-4, 1e-4))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class Testsinh(flow.unittest.TestCase):
def test_sinh(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(2, 3), (2, 4, 5, 6)]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_sinh_impl(test_case, *arg)
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment