Skip to content
Snippets Groups Projects
Unverified Commit f1ccf2a3 authored by Lyon's avatar Lyon Committed by GitHub
Browse files

Dev fix linear module (#4836)


* add broadcast matmul support

* refine

* add batch matmul support

* remove redundant test case

* linear module support high dimension input

* format

* fix linear

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent aa9f6f76
No related branches found
No related tags found
No related merge requests found
......@@ -117,16 +117,8 @@ class Linear(Module):
if bias:
self.bias = flow.nn.Parameter(flow.Tensor(out_features))
self._bias_add_op = (
flow.builtin_op("bias_add")
.Input("a")
.Input("b")
.Output("out")
.Attr("axis", 1)
.Build()
)
self._op = (
self._matmul_op = (
flow.builtin_op("matmul")
.Input("a")
.Input("b")
......@@ -136,6 +128,18 @@ class Linear(Module):
.Attr("alpha", 1.0)
.Build()
)
self._broadcast_matmul_op = (
flow.builtin_op("broadcast_matmul")
.Input("a")
.Input("b")
.Output("out")
.Attr("transpose_a", False)
.Attr("transpose_b", True)
.Attr("alpha", 1.0)
.Build()
)
self.reset_parameters()
def reset_parameters(self) -> None:
......@@ -147,8 +151,14 @@ class Linear(Module):
flow.nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x):
res = self._op(x, self.weight)[0]
assert len(x.shape) >= 2, "Tensor x's dim should >=2"
if len(x.shape) == 2:
res = self._matmul_op(x, self.weight)[0]
else:
res = self._broadcast_matmul_op(x, self.weight)[0]
if self.use_bias:
res = self._bias_add_op(res, self.bias)[0]
res += self.bias
return res
......@@ -26,12 +26,6 @@ import oneflow.typing as tp
".numpy() doesn't work in lazy mode",
)
class TestLinear(flow.unittest.TestCase):
def test_identity(test_case):
m = flow.nn.Identity()
x = flow.Tensor(np.random.rand(2, 3, 4, 5))
y = m(x)
test_case.assertTrue(np.allclose(x.numpy(), y.numpy()))
def test_linear_v1(test_case):
linear = flow.nn.Linear(3, 8, False)
input_arr = np.array(
......@@ -53,7 +47,7 @@ class TestLinear(flow.unittest.TestCase):
flow.nn.init.constant_(linear.weight, 2.3)
of_out = linear(x)
np_out = np.matmul(input_arr, np_weight)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def test_linear_v2(test_case):
linear = flow.nn.Linear(3, 8)
......@@ -80,7 +74,36 @@ class TestLinear(flow.unittest.TestCase):
of_out = linear(x)
np_out = np.matmul(input_arr, np_weight)
np_out += np_bias
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def test_linear_3_dimension_input(test_case):
input_arr = np.random.randn(2, 3, 4)
x = flow.Tensor(input_arr)
m = flow.nn.Linear(4, 5, True)
flow.nn.init.constant_(m.weight, 5.6)
flow.nn.init.constant_(m.bias, 0.78)
of_out = m(x)
np_weight = np.ones((4, 5)).astype(np.float32)
np_weight.fill(5.6)
np_bias = np.ones((5))
np_bias.fill(0.78)
np_out = np.matmul(input_arr, np_weight)
np_out += np_bias
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
def test_linear_4_dimension_input(test_case):
input_arr = np.random.randn(4, 5, 6, 7)
x = flow.Tensor(input_arr)
m = flow.nn.Linear(7, 3, False)
flow.nn.init.constant_(m.weight, 11.3)
of_out = m(x)
np_weight = np.ones((7, 3)).astype(np.float32)
np_weight.fill(11.3)
np_out = np.matmul(input_arr, np_weight)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
@unittest.skipIf(
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment