Skip to content
Snippets Groups Projects
graph_block.py 10.6 KiB
Newer Older
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import OrderedDict
from functools import partial
from typing import Iterator, Optional, Set, Union

import oneflow._oneflow_internal
import oneflow.framework.graph_build_util as graph_build_util
from oneflow.framework.tensor import Tensor
from oneflow.nn.module import Module
from oneflow.nn.parameter import Parameter
from oneflow.nn.util import add_indent


class BlockType:
    NONE = "NONE"
    MODULE = "MODULE"
    PARAMETER = "PARAMETER"
    BUFFER = "BUFFER"


class Block(object):
    def __init__(
        self,
        prefix: str = "",
        name: str = "",
        value: Union[Module, Parameter, Tensor] = None,
    ):
        assert not isinstance(value, Block)
        self._name = name
        self._name_prefix = prefix
        self._type = BlockType.NONE
        self._origin = value
        self.config = BlockConfig()
        self._scope = None
        self._prev_scope = None
        if isinstance(value, Module):
            self._type = BlockType.MODULE
            self._is_executing_forward = False
            self._modules = OrderedDict()
            self._parameters = OrderedDict()
            self._buffers = OrderedDict()
            for (n, m) in list(value.named_children()):
                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m))
            for (n, p) in list(value.named_parameters("", False)):
                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p))
            for (n, b) in list(value.named_buffers("", False)):
                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
        elif isinstance(value, Parameter):
            self._type = BlockType.PARAMETER
            self._lazy_origin = None
            self._lazy_origin_builder = None
        elif isinstance(value, Tensor):
            self._type = BlockType.BUFFER
            self._lazy_origin = None
            self._lazy_origin_builder = None
        else:
            raise NotImplementedError()

    @property
    def name(self):
        return self._name

    @property
    def name_prefix(self):
        return self._name_prefix

    @property
    def type(self):
        return self._type

    @property
    def origin(self):
        return self._origin

    @property
    def lazy_origin(self):
        assert (
            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
        ), "Only Parameter or Buffer Block has lazy_origin"
        return self._lazy_origin

    def lazy_origin_builder(self):
        assert (
            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
        ), "Only Parameter or Buffer Block has lazy_origin_builder"
        return self._lazy_origin_builder

    def set_lazy_origin_builder(self, fn=None):
        assert (
            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
        ), "Only Parameter or Buffer Block has lazy_origin_builder"
        self._lazy_origin_builder = fn

    @property
    def prev_scope(self):
        if self._prev_scope is None:
            self._prev_scope = oneflow._oneflow_internal.GetCurrentScope()
        return self._prev_scope

    @property
    def scope(self):
        if self._scope is None:
            self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self)
        return self._scope

    def scope_context(self):
        return graph_build_util.BlockScopeContext(self.prev_scope, self.scope)

    def __call__(self, *args):
        assert self._type == BlockType.MODULE
        result = self._origin.__class__.__call__(self, *args)
        return result

    def __iter__(self) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        return iter(self._modules.values())

    def forward(self, *args):
        assert self._type == BlockType.MODULE
        self._is_executing_forward = True
        with self.scope_context():
            result = self._origin.__class__.forward(self, *args)
        self._is_executing_forward = False
        return result

    def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        if memo is None:
            memo = set()
        if self not in memo:
            memo.add(self)
            yield self
            for (name, module) in self._modules.items():
                if module is None:
                    continue
                for m in module.modules(memo):
                    yield m

    def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        memo = set()
        modules = self.modules() if recurse else [self]
        for module in modules:
            members = get_members_fn(module)
            for (k, v) in members:
                if v is None or v in memo:
                    continue
                memo.add(v)
                yield v

    def parameters(self, recurse: bool = True) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        gen = self._members(lambda module: module._parameters.items(), recurse=recurse)
        for elem in gen:
            yield elem

    def buffers(self, recurse: bool = True) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        gen = self._members(lambda module: module._buffers.items(), recurse=recurse)
        for elem in gen:
            yield elem

    def __setattr__(self, name: str, value=None) -> None:
        if value is None or not isinstance(value, Block):
            self.__dict__[name] = value
        else:
            dicts_or_sets = (
                self.__dict__,
                self._modules,
                self._parameters,
                self._buffers,
            )
            for d in dicts_or_sets:
                if name in d:
                    raise AttributeError(
                        "'{}' object has duplicated attribute named '{}'".format(
                            self._name, name
                        )
                    )
            if value.type == BlockType.MODULE:
                self._modules[name] = value
            elif value.type == BlockType.PARAMETER:
                self._parameters[name] = value
            elif value.type == BlockType.BUFFER:
                self._buffers[name] = value
            else:
                raise AttributeError(
                    "'{}' object are not allowed to set attribute named '{}'".format(
                        type(self).__name__, name
                    )
                )

    def __getattr__(self, name: str):
        if name in self.__dict__:
            return self.__dict__[name]
        if self._type == BlockType.MODULE:
            # support get module
            if "_modules" in self.__dict__:
                modules = self.__dict__["_modules"]
                if name in modules:
                    return modules[name]
            # support get parameter
            p_state = self._get_in_states(name, "_parameters")
            if p_state is not None:
                return p_state
            # support get buffer
            b_state = self._get_in_states(name, "_buffers")
            if b_state is not None:
                return b_state
            # support get normal attr
            if name in self._origin.__dict__:
                return self._origin.__dict__[name]
            # support get function
            if hasattr(self._origin, name):
                return partial(getattr(self._origin.__class__, name), self)
        raise AttributeError(
            "'{}' object has no attribute '{}'".format(type(self).__name__, name)
        )

    def _get_in_states(self, name, states_name):
        if states_name not in self.__dict__:
            return None

        _states = self.__dict__[states_name]
        if name not in _states:
            return None

        _s_block = _states[name]
        if graph_build_util.lazy_mode.is_enabled():
            #  lazy
            if _s_block._lazy_origin is None:
                assert _s_block._lazy_origin_builder is not None, (
                    repr(_s_block) + " has no lazy Tensor creation function."
                )
                assert self._is_executing_forward, (
                    repr(_s_block)
                    + "'s first get must happened in it's nn.Module.forward() to generate the right scope."
                )
                with _s_block.scope_context():
                    _s_block._lazy_origin = _s_block._lazy_origin_builder()
            return _s_block._lazy_origin
        elif (
            not graph_build_util.lazy_mode.is_enabled()
        ) and self._is_executing_forward:
            # eager and inside nn.Graph.build()
            return _s_block.origin
        else:
            # outside nn.Graph.build()
            return _s_block

    def __repr__(self):
        lines = None
        if self._type == BlockType.MODULE:
            child_lines = []

            def _append_child(d):
                for (_, n) in d.items():
                    n_str = repr(n)
                    n_str = add_indent(n_str, 2)
                    child_lines.append(n_str)

            _append_child(self._modules)
            _append_child(self._parameters)
            _append_child(self._buffers)
            if len(child_lines) > 0:
                lines = child_lines
        main_str = (
            "("
            + self._name_prefix
            + self._name
            + ":"
            + self._origin.__class__.__name__
            + ":"
            + self._type
            + "): ("
        )
        if lines is not None:
            main_str += "\n  " + "\n  ".join(lines) + "\n"
        main_str += ")"
        return main_str


class BlockConfig(object):
    def __init__(self):
        self._stage_id = None
        self._activation_checkpointing = None

    @property
    def stage_id(self):
        return self._stage_id

    @stage_id.setter
    def stage_id(self, value: int = None):
        self._stage_id = value

    @property
    def activation_checkpointing(self):
        return self._activation_checkpointing

    @activation_checkpointing.setter
    def activation_checkpointing(self, value: bool = False):
        self._activation_checkpointing = value