Skip to content
Snippets Groups Projects
Unverified Commit 5c7bab46 authored by Yinggang Wang's avatar Yinggang Wang Committed by GitHub
Browse files

feat(SGD): support weight_dacay(l2 actually) (#5587)


Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 5341044e
No related branches found
No related tags found
No related merge requests found
...@@ -30,7 +30,8 @@ from .optimizer import Optimizer, ParamGroup ...@@ -30,7 +30,8 @@ from .optimizer import Optimizer, ParamGroup
class SGD(Optimizer): class SGD(Optimizer):
r"""Implements SGD algorithm. r"""Implements SGD algorithm.
This algorithm takes a random sample’s gradient as an approximate estimate of the overall gradient in small batch gradient descent. This algorithm takes a random sample’s gradient as an approximate estimate of
the overall gradient in small batch gradient descent.
When the momentum = 0, the equation of parameters updating is: When the momentum = 0, the equation of parameters updating is:
...@@ -42,15 +43,16 @@ class SGD(Optimizer): ...@@ -42,15 +43,16 @@ class SGD(Optimizer):
.. math:: .. math::
& V_t = \beta * V_{t-1} + learning\_rate * g_t & V_t = \beta * V_{t-1} - learning\_rate * (g_t * scale + param_{old} * weight\_decay)
& param_{new} = param_{old} - V_t & param_{new} = param_{old} + V_t
Args: Args:
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts defining
parameter groups parameter groups
lr (float, optional): learning rate (default: 1e-3) lr (float, optional): learning rate (default: 1e-3)
momentum (float, optional): Momentum factor (default: 0.0) momentum (float, optional): Momentum factor (default: 0.0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
scale (float, optional): the scale factor of loss (default: 1.0) scale (float, optional): the scale factor of loss (default: 1.0)
""" """
...@@ -60,16 +62,19 @@ class SGD(Optimizer): ...@@ -60,16 +62,19 @@ class SGD(Optimizer):
parameters: Union[Iterator[Parameter], List[Dict]], parameters: Union[Iterator[Parameter], List[Dict]],
lr: float = 1e-3, lr: float = 1e-3,
momentum: float = 0.0, momentum: float = 0.0,
weight_decay: float = 0.0, # SGD's weight_decay actually does L2 Normalize
scale: float = 1.0, scale: float = 1.0,
): ):
super().__init__() super().__init__()
assert lr >= 0.0, f"Invalid learning rate: {lr}" assert lr >= 0.0, f"Invalid learning rate: {lr}"
assert momentum >= 0.0, f"Invalid momentum: {momentum}" assert momentum >= 0.0, f"Invalid momentum: {momentum}"
assert scale >= 0.0, f"Invalid scale factor: {scale}" assert scale >= 0.0, f"Invalid scale factor: {scale}"
assert weight_decay >= 0.0, f"Invalid weight_decay: {weight_decay}"
self._default_options["lr"] = lr self._default_options["lr"] = lr
self._default_options["scale"] = scale self._default_options["scale"] = scale
self._default_options["momentum"] = momentum self._default_options["momentum"] = momentum
self._default_options["weight_decay"] = weight_decay
# Add parameters # Add parameters
if isinstance(parameters, collections.abc.Iterator): if isinstance(parameters, collections.abc.Iterator):
...@@ -93,7 +98,6 @@ class SGD(Optimizer): ...@@ -93,7 +98,6 @@ class SGD(Optimizer):
.Input("model_diff") .Input("model_diff")
.Input("momentum") .Input("momentum")
.Attr("l1", 0.0) .Attr("l1", 0.0)
.Attr("l2", 0.0)
.Attr("weight_decay", 0.0) .Attr("weight_decay", 0.0)
.Build() .Build()
) )
...@@ -103,7 +107,6 @@ class SGD(Optimizer): ...@@ -103,7 +107,6 @@ class SGD(Optimizer):
.Input("model_diff") .Input("model_diff")
.Attr("weight_decay", 0.0) .Attr("weight_decay", 0.0)
.Attr("l1", 0.0) .Attr("l1", 0.0)
.Attr("l2", 0.0)
.Build() .Build()
) )
...@@ -115,21 +118,24 @@ class SGD(Optimizer): ...@@ -115,21 +118,24 @@ class SGD(Optimizer):
for param_group in self.param_groups: for param_group in self.param_groups:
lr = param_group["lr"] lr = param_group["lr"]
scale = param_group["scale"]
l2 = param_group["weight_decay"]
for param in param_group.parameters: for param in param_group.parameters:
if param.grad is None: if param.grad is None:
continue continue
if param_group["momentum"] == 0.0: if param_group["momentum"] == 0.0:
scale = param_group["scale"] self._sgd(
self._sgd(param, param.grad, learning_rate_val=lr, scale=scale) param, param.grad, learning_rate_val=lr, l2=l2, scale=scale
)
else: else:
momentum_buf = self._state[param]["momentum_buf"] momentum_buf = self._state[param]["momentum_buf"]
scale = param_group["scale"]
beta = param_group["momentum"] beta = param_group["momentum"]
self._momentum_sgd( self._momentum_sgd(
param, param,
param.grad, param.grad,
momentum_buf, momentum_buf,
learning_rate_val=lr, learning_rate_val=lr,
l2=l2,
scale=scale, scale=scale,
beta=beta, beta=beta,
) )
......
...@@ -19,12 +19,19 @@ from collections import OrderedDict ...@@ -19,12 +19,19 @@ from collections import OrderedDict
import numpy as np import numpy as np
import oneflow.experimental as flow import oneflow.experimental as flow
from test_util import GenArgList from test_util import GenArgDict
from oneflow.python.nn.parameter import Parameter from oneflow.python.nn.parameter import Parameter
def compare_with_numpy_sgd( def compare_with_numpy_sgd(
test_case, device, x_shape, scale, momentum, learning_rate, train_iters, test_case,
device,
x_shape,
scale,
momentum,
weight_decay,
learning_rate,
train_iters,
): ):
# generate random number sequences # generate random number sequences
random_grad_seq = [] random_grad_seq = []
...@@ -36,7 +43,15 @@ def compare_with_numpy_sgd( ...@@ -36,7 +43,15 @@ def compare_with_numpy_sgd(
def train_by_oneflow(): def train_by_oneflow():
x = Parameter(flow.Tensor(init_value, device=flow.device(device))) x = Parameter(flow.Tensor(init_value, device=flow.device(device)))
sgd = flow.optim.SGD( sgd = flow.optim.SGD(
[{"params": [x], "lr": learning_rate, "momentum": momentum, "scale": scale}] [
{
"params": [x],
"lr": learning_rate,
"momentum": momentum,
"scale": scale,
"weight_decay": weight_decay,
}
]
) )
def train_one_iter(grad): def train_one_iter(grad):
...@@ -57,8 +72,9 @@ def compare_with_numpy_sgd( ...@@ -57,8 +72,9 @@ def compare_with_numpy_sgd(
vt = np.zeros_like(x) vt = np.zeros_like(x)
def train_one_iter(grad): def train_one_iter(grad):
v = momentum * vt + learning_rate * scale * grad grad = grad * scale + weight_decay * x
param = x - v v = momentum * vt - learning_rate * grad
param = x + v
return param, v return param, v
for i in range(train_iters): for i in range(train_iters):
...@@ -80,10 +96,11 @@ class TestOptimizers(flow.unittest.TestCase): ...@@ -80,10 +96,11 @@ class TestOptimizers(flow.unittest.TestCase):
arg_dict["x_shape"] = [(10,)] arg_dict["x_shape"] = [(10,)]
arg_dict["scale"] = [1.0, 0.9] arg_dict["scale"] = [1.0, 0.9]
arg_dict["momentum"] = [0.0, 0.9] arg_dict["momentum"] = [0.0, 0.9]
arg_dict["learning_rate"] = [1] arg_dict["weight_decay"] = [0.0, 0.9]
arg_dict["learning_rate"] = [1, 0.1]
arg_dict["train_iters"] = [10] arg_dict["train_iters"] = [10]
for arg in GenArgList(arg_dict): for arg in GenArgDict(arg_dict):
compare_with_numpy_sgd(test_case, *arg) compare_with_numpy_sgd(test_case, **arg)
if __name__ == "__main__": if __name__ == "__main__":
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment