From f476d48d9934efaf2450a8070305a9dd5e906af8 Mon Sep 17 00:00:00 2001
From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Date: Mon, 10 May 2021 22:34:56 +0800
Subject: [PATCH] Add rmsprop optimizer (#4834)

* add rmsprop optimizer

* fix rmsprop optimizer bug

* fix rmsprop optimizer bug

* add rmsprop optimizer docs

* add rmsprop docs

* fix comment

* fix comment

* fix comment

* fix comment

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/python/nn/optimizer/rmsprop.py        | 187 ++++++++++++++++++
 .../python/test/modules/test_optim_rmsprop.py | 126 ++++++++++++
 2 files changed, 313 insertions(+)
 create mode 100644 oneflow/python/nn/optimizer/rmsprop.py
 create mode 100644 oneflow/python/test/modules/test_optim_rmsprop.py

diff --git a/oneflow/python/nn/optimizer/rmsprop.py b/oneflow/python/nn/optimizer/rmsprop.py
new file mode 100644
index 000000000..b0e9c9730
--- /dev/null
+++ b/oneflow/python/nn/optimizer/rmsprop.py
@@ -0,0 +1,187 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+from typing import List, Dict, Callable, Union, Iterator, Tuple
+from types import GeneratorType
+
+import oneflow as flow
+
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.nn.parameter import Parameter
+from oneflow.python.nn.optimizer.optimizer import ParamGroup, Optimizer
+
+
+@oneflow_export("optim.RMSprop")
+class RMSprop(Optimizer):
+    r"""Implements RMSprop algorithm.
+
+    oot Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning
+    rate method. The original slides proposed RMSProp: Slide 29 of
+    http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf .
+
+    The original equation is as follows:
+
+    .. math::
+
+        r(w, t) = \alpha r(w, t-1) + (1 - \alpha)(\nabla Q_{i}(w))^2
+
+        W = w - \frac{\eta} {\\sqrt{r(w,t) + \epsilon}} \nabla Q_{i}(w)
+
+    The first equation calculates moving average of the squared gradient for
+    each weight. Then dividing the gradient by :math:`sqrt{v(w,t)}`.
+    In some cases, adding a momentum term :math: `\beta` is beneficial.
+    In our implementation, Nesterov momentum is used:
+
+    .. math::
+
+        r(w, t) = \alpha r(w, t-1) + (1 - \alpha)(\nabla Q_{i}(w))^2
+
+        v(w, t) = \beta v(w, t-1) + \frac{\eta} {\\sqrt{r(w,t) +
+            \epsilon}} \nabla Q_{i}(w)
+  
+        w = w - v(w, t)
+
+    if centered is True:
+
+    .. math::
+
+        r(w, t) = \alpha r(w, t-1) + (1 - \alpha)(\nabla Q_{i}(w))^2
+
+        g(w, t) = \alpha g(w, t-1) + (1 - \alpha)\nabla Q_{i}(w)
+
+        v(w, t) = \beta v(w, t-1) + \frac{\eta} {\\sqrt{r(w,t) - (g(w, t))^2 +
+            \epsilon}} \nabla Q_{i}(w)
+        
+        w = w - v(w, t)
+    
+    where, :math:`\alpha` is a hyperparameter and typical values are 0.99, 0.95
+    and so on. :math:`\beta` is the momentum term. :math:`\epsilon` is a
+    smoothing term to avoid division by zero, usually set somewhere in range
+    from 1e-4 to 1e-8.
+
+    Args:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-2)
+        momentum (float, optional): momentum factor (default: 0, oneflow not support momenmtum > 0 now!)
+        alpha (float, optional): smoothing constant (default: 0.99)
+        eps (float, optional): term added to the denominator to improve
+            numerical stability (default: 1e-8)
+        centered (bool, optional) : if ``True``, compute the centered RMSProp,
+            the gradient is normalized by an estimation of its variance
+        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+    """
+
+    def __init__(
+        self,
+        parameters: Union[Iterator[Parameter], List[Dict]],
+        lr: float = 1e-3,
+        alpha: float = 0.99,
+        eps: float = 1e-8,
+        weight_decay: float = 0,
+        momentum: float = 0.0,
+        centered: bool = False,
+        scale: float = 1.0,
+    ):
+        super().__init__()
+        assert lr >= 0.0, f"Invalid learning rate: {lr}"
+        assert alpha >= 0.0, f"Invalid alpha value: {alpha}"
+        assert eps >= 0.0, f"Invalid epsilon value: {eps}"
+        assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}"
+        assert scale > 0.0, f"Invalid scale factor: {scale}"
+        assert momentum == 0.0, f"Not support momentum greater than zeros now!"
+
+        self._default_options["lr"] = lr
+        self._default_options["alpha"] = alpha
+        self._default_options["eps"] = eps
+        self._default_options["weight_decay"] = weight_decay
+        self._default_options["centered"] = centered
+        self._default_options["scale"] = scale
+
+        # Add parameters
+        if isinstance(parameters, GeneratorType):
+            self._param_groups.append(ParamGroup(parameters, self._default_options))
+        else:  # List[Dict]
+            for param in parameters:
+                self._param_groups.append(ParamGroup(param, self._default_options))
+
+        for param_group in self._param_groups:
+            for param in param_group.parameters:
+                assert param.is_leaf, "parameters must be leaf tensor"
+                self._state[param] = dict()
+                self._state[param]["square_avg"] = flow.tmp.zeros_like(param)
+                if "centered" in self._default_options:
+                    self._state[param]["grad_avg"] = flow.tmp.zeros_like(param)
+        if centered:
+            self._op = (
+                flow.builtin_op("rmsprop_update")
+                .Input("model")
+                .Input("model_diff")
+                .Input("learning_rate")
+                .Input("mean_square")
+                .Input("mean_gradient")
+                .Attr("scale", self._default_options["scale"])
+                .Attr("l1", 0.0)
+                .Attr("l2", 0.0)
+                .Attr("centered", self._default_options["centered"])
+                .Attr("epsilon", self._default_options["eps"])
+                .Attr("decay_rate", self._default_options["alpha"])
+                .Attr("weight_decay", self._default_options["weight_decay"])
+                .Build()
+            )
+        else:
+            self._op = (
+                flow.builtin_op("rmsprop_update")
+                .Input("model")
+                .Input("model_diff")
+                .Input("learning_rate")
+                .Input("mean_square")
+                .Attr("scale", self._default_options["scale"])
+                .Attr("l1", 0.0)
+                .Attr("l2", 0.0)
+                .Attr("centered", self._default_options["centered"])
+                .Attr("epsilon", self._default_options["eps"])
+                .Attr("decay_rate", self._default_options["alpha"])
+                .Attr("weight_decay", self._default_options["weight_decay"])
+                .Build()
+            )
+
+    def step(self, closure: Callable = None):
+        """Performs a single optimization step.
+
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for param_group in self._param_groups:
+            lr_tensor = flow.Tensor([param_group.options["lr"]])
+            for param in param_group.parameters:
+                if param.grad is None:
+                    continue
+                ms_tensor = self._state[param]["square_avg"]
+                if self._default_options["centered"]:
+                    mg_tensor = self._state[param]["grad_avg"]
+                    self._op(param, param.grad, lr_tensor, ms_tensor, mg_tensor)
+                else:
+                    self._op(param, param.grad, lr_tensor, ms_tensor)
+
+        self._state["step"] = self._state["step"] + 1
+
+        return loss
diff --git a/oneflow/python/test/modules/test_optim_rmsprop.py b/oneflow/python/test/modules/test_optim_rmsprop.py
new file mode 100644
index 000000000..6c9171395
--- /dev/null
+++ b/oneflow/python/test/modules/test_optim_rmsprop.py
@@ -0,0 +1,126 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+import oneflow as flow
+from test_util import GenArgList
+from oneflow.python.nn.parameter import Parameter
+
+
+def compare_with_numpy_rmsprop(
+    test_case,
+    x_shape,
+    scale,
+    learning_rate,
+    momentum,
+    train_iters,
+    alpha,
+    eps,
+    weight_decay,
+    centered,
+):
+    # generate random number sequences
+    random_grad_seq = []
+    for _ in range(train_iters):
+        random_grad_seq.append(np.random.uniform(size=x_shape).astype(np.float32))
+
+    init_value = np.random.uniform(size=x_shape).astype(np.float32)
+
+    def train_by_oneflow():
+        x = Parameter(flow.Tensor(init_value))
+        param_list = list()
+        param_list.append(x)
+        rmsprop = flow.optim.RMSprop(
+            [{"param": param_list}],
+            lr=learning_rate,
+            momentum=momentum,
+            scale=scale,
+            alpha=alpha,
+            eps=eps,
+            weight_decay=weight_decay,
+            centered=centered,
+        )
+
+        def train_one_iter(grad):
+            grad_tensor = flow.Tensor(grad, requires_grad=False)
+            loss = x * grad_tensor
+            loss = flow.sum(x * grad_tensor)
+            loss.backward()
+            rmsprop.step()
+            rmsprop.zero_grad()
+
+        for i in range(train_iters):
+            train_one_iter(random_grad_seq[i])
+        return x
+
+    def train_by_numpy():
+        x = init_value
+        r = np.zeros_like(x)
+        v = np.zeros_like(x)
+        g = np.zeros_like(x)
+
+        def train_one_iter(grad):
+            grad = grad * scale
+
+            if centered:
+                r_ = alpha * r + (1 - alpha) * grad * grad
+                g_ = alpha * g + (1 - alpha) * grad
+                v_ = momentum * v + learning_rate / np.sqrt(r_ - g_ * g_ + eps) * grad
+            else:
+                r_ = alpha * r + (1 - alpha) * grad * grad
+                g_ = g
+                v_ = momentum * v + learning_rate / np.sqrt(r_ + eps) * grad
+
+            param = x - v_
+
+            return param, r_, g_, v_
+
+        for i in range(train_iters):
+            x, r, g, v = train_one_iter(random_grad_seq[i])
+
+        return x
+
+    oneflow_res = train_by_oneflow().numpy()
+    numpy_res = train_by_numpy()
+    test_case.assertTrue(
+        np.allclose(oneflow_res.flatten(), numpy_res.flatten(), rtol=1e-4, atol=1e-4)
+    )
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestRMSProp(flow.unittest.TestCase):
+    def test_rmsprop(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["x_shape"] = [(10,)]
+        arg_dict["scale"] = [1.0, 0.9]
+        arg_dict["learning_rate"] = [1]
+        arg_dict["momentum"] = [0.0]  # TODO: support nonzero momentum
+        arg_dict["train_iters"] = [10]
+        arg_dict["alpha"] = [0.9, 0.99]
+        arg_dict["eps"] = [1e-8, 1e-5]
+        arg_dict["weight_decay"] = [0.1, 0.99]
+        arg_dict["centered"] = [False, True]
+        for arg in GenArgList(arg_dict):
+            compare_with_numpy_rmsprop(test_case, *arg)
+
+
+if __name__ == "__main__":
+    unittest.main()
-- 
GitLab