Skip to content
Snippets Groups Projects
Unverified Commit ed82d1da authored by Zhiqiu(Oscar) Xu's avatar Zhiqiu(Oscar) Xu Committed by GitHub
Browse files

Fix optimizer for not supporting all kinds of iterables (#5355)


* added flatten backward

* flatten and softmax backward

* fix bug for not supporting all kinds of iterables in optimizers

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 1c0177ea
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,7 @@ limitations under the License.
"""
from typing import List, Dict, Callable, Union, Iterator, Tuple
from types import GeneratorType
import collections
import oneflow as flow
......@@ -96,7 +96,7 @@ class Adam(Optimizer):
self._default_options["scale"] = scale
# Add parameters
if isinstance(parameters, GeneratorType):
if isinstance(parameters, collections.abc.Iterator):
self.param_groups.append(ParamGroup(parameters, self._default_options))
else: # List[Dict]
for param in parameters:
......
......@@ -15,7 +15,7 @@ limitations under the License.
"""
from typing import List, Dict, Callable, Union, Iterator, Tuple
from types import GeneratorType
import collections
import oneflow as flow
......@@ -99,7 +99,7 @@ class AdamW(Optimizer):
self._default_options["scale"] = scale
# Add parameters
if isinstance(parameters, GeneratorType):
if isinstance(parameters, collections.abc.Iterator):
self.param_groups.append(ParamGroup(parameters, self._default_options))
else: # List[Dict]
for param in parameters:
......
......@@ -16,7 +16,7 @@ limitations under the License.
import warnings
from typing import Dict, Callable, Union, Any, Iterator
from types import GeneratorType
import collections
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.parameter import Parameter
......@@ -29,7 +29,7 @@ class ParamGroup(object):
parameters: Union[Iterator[Parameter], Dict[str, Any]],
default_options: Dict,
):
if isinstance(parameters, GeneratorType):
if isinstance(parameters, collections.abc.Iterator):
self._parameters = list(parameters)
self._options = default_options
else: # Dict
......
......@@ -15,7 +15,7 @@ limitations under the License.
"""
from typing import List, Dict, Callable, Union, Iterator
from types import GeneratorType
import collections
import oneflow as flow
......@@ -113,7 +113,7 @@ class RMSprop(Optimizer):
self._default_options["scale"] = scale
# Add parameters
if isinstance(parameters, GeneratorType):
if isinstance(parameters, collections.abc.Iterator):
self.param_groups.append(ParamGroup(parameters, self._default_options))
else: # List[Dict]
for param in parameters:
......
......@@ -15,7 +15,7 @@ limitations under the License.
"""
from typing import List, Dict, Callable, Union, Iterator
from types import GeneratorType
import collections
import oneflow as flow
......@@ -71,7 +71,7 @@ class SGD(Optimizer):
self._default_options["momentum"] = momentum
# Add parameters
if isinstance(parameters, GeneratorType):
if isinstance(parameters, collections.abc.Iterator):
self.param_groups.append(ParamGroup(parameters, self._default_options))
else: # List[Dict]
for param in parameters:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment