From d1a76aa816781fc382be2613a60325ed1848ec3f Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 6 May 2021 20:47:52 +0800 Subject: [PATCH] Add negative and nllloss module (#4789) * add negative and nllloss module * add negative and nllloss module * add negative and nllloss module * add negative and nllloss module * add nllloss impleted, relay on other module, do not merge now * fix nllloss bug * support 2d nllloss bug * fix comment * add flow.neg api * support bert nllloss * fix nllloss 2d bug * add docs * fix comment * fix comment * fix comment Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/python/framework/tensor.py | 2 +- oneflow/python/nn/modules/loss.py | 130 +++++++++++++++ oneflow/python/nn/modules/negative.py | 57 +++++++ oneflow/python/test/modules/test_negative.py | 61 +++++++ oneflow/python/test/modules/test_nllloss.py | 159 +++++++++++++++++++ 5 files changed, 408 insertions(+), 1 deletion(-) create mode 100644 oneflow/python/nn/modules/negative.py create mode 100644 oneflow/python/test/modules/test_negative.py create mode 100644 oneflow/python/test/modules/test_nllloss.py diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index 4555a7248..96f6bc70c 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -332,7 +332,7 @@ class Tensor: return flow.div(other, self) def __neg__(self): - return flow.mul(-1, self) + return flow.neg(self) def _determine_if_needed(self, determining_initializer=None): if not self.is_determined: diff --git a/oneflow/python/nn/modules/loss.py b/oneflow/python/nn/modules/loss.py index e471f0068..398dda795 100644 --- a/oneflow/python/nn/modules/loss.py +++ b/oneflow/python/nn/modules/loss.py @@ -86,3 +86,133 @@ class CrossEntropyLoss(Module): return flow.sum(out) else: return out + + +@oneflow_export("nn.NLLLoss") +class NLLLoss(Module): + r""" The negative log likelihood loss. It is useful to train a classification + problem with `C` classes. + + The `input` given through a forward call is expected to contain + log-probabilities of each class. `input` has to be a Tensor of size either + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` + with :math:`K \geq 1` for the `K`-dimensional case (described later). + + Obtaining log-probabilities in a neural network is easily achieved by + adding a `LogSoftmax` layer in the last layer of your network. + You may use `CrossEntropyLoss` instead, if you prefer not to add an extra + layer. + + The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` + where `C = number of classes`; + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} x_{n,y_n}, \quad + w_{c} = \mathbb{1}, + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and + :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{N} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Can also be used for higher dimension inputs, such as 2D images, by providing + an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`, + where :math:`K` is the number of dimensions, and a target of appropriate shape + (see below). In the case of images, it computes NLL loss per-pixel. + + Args: + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + + For example: + + .. code-block:: python + + import oneflow as flow + import numpy as np + + input = flow.Tensor( + [[-0.1664078, -1.7256707, -0.14690138], + [-0.21474946, 0.53737473, 0.99684894], + [-1.135804, -0.50371903, 0.7645404]], dtype=flow.float32) + target = flow.Tensor(np.array([0, 1, 2]), dtype=flow.int32) + out = flow.nn.NLLLoss(reduction="none")(input, target) + # out: [0.80199665 1.1166505 0.35826027] + + out_sum = flow.nn.NLLLoss(reduction="sum")(input, target) + # out_sum: [2.2769074] + + out_mean = flow.nn.NLLLoss(reduction="mean")(input, target) + # out_mean: [0.7589692] + + """ + + def __init__( + self, weight=None, ignore_index: int = None, reduction: str = "none", + ) -> None: + super().__init__() + if weight != None: + raise ValueError("Argument weight is not supported yet") + if ignore_index != None: + raise ValueError("Argument ignore_index is not supported yet") + assert reduction in [ + "sum", + "none", + "mean", + None, + ], "only 'sum', 'mean' and None supported by now" + + self.reduction = reduction + self._gather_nd_op = ( + flow.builtin_op("gather_nd") + .Input("params") + .Input("indices") + .Output("out") + .Build() + ) + + def nllloss_1d(self, input, target): + n = input.shape[0] + idx = flow.unsqueeze(flow.arange(0, n, 1), dim=1) + target = flow.unsqueeze(target, dim=1) + t = flow.cat([idx, target], dim=1) + res = self._gather_nd_op(input, t)[0] + return res + + def forward(self, input, target): + assert len(input.shape) == 2 or len(input.shape) == 4 + input = flow.negative(input) + if len(input.shape) == 2: + res = self.nllloss_1d(input, target) + elif len(input.shape) == 4: + b, c, h, w = input.shape[0], input.shape[1], input.shape[2], input.shape[3] + input = flow.tmp.transpose(input, (0, 2, 3, 1)) + input = flow.tmp.reshape(input, shape=[-1, input.shape[3]]) + target = flow.tmp.flatten(target) + res = self.nllloss_1d(input, target) + res = flow.tmp.reshape(res, (b, h, w)) + + else: + raise NotImplemented + + if self.reduction == "none": + return res + elif self.reduction == "sum": + return flow.sum(res) + else: + return flow.mean(res) diff --git a/oneflow/python/nn/modules/negative.py b/oneflow/python/nn/modules/negative.py new file mode 100644 index 000000000..e6b48d8bc --- /dev/null +++ b/oneflow/python/nn/modules/negative.py @@ -0,0 +1,57 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +from oneflow.python.nn.module import Module +from oneflow.python.oneflow_export import oneflow_export +from oneflow.python.framework.tensor import register_tensor_op + + +class Negative(Module): + def __init__(self) -> None: + super().__init__() + self._op = flow.builtin_op("negative").Input("x").Output("y").Build() + + def forward(self, x): + return self._op(x)[0] + + +@oneflow_export("negative", "neg") +@register_tensor_op("negative") +def negative_op(x): + """This operator computes the negative value of Tensor. + + Args: + x (oneflow.Tensor): A Tensor + + Returns: + oneflow.Tensor: The result Tensor + + For example: + + .. code-block:: python + + import oneflow as flow + import numpy as np + + input = flow.Tensor( + np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32 + ) + out = flow.negative(input).numpy() + + # out [-1.0, 1.0, -2.3] + + """ + return Negative()(x) diff --git a/oneflow/python/test/modules/test_negative.py b/oneflow/python/test/modules/test_negative.py new file mode 100644 index 000000000..18e19c9b3 --- /dev/null +++ b/oneflow/python/test/modules/test_negative.py @@ -0,0 +1,61 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest + +import numpy as np +import oneflow as flow + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestNegativeModule(flow.unittest.TestCase): + def test_negtive(test_case): + input = flow.Tensor( + np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32 + ) + of_out = flow.negative(input) + np_out = -(input.numpy()) + test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) + + def test_negative_neg(test_case): + input = flow.Tensor( + np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32 + ) + of_out = flow.neg(input) + np_out = -(input.numpy()) + test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) + + def test_tensor_negative(test_case): + input = flow.Tensor( + np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32 + ) + of_out = input.negative() + np_out = -(input.numpy()) + test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) + + def test_self_tensor_negative(test_case): + input = flow.Tensor( + np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32 + ) + of_out = -input + np_out = -(input.numpy()) + test_case.assertTrue(np.array_equal(of_out.numpy(), np_out)) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_nllloss.py b/oneflow/python/test/modules/test_nllloss.py new file mode 100644 index 000000000..67149cc9e --- /dev/null +++ b/oneflow/python/test/modules/test_nllloss.py @@ -0,0 +1,159 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest + +import numpy as np +import oneflow as flow + + +def nll_loss_1d(logs, targets, reduction="none"): + input_shape = logs.shape + N = input_shape[0] + C = input_shape[1] + out = np.zeros_like(targets).astype(np.float64) + total_weight = N + for i in range(N): + cur_target = targets[i] + out[i] = -logs[i][cur_target] + if reduction == "sum": + return np.sum(out) + elif reduction == "mean": + return out.sum() / total_weight + elif reduction == "none": + return out + + +def nll_loss_2d(logs, targets, reduction="none"): + input_shape = logs.shape + N = input_shape[0] + H = input_shape[2] + W = input_shape[3] + out = np.zeros_like(targets).astype(np.float64) + total_weight = N * H * W + for i in range(N): + for h in range(H): + for w in range(W): + cur_target = targets[i][h][w] + out[i][h][w] = -logs[i][cur_target][h][w] + if reduction == "sum": + return np.sum(out) + elif reduction == "mean": + return out.sum() / total_weight + elif reduction == "none": + return out + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestNLLLossModule(flow.unittest.TestCase): + def test_nllloss_none(test_case): + x = np.array( + [ + [0.88103855, 0.9908683, 0.6226845], + [0.53331435, 0.07999352, 0.8549948], + [0.25879037, 0.39530203, 0.698465], + [0.73427284, 0.63575995, 0.18827209], + [0.05689114, 0.0862954, 0.6325046], + ] + ).astype(np.float32) + y = np.array([0, 2, 1, 1, 0]).astype(np.int) + input = flow.Tensor(x, dtype=flow.float32) + + target = flow.Tensor(y, dtype=flow.int64) + nll_loss = flow.nn.NLLLoss() + of_out = nll_loss(input, target) + np_out = nll_loss_1d(input.numpy(), target.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + + def test_nllloss_mean(test_case): + x = np.array( + [ + [0.88103855, 0.9908683, 0.6226845], + [0.53331435, 0.07999352, 0.8549948], + [0.25879037, 0.39530203, 0.698465], + [0.73427284, 0.63575995, 0.18827209], + [0.05689114, 0.0862954, 0.6325046], + ] + ).astype(np.float32) + y = np.array([0, 2, 1, 1, 0]).astype(np.int) + input = flow.Tensor(x, dtype=flow.float32) + + target = flow.Tensor(y, dtype=flow.int64) + nll_loss = flow.nn.NLLLoss(reduction="mean") + of_out = nll_loss(input, target) + np_out = nll_loss_1d(input.numpy(), target.numpy(), reduction="mean") + test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + + def test_nllloss_sum(test_case): + x = np.array( + [ + [0.88103855, 0.9908683, 0.6226845], + [0.53331435, 0.07999352, 0.8549948], + [0.25879037, 0.39530203, 0.698465], + [0.73427284, 0.63575995, 0.18827209], + [0.05689114, 0.0862954, 0.6325046], + ] + ).astype(np.float32) + y = np.array([0, 2, 1, 1, 0]).astype(np.int) + input = flow.Tensor(x, dtype=flow.float32) + + target = flow.Tensor(y, dtype=flow.int64) + nll_loss = flow.nn.NLLLoss(reduction="sum") + of_out = nll_loss(input, target) + np_out = nll_loss_1d(input.numpy(), target.numpy(), reduction="sum") + test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + + def test_nllloss_segmentation_none(test_case): + x = np.array( + [[[[0.12, 0.36], [0.22, 0.66]], [[0.13, 0.34], [0.52, -0.96]]]] + ).astype(np.float32) + input = flow.Tensor(x, dtype=flow.float32) + y = np.array([[[1, 0], [0, 1]]]).astype(np.int) + target = flow.Tensor(y, dtype=flow.int64) + nll_loss = flow.nn.NLLLoss() + of_out = nll_loss(input, target) + np_out = nll_loss_2d(input.numpy(), target.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + + def test_nllloss_segmentation_mean(test_case): + x = np.array( + [[[[0.12, 0.36], [0.22, 0.66]], [[0.13, 0.34], [0.52, -0.96]]]] + ).astype(np.float32) + input = flow.Tensor(x, dtype=flow.float32) + y = np.array([[[1, 0], [0, 1]]]).astype(np.int) + target = flow.Tensor(y, dtype=flow.int64) + nll_loss = flow.nn.NLLLoss(reduction="mean") + of_out = nll_loss(input, target) + np_out = nll_loss_2d(input.numpy(), target.numpy(), reduction="mean") + test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + + def test_nllloss_segmentation_sum(test_case): + x = np.array( + [[[[0.12, 0.36], [0.22, 0.66]], [[0.13, 0.34], [0.52, -0.96]]]] + ).astype(np.float32) + input = flow.Tensor(x, dtype=flow.float32) + y = np.array([[[1, 0], [0, 1]]]).astype(np.int) + target = flow.Tensor(y, dtype=flow.int64) + nll_loss = flow.nn.NLLLoss(reduction="sum") + of_out = nll_loss(input, target) + np_out = nll_loss_2d(input.numpy(), target.numpy(), reduction="sum") + test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + + +if __name__ == "__main__": + unittest.main() -- GitLab