From a26f7080b866c4be0a606a6c82dffa7975233f32 Mon Sep 17 00:00:00 2001 From: Yinggang Wang <wyg19970408@gmail.com> Date: Fri, 14 May 2021 11:39:10 +0800 Subject: [PATCH] Support custom parameters for optimizer (#4881) * feat(Optim): support custom parameters for optimizer * feat(Adam): adam support custom parameters * feat(Adamw): adamw support custom parameters * feat(RMSprop): rmsprop support custom parameters * style(Optim): refine adam and adamw Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> --- oneflow/core/framework/attr_map.cpp | 2 +- oneflow/python/nn/optimizer/adam.py | 53 ++++---- oneflow/python/nn/optimizer/adamw.py | 50 ++++---- oneflow/python/nn/optimizer/optimizer.py | 4 +- oneflow/python/nn/optimizer/rmsprop.py | 116 +++++++++--------- oneflow/python/nn/optimizer/sgd.py | 70 ++++++----- .../python/test/modules/test_optim_adam.py | 19 +-- .../python/test/modules/test_optim_adamw.py | 15 +-- .../python/test/modules/test_optim_rmsprop.py | 21 ++-- oneflow/python/test/modules/test_optim_sgd.py | 5 +- 10 files changed, 185 insertions(+), 170 deletions(-) diff --git a/oneflow/core/framework/attr_map.cpp b/oneflow/core/framework/attr_map.cpp index b0d22e177..b615efc78 100644 --- a/oneflow/core/framework/attr_map.cpp +++ b/oneflow/core/framework/attr_map.cpp @@ -52,7 +52,7 @@ AttrMap::AttrMap(const MutableCfgAttrMap& other) { template<typename T> Maybe<const T&> AttrMap::GetAttr(const std::string& attr_name) const { const auto& it = this->find(attr_name); - CHECK_OR_RETURN(it != this->end()); + CHECK_OR_RETURN(it != this->end()) << attr_name << " not found"; const auto* ptr = dynamic_cast<const user_op::TypedAttrVal<T>*>(it->second.get()); CHECK_NOTNULL_OR_RETURN(ptr); return ptr->val(); diff --git a/oneflow/python/nn/optimizer/adam.py b/oneflow/python/nn/optimizer/adam.py index 31fbe829c..3a81c926f 100644 --- a/oneflow/python/nn/optimizer/adam.py +++ b/oneflow/python/nn/optimizer/adam.py @@ -71,7 +71,7 @@ class Adam(Optimizer): lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, - weight_decay: float = 0, + weight_decay: float = 0, # Adam's weight_decay actually does L2 Normalize amsgrad: bool = False, scale: float = 1.0, ): @@ -90,7 +90,7 @@ class Adam(Optimizer): self._default_options["lr"] = lr self._default_options["eps"] = eps - self._default_options["beta"] = betas + self._default_options["betas"] = betas self._default_options["weight_decay"] = weight_decay self._default_options["amsgrad"] = amsgrad self._default_options["scale"] = scale @@ -116,14 +116,7 @@ class Adam(Optimizer): .Input("learning_rate") .Input("m") .Input("v") - .Attr("scale", self._default_options["scale"]) .Attr("l1", 0.0) - .Attr( - "l2", self._default_options["weight_decay"] - ) # Adam's weight_decay actually does L2 Normalize - .Attr("beta1", self._default_options["beta"][0]) - .Attr("beta2", self._default_options["beta"][1]) - .Attr("epsilon", self._default_options["eps"]) .Attr("weight_decay", 0.0) .Build() ) @@ -135,19 +128,29 @@ class Adam(Optimizer): closure (callable, optional): A closure that reevaluates the model and returns the loss. """ - loss = None - if closure is not None: - loss = closure() - - for param_group in self._param_groups: - lr_tensor = flow.Tensor([param_group.options["lr"]]) - for param in param_group.parameters: - if param.grad is None: - continue - m_tensor = self._state[param]["exp_avg"] - v_tensor = self._state[param]["exp_avg_sq"] - self._op(param, param.grad, lr_tensor, m_tensor, v_tensor) - - self._state["step"] = self._state["step"] + 1 - - return loss + with flow.no_grad(): + loss = None + if closure is not None: + loss = closure() + + for param_group in self._param_groups: + kwargs = { + "scale": param_group.options["scale"], + "l2": param_group.options["weight_decay"], + "beta1": param_group.options["betas"][0], + "beta2": param_group.options["betas"][1], + "epsilon": param_group.options["eps"], + } + lr_tensor = flow.Tensor([param_group.options["lr"]]) + for param in param_group.parameters: + m_tensor = self._state[param]["exp_avg"] + v_tensor = self._state[param]["exp_avg_sq"] + if param.grad is None: + continue + self._op( + param, param.grad, lr_tensor, m_tensor, v_tensor, **kwargs, + ) + + self._state["step"] = self._state["step"] + 1 + + return loss diff --git a/oneflow/python/nn/optimizer/adamw.py b/oneflow/python/nn/optimizer/adamw.py index fbc662c39..a4b73dc88 100644 --- a/oneflow/python/nn/optimizer/adamw.py +++ b/oneflow/python/nn/optimizer/adamw.py @@ -90,9 +90,10 @@ class AdamW(Optimizer): assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" assert scale > 0.0, f"Invalid scale factor: {scale}" assert amsgrad is False, "Not support AMSGrad now!" + self._default_options["lr"] = lr self._default_options["eps"] = eps - self._default_options["beta"] = betas + self._default_options["betas"] = betas self._default_options["weight_decay"] = weight_decay self._default_options["amsgrad"] = amsgrad self._default_options["scale"] = scale @@ -118,13 +119,8 @@ class AdamW(Optimizer): .Input("learning_rate") .Input("m") .Input("v") - .Attr("scale", self._default_options["scale"]) .Attr("l1", 0.0) .Attr("l2", 0.0) - .Attr("beta1", self._default_options["beta"][0]) - .Attr("beta2", self._default_options["beta"][1]) - .Attr("epsilon", self._default_options["eps"]) - .Attr("weight_decay", self._default_options["weight_decay"]) .Build() ) @@ -135,19 +131,29 @@ class AdamW(Optimizer): closure (callable, optional): A closure that reevaluates the model and returns the loss. """ - loss = None - if closure is not None: - loss = closure() - - for param_group in self._param_groups: - lr_tensor = flow.Tensor([param_group.options["lr"]]) - for param in param_group.parameters: - if param.grad is None: - continue - m_tensor = self._state[param]["exp_avg"] - v_tensor = self._state[param]["exp_avg_sq"] - self._op(param, param.grad, lr_tensor, m_tensor, v_tensor) - - self._state["step"] = self._state["step"] + 1 - - return loss + with flow.no_grad(): + loss = None + if closure is not None: + loss = closure() + + for param_group in self._param_groups: + kwargs = { + "scale": param_group.options["scale"], + "weight_decay": param_group.options["weight_decay"], + "beta1": param_group.options["betas"][0], + "beta2": param_group.options["betas"][1], + "epsilon": param_group.options["eps"], + } + lr_tensor = flow.Tensor([param_group.options["lr"]]) + for param in param_group.parameters: + if param.grad is None: + continue + m_tensor = self._state[param]["exp_avg"] + v_tensor = self._state[param]["exp_avg_sq"] + self._op( + param, param.grad, lr_tensor, m_tensor, v_tensor, **kwargs, + ) + + self._state["step"] = self._state["step"] + 1 + + return loss diff --git a/oneflow/python/nn/optimizer/optimizer.py b/oneflow/python/nn/optimizer/optimizer.py index 9fb16b1ac..d70496324 100644 --- a/oneflow/python/nn/optimizer/optimizer.py +++ b/oneflow/python/nn/optimizer/optimizer.py @@ -31,8 +31,8 @@ class ParamGroup(object): self._parameters = list(parameters) self._options = default_options else: # Dict - assert "param" in parameters - self._parameters = list(parameters["param"]) + assert "params" in parameters + self._parameters = list(parameters["params"]) self._options = default_options for key in self._options: if key in parameters: diff --git a/oneflow/python/nn/optimizer/rmsprop.py b/oneflow/python/nn/optimizer/rmsprop.py index 0ac928b3f..7a23bd80e 100644 --- a/oneflow/python/nn/optimizer/rmsprop.py +++ b/oneflow/python/nn/optimizer/rmsprop.py @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import List, Dict, Callable, Union, Iterator, Tuple +from typing import List, Dict, Callable, Union, Iterator from types import GeneratorType import oneflow as flow @@ -52,7 +52,7 @@ class RMSprop(Optimizer): v(w, t) = \beta v(w, t-1) + \frac{\eta} {\\sqrt{r(w,t) + \epsilon}} \nabla Q_{i}(w) - + w = w - v(w, t) if centered is True: @@ -65,9 +65,9 @@ class RMSprop(Optimizer): v(w, t) = \beta v(w, t-1) + \frac{\eta} {\\sqrt{r(w,t) - (g(w, t))^2 + \epsilon}} \nabla Q_{i}(w) - + w = w - v(w, t) - + where, :math:`\alpha` is a hyperparameter and typical values are 0.99, 0.95 and so on. :math:`\beta` is the momentum term. :math:`\epsilon` is a smoothing term to avoid division by zero, usually set somewhere in range @@ -103,7 +103,7 @@ class RMSprop(Optimizer): assert eps >= 0.0, f"Invalid epsilon value: {eps}" assert weight_decay >= 0.0, f"Invalid weight_decay value: {weight_decay}" assert scale > 0.0, f"Invalid scale factor: {scale}" - assert momentum == 0.0, f"Not support momentum greater than zeros now!" + assert momentum == 0.0, "Not support momentum greater than zeros now!" self._default_options["lr"] = lr self._default_options["alpha"] = alpha @@ -124,41 +124,32 @@ class RMSprop(Optimizer): assert param.is_leaf, "parameters must be leaf tensor" self._state[param] = dict() self._state[param]["square_avg"] = flow.experimental.zeros_like(param) - if "centered" in self._default_options: + if param_group.options["centered"]: self._state[param]["grad_avg"] = flow.experimental.zeros_like(param) - if centered: - self._op = ( - flow.builtin_op("rmsprop_update") - .Input("model") - .Input("model_diff") - .Input("learning_rate") - .Input("mean_square") - .Input("mean_gradient") - .Attr("scale", self._default_options["scale"]) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("centered", self._default_options["centered"]) - .Attr("epsilon", self._default_options["eps"]) - .Attr("decay_rate", self._default_options["alpha"]) - .Attr("weight_decay", self._default_options["weight_decay"]) - .Build() - ) - else: - self._op = ( - flow.builtin_op("rmsprop_update") - .Input("model") - .Input("model_diff") - .Input("learning_rate") - .Input("mean_square") - .Attr("scale", self._default_options["scale"]) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("centered", self._default_options["centered"]) - .Attr("epsilon", self._default_options["eps"]) - .Attr("decay_rate", self._default_options["alpha"]) - .Attr("weight_decay", self._default_options["weight_decay"]) - .Build() - ) + + self._centered_rmsprop = ( + flow.builtin_op("rmsprop_update") + .Input("model") + .Input("model_diff") + .Input("learning_rate") + .Input("mean_square") + .Input("mean_gradient") + .Attr("centered", True) + .Attr("l1", 0.0) + .Attr("l2", 0.0) + .Build() + ) + self._rmsprop = ( + flow.builtin_op("rmsprop_update") + .Input("model") + .Input("model_diff") + .Input("learning_rate") + .Input("mean_square") + .Attr("centered", False) + .Attr("l1", 0.0) + .Attr("l2", 0.0) + .Build() + ) def step(self, closure: Callable = None): """Performs a single optimization step. @@ -167,22 +158,31 @@ class RMSprop(Optimizer): closure (callable, optional): A closure that reevaluates the model and returns the loss. """ - loss = None - if closure is not None: - loss = closure() - - for param_group in self._param_groups: - lr_tensor = flow.Tensor([param_group.options["lr"]]) - for param in param_group.parameters: - if param.grad is None: - continue - ms_tensor = self._state[param]["square_avg"] - if self._default_options["centered"]: - mg_tensor = self._state[param]["grad_avg"] - self._op(param, param.grad, lr_tensor, ms_tensor, mg_tensor) - else: - self._op(param, param.grad, lr_tensor, ms_tensor) - - self._state["step"] = self._state["step"] + 1 - - return loss + with flow.no_grad(): + loss = None + if closure is not None: + loss = closure() + + for param_group in self._param_groups: + lr_tensor = flow.Tensor([param_group.options["lr"]]) + kwargs = { + "scale": param_group.options["scale"], + "epsilon": param_group.options["eps"], + "decay_rate": param_group.options["alpha"], + "weight_decay": param_group.options["weight_decay"], + } + for param in param_group.parameters: + if param.grad is None: + continue + ms_tensor = self._state[param]["square_avg"] + if param_group.options["centered"]: + mg_tensor = self._state[param]["grad_avg"] + self._centered_rmsprop( + param, param.grad, lr_tensor, ms_tensor, mg_tensor, **kwargs + ) + else: + self._rmsprop(param, param.grad, lr_tensor, ms_tensor, **kwargs) + + self._state["step"] = self._state["step"] + 1 + + return loss diff --git a/oneflow/python/nn/optimizer/sgd.py b/oneflow/python/nn/optimizer/sgd.py index e41a13ba3..9d4d20d6a 100644 --- a/oneflow/python/nn/optimizer/sgd.py +++ b/oneflow/python/nn/optimizer/sgd.py @@ -68,8 +68,7 @@ class SGD(Optimizer): self._default_options["lr"] = lr self._default_options["scale"] = scale - if momentum != 0.0: - self._default_options["momentum"] = momentum + self._default_options["momentum"] = momentum # Add parameters if isinstance(parameters, GeneratorType): @@ -82,37 +81,32 @@ class SGD(Optimizer): for param in param_group.parameters: assert param.is_leaf, "parameters must be leaf tensor" self._state[param] = dict() - if "momentum" in self._default_options: + if param_group.options["momentum"] != 0.0: self._state[param]["momentum_buf"] = flow.experimental.zeros_like( param ) - if "momentum" in self._default_options.keys(): - self._op = ( - flow.builtin_op("momentum_update") - .Input("model") - .Input("model_diff") - .Input("learning_rate") - .Input("momentum") - .Attr("scale", self._default_options["scale"]) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("beta", self._default_options["momentum"]) - .Attr("weight_decay", 0.0) - .Build() - ) - else: - self._op = ( - flow.builtin_op("sgd_update") - .Input("model") - .Input("model_diff") - .Input("learning_rate") - .Attr("scale", self._default_options["scale"]) - .Attr("weight_decay", 0.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Build() - ) + self._momentum_sgd = ( + flow.builtin_op("momentum_update") + .Input("model") + .Input("model_diff") + .Input("learning_rate") + .Input("momentum") + .Attr("l1", 0.0) + .Attr("l2", 0.0) + .Attr("weight_decay", 0.0) + .Build() + ) + self._sgd = ( + flow.builtin_op("sgd_update") + .Input("model") + .Input("model_diff") + .Input("learning_rate") + .Attr("weight_decay", 0.0) + .Attr("l1", 0.0) + .Attr("l2", 0.0) + .Build() + ) def step(self, closure: Callable = None): with flow.no_grad(): @@ -125,11 +119,21 @@ class SGD(Optimizer): for param in param_group.parameters: if param.grad is None: continue - if "momentum" in self._default_options: - momentum_buf = self._state[param]["momentum_buf"] - self._op(param, param.grad, lr_tensor, momentum_buf) + if param_group.options["momentum"] == 0.0: + scale = param_group.options["scale"] + self._sgd(param, param.grad, lr_tensor, scale=scale) else: - self._op(param, param.grad, lr_tensor) + momentum_buf = self._state[param]["momentum_buf"] + scale = param_group.options["scale"] + beta = param_group.options["momentum"] + self._momentum_sgd( + param, + param.grad, + lr_tensor, + momentum_buf, + scale=scale, + beta=beta, + ) self._state["step"] = self._state["step"] + 1 return loss diff --git a/oneflow/python/test/modules/test_optim_adam.py b/oneflow/python/test/modules/test_optim_adam.py index fd2aa2c41..273eb115e 100644 --- a/oneflow/python/test/modules/test_optim_adam.py +++ b/oneflow/python/test/modules/test_optim_adam.py @@ -35,20 +35,21 @@ def compare_with_numpy_adam( def train_by_oneflow(): x = Parameter(flow.Tensor(init_value)) - param_list = list() - param_list.append(x) adam = flow.optim.Adam( - [{"param": param_list}], - lr=learning_rate, - betas=betas, - eps=eps, - weight_decay=weight_decay, - scale=scale, + [ + { + "params": [x], + "lr": learning_rate, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "scale": scale, + } + ] ) def train_one_iter(grad): grad_tensor = flow.Tensor(grad, requires_grad=False) - loss = x * grad_tensor loss = flow.sum(x * grad_tensor) loss.backward() adam.step() diff --git a/oneflow/python/test/modules/test_optim_adamw.py b/oneflow/python/test/modules/test_optim_adamw.py index 49bb3c2f4..395565f04 100644 --- a/oneflow/python/test/modules/test_optim_adamw.py +++ b/oneflow/python/test/modules/test_optim_adamw.py @@ -34,18 +34,19 @@ def compare_with_numpy_adamw( def train_by_oneflow(): x = Parameter(flow.Tensor(init_value)) - param_list = list() - param_list.append(x) adam = flow.optim.AdamW( - [{"param": param_list}], - lr=learning_rate, - scale=scale, - weight_decay=weight_decay, + [ + { + "params": [x], + "lr": learning_rate, + "weight_decay": weight_decay, + "scale": scale, + } + ] ) def train_one_iter(grad): grad_tensor = flow.Tensor(grad, requires_grad=False) - loss = x * grad_tensor loss = flow.sum(x * grad_tensor) loss.backward() adam.step() diff --git a/oneflow/python/test/modules/test_optim_rmsprop.py b/oneflow/python/test/modules/test_optim_rmsprop.py index fb549f4b5..e533d0d8c 100644 --- a/oneflow/python/test/modules/test_optim_rmsprop.py +++ b/oneflow/python/test/modules/test_optim_rmsprop.py @@ -46,19 +46,22 @@ def compare_with_numpy_rmsprop( param_list = list() param_list.append(x) rmsprop = flow.optim.RMSprop( - [{"param": param_list}], - lr=learning_rate, - momentum=momentum, - scale=scale, - alpha=alpha, - eps=eps, - weight_decay=weight_decay, - centered=centered, + [ + { + "params": param_list, + "lr": learning_rate, + "alpha": alpha, + "eps": eps, + "weight_decay": weight_decay, + "momentum": momentum, + "centered": centered, + "scale": scale, + } + ] ) def train_one_iter(grad): grad_tensor = flow.Tensor(grad, requires_grad=False) - loss = x * grad_tensor loss = flow.sum(x * grad_tensor) loss.backward() rmsprop.step() diff --git a/oneflow/python/test/modules/test_optim_sgd.py b/oneflow/python/test/modules/test_optim_sgd.py index 18abfbc27..ac4cc15aa 100644 --- a/oneflow/python/test/modules/test_optim_sgd.py +++ b/oneflow/python/test/modules/test_optim_sgd.py @@ -35,15 +35,12 @@ def compare_with_numpy_sgd( def train_by_oneflow(): x = Parameter(flow.Tensor(init_value)) - param_list = list() - param_list.append(x) sgd = flow.optim.SGD( - [{"param": param_list}], lr=learning_rate, momentum=momentum, scale=scale + [{"params": [x], "lr": learning_rate, "momentum": momentum, "scale": scale}] ) def train_one_iter(grad): grad_tensor = flow.Tensor(grad, requires_grad=False) - loss = x * grad_tensor loss = flow.sum(x * grad_tensor) loss.backward() sgd.step() -- GitLab