From 28bb85e5036f19b5fa1b41b1282a5a0683d5c8d4 Mon Sep 17 00:00:00 2001 From: zjlablichenyang <84563719+zjlablichenyang@users.noreply.github.com> Date: Wed, 28 Jul 2021 13:39:03 +0800 Subject: [PATCH] Add clip_grad_norm (#5299) * Adapt to new structure * auto format by CI * fix bug Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org> --- python/oneflow/nn/__init__.py | 1 + python/oneflow/nn/graph.py | 2 +- python/oneflow/nn/graph_block.py | 2 +- python/oneflow/nn/{utils.py => util.py} | 0 python/oneflow/nn/utils/__init__.py | 16 ++ python/oneflow/nn/utils/clip_grad.py | 138 ++++++++++++++++++ python/oneflow/test/modules/test_clip_grad.py | 81 ++++++++++ 7 files changed, 238 insertions(+), 2 deletions(-) rename python/oneflow/nn/{utils.py => util.py} (100%) create mode 100644 python/oneflow/nn/utils/__init__.py create mode 100644 python/oneflow/nn/utils/clip_grad.py create mode 100644 python/oneflow/test/modules/test_clip_grad.py diff --git a/python/oneflow/nn/__init__.py b/python/oneflow/nn/__init__.py index 890810e61..0993b8e4e 100644 --- a/python/oneflow/nn/__init__.py +++ b/python/oneflow/nn/__init__.py @@ -106,5 +106,6 @@ from oneflow.ops.domain_ops import ( api_fused_self_attention_query_mul_key_and_value as fused_self_attention_query_mul_key_and_value, ) from oneflow.ops.loss_ops import ctc_greedy_decoder +from oneflow.nn import utils from . import functional diff --git a/python/oneflow/nn/graph.py b/python/oneflow/nn/graph.py index d0456ac68..351bca0b7 100644 --- a/python/oneflow/nn/graph.py +++ b/python/oneflow/nn/graph.py @@ -29,7 +29,7 @@ from oneflow.nn.graph_block import Block, BlockType from oneflow.nn.graph_optimizer import OptimizerConfig from oneflow.nn.module import Module from oneflow.nn.optimizer.optimizer import Optimizer -from oneflow.nn.utils import add_indent +from oneflow.nn.util import add_indent class Graph(object): diff --git a/python/oneflow/nn/graph_block.py b/python/oneflow/nn/graph_block.py index 35634cd09..418573ce2 100644 --- a/python/oneflow/nn/graph_block.py +++ b/python/oneflow/nn/graph_block.py @@ -21,7 +21,7 @@ import oneflow.framework.graph_build_util as graph_build_util from oneflow.framework.tensor import Tensor from oneflow.nn.module import Module from oneflow.nn.parameter import Parameter -from oneflow.nn.utils import add_indent +from oneflow.nn.util import add_indent class BlockType: diff --git a/python/oneflow/nn/utils.py b/python/oneflow/nn/util.py similarity index 100% rename from python/oneflow/nn/utils.py rename to python/oneflow/nn/util.py diff --git a/python/oneflow/nn/utils/__init__.py b/python/oneflow/nn/utils/__init__.py new file mode 100644 index 000000000..818a831cd --- /dev/null +++ b/python/oneflow/nn/utils/__init__.py @@ -0,0 +1,16 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from oneflow.nn.utils.clip_grad import clip_grad_norm_ diff --git a/python/oneflow/nn/utils/clip_grad.py b/python/oneflow/nn/utils/clip_grad.py new file mode 100644 index 000000000..6661cbc2a --- /dev/null +++ b/python/oneflow/nn/utils/clip_grad.py @@ -0,0 +1,138 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import warnings +from typing import Union, Iterable + +import numpy as np +import oneflow as flow + +from oneflow.framework.tensor import Tensor +from oneflow.framework.tensor import register_tensor_op +from oneflow.nn.module import Module + + +_tensor_or_tensors = Union[Tensor, Iterable[Tensor]] + + +def clip_grad_norm_( + parameters: _tensor_or_tensors, + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = True, +) -> Tensor: + r"""Clips gradient norm of an iterable of parameters. + The norm is computed over all gradients together, as if they were + concatenated into a single vector. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:``parameters`` is ``nan``, + ``inf``, or ``-inf``. Default: True + + Returns: + Parameters after cliping gradient norm + Total norm of the parameters (viewed as a single vector). + + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> import numpy as np + >>> x1 = flow.Tensor(np.array([[2, 3, 4], [1.5, 2.6, 3.7]]).astype(np.float32), requires_grad=True) + >>> m1 = flow.nn.ReLU() + >>> out1 = m1(x1) + >>> out1 = out1.sum() + >>> out1.backward() + >>> norm1 = flow.nn.utils.clip_grad_norm_(x1, 0.6, 1.0) + >>> norm1 + tensor([6.], dtype=oneflow.float32) + >>> x1.grad + tensor([[0.1, 0.1, 0.1], + [0.1, 0.1, 0.1]], dtype=oneflow.float32) + >>> x2 = flow.Tensor(np.array([[-2, -3, -4], [2.5, 0, 3.2]]).astype(np.float32), device='cuda:0', requires_grad=True) + >>> out2 = flow.atan(x2) + >>> out2 = out2.sum() + >>> out2.backward() + >>> norm2 = flow.nn.utils.clip_grad_norm_(x2, 0.5) + >>> norm2 + tensor([1.0394], device='cuda:0', dtype=oneflow.float32) + >>> x2.grad + tensor([[0.0962, 0.0481, 0.0283], + [0.0663, 0.481 , 0.0428]], device='cuda:0', dtype=oneflow.float32) + + """ + + if isinstance(parameters, (Tensor, flow._oneflow_internal.Tensor)): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(parameters) == 0: + return flow.tensor(0.0) + device = parameters[0].grad.device + if norm_type == float("inf"): + norms = [p.grad.detach().abs().max().to(device) for p in parameters] + total_norm = norms[0] if len(norms) == 1 else flow.max(flow.stack(norms)) + elif norm_type == float("-inf"): + norms = [p.grad.detach().abs().min().to(device) for p in parameters] + total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms)) + else: + total_norm = flow.linalg.vector_norm( + flow.stack( + [ + flow.linalg.vector_norm(p.grad.detach(), norm_type).to(device) + for p in parameters + ] + ), + norm_type, + ) + if np.isnan(total_norm.numpy()) or np.isinf(total_norm.numpy()): + if error_if_nonfinite: + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + else: + warnings.warn( + "Non-finite norm encountered in flow.nn.utils.clip_grad_norm_; continuing anyway. " + "Note that the default behavior will change in a future release to error out " + "if a non-finite total norm is encountered. At that point, setting " + "error_if_nonfinite=false will be required to retain the old behavior.", + FutureWarning, + stacklevel=2, + ) + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1: + for p in parameters: + # TODO: Switch to inplace multiply in future + p.grad[:] = p.grad.detach().mul(clip_coef.to(p.grad.device)) + return total_norm + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/python/oneflow/test/modules/test_clip_grad.py b/python/oneflow/test/modules/test_clip_grad.py new file mode 100644 index 000000000..9dc8d3dd4 --- /dev/null +++ b/python/oneflow/test/modules/test_clip_grad.py @@ -0,0 +1,81 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np + +import oneflow as flow +from test_util import GenArgList + + +def _clip_grad_norm_np(input, max_norm, norm_type): + np_out = np.maximum(0, input) + np_grad = np.array(np_out > 0, dtype=np.float32) + max_norm = float(max_norm) + norm_type = float(norm_type) + input = [input] + if len(input) == 0: + return 0, 0 + if norm_type == float("inf"): + total_norm = np.max(np.abs(np_grad)) + if norm_type == float("-inf"): + total_norm = np.min(np.abs(np_grad)) + elif norm_type == 0: + total_norm = np.sum(np.stack([np.sum(np_grad != 0)]) != 0) + else: + total_norm = np_grad + for i in range(np_grad.ndim, 0, -1): + total_norm = np.linalg.norm(total_norm, norm_type, axis=i - 1) + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1: + np_grad = np.dot(np_grad, clip_coef) + return total_norm, np_grad + + +def _test_clip_grad_norm_impl(test_case, shape, device, max_norm, norm_type): + np_input = np.random.rand(*shape) + of_input = flow.Tensor( + np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True + ) + m = flow.nn.ReLU() + of_out = m(of_input) + of_out = of_out.sum() + of_out.backward() + of_total_norm = flow.nn.utils.clip_grad_norm_(of_input, max_norm, norm_type) + np_total_norm, np_grad = _clip_grad_norm_np(np_input, max_norm, norm_type) + test_case.assertTrue( + np.allclose(of_total_norm.numpy(), np_total_norm, 1e-4, 1e-4, equal_nan=True) + ) + test_case.assertTrue( + np.allclose(of_input.grad.numpy(), np_grad, 1e-4, 1e-4, equal_nan=True) + ) + + +@flow.unittest.skip_unless_1n1d() +class TestAcosh(flow.unittest.TestCase): + def test_acosh(test_case): + arg_dict = OrderedDict() + arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)] + arg_dict["device"] = ["cpu", "cuda"] + arg_dict["max_norm"] = [0, 0.5, 1.0] + arg_dict["norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5] + for arg in GenArgList(arg_dict): + _test_clip_grad_norm_impl(test_case, *arg) + + +if __name__ == "__main__": + unittest.main() -- GitLab