diff --git a/oneflow/python/nn/modules/loss.py b/oneflow/python/nn/modules/loss.py
index b0a94e814dd1a4623ef118062a999d250e3a5000..c65ac2439c0297cae382c1af68e2be54c01f223a 100644
--- a/oneflow/python/nn/modules/loss.py
+++ b/oneflow/python/nn/modules/loss.py
@@ -238,6 +238,13 @@ class NLLLoss(Module):
.Attr("dim", 1)
.Build()
)
+ self._transpose_op = (
+ flow.builtin_op("transpose")
+ .Input("input")
+ .Output("output")
+ .Attr("perm", [])
+ .Build()
+ )
def nllloss_1d(self, input, target):
target = flow.experimental.reshape(target, (target.shape[0], 1))
@@ -246,18 +253,25 @@ class NLLLoss(Module):
return res
def forward(self, input, target):
- assert len(input.shape) == 2 or len(input.shape) == 4
+ assert len(input.shape) <= 4
+ assert len(target.shape) == len(input.shape) - 1
input = input.negative()
if len(input.shape) == 2:
res = self.nllloss_1d(input, target)
+ elif len(input.shape) == 3:
+ b, c, h = input.shape[0], input.shape[1], input.shape[2]
+ input = self._transpose_op(input, perm=(0, 2, 1))[0]
+ input = input.reshape(shape=[-1, input.shape[2]])
+ target = target.flatten()
+ res = self.nllloss_1d(input, target)
+ res = res.reshape((b, h))
elif len(input.shape) == 4:
b, c, h, w = input.shape[0], input.shape[1], input.shape[2], input.shape[3]
- input = input.transpose((0, 2, 3, 1))
+ input = self._transpose_op(input, perm=(0, 2, 3, 1))[0]
input = input.reshape(shape=[-1, input.shape[3]])
target = target.flatten()
res = self.nllloss_1d(input, target)
res = res.reshape((b, h, w))
-
else:
raise NotImplemented
diff --git a/oneflow/python/test/modules/test_nllloss.py b/oneflow/python/test/modules/test_nllloss.py
index 42ef38d4b2ef1a32402a6c5a97b6d0d58e1716cc..5e8aeaf209934be6565b492ba639c38800452b15 100644
--- a/oneflow/python/test/modules/test_nllloss.py
+++ b/oneflow/python/test/modules/test_nllloss.py
@@ -56,6 +56,24 @@ def nll_loss_2d(logs, targets, reduction="none"):
return out
+def nll_loss_bert(logs, targets, reduction="none"):
+ input_shape = logs.shape
+ N = input_shape[0]
+ H = input_shape[2]
+ out = np.zeros_like(targets).astype(np.float64)
+ total_weight = N * H
+ for i in range(N):
+ for h in range(H):
+ cur_target = targets[i][h]
+ out[i][h] = -logs[i][cur_target][h]
+ if reduction == "sum":
+ return np.sum(out)
+ elif reduction == "mean":
+ return out.sum() / total_weight
+ elif reduction == "none":
+ return out
+
+
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
@@ -154,6 +172,42 @@ class TestNLLLossModule(flow.unittest.TestCase):
np_out = nll_loss_2d(input.numpy(), target.numpy(), reduction="sum")
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
+ def test_nllloss_bert_none(test_case):
+ x = np.array([[[0.12, 0.36, 0.22, 0.66], [0.13, 0.34, 0.52, -0.96]]]).astype(
+ np.float32
+ )
+ input = flow.Tensor(x, dtype=flow.float32)
+ y = np.array([[1, 0, 0, 1]]).astype(np.int)
+ target = flow.Tensor(y, dtype=flow.int64)
+ nll_loss = flow.nn.NLLLoss()
+ of_out = nll_loss(input, target)
+ np_out = nll_loss_bert(input.numpy(), target.numpy())
+ test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
+
+ def test_nllloss_bert_mean(test_case):
+ x = np.array([[[0.12, 0.36, 0.22, 0.66], [0.13, 0.34, 0.52, -0.96]]]).astype(
+ np.float32
+ )
+ input = flow.Tensor(x, dtype=flow.float32)
+ y = np.array([[1, 0, 0, 1]]).astype(np.int)
+ target = flow.Tensor(y, dtype=flow.int64)
+ nll_loss = flow.nn.NLLLoss(reduction="mean")
+ of_out = nll_loss(input, target)
+ np_out = nll_loss_bert(input.numpy(), target.numpy(), reduction="mean")
+ test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
+
+ def test_nllloss_bert_sum(test_case):
+ x = np.array([[[0.12, 0.36, 0.22, 0.66], [0.13, 0.34, 0.52, -0.96]]]).astype(
+ np.float32
+ )
+ input = flow.Tensor(x, dtype=flow.float32)
+ y = np.array([[1, 0, 0, 1]]).astype(np.int)
+ target = flow.Tensor(y, dtype=flow.int64)
+ nll_loss = flow.nn.NLLLoss(reduction="sum")
+ of_out = nll_loss(input, target)
+ np_out = nll_loss_bert(input.numpy(), target.numpy(), reduction="sum")
+ test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
+
if __name__ == "__main__":
unittest.main()