From cd6ffac6215df1231894d269e1c26f1eeb23b841 Mon Sep 17 00:00:00 2001
From: Yinggang Wang <wyg19970408@gmail.com>
Date: Mon, 10 May 2021 16:34:39 +0800
Subject: [PATCH] Refine optimizer (#4840)

* refactor(Optim): refine optimizer codes

* docs(SGD): add document for SGD

* docs(SGD): fix code

* test(Adam): fix test_optim_adam bug

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/python/nn/optimizer/adam.py           |  13 +-
 oneflow/python/nn/optimizer/optimizer.py      |  93 +-----------
 oneflow/python/nn/optimizer/sgd.py            | 132 ++++++++++++++++++
 .../test_optim_adam.py}                       |   8 +-
 .../{test_optimizers.py => test_optim_sgd.py} |   5 +-
 5 files changed, 142 insertions(+), 109 deletions(-)
 create mode 100644 oneflow/python/nn/optimizer/sgd.py
 rename oneflow/python/test/{optimizer/test_adam.py => modules/test_optim_adam.py} (93%)
 rename oneflow/python/test/modules/{test_optimizers.py => test_optim_sgd.py} (95%)

diff --git a/oneflow/python/nn/optimizer/adam.py b/oneflow/python/nn/optimizer/adam.py
index 8a7d23ddd..b1e8417e4 100644
--- a/oneflow/python/nn/optimizer/adam.py
+++ b/oneflow/python/nn/optimizer/adam.py
@@ -21,8 +21,7 @@ import oneflow as flow
 
 from oneflow.python.oneflow_export import oneflow_export
 from oneflow.python.nn.parameter import Parameter
-from oneflow.python.nn.optimizer.optimizer import ParamGroup
-from oneflow.python.nn.optimizer.optimizer import Optimizer
+from oneflow.python.nn.optimizer.optimizer import Optimizer, ParamGroup
 
 
 @oneflow_export("optim.Adam")
@@ -106,14 +105,8 @@ class Adam(Optimizer):
             for param in param_group.parameters:
                 assert param.is_leaf, "parameters must be leaf tensor"
                 self._state[param] = dict()
-                self._state[param]["exp_avg"] = flow.tmp.zeros(
-                    # TODO: zeros module support flow.Size parameter
-                    tuple(param.shape)
-                )
-                self._state[param]["exp_avg_sq"] = flow.tmp.zeros(
-                    # TODO: zeros module support flow.Size parameter
-                    tuple(param.shape)
-                )
+                self._state[param]["exp_avg"] = flow.tmp.zeros_like(param)
+                self._state[param]["exp_avg_sq"] = flow.tmp.zeros_like(param)
 
         self._op = (
             flow.builtin_op("adam_update")
diff --git a/oneflow/python/nn/optimizer/optimizer.py b/oneflow/python/nn/optimizer/optimizer.py
index 8e04fcfe9..9fb16b1ac 100644
--- a/oneflow/python/nn/optimizer/optimizer.py
+++ b/oneflow/python/nn/optimizer/optimizer.py
@@ -14,12 +14,9 @@ See the License for the specific language governing permissions and
 limitations under the License.
 """
 
-from typing import List, Dict, Callable, Union, Any, Iterator
+from typing import Dict, Callable, Union, Any, Iterator
 from types import GeneratorType
 
-import oneflow as flow
-
-from oneflow.python.oneflow_export import oneflow_export
 from oneflow.python.nn.parameter import Parameter
 from oneflow.python.framework.tensor import Tensor
 
@@ -81,91 +78,3 @@ class Optimizer(object):
                 else:
                     param.grad.fill_(0)
                     # param.grad.zeros_()
-
-
-@oneflow_export("optim.SGD")
-class SGD(Optimizer):
-    r"""
-    TODO
-    """
-
-    def __init__(
-        self,
-        parameters: Union[Iterator[Parameter], List[Dict]],
-        lr: float,
-        momentum: float = 0.0,
-        scale: float = 1.0,
-    ):
-        super().__init__()
-        assert lr >= 0.0, f"Invalid learning rate: {lr}"
-        assert momentum >= 0.0, f"Invalid momentum: {momentum}"
-        assert scale >= 0.0, f"Invalid scale factor: {scale}"
-
-        self._default_options["lr"] = lr
-        self._default_options["scale"] = scale
-        if momentum != 0.0:
-            self._default_options["momentum"] = momentum
-
-        # Add parameters
-        if isinstance(parameters, GeneratorType):
-            self._param_groups.append(ParamGroup(parameters, self._default_options))
-        else:  # List[Dict]
-            for param in parameters:
-                self._param_groups.append(ParamGroup(param, self._default_options))
-
-        for param_group in self._param_groups:
-            for param in param_group.parameters:
-                assert param.is_leaf, "parameters must be leaf tensor"
-                self._state[param] = dict()
-                if "momentum" in self._default_options:
-                    self._state[param]["momentum_buf"] = flow.tmp.zeros(
-                        # TODO: zeros module support flow.Size parameter
-                        tuple(param.shape)
-                    )
-
-        if "momentum" in self._default_options.keys():
-            self._op = (
-                flow.builtin_op("momentum_update")
-                .Input("model")
-                .Input("model_diff")
-                .Input("learning_rate")
-                .Input("momentum")
-                .Attr("scale", self._default_options["scale"])
-                .Attr("l1", 0.0)
-                .Attr("l2", 0.0)
-                .Attr("beta", self._default_options["momentum"])
-                .Attr("weight_decay", 0.0)
-                .Build()
-            )
-        else:
-            self._op = (
-                flow.builtin_op("sgd_update")
-                .Input("model")
-                .Input("model_diff")
-                .Input("learning_rate")
-                .Attr("scale", self._default_options["scale"])
-                .Attr("weight_decay", 0.0)
-                .Attr("l1", 0.0)
-                .Attr("l2", 0.0)
-                .Build()
-            )
-
-    def step(self, closure: Callable = None):
-        with flow.no_grad():
-            loss = None
-            if closure is not None:
-                loss = closure()
-
-            for param_group in self._param_groups:
-                lr_tensor = flow.Tensor([param_group.options["lr"]])
-                for param in param_group.parameters:
-                    if param.grad is None:
-                        continue
-                    if "momentum" in self._default_options:
-                        momentum_buf = self._state[param]["momentum_buf"]
-                        self._op(param, param.grad, lr_tensor, momentum_buf)
-                    else:
-                        self._op(param, param.grad, lr_tensor)
-
-            self._state["step"] = self._state["step"] + 1
-            return loss
diff --git a/oneflow/python/nn/optimizer/sgd.py b/oneflow/python/nn/optimizer/sgd.py
new file mode 100644
index 000000000..be18970ec
--- /dev/null
+++ b/oneflow/python/nn/optimizer/sgd.py
@@ -0,0 +1,132 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+from typing import List, Dict, Callable, Union, Iterator
+from types import GeneratorType
+
+import oneflow as flow
+
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.nn.parameter import Parameter
+from .optimizer import Optimizer, ParamGroup
+
+
+@oneflow_export("optim.SGD")
+class SGD(Optimizer):
+    r"""Implements SGD algorithm.
+
+    This algorithm takes a random sample鈥檚 gradient as an approximate estimate of the overall gradient in small batch gradient descent.
+
+    When the momentum = 0, the equation of parameters updating is:
+
+        .. math::
+
+            param_{new} = param_{old} - learning\_rate * grad
+
+    With momentum, the equation of parameters updating is:
+
+        .. math::
+
+            & V_t = \beta * V_{t-1} + learning\_rate * g_t
+
+            & param_{new} = param_{old} - V_t
+
+    Args:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-3)
+        momentum (float, optional): Momentum factor (default: 0.0)
+        scale (float, optional): the scale factor of loss (default: 1.0)
+
+    """
+
+    def __init__(
+        self,
+        parameters: Union[Iterator[Parameter], List[Dict]],
+        lr: float = 1e-3,
+        momentum: float = 0.0,
+        scale: float = 1.0,
+    ):
+        super().__init__()
+        assert lr >= 0.0, f"Invalid learning rate: {lr}"
+        assert momentum >= 0.0, f"Invalid momentum: {momentum}"
+        assert scale >= 0.0, f"Invalid scale factor: {scale}"
+
+        self._default_options["lr"] = lr
+        self._default_options["scale"] = scale
+        if momentum != 0.0:
+            self._default_options["momentum"] = momentum
+
+        # Add parameters
+        if isinstance(parameters, GeneratorType):
+            self._param_groups.append(ParamGroup(parameters, self._default_options))
+        else:  # List[Dict]
+            for param in parameters:
+                self._param_groups.append(ParamGroup(param, self._default_options))
+
+        for param_group in self._param_groups:
+            for param in param_group.parameters:
+                assert param.is_leaf, "parameters must be leaf tensor"
+                self._state[param] = dict()
+                if "momentum" in self._default_options:
+                    self._state[param]["momentum_buf"] = flow.tmp.zeros_like(param)
+
+        if "momentum" in self._default_options.keys():
+            self._op = (
+                flow.builtin_op("momentum_update")
+                .Input("model")
+                .Input("model_diff")
+                .Input("learning_rate")
+                .Input("momentum")
+                .Attr("scale", self._default_options["scale"])
+                .Attr("l1", 0.0)
+                .Attr("l2", 0.0)
+                .Attr("beta", self._default_options["momentum"])
+                .Attr("weight_decay", 0.0)
+                .Build()
+            )
+        else:
+            self._op = (
+                flow.builtin_op("sgd_update")
+                .Input("model")
+                .Input("model_diff")
+                .Input("learning_rate")
+                .Attr("scale", self._default_options["scale"])
+                .Attr("weight_decay", 0.0)
+                .Attr("l1", 0.0)
+                .Attr("l2", 0.0)
+                .Build()
+            )
+
+    def step(self, closure: Callable = None):
+        with flow.no_grad():
+            loss = None
+            if closure is not None:
+                loss = closure()
+
+            for param_group in self._param_groups:
+                lr_tensor = flow.Tensor([param_group.options["lr"]])
+                for param in param_group.parameters:
+                    if param.grad is None:
+                        continue
+                    if "momentum" in self._default_options:
+                        momentum_buf = self._state[param]["momentum_buf"]
+                        self._op(param, param.grad, lr_tensor, momentum_buf)
+                    else:
+                        self._op(param, param.grad, lr_tensor)
+
+            self._state["step"] = self._state["step"] + 1
+            return loss
diff --git a/oneflow/python/test/optimizer/test_adam.py b/oneflow/python/test/modules/test_optim_adam.py
similarity index 93%
rename from oneflow/python/test/optimizer/test_adam.py
rename to oneflow/python/test/modules/test_optim_adam.py
index cbfb2a634..789ff85ac 100644
--- a/oneflow/python/test/optimizer/test_adam.py
+++ b/oneflow/python/test/modules/test_optim_adam.py
@@ -18,7 +18,8 @@ from collections import OrderedDict
 
 import numpy as np
 import oneflow as flow
-from oneflow.python.test.modules.test_util import GenArgList
+
+from test_util import GenArgList
 from oneflow.python.nn.parameter import Parameter
 
 
@@ -41,9 +42,8 @@ def compare_with_numpy_adam(
         def train_one_iter(grad):
             grad_tensor = flow.Tensor(grad, requires_grad=False)
             loss = x * grad_tensor
-            # BUG: loss = flow.sum(x * grad_tensor)
-            grad = flow.Tensor(np.ones(list(loss.shape)))
-            loss.backward(grad)
+            loss = flow.sum(x * grad_tensor)
+            loss.backward()
             adam.step()
             adam.zero_grad()
 
diff --git a/oneflow/python/test/modules/test_optimizers.py b/oneflow/python/test/modules/test_optim_sgd.py
similarity index 95%
rename from oneflow/python/test/modules/test_optimizers.py
rename to oneflow/python/test/modules/test_optim_sgd.py
index 228240b16..268f7f967 100644
--- a/oneflow/python/test/modules/test_optimizers.py
+++ b/oneflow/python/test/modules/test_optim_sgd.py
@@ -44,9 +44,8 @@ def compare_with_numpy_sgd(
         def train_one_iter(grad):
             grad_tensor = flow.Tensor(grad, requires_grad=False)
             loss = x * grad_tensor
-            # BUG: loss = flow.sum(x * grad_tensor)
-            grad = flow.Tensor(np.ones(list(loss.shape)))
-            loss.backward(grad)
+            loss = flow.sum(x * grad_tensor)
+            loss.backward()
             sgd.step()
             sgd.zero_grad()
 
-- 
GitLab