diff --git a/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp b/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp new file mode 100644 index 0000000000000000000000000000000000000000..85e19231f78cfeb0cb4569025842b4a2c8730a23 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/clip_by_scalar_max.cpp @@ -0,0 +1,72 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct ClipByScalarMaxInterpState : public OpExprInterpState { + bool requires_grad; + functional::Scalar max; +}; + +class ClipByScalarMax : public OpExprGradFunction<ClipByScalarMaxInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe<void>::Ok(); + } + + Maybe<void> Capture(ClipByScalarMaxInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe<void>::Ok(); } + ctx->SaveTensorForBackward(inputs.at(0)); + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + if (IsFloatingDataType(inputs.at(0)->dtype())) { + ctx->max = functional::Scalar(JUST(composed_attrs.GetAttr<double>("floating_max"))); + } else if (IsIntegralDataType(inputs.at(0)->dtype())) { + ctx->max = functional::Scalar(JUST(composed_attrs.GetAttr<int64_t>("integral_max"))); + } else { + UNIMPLEMENTED_THEN_RETURN() << "Data type is not floating or integral type."; + } + return Maybe<void>::Ok(); + } + + Maybe<void> Apply(const ClipByScalarMaxInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::ClipByScalarMaxGrad(out_grads.at(0), x, ctx->max)); + } + return Maybe<void>::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("clip_by_scalar_max", ClipByScalarMax); + +} // namespace one +} // namespace oneflow diff --git a/python/oneflow/nn/modules/activation.py b/python/oneflow/nn/modules/activation.py index 8c3abb6e5f5aa516a1fb6aaecb0c5fe08d04819a..406d0187e6e40ba0a5bc52f58b851e39eb5344ea 100644 --- a/python/oneflow/nn/modules/activation.py +++ b/python/oneflow/nn/modules/activation.py @@ -199,7 +199,7 @@ class Tanh(Module): out = \\frac{e^x-e^{-x}}{e^x+e^{-x}} Args: - x (oneflow.Tensor): A Tensor + input (oneflow.Tensor): A Tensor Returns: oneflow.Tensor: The result Tensor @@ -223,12 +223,12 @@ class Tanh(Module): def __init__(self): super().__init__() - def forward(self, x): - return flow.F.tanh(x) + def forward(self, input): + return flow.F.tanh(input) @register_tensor_op("tanh") -def tanh_op(x): +def tanh_op(input): """This operator computes the hyperbolic tangent value of Tensor. The equation is: @@ -258,7 +258,7 @@ def tanh_op(x): tensor([-0.7616, 0. , 0.7616], dtype=oneflow.float32) """ - return Tanh()(x) + return Tanh()(input) class ELU(Module): diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index d684591b6b1820fab5a9871da10e332ee4c244c5..a54d81c5fb798331968aa561d46077b9da9401f4 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -60,7 +60,7 @@ class TestReLUModule(flow.unittest.TestCase): for arg in GenArgList(arg_dict): _test_relu_impl(test_case, *arg) - @autotest + @autotest() def test_relu_module_with_random_data(test_case): m = torch.nn.ReLU() m.train(random()) @@ -101,7 +101,7 @@ class TestReLU6Module(flow.unittest.TestCase): for arg in GenArgList(arg_dict): _test_relu6_impl(test_case, *arg) - @autotest + @autotest() def test_relu6_module_with_random_data(test_case): m = torch.nn.ReLU6() m.train(random()) @@ -153,7 +153,7 @@ class TestTanh(flow.unittest.TestCase): _test_tanh_nn_impl(test_case, *arg) _test_tanh_function_impl(test_case, *arg) - @autotest + @autotest() def test_tanh_module_with_random_data(test_case): m = torch.nn.Tanh() m.train(random()) @@ -163,11 +163,11 @@ class TestTanh(flow.unittest.TestCase): y = m(x) return y - @autotest + @autotest() def test_flow_tanh_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) - y = flow.tanh(x) + y = torch.tanh(x) return y @@ -199,7 +199,7 @@ class TestELUModule(flow.unittest.TestCase): for arg in GenArgList(arg_dict): _test_elu_function_impl(test_case, *arg) - @autotest + @autotest() def test_elu_module_with_random_data(test_case): m = torch.nn.ELU(alpha=random() | nothing()) m.train(random()) @@ -678,7 +678,8 @@ class TestSoftplusModule(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest + @unittest.skip("pytorch softplus backward has bug") + @autotest() def test_softplus_module_with_random_data(test_case): m = torch.nn.Softplus(beta=random() | nothing(), threshold=random() | nothing()) m.train(random()) @@ -782,7 +783,7 @@ class TestLeakyReLUModule(flow.unittest.TestCase): for arg in GenArgList(arg_dict): _test_leakyrelu_impl(test_case, *arg) - @autotest + @autotest() def test_leakyrelu_module_with_random_data(test_case): m = torch.nn.LeakyReLU(negative_slope=random() | nothing()) m.train(random()) diff --git a/python/oneflow/test/modules/test_addmm.py b/python/oneflow/test/modules/test_addmm.py index 8dd938c29567139d19312047a1bac5afe17e0633..787b914090611d9612a9084c86f92915b3aff33f 100644 --- a/python/oneflow/test/modules/test_addmm.py +++ b/python/oneflow/test/modules/test_addmm.py @@ -22,6 +22,7 @@ from test_util import GenArgList import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_addmm(test_case, shape, alpha, beta, device): @@ -65,6 +66,36 @@ class TestAddmm(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest() + def test_addmm_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=2, dim1=3).to(device) + mat1 = random_pytorch_tensor(ndim=2, dim0=2, dim1=4).to(device) + mat2 = random_pytorch_tensor(ndim=2, dim0=4, dim1=3).to(device) + y = torch.addmm( + input, + mat1, + mat2, + beta=random().to(float) | nothing(), + alpha=random().to(float) | nothing(), + ) + return y + + @autotest() + def test_addmm_broadcast_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=1, dim1=1).to(device) + mat1 = random_pytorch_tensor(ndim=2, dim0=2, dim1=4).to(device) + mat2 = random_pytorch_tensor(ndim=2, dim0=4, dim1=3).to(device) + y = torch.addmm( + input, + mat1, + mat2, + beta=random().to(float) | nothing(), + alpha=random().to(float) | nothing(), + ) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_batchnorm.py b/python/oneflow/test/modules/test_batchnorm.py index ffe4c767ea53e4f414689fab15a355dfb6893ad3..193fb36454d3e4b0c8a6caada514cbd868a84733 100644 --- a/python/oneflow/test/modules/test_batchnorm.py +++ b/python/oneflow/test/modules/test_batchnorm.py @@ -516,6 +516,17 @@ class TestBatchNorm(flow.unittest.TestCase): n=10, ) + @autotest(n=1, auto_backward=False) + def test_batchnorm3d_module_with_random_data(test_case): + channel = random().to(int) + m = torch.nn.BatchNorm2d(num_features=channel, track_running_stats=False) + m.train(random()) + device = random_device() + m.to(device) + x = random_pytorch_tensor(ndim=4, dim1=channel, requires_grad=False).to(device) + y = m(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_ceil.py b/python/oneflow/test/modules/test_ceil.py index 6eb12d7a32be4b65546b8bd0a3e2c472cbacdb5f..002251d666225c6838edab796332a94a3896d97e 100644 --- a/python/oneflow/test/modules/test_ceil.py +++ b/python/oneflow/test/modules/test_ceil.py @@ -22,6 +22,7 @@ from test_util import GenArgList import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_ceil_impl(test_case, device, shape): @@ -46,6 +47,13 @@ class TestCeilModule(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest() + def test_ceil_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = torch.ceil(input) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_clamp.py b/python/oneflow/test/modules/test_clamp.py index 4ef548bfd688db283eb8d4856a8a5ab201691a6f..640408dd1d68370a146f0f9e64c070df14fac650 100644 --- a/python/oneflow/test/modules/test_clamp.py +++ b/python/oneflow/test/modules/test_clamp.py @@ -22,6 +22,7 @@ from test_util import GenArgList import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_clamp(test_case, shape, device): @@ -106,6 +107,52 @@ class TestClampModule(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest() + def test_clamp_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = torch.clamp(input, min=random().to(float), max=random().to(float)) + return y + + @autotest() + def test_clamp_min_none_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = torch.clamp(input, min=random().to(float), max=random().to(float)) + return y + + @autotest() + def test_clamp_max_none_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = torch.clamp( + input, min=random().to(float), max=random().to(float) | nothing() + ) + return y + + @autotest() + def test_clip_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = torch.clip(input, min=random().to(float), max=random().to(float)) + return y + + @autotest() + def test_clip_min_none_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = torch.clip(input, min=random().to(float), max=random().to(float)) + return y + + @autotest() + def test_clip_max_none_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = torch.clip( + input, min=random().to(float), max=random().to(float) | nothing() + ) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_expm1.py b/python/oneflow/test/modules/test_expm1.py index 9084a370b62d39c1ade318bc7cf775d57ebe2b05..454b5daf6297318353285796f1f26e082949af58 100644 --- a/python/oneflow/test/modules/test_expm1.py +++ b/python/oneflow/test/modules/test_expm1.py @@ -22,6 +22,7 @@ from test_util import GenArgList import oneflow as flow import oneflow.unittest +from automated_test_util import * def _test_expm1_impl(test_case, device, shape): @@ -46,6 +47,13 @@ class TestExpm1Module(flow.unittest.TestCase): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + @autotest() + def test_expm1_flow_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = torch.expm1(input) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/tensor/test_tensor.py b/python/oneflow/test/tensor/test_tensor.py index b5bc832f3e36b7bb4bcbe0af4950bcaccb5a7c10..072cd2a4ea9c6c8f37694cc5dd5437701c311301 100644 --- a/python/oneflow/test/tensor/test_tensor.py +++ b/python/oneflow/test/tensor/test_tensor.py @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -import random import unittest from collections import OrderedDict @@ -714,13 +713,89 @@ class TestTensor(flow.unittest.TestCase): np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05, equal_nan=True) ) - def test_tensor_addmm_(test_case): - input = flow.Tensor(np.random.randn(2, 6), dtype=flow.float32) - mat1 = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32) - mat2 = flow.Tensor(np.random.randn(3, 6), dtype=flow.float32) - of_out = input.addmm(mat1, mat2, alpha=1, beta=2) - np_out = np.add(2 * input.numpy(), 1 * np.matmul(mat1.numpy(), mat2.numpy())) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-05, 1e-05)) + @autotest() + def test_addmm_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=2, dim1=3).to(device) + mat1 = random_pytorch_tensor(ndim=2, dim0=2, dim1=4).to(device) + mat2 = random_pytorch_tensor(ndim=2, dim0=4, dim1=3).to(device) + y = input.addmm( + mat1, + mat2, + beta=random().to(float) | nothing(), + alpha=random().to(float) | nothing(), + ) + return y + + @autotest() + def test_addmm_broadcast_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor(ndim=2, dim0=1, dim1=1).to(device) + mat1 = random_pytorch_tensor(ndim=2, dim0=2, dim1=4).to(device) + mat2 = random_pytorch_tensor(ndim=2, dim0=4, dim1=3).to(device) + y = input.addmm( + mat1, + mat2, + beta=random().to(float) | nothing(), + alpha=random().to(float) | nothing(), + ) + return y + + @autotest() + def test_clamp_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = input.clamp(min=random().to(float), max=random().to(float)) + return y + + @autotest() + def test_clamp_minnone_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = input.clamp(min=random().to(float) | nothing(), max=random().to(float)) + return y + + @autotest() + def test_clamp_maxnone_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = input.clamp(min=random().to(float), max=random().to(float) | nothing()) + return y + + @autotest() + def test_clip_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = input.clip(min=random().to(float), max=random().to(float)) + return y + + @autotest() + def test_clip_minnone_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = input.clip(min=random().to(float) | nothing(), max=random().to(float)) + return y + + @autotest() + def test_clip_maxnone_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = input.clip(min=random().to(float), max=random().to(float) | nothing()) + return y + + @autotest() + def test_ceil_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = input.ceil() + return y + + @autotest() + def test_expm1_tensor_with_random_data(test_case): + device = random_device() + input = random_pytorch_tensor().to(device) + y = input.expm1() + return y def test_norm_tensor_function(test_case): input = flow.Tensor( @@ -818,7 +893,7 @@ class TestTensor(flow.unittest.TestCase): ) def test_tensor_fmod(test_case): x = flow.Tensor(np.random.uniform(-100, 100, (5, 5)), requires_grad=True) - y = random.uniform(-10, 10) + y = np.random.uniform(-10, 10) of_out = x.fmod(y) np_out = np.sign(x.numpy()) * np.abs(np.fmod(x.numpy(), y)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) @@ -834,7 +909,7 @@ class TestTensor(flow.unittest.TestCase): ) def test_magic_fmod(test_case): x = flow.Tensor(np.random.uniform(-100, 100, (5, 5)), requires_grad=True) - y = random.uniform(-10, 10) + y = np.random.uniform(-10, 10) of_out = x % y np_out = np.sign(x.numpy()) * np.abs(np.fmod(x.numpy(), y)) test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) @@ -844,32 +919,6 @@ class TestTensor(flow.unittest.TestCase): np.allclose(x.grad.numpy(), np.ones((5, 5)), 0.0001, 0.0001) ) - @unittest.skipIf( - not flow.unittest.env.eager_execution_enabled(), - "numpy doesn't work in lazy mode", - ) - def test_tensor_ceil(test_case): - x = flow.Tensor(np.random.randn(2, 3), requires_grad=True) - of_out = x.ceil() - np_out = np.ceil(x.numpy()) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) - of_out = of_out.sum() - of_out.backward() - test_case.assertTrue( - np.allclose(x.grad.numpy(), np.zeros((2, 3)), 0.0001, 0.0001) - ) - - def test_tensor_expm1(test_case): - x = flow.Tensor(np.random.randn(2, 3), requires_grad=True) - of_out = x.expm1() - np_out = np.expm1(x.numpy()) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 0.0001, 0.0001)) - of_out = of_out.sum() - of_out.backward() - test_case.assertTrue( - np.allclose(x.grad.numpy(), np.exp(x.numpy()), 0.0001, 0.0001) - ) - def test_tensor_mish(test_case): def np_mish(x): f = 1 + np.exp(x)