From 28bb85e5036f19b5fa1b41b1282a5a0683d5c8d4 Mon Sep 17 00:00:00 2001
From: zjlablichenyang <84563719+zjlablichenyang@users.noreply.github.com>
Date: Wed, 28 Jul 2021 13:39:03 +0800
Subject: [PATCH] Add clip_grad_norm (#5299)

* Adapt to new structure

* auto format by CI

* fix bug

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
---
 python/oneflow/nn/__init__.py                 |   1 +
 python/oneflow/nn/graph.py                    |   2 +-
 python/oneflow/nn/graph_block.py              |   2 +-
 python/oneflow/nn/{utils.py => util.py}       |   0
 python/oneflow/nn/utils/__init__.py           |  16 ++
 python/oneflow/nn/utils/clip_grad.py          | 138 ++++++++++++++++++
 python/oneflow/test/modules/test_clip_grad.py |  81 ++++++++++
 7 files changed, 238 insertions(+), 2 deletions(-)
 rename python/oneflow/nn/{utils.py => util.py} (100%)
 create mode 100644 python/oneflow/nn/utils/__init__.py
 create mode 100644 python/oneflow/nn/utils/clip_grad.py
 create mode 100644 python/oneflow/test/modules/test_clip_grad.py

diff --git a/python/oneflow/nn/__init__.py b/python/oneflow/nn/__init__.py
index 890810e61..0993b8e4e 100644
--- a/python/oneflow/nn/__init__.py
+++ b/python/oneflow/nn/__init__.py
@@ -106,5 +106,6 @@ from oneflow.ops.domain_ops import (
     api_fused_self_attention_query_mul_key_and_value as fused_self_attention_query_mul_key_and_value,
 )
 from oneflow.ops.loss_ops import ctc_greedy_decoder
+from oneflow.nn import utils
 
 from . import functional
diff --git a/python/oneflow/nn/graph.py b/python/oneflow/nn/graph.py
index d0456ac68..351bca0b7 100644
--- a/python/oneflow/nn/graph.py
+++ b/python/oneflow/nn/graph.py
@@ -29,7 +29,7 @@ from oneflow.nn.graph_block import Block, BlockType
 from oneflow.nn.graph_optimizer import OptimizerConfig
 from oneflow.nn.module import Module
 from oneflow.nn.optimizer.optimizer import Optimizer
-from oneflow.nn.utils import add_indent
+from oneflow.nn.util import add_indent
 
 
 class Graph(object):
diff --git a/python/oneflow/nn/graph_block.py b/python/oneflow/nn/graph_block.py
index 35634cd09..418573ce2 100644
--- a/python/oneflow/nn/graph_block.py
+++ b/python/oneflow/nn/graph_block.py
@@ -21,7 +21,7 @@ import oneflow.framework.graph_build_util as graph_build_util
 from oneflow.framework.tensor import Tensor
 from oneflow.nn.module import Module
 from oneflow.nn.parameter import Parameter
-from oneflow.nn.utils import add_indent
+from oneflow.nn.util import add_indent
 
 
 class BlockType:
diff --git a/python/oneflow/nn/utils.py b/python/oneflow/nn/util.py
similarity index 100%
rename from python/oneflow/nn/utils.py
rename to python/oneflow/nn/util.py
diff --git a/python/oneflow/nn/utils/__init__.py b/python/oneflow/nn/utils/__init__.py
new file mode 100644
index 000000000..818a831cd
--- /dev/null
+++ b/python/oneflow/nn/utils/__init__.py
@@ -0,0 +1,16 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from oneflow.nn.utils.clip_grad import clip_grad_norm_
diff --git a/python/oneflow/nn/utils/clip_grad.py b/python/oneflow/nn/utils/clip_grad.py
new file mode 100644
index 000000000..6661cbc2a
--- /dev/null
+++ b/python/oneflow/nn/utils/clip_grad.py
@@ -0,0 +1,138 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import warnings
+from typing import Union, Iterable
+
+import numpy as np
+import oneflow as flow
+
+from oneflow.framework.tensor import Tensor
+from oneflow.framework.tensor import register_tensor_op
+from oneflow.nn.module import Module
+
+
+_tensor_or_tensors = Union[Tensor, Iterable[Tensor]]
+
+
+def clip_grad_norm_(
+    parameters: _tensor_or_tensors,
+    max_norm: float,
+    norm_type: float = 2.0,
+    error_if_nonfinite: bool = True,
+) -> Tensor:
+    r"""Clips gradient norm of an iterable of parameters.
+    The norm is computed over all gradients together, as if they were
+    concatenated into a single vector.
+
+    Args:
+        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+            single Tensor that will have gradients normalized
+        max_norm (float or int): max norm of the gradients
+        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+            infinity norm.
+        error_if_nonfinite (bool): if True, an error is thrown if the total
+            norm of the gradients from :attr:``parameters`` is ``nan``,
+            ``inf``, or ``-inf``. Default: True
+
+    Returns:
+        Parameters after cliping gradient norm
+        Total norm of the parameters (viewed as a single vector).
+    
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import oneflow as flow
+        >>> import numpy as np
+        >>> x1 = flow.Tensor(np.array([[2, 3, 4], [1.5, 2.6, 3.7]]).astype(np.float32), requires_grad=True)
+        >>> m1 = flow.nn.ReLU()
+        >>> out1 = m1(x1)
+        >>> out1 = out1.sum()
+        >>> out1.backward()
+        >>> norm1 = flow.nn.utils.clip_grad_norm_(x1, 0.6, 1.0)
+        >>> norm1
+        tensor([6.], dtype=oneflow.float32)
+        >>> x1.grad
+        tensor([[0.1, 0.1, 0.1],
+                [0.1, 0.1, 0.1]], dtype=oneflow.float32)
+        >>> x2 = flow.Tensor(np.array([[-2, -3, -4], [2.5, 0, 3.2]]).astype(np.float32), device='cuda:0', requires_grad=True)
+        >>> out2 = flow.atan(x2)
+        >>> out2 = out2.sum()
+        >>> out2.backward()
+        >>> norm2 = flow.nn.utils.clip_grad_norm_(x2, 0.5)
+        >>> norm2
+        tensor([1.0394], device='cuda:0', dtype=oneflow.float32)
+        >>> x2.grad
+        tensor([[0.0962, 0.0481, 0.0283],
+                [0.0663, 0.481 , 0.0428]], device='cuda:0', dtype=oneflow.float32)
+
+    """
+
+    if isinstance(parameters, (Tensor, flow._oneflow_internal.Tensor)):
+        parameters = [parameters]
+    parameters = [p for p in parameters if p.grad is not None]
+    max_norm = float(max_norm)
+    norm_type = float(norm_type)
+    if len(parameters) == 0:
+        return flow.tensor(0.0)
+    device = parameters[0].grad.device
+    if norm_type == float("inf"):
+        norms = [p.grad.detach().abs().max().to(device) for p in parameters]
+        total_norm = norms[0] if len(norms) == 1 else flow.max(flow.stack(norms))
+    elif norm_type == float("-inf"):
+        norms = [p.grad.detach().abs().min().to(device) for p in parameters]
+        total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms))
+    else:
+        total_norm = flow.linalg.vector_norm(
+            flow.stack(
+                [
+                    flow.linalg.vector_norm(p.grad.detach(), norm_type).to(device)
+                    for p in parameters
+                ]
+            ),
+            norm_type,
+        )
+    if np.isnan(total_norm.numpy()) or np.isinf(total_norm.numpy()):
+        if error_if_nonfinite:
+            raise RuntimeError(
+                f"The total norm of order {norm_type} for gradients from "
+                "`parameters` is non-finite, so it cannot be clipped. To disable "
+                "this error and scale the gradients by the non-finite norm anyway, "
+                "set `error_if_nonfinite=False`"
+            )
+        else:
+            warnings.warn(
+                "Non-finite norm encountered in flow.nn.utils.clip_grad_norm_; continuing anyway. "
+                "Note that the default behavior will change in a future release to error out "
+                "if a non-finite total norm is encountered. At that point, setting "
+                "error_if_nonfinite=false will be required to retain the old behavior.",
+                FutureWarning,
+                stacklevel=2,
+            )
+    clip_coef = max_norm / (total_norm + 1e-6)
+    if clip_coef < 1:
+        for p in parameters:
+            # TODO: Switch to inplace multiply in future
+            p.grad[:] = p.grad.detach().mul(clip_coef.to(p.grad.device))
+    return total_norm
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod(raise_on_error=True)
diff --git a/python/oneflow/test/modules/test_clip_grad.py b/python/oneflow/test/modules/test_clip_grad.py
new file mode 100644
index 000000000..9dc8d3dd4
--- /dev/null
+++ b/python/oneflow/test/modules/test_clip_grad.py
@@ -0,0 +1,81 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+
+import oneflow as flow
+from test_util import GenArgList
+
+
+def _clip_grad_norm_np(input, max_norm, norm_type):
+    np_out = np.maximum(0, input)
+    np_grad = np.array(np_out > 0, dtype=np.float32)
+    max_norm = float(max_norm)
+    norm_type = float(norm_type)
+    input = [input]
+    if len(input) == 0:
+        return 0, 0
+    if norm_type == float("inf"):
+        total_norm = np.max(np.abs(np_grad))
+    if norm_type == float("-inf"):
+        total_norm = np.min(np.abs(np_grad))
+    elif norm_type == 0:
+        total_norm = np.sum(np.stack([np.sum(np_grad != 0)]) != 0)
+    else:
+        total_norm = np_grad
+        for i in range(np_grad.ndim, 0, -1):
+            total_norm = np.linalg.norm(total_norm, norm_type, axis=i - 1)
+    clip_coef = max_norm / (total_norm + 1e-6)
+    if clip_coef < 1:
+        np_grad = np.dot(np_grad, clip_coef)
+    return total_norm, np_grad
+
+
+def _test_clip_grad_norm_impl(test_case, shape, device, max_norm, norm_type):
+    np_input = np.random.rand(*shape)
+    of_input = flow.Tensor(
+        np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
+    )
+    m = flow.nn.ReLU()
+    of_out = m(of_input)
+    of_out = of_out.sum()
+    of_out.backward()
+    of_total_norm = flow.nn.utils.clip_grad_norm_(of_input, max_norm, norm_type)
+    np_total_norm, np_grad = _clip_grad_norm_np(np_input, max_norm, norm_type)
+    test_case.assertTrue(
+        np.allclose(of_total_norm.numpy(), np_total_norm, 1e-4, 1e-4, equal_nan=True)
+    )
+    test_case.assertTrue(
+        np.allclose(of_input.grad.numpy(), np_grad, 1e-4, 1e-4, equal_nan=True)
+    )
+
+
+@flow.unittest.skip_unless_1n1d()
+class TestAcosh(flow.unittest.TestCase):
+    def test_acosh(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 4, 5, 6)]
+        arg_dict["device"] = ["cpu", "cuda"]
+        arg_dict["max_norm"] = [0, 0.5, 1.0]
+        arg_dict["norm_type"] = ["inf", "-inf", 0.0, 1.0, 2.0, 3.5]
+        for arg in GenArgList(arg_dict):
+            _test_clip_grad_norm_impl(test_case, *arg)
+
+
+if __name__ == "__main__":
+    unittest.main()
-- 
GitLab