""" Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from collections import OrderedDict from typing import Iterator, Optional, Set, Union import oneflow._oneflow_internal import oneflow.framework.graph_build_util as graph_build_util from oneflow.framework.tensor import Tensor from oneflow.nn.module import Module from oneflow.nn.parameter import Parameter from oneflow.nn.util import add_indent class BlockType: NONE = "NONE" MODULE = "MODULE" PARAMETER = "PARAMETER" BUFFER = "BUFFER" class Block(object): def __init__( self, prefix: str = "", name: str = "", value: Union[Module, Parameter, Tensor] = None, ): assert not isinstance(value, Block) self._name = name self._name_prefix = prefix self._type = BlockType.NONE self._origin = value self.config = BlockConfig() self._scope = None self._prev_scope = None if isinstance(value, Module): self._type = BlockType.MODULE self._is_executing_forward = False self._modules = OrderedDict() self._parameters = OrderedDict() self._buffers = OrderedDict() for (n, m) in list(value.named_children()): self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m)) for (n, p) in list(value.named_parameters("", False)): self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p)) for (n, b) in list(value.named_buffers("", False)): self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b)) elif isinstance(value, Parameter): self._type = BlockType.PARAMETER self._lazy_origin = None self._lazy_origin_builder = None elif isinstance(value, Tensor): self._type = BlockType.BUFFER self._lazy_origin = None self._lazy_origin_builder = None else: raise NotImplementedError() @property def name(self): return self._name @property def name_prefix(self): return self._name_prefix @property def type(self): return self._type @property def origin(self): return self._origin @property def lazy_origin(self): assert ( self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER ), "Only Parameter or Buffer Block has lazy_origin" return self._lazy_origin def lazy_origin_builder(self): assert ( self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER ), "Only Parameter or Buffer Block has lazy_origin_builder" return self._lazy_origin_builder def set_lazy_origin_builder(self, fn=None): assert ( self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER ), "Only Parameter or Buffer Block has lazy_origin_builder" self._lazy_origin_builder = fn @property def prev_scope(self): if self._prev_scope is None: self._prev_scope = oneflow._oneflow_internal.GetCurrentScope() return self._prev_scope @property def scope(self): if self._scope is None: self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self) return self._scope def scope_context(self): return graph_build_util.BlockScopeContext(self.prev_scope, self.scope) def __call__(self, *args): assert self._type == BlockType.MODULE result = self._origin.__class__.__call__(self, *args) return result def forward(self, *args): assert self._type == BlockType.MODULE self._is_executing_forward = True with self.scope_context(): result = self._origin.__class__.forward(self, *args) self._is_executing_forward = False return result def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]: assert self._type == BlockType.MODULE if memo is None: memo = set() if self not in memo: memo.add(self) yield self for (name, module) in self._modules.items(): if module is None: continue for m in module.modules(memo): yield m def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]: assert self._type == BlockType.MODULE memo = set() modules = self.modules() if recurse else [self] for module in modules: members = get_members_fn(module) for (k, v) in members: if v is None or v in memo: continue memo.add(v) yield v def parameters(self, recurse: bool = True) -> Iterator["Block"]: assert self._type == BlockType.MODULE gen = self._members(lambda module: module._parameters.items(), recurse=recurse) for elem in gen: yield elem def buffers(self, recurse: bool = True) -> Iterator["Block"]: assert self._type == BlockType.MODULE gen = self._members(lambda module: module._buffers.items(), recurse=recurse) for elem in gen: yield elem def __setattr__(self, name: str, value=None) -> None: if value is None or not isinstance(value, Block): self.__dict__[name] = value else: dicts_or_sets = ( self.__dict__, self._modules, self._parameters, self._buffers, ) for d in dicts_or_sets: if name in d: raise AttributeError( "'{}' object has duplicated attribute named '{}'".format( self._name, name ) ) if value.type == BlockType.MODULE: self._modules[name] = value elif value.type == BlockType.PARAMETER: self._parameters[name] = value elif value.type == BlockType.BUFFER: self._buffers[name] = value else: raise AttributeError( "'{}' object are not allowed to set attribute named '{}'".format( type(self).__name__, name ) ) def __getattr__(self, name: str): if name in self.__dict__: return self.__dict__[name] if self._type == BlockType.MODULE: if "_modules" in self.__dict__: modules = self.__dict__["_modules"] if name in modules: return modules[name] if "_parameters" in self.__dict__: _parameters = self.__dict__["_parameters"] if name in _parameters: p_block = _parameters[name] if self._is_executing_forward: if graph_build_util.lazy_mode.is_enabled(): if p_block._lazy_origin is None: assert p_block._lazy_origin_builder is not None, ( repr(p_block) + " has no lazy Tensor creation function." ) with p_block.scope_context(): p_block._lazy_origin = ( p_block._lazy_origin_builder() ) return p_block._lazy_origin else: return p_block.origin else: return p_block if "_buffers" in self.__dict__: _buffers = self.__dict__["_buffers"] if name in _buffers: b_block = _buffers[name] if self._is_executing_forward: if graph_build_util.lazy_mode.is_enabled(): if b_block._lazy_origin is None: assert b_block._lazy_origin_builder is not None, ( repr(b_block) + " has no lazy Tensor creation function." ) with b_block.scope_context(): b_block._lazy_origin = ( b_block._lazy_origin_builder() ) return b_block._lazy_origin else: return b_block.origin else: return b_block if name in self._origin.__dict__: return self._origin.__dict__[name] raise AttributeError( "'{}' object has no attribute '{}'".format(type(self).__name__, name) ) def __repr__(self): lines = None if self._type == BlockType.MODULE: child_lines = [] def _append_child(d): for (_, n) in d.items(): n_str = repr(n) n_str = add_indent(n_str, 2) child_lines.append(n_str) _append_child(self._modules) _append_child(self._parameters) _append_child(self._buffers) if len(child_lines) > 0: lines = child_lines main_str = ( "(" + self._name_prefix + self._name + ":" + self._origin.__class__.__name__ + ":" + self._type + "): (" ) if lines is not None: main_str += "\n " + "\n ".join(lines) + "\n" main_str += ")" return main_str class BlockConfig(object): def __init__(self): self._stage_id = None self._activation_checkpointing = None @property def stage_id(self): return self._stage_id @stage_id.setter def stage_id(self, value: int = None): self._stage_id = value @property def activation_checkpointing(self): return self._activation_checkpointing @activation_checkpointing.setter def activation_checkpointing(self, value: bool = False): self._activation_checkpointing = value