From ec0d02c1ff04f023d497dc9e88fe611e5f76800c Mon Sep 17 00:00:00 2001 From: ZhongHW <35329085+puchapu@users.noreply.github.com> Date: Sat, 17 Jul 2021 05:40:22 +0800 Subject: [PATCH] split vector-matrix norm (#5478) * split vector-matrix norm * fix_vector_norm * fix_docstring * fix_matrix_norm * fix-default * fix-doctest * auto format by CI Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- docs/source/experimental.rst | 2 + oneflow/python/nn/modules/norm.py | 398 +++++++++++++++++++++++------- 2 files changed, 312 insertions(+), 88 deletions(-) diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst index cd07e80cd..4242190fa 100644 --- a/docs/source/experimental.rst +++ b/docs/source/experimental.rst @@ -209,6 +209,8 @@ Experimental features .. autofunction:: oneflow.experimental.nn.UpsamplingNearest2d .. autofunction:: oneflow.experimental.nn.UpsamplingBilinear2d .. autofunction:: oneflow.experimental.linalg.norm +.. autofunction:: oneflow.experimental.linalg.vector_norm +.. autofunction:: oneflow.experimental.linalg.matrix_norm .. autofunction:: oneflow.experimental.Tensor.norm .. autofunction:: oneflow.experimental.floor .. autofunction:: oneflow.experimental.Tensor.floor diff --git a/oneflow/python/nn/modules/norm.py b/oneflow/python/nn/modules/norm.py index fe5a64f7d..2290d5b9e 100644 --- a/oneflow/python/nn/modules/norm.py +++ b/oneflow/python/nn/modules/norm.py @@ -21,111 +21,190 @@ from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.framework.tensor import register_tensor_op -class Norm(Module): - def __init__(self, ord=None, dim=None, keepdim=False) -> None: - super().__init__() +def check_dim(num_dims, input_dim): + if input_dim == None: + dim = input_dim + elif isinstance(input_dim, (int, tuple)): + if isinstance(input_dim, int): + dim = input_dim if input_dim >= 0 else input_dim + num_dims + if dim >= num_dims or dim < 0: + raise IndexError("Dimension out of range") + else: + temp = list(input_dim) + for i in range(len(temp)): + temp[i] = temp[i] if temp[i] >= 0 else temp[i] + num_dims + if temp[i] >= num_dims or temp[i] < 0: + raise IndexError("Dimension out of range") + dim = temp + else: + raise TypeError( + "linalg_vector_norm(): argument 'dim' must be tuple of ints, not {}".format( + type(input_dim) + ) + ) + return dim - self.ord = ord + +def _norm_min_max(input, ord, dim, keepdim): + if ord > 0: + return flow.experimental.max(input, dim=dim, keepdim=keepdim) + else: + return flow.experimental.min(input, dim=dim, keepdim=keepdim) + + +class Vector_Norm(Module): + def __init__(self, ord=2, dim=None, keepdim=False) -> None: + super().__init__() + if ord == None: + self.ord = 2.0 + elif isinstance(ord, (int, float)): + self.ord = float(ord) + else: + raise TypeError( + "linalg_vector_norm(): argument 'ord' must be Number, not {}".format( + type(ord) + ) + ) self.dim = dim self.keepdim = keepdim - def _vector_norm(self, x, ord, dim): - if isinstance(ord, str) and ord in ["fro", "nuc"]: - raise ValueError("Norm order {} is not supported for vectors".format(ord)) - elif isinstance(ord, float) and ord in [float("inf"), float("-inf")]: - if ord == float("inf"): - return flow.experimental.max(flow.experimental.abs(x), dim=dim) - else: - return flow.experimental.min(flow.experimental.abs(x), dim=dim) - elif isinstance(ord, int): - if ord == 0: - # TODO: fix error when input are all zero vector - return flow.tensor([flow.experimental.argwhere(x).shape[0]]) - else: - return flow.experimental.pow( - flow.experimental.sum( - flow.experimental.pow(flow.experimental.abs(x), ord), dim=dim - ), - 1.0 / ord, - ) + def _vector_norm(self, x, ord, dim, keepdim=False): + if ord == 0: + # TODO: fix error when input are all zero vector + return flow.experimental.cast( + flow.tensor([flow.experimental.argwhere(x).shape[0]]), flow.float32 + ) + elif ord == float("inf"): + return flow.experimental.max( + flow.experimental.abs(x), dim=dim, keepdim=keepdim + ) + elif ord == float("-inf"): + return flow.experimental.min( + flow.experimental.abs(x), dim=dim, keepdim=keepdim + ) else: - raise ValueError("Invalid norm order: {}".format(ord)) + return flow.experimental.pow( + flow.experimental.sum( + flow.experimental.pow(flow.experimental.abs(x), ord), + dim=dim, + keepdim=keepdim, + ), + 1.0 / ord, + ) - def _matrix_norm(self, x, ord, dim): - if isinstance(ord, str) and ord in ["fro", "nuc"]: - if ord == "nuc": - raise NotImplementedError - else: - return flow.experimental.sqrt( - flow.experimental.sum(flow.experimental.square(x), dim=dim) - ) - elif isinstance(ord, float) and ord in [float("inf"), float("-inf")]: - if ord == float("inf"): - return flow.experimental.max( - flow.experimental.sum(flow.experimental.abs(x), dim=1) - ) - else: - return flow.experimental.min( - flow.experimental.sum(flow.experimental.abs(x), dim=1) - ) + def forward(self, x): + num_dims = len(x.shape) + dim = check_dim(num_dims, self.dim) + if dim == None: + return self._vector_norm( + x.flatten(), ord=self.ord, dim=self.dim, keepdim=self.keepdim + ) + else: + return self._vector_norm(x, ord=self.ord, dim=dim, keepdim=self.keepdim) + + +class Matrix_Norm(Module): + def __init__(self, ord="fro", dim=(-2, -1), keepdim=False) -> None: + super().__init__() + if isinstance(ord, str): + assert ord in ["fro", "nuc"], "{} are not supported in matrix norm".format( + ord + ) + self.ord = ord + elif isinstance(ord, float): + assert ord in [ + float("inf"), + float("-inf"), + ], "{} are not supported in matrix norm".format(ord) + self.ord = ord elif isinstance(ord, int): - if ord == 1: - return flow.experimental.max( - flow.experimental.sum(flow.experimental.abs(x), dim=0) - ) - elif ord == -1: - return flow.experimental.min( - flow.experimental.sum(flow.experimental.abs(x), dim=0) + assert ord in [1, -1, 2, -2], "{} are not supported in matrix norm".format( + ord + ) + self.ord = ord + elif ord == None: + self.ord = "fro" + else: + raise TypeError( + "linalg_matrix_norm(): argument 'ord' must be Number, not {}".format( + type(ord) ) - elif ord == 2: - raise NotImplementedError - elif ord == -2: - raise NotImplementedError - else: - raise ValueError( - "Norm order {} is not supported for matrices".format(ord) + ) + if isinstance(dim, tuple) and len(dim) == 2 and dim[0] != dim[1]: + self.dim = dim + else: + raise TypeError( + "linalg.matrix_norm(): dim must be a 2-tuple of ints with different elements" + ) + self.keepdim = keepdim + + def _matrix_norm(self, x, ord, dim, keepdim): + if ord == "nuc": + raise NotImplementedError + elif ord == "fro": + return flow.experimental.sqrt( + flow.experimental.sum( + flow.experimental.square(x), dim=dim, keepdim=keepdim ) + ) + + elif ord in [float("inf"), float("-inf")]: + dim_0, dim_1 = dim[0], dim[1] + dim_0, dim_1 = dim_1, dim_0 + if dim_1 > dim_0 and not keepdim: + dim_1 -= 1 + res = flow.experimental.sum( + flow.experimental.abs(x), dim=dim_0, keepdim=keepdim + ) + return _norm_min_max(res, ord, dim_1, keepdim) + + elif ord in [1, -1]: + dim_0, dim_1 = dim[0], dim[1] + if dim_1 > dim_0 and not keepdim: + dim_1 -= 1 + res = flow.experimental.sum( + flow.experimental.abs(x), dim=dim_0, keepdim=keepdim + ) + return _norm_min_max(res, ord, dim_1, keepdim) + elif ord in [2, -2]: + raise NotImplementedError else: raise ValueError("Invalid norm order: {}".format(ord)) - def _whether_keepdim(self, x): - if self.keepdim == True and self.dim != None: - return flow.experimental.unsqueeze(x, self.dim) - else: - return x + def forward(self, x): + num_dims = len(x.shape) + if num_dims < 2: + raise RuntimeError( + "linalg.matrix_norm(): input tensor must be a matrix or batch of matrices" + ) + dim = check_dim(num_dims, self.dim) + return self._matrix_norm(x, ord=self.ord, dim=dim, keepdim=self.keepdim) + + +class Norm(Module): + def __init__(self, ord=None, dim=None, keepdim=False) -> None: + super().__init__() + + self.ord = ord + self.dim = dim + self.keepdim = keepdim def forward(self, x): - num_axes = len(x.shape) - if self.dim == None and self.ord == None: - res = self._vector_norm(x.reshape((1, -1))[0], ord=2, dim=self.dim) + if isinstance(self.dim, int): + res = Vector_Norm(ord=self.ord, dim=self.dim, keepdim=self.keepdim)(x) + elif isinstance(self.dim, tuple): + res = Matrix_Norm(ord=self.ord, dim=self.dim, keepdim=self.keepdim)(x) elif self.dim == None and self.ord != None: assert ( - num_axes <= 2 + len(x.shape) <= 2 ), "input must be 1-D or 2-D when dim is None and ord is not None" - res = ( - self._vector_norm(x, self.ord, self.dim) - if num_axes == 1 - else self._matrix_norm(x, self.ord, self.dim) - ) - elif isinstance(self.dim, (int, tuple, list)): - if isinstance(self.dim, int): - self.dim = self.dim if self.dim >= 0 else self.dim + num_axes - assert 0 <= self.dim < num_axes, "dim out of range" - res = self._vector_norm( - x, ord=2 if self.ord == None else self.ord, dim=self.dim - ) + if len(x.shape) == 1: + res = Vector_Norm(ord=self.ord, keepdim=self.keepdim)(x) else: - temp = list(self.dim) if isinstance(self.dim, tuple) else self.dim - for i in range(len(temp)): - temp[i] = temp[i] if temp[i] >= 0 else temp[i] + num_axes - assert 0 <= temp[i] < num_axes, "dim out of range" - self.dim = temp - res = self._matrix_norm( - x, ord="fro" if self.ord == None else self.ord, dim=self.dim - ) - else: - raise ValueError("Invalid dimension: {}".format(self.dim)) - return self._whether_keepdim(res) + res = Matrix_Norm(ord=self.ord, keepdim=self.keepdim)(x) + elif self.dim == None and self.ord == None: + res = Vector_Norm(keepdim=self.keepdim)(x) + return res @oneflow_export("linalg.norm") @@ -253,11 +332,154 @@ def norm_op(input, ord=None, dim=None, keepdim=False): @experimental_api def norm_tensor_op(input, ord=None, dim=None, keepdim=False): r""" - See :func:`oneflow.experimental.linalg.norm.` + See :func:`oneflow.experimental.linalg.norm` """ return Norm(ord, dim, keepdim)(input) +@oneflow_export("linalg.vector_norm") +@experimental_api +def vector_norm_tensor_op(input, ord=2, dim=None, keepdim=False): + r""" + linalg.vector_norm(input, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + + Computes a vector norm. + + Supports input of float, double dtypes. + + This function does not necessarily treat multidimensonal attr:`input` as a batch of + vectors, instead: + + - If :attr:`dim`\ `= None`, :attr:`input` will be flattened before the norm is computed. + - If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions and the other dimensions will be treated as batch dimensions. + + This behavior is for consistency with :func:`flow.linalg.norm`. + + :attr:`ord` defines the vector norm that is computed. The following norms are supported: + + ====================== ======================================================== + :attr:`ord` vector norm + ====================== ======================================================== + `2` (default) `2`-norm (see below) + `inf` `max(abs(x))` + `-inf` `min(abs(x))` + `0` `sum(x != 0)` + other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` + ====================== ======================================================== + + where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + + + Args: + input (Tensor): tensor, flattened by default, but this behavior can be + controlled using :attr:`dim`. + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` + dim (int, Tuple[int], optional): dimensions over which to compute + the norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + + Returns: + A real-valued tensor. + + Examples:: + + >>> import oneflow.experimental as flow + >>> from oneflow.experimental import linalg as LA + >>> import numpy as np + >>> flow.enable_eager_execution() + >>> a = flow.tensor(np.arange(9, dtype=np.float32) - 4) + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=oneflow.float32) + >>> b = a.reshape((3, 3)) + >>> b + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]], dtype=oneflow.float32) + >>> LA.vector_norm(a, ord=3.5) + tensor([5.4345], dtype=oneflow.float32) + >>> LA.vector_norm(b, ord=3.5) + tensor([5.4345], dtype=oneflow.float32) + """ + return Vector_Norm(ord, dim, keepdim)(input) + + +@oneflow_export("linalg.matrix_norm") +@experimental_api +def matrix_norm_tensor_op(input, ord="fro", dim=(-2, -1), keepdim=False): + r""" + linalg.matrix_norm(input, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor + + Computes a matrix norm. + + Support input of float, double, cfloat and cdouble dtypes. + Also supports batches of matrices: the norm will be computed over the + dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will + be treated as batch dimensions. The output will have the same batch dimensions. + + :attr:`ord` defines the matrix norm that is computed. The following norms are supported: + + ====================== ======================================================== + :attr:`ord` matrix norm + ====================== ======================================================== + `'fro'` (default) Frobenius norm + `'nuc'` -- not supported yet -- + `inf` `max(sum(abs(x), dim=1))` + `-inf` `min(sum(abs(x), dim=1))` + `1` `max(sum(abs(x), dim=0))` + `-1` `min(sum(abs(x), dim=0))` + `2` -- not supported yet -- + `-2` -- not supported yet -- + ====================== ======================================================== + + where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + + Args: + input (Tensor): tensor with two or more dimensions. By default its + shape is interpreted as `(*, m, n)` where `*` is zero or more + batch dimensions, but this behavior can be controlled using :attr:`dim`. + ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'` + dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + + + Returns: + A real-valued tensor. + + Examples:: + + >>> import oneflow.experimental as flow + >>> from oneflow.experimental import linalg as LA + >>> import numpy as np + >>> flow.enable_eager_execution() + >>> a = flow.tensor(np.arange(9, dtype=np.float32)).reshape((3,3)) + >>> a + tensor([[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]], dtype=oneflow.float32) + >>> LA.matrix_norm(a) + tensor([14.2829], dtype=oneflow.float32) + >>> LA.matrix_norm(a, ord=-1) + tensor([9.], dtype=oneflow.float32) + >>> b = a.expand(2, -1, -1) + >>> b + tensor([[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]], + <BLANKLINE> + [[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]], dtype=oneflow.float32) + >>> LA.matrix_norm(b) + tensor([14.2829, 14.2829], dtype=oneflow.float32) + >>> LA.matrix_norm(b, dim=(0, 2)) + tensor([ 3.1623, 10. , 17.2627], dtype=oneflow.float32) + """ + return Matrix_Norm(ord, dim, keepdim)(input) + + if __name__ == "__main__": import doctest -- GitLab