Skip to content
Snippets Groups Projects
Unverified Commit ec0d02c1 authored by ZhongHW's avatar ZhongHW Committed by GitHub
Browse files

split vector-matrix norm (#5478)


* split vector-matrix norm

* fix_vector_norm

* fix_docstring

* fix_matrix_norm

* fix-default

* fix-doctest

* auto format by CI

Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 06da41da
No related branches found
No related tags found
No related merge requests found
......@@ -209,6 +209,8 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.UpsamplingNearest2d
.. autofunction:: oneflow.experimental.nn.UpsamplingBilinear2d
.. autofunction:: oneflow.experimental.linalg.norm
.. autofunction:: oneflow.experimental.linalg.vector_norm
.. autofunction:: oneflow.experimental.linalg.matrix_norm
.. autofunction:: oneflow.experimental.Tensor.norm
.. autofunction:: oneflow.experimental.floor
.. autofunction:: oneflow.experimental.Tensor.floor
......
......@@ -21,111 +21,190 @@ from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.framework.tensor import register_tensor_op
class Norm(Module):
def __init__(self, ord=None, dim=None, keepdim=False) -> None:
super().__init__()
def check_dim(num_dims, input_dim):
if input_dim == None:
dim = input_dim
elif isinstance(input_dim, (int, tuple)):
if isinstance(input_dim, int):
dim = input_dim if input_dim >= 0 else input_dim + num_dims
if dim >= num_dims or dim < 0:
raise IndexError("Dimension out of range")
else:
temp = list(input_dim)
for i in range(len(temp)):
temp[i] = temp[i] if temp[i] >= 0 else temp[i] + num_dims
if temp[i] >= num_dims or temp[i] < 0:
raise IndexError("Dimension out of range")
dim = temp
else:
raise TypeError(
"linalg_vector_norm(): argument 'dim' must be tuple of ints, not {}".format(
type(input_dim)
)
)
return dim
self.ord = ord
def _norm_min_max(input, ord, dim, keepdim):
if ord > 0:
return flow.experimental.max(input, dim=dim, keepdim=keepdim)
else:
return flow.experimental.min(input, dim=dim, keepdim=keepdim)
class Vector_Norm(Module):
def __init__(self, ord=2, dim=None, keepdim=False) -> None:
super().__init__()
if ord == None:
self.ord = 2.0
elif isinstance(ord, (int, float)):
self.ord = float(ord)
else:
raise TypeError(
"linalg_vector_norm(): argument 'ord' must be Number, not {}".format(
type(ord)
)
)
self.dim = dim
self.keepdim = keepdim
def _vector_norm(self, x, ord, dim):
if isinstance(ord, str) and ord in ["fro", "nuc"]:
raise ValueError("Norm order {} is not supported for vectors".format(ord))
elif isinstance(ord, float) and ord in [float("inf"), float("-inf")]:
if ord == float("inf"):
return flow.experimental.max(flow.experimental.abs(x), dim=dim)
else:
return flow.experimental.min(flow.experimental.abs(x), dim=dim)
elif isinstance(ord, int):
if ord == 0:
# TODO: fix error when input are all zero vector
return flow.tensor([flow.experimental.argwhere(x).shape[0]])
else:
return flow.experimental.pow(
flow.experimental.sum(
flow.experimental.pow(flow.experimental.abs(x), ord), dim=dim
),
1.0 / ord,
)
def _vector_norm(self, x, ord, dim, keepdim=False):
if ord == 0:
# TODO: fix error when input are all zero vector
return flow.experimental.cast(
flow.tensor([flow.experimental.argwhere(x).shape[0]]), flow.float32
)
elif ord == float("inf"):
return flow.experimental.max(
flow.experimental.abs(x), dim=dim, keepdim=keepdim
)
elif ord == float("-inf"):
return flow.experimental.min(
flow.experimental.abs(x), dim=dim, keepdim=keepdim
)
else:
raise ValueError("Invalid norm order: {}".format(ord))
return flow.experimental.pow(
flow.experimental.sum(
flow.experimental.pow(flow.experimental.abs(x), ord),
dim=dim,
keepdim=keepdim,
),
1.0 / ord,
)
def _matrix_norm(self, x, ord, dim):
if isinstance(ord, str) and ord in ["fro", "nuc"]:
if ord == "nuc":
raise NotImplementedError
else:
return flow.experimental.sqrt(
flow.experimental.sum(flow.experimental.square(x), dim=dim)
)
elif isinstance(ord, float) and ord in [float("inf"), float("-inf")]:
if ord == float("inf"):
return flow.experimental.max(
flow.experimental.sum(flow.experimental.abs(x), dim=1)
)
else:
return flow.experimental.min(
flow.experimental.sum(flow.experimental.abs(x), dim=1)
)
def forward(self, x):
num_dims = len(x.shape)
dim = check_dim(num_dims, self.dim)
if dim == None:
return self._vector_norm(
x.flatten(), ord=self.ord, dim=self.dim, keepdim=self.keepdim
)
else:
return self._vector_norm(x, ord=self.ord, dim=dim, keepdim=self.keepdim)
class Matrix_Norm(Module):
def __init__(self, ord="fro", dim=(-2, -1), keepdim=False) -> None:
super().__init__()
if isinstance(ord, str):
assert ord in ["fro", "nuc"], "{} are not supported in matrix norm".format(
ord
)
self.ord = ord
elif isinstance(ord, float):
assert ord in [
float("inf"),
float("-inf"),
], "{} are not supported in matrix norm".format(ord)
self.ord = ord
elif isinstance(ord, int):
if ord == 1:
return flow.experimental.max(
flow.experimental.sum(flow.experimental.abs(x), dim=0)
)
elif ord == -1:
return flow.experimental.min(
flow.experimental.sum(flow.experimental.abs(x), dim=0)
assert ord in [1, -1, 2, -2], "{} are not supported in matrix norm".format(
ord
)
self.ord = ord
elif ord == None:
self.ord = "fro"
else:
raise TypeError(
"linalg_matrix_norm(): argument 'ord' must be Number, not {}".format(
type(ord)
)
elif ord == 2:
raise NotImplementedError
elif ord == -2:
raise NotImplementedError
else:
raise ValueError(
"Norm order {} is not supported for matrices".format(ord)
)
if isinstance(dim, tuple) and len(dim) == 2 and dim[0] != dim[1]:
self.dim = dim
else:
raise TypeError(
"linalg.matrix_norm(): dim must be a 2-tuple of ints with different elements"
)
self.keepdim = keepdim
def _matrix_norm(self, x, ord, dim, keepdim):
if ord == "nuc":
raise NotImplementedError
elif ord == "fro":
return flow.experimental.sqrt(
flow.experimental.sum(
flow.experimental.square(x), dim=dim, keepdim=keepdim
)
)
elif ord in [float("inf"), float("-inf")]:
dim_0, dim_1 = dim[0], dim[1]
dim_0, dim_1 = dim_1, dim_0
if dim_1 > dim_0 and not keepdim:
dim_1 -= 1
res = flow.experimental.sum(
flow.experimental.abs(x), dim=dim_0, keepdim=keepdim
)
return _norm_min_max(res, ord, dim_1, keepdim)
elif ord in [1, -1]:
dim_0, dim_1 = dim[0], dim[1]
if dim_1 > dim_0 and not keepdim:
dim_1 -= 1
res = flow.experimental.sum(
flow.experimental.abs(x), dim=dim_0, keepdim=keepdim
)
return _norm_min_max(res, ord, dim_1, keepdim)
elif ord in [2, -2]:
raise NotImplementedError
else:
raise ValueError("Invalid norm order: {}".format(ord))
def _whether_keepdim(self, x):
if self.keepdim == True and self.dim != None:
return flow.experimental.unsqueeze(x, self.dim)
else:
return x
def forward(self, x):
num_dims = len(x.shape)
if num_dims < 2:
raise RuntimeError(
"linalg.matrix_norm(): input tensor must be a matrix or batch of matrices"
)
dim = check_dim(num_dims, self.dim)
return self._matrix_norm(x, ord=self.ord, dim=dim, keepdim=self.keepdim)
class Norm(Module):
def __init__(self, ord=None, dim=None, keepdim=False) -> None:
super().__init__()
self.ord = ord
self.dim = dim
self.keepdim = keepdim
def forward(self, x):
num_axes = len(x.shape)
if self.dim == None and self.ord == None:
res = self._vector_norm(x.reshape((1, -1))[0], ord=2, dim=self.dim)
if isinstance(self.dim, int):
res = Vector_Norm(ord=self.ord, dim=self.dim, keepdim=self.keepdim)(x)
elif isinstance(self.dim, tuple):
res = Matrix_Norm(ord=self.ord, dim=self.dim, keepdim=self.keepdim)(x)
elif self.dim == None and self.ord != None:
assert (
num_axes <= 2
len(x.shape) <= 2
), "input must be 1-D or 2-D when dim is None and ord is not None"
res = (
self._vector_norm(x, self.ord, self.dim)
if num_axes == 1
else self._matrix_norm(x, self.ord, self.dim)
)
elif isinstance(self.dim, (int, tuple, list)):
if isinstance(self.dim, int):
self.dim = self.dim if self.dim >= 0 else self.dim + num_axes
assert 0 <= self.dim < num_axes, "dim out of range"
res = self._vector_norm(
x, ord=2 if self.ord == None else self.ord, dim=self.dim
)
if len(x.shape) == 1:
res = Vector_Norm(ord=self.ord, keepdim=self.keepdim)(x)
else:
temp = list(self.dim) if isinstance(self.dim, tuple) else self.dim
for i in range(len(temp)):
temp[i] = temp[i] if temp[i] >= 0 else temp[i] + num_axes
assert 0 <= temp[i] < num_axes, "dim out of range"
self.dim = temp
res = self._matrix_norm(
x, ord="fro" if self.ord == None else self.ord, dim=self.dim
)
else:
raise ValueError("Invalid dimension: {}".format(self.dim))
return self._whether_keepdim(res)
res = Matrix_Norm(ord=self.ord, keepdim=self.keepdim)(x)
elif self.dim == None and self.ord == None:
res = Vector_Norm(keepdim=self.keepdim)(x)
return res
@oneflow_export("linalg.norm")
......@@ -253,11 +332,154 @@ def norm_op(input, ord=None, dim=None, keepdim=False):
@experimental_api
def norm_tensor_op(input, ord=None, dim=None, keepdim=False):
r"""
See :func:`oneflow.experimental.linalg.norm.`
See :func:`oneflow.experimental.linalg.norm`
"""
return Norm(ord, dim, keepdim)(input)
@oneflow_export("linalg.vector_norm")
@experimental_api
def vector_norm_tensor_op(input, ord=2, dim=None, keepdim=False):
r"""
linalg.vector_norm(input, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor
Computes a vector norm.
Supports input of float, double dtypes.
This function does not necessarily treat multidimensonal attr:`input` as a batch of
vectors, instead:
- If :attr:`dim`\ `= None`, :attr:`input` will be flattened before the norm is computed.
- If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions and the other dimensions will be treated as batch dimensions.
This behavior is for consistency with :func:`flow.linalg.norm`.
:attr:`ord` defines the vector norm that is computed. The following norms are supported:
====================== ========================================================
:attr:`ord` vector norm
====================== ========================================================
`2` (default) `2`-norm (see below)
`inf` `max(abs(x))`
`-inf` `min(abs(x))`
`0` `sum(x != 0)`
other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}`
====================== ========================================================
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
Args:
input (Tensor): tensor, flattened by default, but this behavior can be
controlled using :attr:`dim`.
ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2`
dim (int, Tuple[int], optional): dimensions over which to compute
the norm. See above for the behavior when :attr:`dim`\ `= None`.
Default: `None`
keepdim (bool, optional): If set to `True`, the reduced dimensions are retained
in the result as dimensions with size one. Default: `False`
Returns:
A real-valued tensor.
Examples::
>>> import oneflow.experimental as flow
>>> from oneflow.experimental import linalg as LA
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4)
>>> a
tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32)
>>> b = a.reshape((3, 3))
>>> b
tensor([[-4., -3., -2.],
[-1., 0., 1.],
[ 2., 3., 4.]], dtype=oneflow.float32)
>>> LA.vector_norm(a, ord=3.5)
tensor([5.4345], dtype=oneflow.float32)
>>> LA.vector_norm(b, ord=3.5)
tensor([5.4345], dtype=oneflow.float32)
"""
return Vector_Norm(ord, dim, keepdim)(input)
@oneflow_export("linalg.matrix_norm")
@experimental_api
def matrix_norm_tensor_op(input, ord="fro", dim=(-2, -1), keepdim=False):
r"""
linalg.matrix_norm(input, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor
Computes a matrix norm.
Support input of float, double, cfloat and cdouble dtypes.
Also supports batches of matrices: the norm will be computed over the
dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will
be treated as batch dimensions. The output will have the same batch dimensions.
:attr:`ord` defines the matrix norm that is computed. The following norms are supported:
====================== ========================================================
:attr:`ord` matrix norm
====================== ========================================================
`'fro'` (default) Frobenius norm
`'nuc'` -- not supported yet --
`inf` `max(sum(abs(x), dim=1))`
`-inf` `min(sum(abs(x), dim=1))`
`1` `max(sum(abs(x), dim=0))`
`-1` `min(sum(abs(x), dim=0))`
`2` -- not supported yet --
`-2` -- not supported yet --
====================== ========================================================
where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object.
Args:
input (Tensor): tensor with two or more dimensions. By default its
shape is interpreted as `(*, m, n)` where `*` is zero or more
batch dimensions, but this behavior can be controlled using :attr:`dim`.
ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'`
dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)`
keepdim (bool, optional): If set to `True`, the reduced dimensions are retained
in the result as dimensions with size one. Default: `False`
Returns:
A real-valued tensor.
Examples::
>>> import oneflow.experimental as flow
>>> from oneflow.experimental import linalg as LA
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> a = flow.tensor(np.arange(9, dtype=np.float32)).reshape((3,3))
>>> a
tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]], dtype=oneflow.float32)
>>> LA.matrix_norm(a)
tensor([14.2829], dtype=oneflow.float32)
>>> LA.matrix_norm(a, ord=-1)
tensor([9.], dtype=oneflow.float32)
>>> b = a.expand(2, -1, -1)
>>> b
tensor([[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]],
<BLANKLINE>
[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]]], dtype=oneflow.float32)
>>> LA.matrix_norm(b)
tensor([14.2829, 14.2829], dtype=oneflow.float32)
>>> LA.matrix_norm(b, dim=(0, 2))
tensor([ 3.1623, 10. , 17.2627], dtype=oneflow.float32)
"""
return Matrix_Norm(ord, dim, keepdim)(input)
if __name__ == "__main__":
import doctest
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment