diff --git a/python/oneflow/nn/modules/padding.py b/python/oneflow/nn/modules/padding.py index 5283c5c32b5ac40ad13fc42e0497ba06820af246..24031436630823fc649648c9744ceac2ac200007 100644 --- a/python/oneflow/nn/modules/padding.py +++ b/python/oneflow/nn/modules/padding.py @@ -17,6 +17,8 @@ from typing import Union import oneflow as flow from oneflow.nn.module import Module +from oneflow.nn.common_types import _size_4_t +from oneflow.nn.modules.utils import _quadruple class ReplicationPad2d(Module): @@ -77,30 +79,19 @@ class ReplicationPad2d(Module): """ - def __init__(self, padding: Union[int, tuple, list]): + def __init__(self, padding: _size_4_t): super().__init__() if isinstance(padding, (tuple, list)): assert len(padding) == 4, ValueError("Length of padding must be 4") - boundary = [padding[0], padding[1], padding[2], padding[3]] + boundary = [*padding] elif isinstance(padding, int): - boundary = [padding, padding, padding, padding] + boundary = _quadruple(padding) else: raise ValueError("padding must be int or list or tuple!") self.padding = boundary def forward(self, x): - (_, _, h, w) = x.shape - if ( - self.padding[2] < h - and self.padding[3] < h - and (self.padding[0] < w) - and (self.padding[1] < w) - ): - return flow.F.pad(x, pad=self.padding, mode="replicate") - else: - raise AssertionError( - "Padding size should be less than the corresponding input dimension. Please check." - ) + return flow.F.pad(x, pad=self.padding, mode="replicate") def extra_repr(self) -> str: return "{}".format(self.padding) @@ -152,13 +143,13 @@ class ReflectionPad2d(Module): """ - def __init__(self, padding: Union[int, tuple]) -> None: + def __init__(self, padding: _size_4_t) -> None: super().__init__() if isinstance(padding, tuple): assert len(padding) == 4, ValueError("Padding length must be 4") - boundary = [padding[0], padding[1], padding[2], padding[3]] + boundary = [*padding] elif isinstance(padding, int): - boundary = [padding, padding, padding, padding] + boundary = _quadruple(padding) else: raise ValueError("padding must be in or list or tuple!") self.padding = boundary diff --git a/python/oneflow/test/modules/test_arange.py b/python/oneflow/test/modules/test_arange.py index 9c4b92ffaafd5e1f2fa82f62b5c09454cb2c087f..e62fbd2ffe9661244f4bb7b8dd7ab498e71ddb54 100644 --- a/python/oneflow/test/modules/test_arange.py +++ b/python/oneflow/test/modules/test_arange.py @@ -22,6 +22,7 @@ from test_util import GenArgList import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_arange(test_case, device): @@ -53,7 +54,7 @@ def _test_arange_backward(test_case, device): @flow.unittest.skip_unless_1n1d() class TestArange(flow.unittest.TestCase): - def test_transpose(test_case): + def test_arange(test_case): arg_dict = OrderedDict() arg_dict["function_test"] = [ _test_arange, @@ -65,6 +66,16 @@ class TestArange(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5) + def test_arange_with_random_data(test_case): + start = random().to(int) + end = start + random().to(int) + step = random(0, end - start).to(int) + x = torch.arange(start=start, end=end, step=step) + device = random_device() + x.to(device) + return x + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_argmax.py b/python/oneflow/test/modules/test_argmax.py index d1299eaa8cf764d8529236045a99d03431a43d30..316edd4a4d4be23b5e42356f4770f92cdd27b0d7 100644 --- a/python/oneflow/test/modules/test_argmax.py +++ b/python/oneflow/test/modules/test_argmax.py @@ -22,6 +22,7 @@ from test_util import GenArgList import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_argmax_aixs_negative(test_case, device): @@ -91,6 +92,14 @@ class TestArgmax(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5) + def test_argmax_with_random_data(test_case): + device = random_device() + ndim = random(1, 6).to(int) + x = random_pytorch_tensor(ndim=ndim).to(device) + y = torch.argmax(x, dim=random(0, ndim).to(int), keepdim=random().to(bool)) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_argwhere.py b/python/oneflow/test/modules/test_argwhere.py index 375c1f52660d90c530ba43a4d15b4e4b16043d76..ac7cd37b511d0dfc53ac5ad84784bfa8dd99b9c3 100644 --- a/python/oneflow/test/modules/test_argwhere.py +++ b/python/oneflow/test/modules/test_argwhere.py @@ -22,6 +22,7 @@ from test_util import GenArgList import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_argwhere(test_case, shape, device): @@ -43,6 +44,14 @@ class TestArgwhere(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @unittest.skip("pytorch do not have argwhere fn/module yet!") + @autotest(n=5, rtol=1e-5, atol=1e-5) + def test_argwhere_with_random_data(test_case): + device = random_device() + x = random_pytorch_tensor(ndim=random(2, 5).to(int)).to(device) + y = torch.argwhere(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_instancenorm.py b/python/oneflow/test/modules/test_instancenorm.py index 27760fa9a13eabf4bcf3befd0a488130673697c8..bb61600e709cf09b06a4f0647d2bad784d546475 100644 --- a/python/oneflow/test/modules/test_instancenorm.py +++ b/python/oneflow/test/modules/test_instancenorm.py @@ -22,6 +22,7 @@ from test_util import GenArgList import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_instancenorm1d(test_case, device): @@ -418,6 +419,69 @@ class TestInstanceNorm(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(n=5, auto_backward=True, rtol=1e-4, atol=1e-4) + def test_instancenorm_with_random_data(test_case): + height = random(1, 6).to(int) + width = random(1, 6).to(int) + m = torch.nn.InstanceNorm1d( + num_features=height, + eps=random().to(float) | nothing(), + momentum=random().to(float) | nothing(), + affine=random().to(bool), + track_running_stats=random().to(bool), + ) + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(ndim=3, dim1=height, dim2=width).to(device) + y = m(x) + return y + + @autotest(n=5, auto_backward=True, rtol=1e-4, atol=1e-4) + def test_instancenorm_with_random_data(test_case): + channel = random(1, 6).to(int) + height = random(1, 6).to(int) + width = random(1, 6).to(int) + m = torch.nn.InstanceNorm2d( + num_features=channel, + eps=random().to(float) | nothing(), + momentum=random().to(float) | nothing(), + affine=random().to(bool), + track_running_stats=random().to(bool), + ) + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to( + device + ) + y = m(x) + return y + + @autotest(n=5, auto_backward=False, rtol=1e-4, atol=1e-4) + def test_instancenorm_with_random_data(test_case): + channel = random(1, 6).to(int) + depth = random(1, 6).to(int) + height = random(1, 6).to(int) + width = random(1, 6).to(int) + # Set auto_backward=True will raise AssertionError: False is not true + # Set track_running_stats=True will raise error: Unexpected key(s) in state_dict: "num_batches_tracked". + m = torch.nn.InstanceNorm3d( + num_features=channel, + eps=random().to(float) | nothing(), + momentum=random().to(float) | nothing(), + affine=random().to(bool), + track_running_stats=False, + ) + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor( + ndim=5, dim1=channel, dim2=depth, dim3=height, dim4=width + ).to(device) + y = m(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_normalization.py b/python/oneflow/test/modules/test_normalization.py index 2542238376abd229b53c485dbe4e9abfcf869f10..f10851c85bf02f306b67c8aa322d9bf7ea7965c4 100644 --- a/python/oneflow/test/modules/test_normalization.py +++ b/python/oneflow/test/modules/test_normalization.py @@ -22,6 +22,7 @@ from test_util import GenArgList import oneflow as flow import oneflow.unittest +from automated_test_util import * input_arr = np.array( [ @@ -136,6 +137,25 @@ class TestLayerNorm(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest(n=5, auto_backward=False, rtol=1e-4, atol=1e-4) + def test_layernorm_with_random_data(test_case): + channel = random(1, 6).to(int) + height = random(1, 6).to(int) + width = random(1, 6).to(int) + m = torch.nn.LayerNorm( + normalized_shape=random(1, 6).to(int), + eps=random().to(float) | nothing(), + elementwise_affine=random().to(bool), + ) + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to( + device + ) + y = m(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_replicationpad2d.py b/python/oneflow/test/modules/test_replicationpad2d.py index b80fd8fcd810b5f52b382235d5ec000ecff22731..d93a832e24919f12249e686f83b61888437c34e0 100644 --- a/python/oneflow/test/modules/test_replicationpad2d.py +++ b/python/oneflow/test/modules/test_replicationpad2d.py @@ -22,6 +22,7 @@ from test_util import Array2Numpy, FlattenArray, GenArgList, Index2Coordinate import oneflow as flow import oneflow.unittest +from automated_test_util import * def _np_replication_pad2d_grad(src, dest, padding): @@ -102,6 +103,19 @@ class TestReplicationPad2dModule(flow.unittest.TestCase): for arg in GenArgList(arg_dict): _test_ReplicationPad2d(test_case, *arg) + @autotest(n=5) + def test_replication_pad2d_with_random_data(test_case): + c = random(1, 6).to(int) + h = random(1, 6).to(int) + w = random(1, 6).to(int) + m = torch.nn.ReplicationPad2d(padding=random(low=0, high=7)) + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(ndim=4, dim1=c, dim2=h, dim3=w).to(device) + y = m(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/tensor/test_tensor.py b/python/oneflow/test/tensor/test_tensor.py index 4316349e12da9f9e2047766e3a64c048ce3d5e32..d941e8af0bbcccade0354f22edbed12c65edcb25 100644 --- a/python/oneflow/test/tensor/test_tensor.py +++ b/python/oneflow/test/tensor/test_tensor.py @@ -355,6 +355,24 @@ class TestTensor(flow.unittest.TestCase): np_out = np.sum(input.numpy(), axis=(2, 1)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) + def test_argwhere(test_case): + shape = (2, 3, 4, 5) + precision = 1e-5 + np_input = np.random.randn(*shape) + input = flow.Tensor(np_input) + of_out = input.argwhere() + np_out = np.argwhere(np_input) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, precision, precision)) + test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape)) + + @autotest(n=5, auto_backward=False) + def test_tensor_argmax_with_random_data(test_case): + device = random_device() + ndim = random(1, 6).to(int) + x = random_pytorch_tensor(ndim=ndim).to(device) + y = x.argmax(dim=random(0, ndim).to(int), keepdim=random().to(bool)) + return y + @autotest() def test_tensor_tanh_with_random_data(test_case): device = random_device()