Skip to content
Snippets Groups Projects
Unverified Commit d1a76aa8 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add negative and nllloss module (#4789)


* add negative and nllloss module

* add negative and nllloss module

* add negative and nllloss module

* add negative and nllloss module

* add nllloss impleted, relay on other module, do not merge now

* fix nllloss bug

* support 2d nllloss bug

* fix comment

* add flow.neg api

* support bert nllloss

* fix nllloss 2d bug

* add docs

* fix comment

* fix comment

* fix comment

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 67bdf16e
No related branches found
No related tags found
No related merge requests found
......@@ -332,7 +332,7 @@ class Tensor:
return flow.div(other, self)
def __neg__(self):
return flow.mul(-1, self)
return flow.neg(self)
def _determine_if_needed(self, determining_initializer=None):
if not self.is_determined:
......
......@@ -86,3 +86,133 @@ class CrossEntropyLoss(Module):
return flow.sum(out)
else:
return out
@oneflow_export("nn.NLLLoss")
class NLLLoss(Module):
r""" The negative log likelihood loss. It is useful to train a classification
problem with `C` classes.
The `input` given through a forward call is expected to contain
log-probabilities of each class. `input` has to be a Tensor of size either
:math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)`
with :math:`K \geq 1` for the `K`-dimensional case (described later).
Obtaining log-probabilities in a neural network is easily achieved by
adding a `LogSoftmax` layer in the last layer of your network.
You may use `CrossEntropyLoss` instead, if you prefer not to add an extra
layer.
The `target` that this loss expects should be a class index in the range :math:`[0, C-1]`
where `C = number of classes`;
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = - w_{y_n} x_{n,y_n}, \quad
w_{c} = \mathbb{1},
where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and
:math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then
.. math::
\ell(x, y) = \begin{cases}
\sum_{n=1}^N \frac{1}{N} l_n, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
Can also be used for higher dimension inputs, such as 2D images, by providing
an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`,
where :math:`K` is the number of dimensions, and a target of appropriate shape
(see below). In the case of images, it computes NLL loss per-pixel.
Args:
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
be applied, ``'mean'``: the weighted mean of the output is taken,
``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in
the meantime, specifying either of those two args will override
:attr:`reduction`. Default: ``'mean'``
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
input = flow.Tensor(
[[-0.1664078, -1.7256707, -0.14690138],
[-0.21474946, 0.53737473, 0.99684894],
[-1.135804, -0.50371903, 0.7645404]], dtype=flow.float32)
target = flow.Tensor(np.array([0, 1, 2]), dtype=flow.int32)
out = flow.nn.NLLLoss(reduction="none")(input, target)
# out: [0.80199665 1.1166505 0.35826027]
out_sum = flow.nn.NLLLoss(reduction="sum")(input, target)
# out_sum: [2.2769074]
out_mean = flow.nn.NLLLoss(reduction="mean")(input, target)
# out_mean: [0.7589692]
"""
def __init__(
self, weight=None, ignore_index: int = None, reduction: str = "none",
) -> None:
super().__init__()
if weight != None:
raise ValueError("Argument weight is not supported yet")
if ignore_index != None:
raise ValueError("Argument ignore_index is not supported yet")
assert reduction in [
"sum",
"none",
"mean",
None,
], "only 'sum', 'mean' and None supported by now"
self.reduction = reduction
self._gather_nd_op = (
flow.builtin_op("gather_nd")
.Input("params")
.Input("indices")
.Output("out")
.Build()
)
def nllloss_1d(self, input, target):
n = input.shape[0]
idx = flow.unsqueeze(flow.arange(0, n, 1), dim=1)
target = flow.unsqueeze(target, dim=1)
t = flow.cat([idx, target], dim=1)
res = self._gather_nd_op(input, t)[0]
return res
def forward(self, input, target):
assert len(input.shape) == 2 or len(input.shape) == 4
input = flow.negative(input)
if len(input.shape) == 2:
res = self.nllloss_1d(input, target)
elif len(input.shape) == 4:
b, c, h, w = input.shape[0], input.shape[1], input.shape[2], input.shape[3]
input = flow.tmp.transpose(input, (0, 2, 3, 1))
input = flow.tmp.reshape(input, shape=[-1, input.shape[3]])
target = flow.tmp.flatten(target)
res = self.nllloss_1d(input, target)
res = flow.tmp.reshape(res, (b, h, w))
else:
raise NotImplemented
if self.reduction == "none":
return res
elif self.reduction == "sum":
return flow.sum(res)
else:
return flow.mean(res)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
class Negative(Module):
def __init__(self) -> None:
super().__init__()
self._op = flow.builtin_op("negative").Input("x").Output("y").Build()
def forward(self, x):
return self._op(x)[0]
@oneflow_export("negative", "neg")
@register_tensor_op("negative")
def negative_op(x):
"""This operator computes the negative value of Tensor.
Args:
x (oneflow.Tensor): A Tensor
Returns:
oneflow.Tensor: The result Tensor
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
input = flow.Tensor(
np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32
)
out = flow.negative(input).numpy()
# out [-1.0, 1.0, -2.3]
"""
return Negative()(x)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestNegativeModule(flow.unittest.TestCase):
def test_negtive(test_case):
input = flow.Tensor(
np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32
)
of_out = flow.negative(input)
np_out = -(input.numpy())
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def test_negative_neg(test_case):
input = flow.Tensor(
np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32
)
of_out = flow.neg(input)
np_out = -(input.numpy())
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def test_tensor_negative(test_case):
input = flow.Tensor(
np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32
)
of_out = input.negative()
np_out = -(input.numpy())
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def test_self_tensor_negative(test_case):
input = flow.Tensor(
np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32
)
of_out = -input
np_out = -(input.numpy())
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
def nll_loss_1d(logs, targets, reduction="none"):
input_shape = logs.shape
N = input_shape[0]
C = input_shape[1]
out = np.zeros_like(targets).astype(np.float64)
total_weight = N
for i in range(N):
cur_target = targets[i]
out[i] = -logs[i][cur_target]
if reduction == "sum":
return np.sum(out)
elif reduction == "mean":
return out.sum() / total_weight
elif reduction == "none":
return out
def nll_loss_2d(logs, targets, reduction="none"):
input_shape = logs.shape
N = input_shape[0]
H = input_shape[2]
W = input_shape[3]
out = np.zeros_like(targets).astype(np.float64)
total_weight = N * H * W
for i in range(N):
for h in range(H):
for w in range(W):
cur_target = targets[i][h][w]
out[i][h][w] = -logs[i][cur_target][h][w]
if reduction == "sum":
return np.sum(out)
elif reduction == "mean":
return out.sum() / total_weight
elif reduction == "none":
return out
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestNLLLossModule(flow.unittest.TestCase):
def test_nllloss_none(test_case):
x = np.array(
[
[0.88103855, 0.9908683, 0.6226845],
[0.53331435, 0.07999352, 0.8549948],
[0.25879037, 0.39530203, 0.698465],
[0.73427284, 0.63575995, 0.18827209],
[0.05689114, 0.0862954, 0.6325046],
]
).astype(np.float32)
y = np.array([0, 2, 1, 1, 0]).astype(np.int)
input = flow.Tensor(x, dtype=flow.float32)
target = flow.Tensor(y, dtype=flow.int64)
nll_loss = flow.nn.NLLLoss()
of_out = nll_loss(input, target)
np_out = nll_loss_1d(input.numpy(), target.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
def test_nllloss_mean(test_case):
x = np.array(
[
[0.88103855, 0.9908683, 0.6226845],
[0.53331435, 0.07999352, 0.8549948],
[0.25879037, 0.39530203, 0.698465],
[0.73427284, 0.63575995, 0.18827209],
[0.05689114, 0.0862954, 0.6325046],
]
).astype(np.float32)
y = np.array([0, 2, 1, 1, 0]).astype(np.int)
input = flow.Tensor(x, dtype=flow.float32)
target = flow.Tensor(y, dtype=flow.int64)
nll_loss = flow.nn.NLLLoss(reduction="mean")
of_out = nll_loss(input, target)
np_out = nll_loss_1d(input.numpy(), target.numpy(), reduction="mean")
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
def test_nllloss_sum(test_case):
x = np.array(
[
[0.88103855, 0.9908683, 0.6226845],
[0.53331435, 0.07999352, 0.8549948],
[0.25879037, 0.39530203, 0.698465],
[0.73427284, 0.63575995, 0.18827209],
[0.05689114, 0.0862954, 0.6325046],
]
).astype(np.float32)
y = np.array([0, 2, 1, 1, 0]).astype(np.int)
input = flow.Tensor(x, dtype=flow.float32)
target = flow.Tensor(y, dtype=flow.int64)
nll_loss = flow.nn.NLLLoss(reduction="sum")
of_out = nll_loss(input, target)
np_out = nll_loss_1d(input.numpy(), target.numpy(), reduction="sum")
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
def test_nllloss_segmentation_none(test_case):
x = np.array(
[[[[0.12, 0.36], [0.22, 0.66]], [[0.13, 0.34], [0.52, -0.96]]]]
).astype(np.float32)
input = flow.Tensor(x, dtype=flow.float32)
y = np.array([[[1, 0], [0, 1]]]).astype(np.int)
target = flow.Tensor(y, dtype=flow.int64)
nll_loss = flow.nn.NLLLoss()
of_out = nll_loss(input, target)
np_out = nll_loss_2d(input.numpy(), target.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
def test_nllloss_segmentation_mean(test_case):
x = np.array(
[[[[0.12, 0.36], [0.22, 0.66]], [[0.13, 0.34], [0.52, -0.96]]]]
).astype(np.float32)
input = flow.Tensor(x, dtype=flow.float32)
y = np.array([[[1, 0], [0, 1]]]).astype(np.int)
target = flow.Tensor(y, dtype=flow.int64)
nll_loss = flow.nn.NLLLoss(reduction="mean")
of_out = nll_loss(input, target)
np_out = nll_loss_2d(input.numpy(), target.numpy(), reduction="mean")
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
def test_nllloss_segmentation_sum(test_case):
x = np.array(
[[[[0.12, 0.36], [0.22, 0.66]], [[0.13, 0.34], [0.52, -0.96]]]]
).astype(np.float32)
input = flow.Tensor(x, dtype=flow.float32)
y = np.array([[[1, 0], [0, 1]]]).astype(np.int)
target = flow.Tensor(y, dtype=flow.int64)
nll_loss = flow.nn.NLLLoss(reduction="sum")
of_out = nll_loss(input, target)
np_out = nll_loss_2d(input.numpy(), target.numpy(), reduction="sum")
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment