Skip to content
Snippets Groups Projects
Unverified Commit 30a75cee authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

supoort nllloss 3dim (#4874)


* supoort nllloss 3dim

* supoort nllloss 3dim

* merge conflict

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 4e689d25
No related branches found
No related tags found
No related merge requests found
......@@ -238,6 +238,13 @@ class NLLLoss(Module):
.Attr("dim", 1)
.Build()
)
self._transpose_op = (
flow.builtin_op("transpose")
.Input("input")
.Output("output")
.Attr("perm", [])
.Build()
)
def nllloss_1d(self, input, target):
target = flow.experimental.reshape(target, (target.shape[0], 1))
......@@ -246,18 +253,25 @@ class NLLLoss(Module):
return res
def forward(self, input, target):
assert len(input.shape) == 2 or len(input.shape) == 4
assert len(input.shape) <= 4
assert len(target.shape) == len(input.shape) - 1
input = input.negative()
if len(input.shape) == 2:
res = self.nllloss_1d(input, target)
elif len(input.shape) == 3:
b, c, h = input.shape[0], input.shape[1], input.shape[2]
input = self._transpose_op(input, perm=(0, 2, 1))[0]
input = input.reshape(shape=[-1, input.shape[2]])
target = target.flatten()
res = self.nllloss_1d(input, target)
res = res.reshape((b, h))
elif len(input.shape) == 4:
b, c, h, w = input.shape[0], input.shape[1], input.shape[2], input.shape[3]
input = input.transpose((0, 2, 3, 1))
input = self._transpose_op(input, perm=(0, 2, 3, 1))[0]
input = input.reshape(shape=[-1, input.shape[3]])
target = target.flatten()
res = self.nllloss_1d(input, target)
res = res.reshape((b, h, w))
else:
raise NotImplemented
......
......@@ -56,6 +56,24 @@ def nll_loss_2d(logs, targets, reduction="none"):
return out
def nll_loss_bert(logs, targets, reduction="none"):
input_shape = logs.shape
N = input_shape[0]
H = input_shape[2]
out = np.zeros_like(targets).astype(np.float64)
total_weight = N * H
for i in range(N):
for h in range(H):
cur_target = targets[i][h]
out[i][h] = -logs[i][cur_target][h]
if reduction == "sum":
return np.sum(out)
elif reduction == "mean":
return out.sum() / total_weight
elif reduction == "none":
return out
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
......@@ -154,6 +172,42 @@ class TestNLLLossModule(flow.unittest.TestCase):
np_out = nll_loss_2d(input.numpy(), target.numpy(), reduction="sum")
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
def test_nllloss_bert_none(test_case):
x = np.array([[[0.12, 0.36, 0.22, 0.66], [0.13, 0.34, 0.52, -0.96]]]).astype(
np.float32
)
input = flow.Tensor(x, dtype=flow.float32)
y = np.array([[1, 0, 0, 1]]).astype(np.int)
target = flow.Tensor(y, dtype=flow.int64)
nll_loss = flow.nn.NLLLoss()
of_out = nll_loss(input, target)
np_out = nll_loss_bert(input.numpy(), target.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
def test_nllloss_bert_mean(test_case):
x = np.array([[[0.12, 0.36, 0.22, 0.66], [0.13, 0.34, 0.52, -0.96]]]).astype(
np.float32
)
input = flow.Tensor(x, dtype=flow.float32)
y = np.array([[1, 0, 0, 1]]).astype(np.int)
target = flow.Tensor(y, dtype=flow.int64)
nll_loss = flow.nn.NLLLoss(reduction="mean")
of_out = nll_loss(input, target)
np_out = nll_loss_bert(input.numpy(), target.numpy(), reduction="mean")
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
def test_nllloss_bert_sum(test_case):
x = np.array([[[0.12, 0.36, 0.22, 0.66], [0.13, 0.34, 0.52, -0.96]]]).astype(
np.float32
)
input = flow.Tensor(x, dtype=flow.float32)
y = np.array([[1, 0, 0, 1]]).astype(np.int)
target = flow.Tensor(y, dtype=flow.int64)
nll_loss = flow.nn.NLLLoss(reduction="sum")
of_out = nll_loss(input, target)
np_out = nll_loss_bert(input.numpy(), target.numpy(), reduction="sum")
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment