Skip to content
Snippets Groups Projects 10.6 KiB
Newer Older
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.
from collections import OrderedDict
from functools import partial
from typing import Iterator, Optional, Set, Union

import oneflow._oneflow_internal
import oneflow.framework.graph_build_util as graph_build_util
from oneflow.framework.tensor import Tensor
from oneflow.nn.module import Module
from oneflow.nn.parameter import Parameter
from oneflow.nn.util import add_indent

class BlockType:
    NONE = "NONE"

class Block(object):
    def __init__(
        prefix: str = "",
        name: str = "",
        value: Union[Module, Parameter, Tensor] = None,
        assert not isinstance(value, Block)
        self._name = name
        self._name_prefix = prefix
        self._type = BlockType.NONE
        self._origin = value
        self.config = BlockConfig()
        self._scope = None
        self._prev_scope = None
        if isinstance(value, Module):
            self._type = BlockType.MODULE
            self._is_executing_forward = False
            self._modules = OrderedDict()
            self._parameters = OrderedDict()
            self._buffers = OrderedDict()
            for (n, m) in list(value.named_children()):
                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m))
            for (n, p) in list(value.named_parameters("", False)):
                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p))
            for (n, b) in list(value.named_buffers("", False)):
                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
        elif isinstance(value, Parameter):
            self._type = BlockType.PARAMETER
            self._lazy_origin = None
            self._lazy_origin_builder = None
        elif isinstance(value, Tensor):
            self._type = BlockType.BUFFER
            self._lazy_origin = None
            self._lazy_origin_builder = None
            raise NotImplementedError()

    def name(self):
        return self._name

    def name_prefix(self):
        return self._name_prefix

    def type(self):
        return self._type

    def origin(self):
        return self._origin

    def lazy_origin(self):
        assert (
            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
        ), "Only Parameter or Buffer Block has lazy_origin"
        return self._lazy_origin

    def lazy_origin_builder(self):
        assert (
            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
        ), "Only Parameter or Buffer Block has lazy_origin_builder"
        return self._lazy_origin_builder

    def set_lazy_origin_builder(self, fn=None):
        assert (
            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
        ), "Only Parameter or Buffer Block has lazy_origin_builder"
        self._lazy_origin_builder = fn

    def prev_scope(self):
        if self._prev_scope is None:
            self._prev_scope = oneflow._oneflow_internal.GetCurrentScope()
        return self._prev_scope

    def scope(self):
        if self._scope is None:
            self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self)
        return self._scope

    def scope_context(self):
        return graph_build_util.BlockScopeContext(self.prev_scope, self.scope)

    def __call__(self, *args):
        assert self._type == BlockType.MODULE
        result = self._origin.__class__.__call__(self, *args)
        return result

    def __iter__(self) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        return iter(self._modules.values())

    def forward(self, *args):
        assert self._type == BlockType.MODULE
        self._is_executing_forward = True
        with self.scope_context():
            result = self._origin.__class__.forward(self, *args)
        self._is_executing_forward = False
        return result

    def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        if memo is None:
            memo = set()
        if self not in memo:
            yield self
            for (name, module) in self._modules.items():
                if module is None:
                for m in module.modules(memo):
                    yield m

    def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        memo = set()
        modules = self.modules() if recurse else [self]
        for module in modules:
            members = get_members_fn(module)
            for (k, v) in members:
                if v is None or v in memo:
                yield v

    def parameters(self, recurse: bool = True) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        gen = self._members(lambda module: module._parameters.items(), recurse=recurse)
        for elem in gen:
            yield elem

    def buffers(self, recurse: bool = True) -> Iterator["Block"]:
        assert self._type == BlockType.MODULE
        gen = self._members(lambda module: module._buffers.items(), recurse=recurse)
        for elem in gen:
            yield elem

    def __setattr__(self, name: str, value=None) -> None:
        if value is None or not isinstance(value, Block):
            self.__dict__[name] = value
            dicts_or_sets = (
            for d in dicts_or_sets:
                if name in d:
                    raise AttributeError(
                        "'{}' object has duplicated attribute named '{}'".format(
                            self._name, name
            if value.type == BlockType.MODULE:
                self._modules[name] = value
            elif value.type == BlockType.PARAMETER:
                self._parameters[name] = value
            elif value.type == BlockType.BUFFER:
                self._buffers[name] = value
                raise AttributeError(
                    "'{}' object are not allowed to set attribute named '{}'".format(
                        type(self).__name__, name

    def __getattr__(self, name: str):
        if name in self.__dict__:
            return self.__dict__[name]
        if self._type == BlockType.MODULE:
            # support get module
            if "_modules" in self.__dict__:
                modules = self.__dict__["_modules"]
                if name in modules:
                    return modules[name]
            # support get parameter
            p_state = self._get_in_states(name, "_parameters")
            if p_state is not None:
                return p_state
            # support get buffer
            b_state = self._get_in_states(name, "_buffers")
            if b_state is not None:
                return b_state
            # support get normal attr
            if name in self._origin.__dict__:
                return self._origin.__dict__[name]
            # support get function
            if hasattr(self._origin, name):
                return partial(getattr(self._origin.__class__, name), self)
        raise AttributeError(
            "'{}' object has no attribute '{}'".format(type(self).__name__, name)

    def _get_in_states(self, name, states_name):
        if states_name not in self.__dict__:
            return None

        _states = self.__dict__[states_name]
        if name not in _states:
            return None

        _s_block = _states[name]
        if graph_build_util.lazy_mode.is_enabled():
            #  lazy
            if _s_block._lazy_origin is None:
                assert _s_block._lazy_origin_builder is not None, (
                    repr(_s_block) + " has no lazy Tensor creation function."
                assert self._is_executing_forward, (
                    + "'s first get must happened in it's nn.Module.forward() to generate the right scope."
                with _s_block.scope_context():
                    _s_block._lazy_origin = _s_block._lazy_origin_builder()
            return _s_block._lazy_origin
        elif (
            not graph_build_util.lazy_mode.is_enabled()
        ) and self._is_executing_forward:
            # eager and inside
            return _s_block.origin
            # outside
            return _s_block

    def __repr__(self):
        lines = None
        if self._type == BlockType.MODULE:
            child_lines = []

            def _append_child(d):
                for (_, n) in d.items():
                    n_str = repr(n)
                    n_str = add_indent(n_str, 2)

            if len(child_lines) > 0:
                lines = child_lines
        main_str = (
            + self._name_prefix
            + self._name
            + ":"
            + self._origin.__class__.__name__
            + ":"
            + self._type
            + "): ("
        if lines is not None:
            main_str += "\n  " + "\n  ".join(lines) + "\n"
        main_str += ")"
        return main_str

class BlockConfig(object):
    def __init__(self):
        self._stage_id = None
        self._activation_checkpointing = None

    def stage_id(self):
        return self._stage_id

    def stage_id(self, value: int = None):
        self._stage_id = value

    def activation_checkpointing(self):
        return self._activation_checkpointing

    def activation_checkpointing(self, value: bool = False):
        self._activation_checkpointing = value