From f1ccf2a324a0e73b74c20bb8c61585a8c2d3087c Mon Sep 17 00:00:00 2001 From: Lyon <flowingsun007@163.com> Date: Mon, 10 May 2021 21:45:55 +0800 Subject: [PATCH] Dev fix linear module (#4836) * add broadcast matmul support * refine * add batch matmul support * remove redundant test case * linear module support high dimension input * format * fix linear Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/python/nn/modules/linear.py | 34 ++++++++++++------- oneflow/python/test/modules/test_linear.py | 39 +++++++++++++++++----- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/oneflow/python/nn/modules/linear.py b/oneflow/python/nn/modules/linear.py index 4877b6ff5..62018262f 100644 --- a/oneflow/python/nn/modules/linear.py +++ b/oneflow/python/nn/modules/linear.py @@ -117,16 +117,8 @@ class Linear(Module): if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_features)) - self._bias_add_op = ( - flow.builtin_op("bias_add") - .Input("a") - .Input("b") - .Output("out") - .Attr("axis", 1) - .Build() - ) - - self._op = ( + + self._matmul_op = ( flow.builtin_op("matmul") .Input("a") .Input("b") @@ -136,6 +128,18 @@ class Linear(Module): .Attr("alpha", 1.0) .Build() ) + + self._broadcast_matmul_op = ( + flow.builtin_op("broadcast_matmul") + .Input("a") + .Input("b") + .Output("out") + .Attr("transpose_a", False) + .Attr("transpose_b", True) + .Attr("alpha", 1.0) + .Build() + ) + self.reset_parameters() def reset_parameters(self) -> None: @@ -147,8 +151,14 @@ class Linear(Module): flow.nn.init.uniform_(self.bias, -bound, bound) def forward(self, x): - res = self._op(x, self.weight)[0] + assert len(x.shape) >= 2, "Tensor x's dim should >=2" + + if len(x.shape) == 2: + res = self._matmul_op(x, self.weight)[0] + else: + res = self._broadcast_matmul_op(x, self.weight)[0] + if self.use_bias: - res = self._bias_add_op(res, self.bias)[0] + res += self.bias return res diff --git a/oneflow/python/test/modules/test_linear.py b/oneflow/python/test/modules/test_linear.py index 2a6a034f2..ffc32f0ec 100644 --- a/oneflow/python/test/modules/test_linear.py +++ b/oneflow/python/test/modules/test_linear.py @@ -26,12 +26,6 @@ import oneflow.typing as tp ".numpy() doesn't work in lazy mode", ) class TestLinear(flow.unittest.TestCase): - def test_identity(test_case): - m = flow.nn.Identity() - x = flow.Tensor(np.random.rand(2, 3, 4, 5)) - y = m(x) - test_case.assertTrue(np.allclose(x.numpy(), y.numpy())) - def test_linear_v1(test_case): linear = flow.nn.Linear(3, 8, False) input_arr = np.array( @@ -53,7 +47,7 @@ class TestLinear(flow.unittest.TestCase): flow.nn.init.constant_(linear.weight, 2.3) of_out = linear(x) np_out = np.matmul(input_arr, np_weight) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) def test_linear_v2(test_case): linear = flow.nn.Linear(3, 8) @@ -80,7 +74,36 @@ class TestLinear(flow.unittest.TestCase): of_out = linear(x) np_out = np.matmul(input_arr, np_weight) np_out += np_bias - test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + def test_linear_3_dimension_input(test_case): + input_arr = np.random.randn(2, 3, 4) + x = flow.Tensor(input_arr) + m = flow.nn.Linear(4, 5, True) + flow.nn.init.constant_(m.weight, 5.6) + flow.nn.init.constant_(m.bias, 0.78) + of_out = m(x) + + np_weight = np.ones((4, 5)).astype(np.float32) + np_weight.fill(5.6) + np_bias = np.ones((5)) + np_bias.fill(0.78) + np_out = np.matmul(input_arr, np_weight) + np_out += np_bias + + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + def test_linear_4_dimension_input(test_case): + input_arr = np.random.randn(4, 5, 6, 7) + x = flow.Tensor(input_arr) + m = flow.nn.Linear(7, 3, False) + flow.nn.init.constant_(m.weight, 11.3) + of_out = m(x) + + np_weight = np.ones((7, 3)).astype(np.float32) + np_weight.fill(11.3) + np_out = np.matmul(input_arr, np_weight) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) @unittest.skipIf( -- GitLab