Skip to content
Snippets Groups Projects
Unverified Commit 5da25cff authored by YongtaoShi's avatar YongtaoShi Committed by GitHub
Browse files

add mseloss module (#5116)


* add mseloss module

* add mseloss testcase

* delete debug code

* add mseloss testcase

* rename mseloss testcase

* fix docstring warning

* format docstring

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 55c3f17d
No related branches found
No related tags found
No related merge requests found
......@@ -88,6 +88,7 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.Linear
.. autofunction:: oneflow.experimental.nn.CrossEntropyLoss
.. autofunction:: oneflow.experimental.nn.NLLLoss
.. autofunction:: oneflow.experimental.nn.MSELoss
.. autofunction:: oneflow.experimental.nn.MarginRankingLoss
.. autofunction:: oneflow.experimental.masked_fill
.. autofunction:: oneflow.experimental.Tensor.masked_fill
......
......@@ -16,8 +16,10 @@ limitations under the License.
from typing import Optional
import oneflow as flow
from oneflow.python.framework.tensor import Tensor
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.nn.modules.math_ops import Subtract, Square, Sum, Mean
@oneflow_export("nn.CrossEntropyLoss")
......@@ -296,6 +298,123 @@ class NLLLoss(Module):
return res.mean()
@oneflow_export("nn.MSELoss")
@experimental_api
class MSELoss(Module):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html?highlight=mseloss#torch.nn.MSELoss
Creates a criterion that measures the mean squared error (squared L2 norm) between
each element in the input :math:`x` and target :math:`y`.
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
.. math::
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
l_n = \left( x_n - y_n \right)^2,
where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
(default ``'mean'``), then:
.. math::
\ell(x, y) =
\begin{cases}
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
\end{cases}
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total
of :math:`n` elements each.
The mean operation still operates over all the elements, and divides by :math:`n`.
The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``.
Args:
size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
the losses are averaged over each loss element in the batch. Note that for
some losses, there are multiple elements per sample. If the field :attr:`size_average`
is set to ``False``, the losses are instead summed for each minibatch. Ignored
when :attr:`reduce` is ``False``. Default: ``True``
reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
losses are averaged or summed over observations for each minibatch depending
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
batch element instead and ignores :attr:`size_average`. Default: ``True``
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
Shape:
- Input: :math:`(N, *)` where :math:`*` means, any number of additional
dimensions
- Target: :math:`(N, *)`, same shape as the input
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> input = flow.Tensor(
... [[-0.02557137, 0.03101675, 1.37493674],
... [0.25599439, -1.08372561, -0.21006816]], dtype=flow.float32)
>>> target = flow.Tensor(
... [[-1.53105064, -0.68137555, 0.5931354],
... [-0.49158347, 0.93673637, 0.1324141]], dtype=flow.float32)
>>> m = flow.nn.MSELoss(reduction="none")
>>> out = m(input, target)
>>> print(out.numpy())
[[2.266468 0.50750285 0.61121327]
[0.55887264 4.082267 0.1172941 ]]
>>> m = flow.nn.MSELoss(reduction="mean")
>>> out = m(input, target)
>>> print(out.numpy())
[1.3572696]
>>> m = flow.nn.MSELoss(reduction="sum")
>>> out = m(input, target)
>>> print(out.numpy())
[8.143618]
"""
def __init__(
self, reduction: str = "mean", size_average: bool = True, reduce: bool = True
) -> None:
super().__init__()
if size_average is False:
raise ValueError("Argument size_average is not supported yet")
if reduce is False:
raise ValueError("Argument reduce is not supported yet")
assert reduction in [
"sum",
"none",
"mean",
None,
], "Argument reduction only support 'sum'/'mean'/'none'/None for now!"
self.reduction = reduction
self.square_op = Square()
self.subtract_op = Subtract()
self.sum_op = Sum()
self.mean_op = Mean()
def forward(self, input: Tensor, target: Tensor) -> Tensor:
mean_squared_difference = self.square_op(self.subtract_op(input, target))
if self.reduction == "mean":
return self.mean_op(mean_squared_difference)
elif self.reduction == "sum":
return self.sum_op(mean_squared_difference)
else:
# Do no reduction
return mean_squared_difference
@oneflow_export("nn.MarginRankingLoss")
@experimental_api
class MarginRankingLoss(Module):
......
......@@ -3917,6 +3917,7 @@ def bce_with_logits_loss(
@oneflow_export("nn.MSELoss")
@stable_api
def mse_loss(
input: oneflow._oneflow_internal.BlobDesc,
target: oneflow._oneflow_internal.BlobDesc,
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList
def _np_mseloss(np_input, np_target):
np_mse = np.square(np_target - np_input)
np_mse_mean = np.mean(np_mse)
np_mse_sum = np.sum(np_mse)
return {
"none": np_mse,
"mean": np_mse_mean,
"sum": np_mse_sum,
}
def _np_mseloss_grad(np_input, np_target):
elem_cnt = np_input.size
np_mse_grad_sum = -2 * (np_target - np_input)
np_mse_grad_mean = np_mse_grad_sum / elem_cnt
return {
"none": np_mse_grad_sum,
"mean": np_mse_grad_mean,
"sum": np_mse_grad_sum,
}
def _test_mseloss_impl(test_case, device, shape, reduction):
x = np.random.randn(*shape)
y = np.random.randn(*shape)
input = flow.Tensor(
x, dtype=flow.float32, requires_grad=True, device=flow.device(device)
)
target = flow.Tensor(y, dtype=flow.float32, device=flow.device(device))
loss = flow.nn.MSELoss(reduction=reduction)
loss = loss.to(device)
of_out = loss(input, target)
np_out = _np_mseloss(x, y)[reduction]
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
of_out = of_out.sum()
of_out.backward()
np_grad = _np_mseloss_grad(x, y)[reduction]
test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestMSELossModule(flow.unittest.TestCase):
def test_mseloss(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_mseloss_impl,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [
(3, 5),
(10, 9, 21),
(14, 22, 9, 21),
(3, 2, 4, 16, 5),
(1,),
]
arg_dict["reduction"] = ["none", "mean", "sum"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment