From f1ccf2a324a0e73b74c20bb8c61585a8c2d3087c Mon Sep 17 00:00:00 2001
From: Lyon <flowingsun007@163.com>
Date: Mon, 10 May 2021 21:45:55 +0800
Subject: [PATCH] Dev fix linear module (#4836)

* add broadcast matmul support

* refine

* add batch matmul support

* remove redundant test case

* linear module support high dimension input

* format

* fix linear

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/python/nn/modules/linear.py        | 34 ++++++++++++-------
 oneflow/python/test/modules/test_linear.py | 39 +++++++++++++++++-----
 2 files changed, 53 insertions(+), 20 deletions(-)

diff --git a/oneflow/python/nn/modules/linear.py b/oneflow/python/nn/modules/linear.py
index 4877b6ff5..62018262f 100644
--- a/oneflow/python/nn/modules/linear.py
+++ b/oneflow/python/nn/modules/linear.py
@@ -117,16 +117,8 @@ class Linear(Module):
 
         if bias:
             self.bias = flow.nn.Parameter(flow.Tensor(out_features))
-            self._bias_add_op = (
-                flow.builtin_op("bias_add")
-                .Input("a")
-                .Input("b")
-                .Output("out")
-                .Attr("axis", 1)
-                .Build()
-            )
-
-        self._op = (
+
+        self._matmul_op = (
             flow.builtin_op("matmul")
             .Input("a")
             .Input("b")
@@ -136,6 +128,18 @@ class Linear(Module):
             .Attr("alpha", 1.0)
             .Build()
         )
+
+        self._broadcast_matmul_op = (
+            flow.builtin_op("broadcast_matmul")
+            .Input("a")
+            .Input("b")
+            .Output("out")
+            .Attr("transpose_a", False)
+            .Attr("transpose_b", True)
+            .Attr("alpha", 1.0)
+            .Build()
+        )
+
         self.reset_parameters()
 
     def reset_parameters(self) -> None:
@@ -147,8 +151,14 @@ class Linear(Module):
             flow.nn.init.uniform_(self.bias, -bound, bound)
 
     def forward(self, x):
-        res = self._op(x, self.weight)[0]
+        assert len(x.shape) >= 2, "Tensor x's dim should >=2"
+
+        if len(x.shape) == 2:
+            res = self._matmul_op(x, self.weight)[0]
+        else:
+            res = self._broadcast_matmul_op(x, self.weight)[0]
+
         if self.use_bias:
-            res = self._bias_add_op(res, self.bias)[0]
+            res += self.bias
 
         return res
diff --git a/oneflow/python/test/modules/test_linear.py b/oneflow/python/test/modules/test_linear.py
index 2a6a034f2..ffc32f0ec 100644
--- a/oneflow/python/test/modules/test_linear.py
+++ b/oneflow/python/test/modules/test_linear.py
@@ -26,12 +26,6 @@ import oneflow.typing as tp
     ".numpy() doesn't work in lazy mode",
 )
 class TestLinear(flow.unittest.TestCase):
-    def test_identity(test_case):
-        m = flow.nn.Identity()
-        x = flow.Tensor(np.random.rand(2, 3, 4, 5))
-        y = m(x)
-        test_case.assertTrue(np.allclose(x.numpy(), y.numpy()))
-
     def test_linear_v1(test_case):
         linear = flow.nn.Linear(3, 8, False)
         input_arr = np.array(
@@ -53,7 +47,7 @@ class TestLinear(flow.unittest.TestCase):
         flow.nn.init.constant_(linear.weight, 2.3)
         of_out = linear(x)
         np_out = np.matmul(input_arr, np_weight)
-        test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
 
     def test_linear_v2(test_case):
         linear = flow.nn.Linear(3, 8)
@@ -80,7 +74,36 @@ class TestLinear(flow.unittest.TestCase):
         of_out = linear(x)
         np_out = np.matmul(input_arr, np_weight)
         np_out += np_bias
-        test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+    def test_linear_3_dimension_input(test_case):
+        input_arr = np.random.randn(2, 3, 4)
+        x = flow.Tensor(input_arr)
+        m = flow.nn.Linear(4, 5, True)
+        flow.nn.init.constant_(m.weight, 5.6)
+        flow.nn.init.constant_(m.bias, 0.78)
+        of_out = m(x)
+
+        np_weight = np.ones((4, 5)).astype(np.float32)
+        np_weight.fill(5.6)
+        np_bias = np.ones((5))
+        np_bias.fill(0.78)
+        np_out = np.matmul(input_arr, np_weight)
+        np_out += np_bias
+
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+    def test_linear_4_dimension_input(test_case):
+        input_arr = np.random.randn(4, 5, 6, 7)
+        x = flow.Tensor(input_arr)
+        m = flow.nn.Linear(7, 3, False)
+        flow.nn.init.constant_(m.weight, 11.3)
+        of_out = m(x)
+
+        np_weight = np.ones((7, 3)).astype(np.float32)
+        np_weight.fill(11.3)
+        np_out = np.matmul(input_arr, np_weight)
+        test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
 
 
 @unittest.skipIf(
-- 
GitLab