From ab8aab8ba330e499b1f4ea409b8bd95245a07fab Mon Sep 17 00:00:00 2001 From: Xiaoyu Xu <xiaoyulink@gmail.com> Date: Wed, 21 Jul 2021 23:12:25 +0800 Subject: [PATCH] Fea/nn graph/forward graph (#5516) * add test on add input to graph * add var into graph * LazyInterpreter for FetchOutputOpExpr and set op parallel_distribution * refine input var build * split file * rename * mini refine * Add note * LazyInterpret::ApplyImpl for UserOpExpr * refine test scripts * add output to graph * format * Fix bug of LazyInterpret UserOpExpr for change output lbns * Add test user op expr test * fix note mistake * add userop and test * address review * address review * save i/o/s op_name and tensor for c_nn_graph * address review * adjust test * refine build_graph_state Co-authored-by: chengtbf <472491134@qq.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/python/framework/graph_build_util.py | 61 ++- oneflow/python/nn/graph.py | 395 +++--------------- oneflow/python/nn/graph_block.py | 331 +++++++++++++++ oneflow/python/nn/graph_optimizer.py | 37 ++ oneflow/python/nn/utils.py | 28 ++ .../python/test/graph/test_forward_graph.py | 85 ++++ oneflow/python/test/graph/test_graph.py | 14 +- .../python/test/graph/test_input_op_expr.py | 6 - .../test/graph/test_multi_client_session.py | 6 - .../python/test/graph/test_output_op_expr.py | 6 - .../python/test/graph/test_user_op_expr.py | 6 - .../test/graph/test_variable_op_expr.py | 6 - 12 files changed, 598 insertions(+), 383 deletions(-) create mode 100644 oneflow/python/nn/graph_block.py create mode 100644 oneflow/python/nn/graph_optimizer.py create mode 100644 oneflow/python/nn/utils.py create mode 100644 oneflow/python/test/graph/test_forward_graph.py diff --git a/oneflow/python/framework/graph_build_util.py b/oneflow/python/framework/graph_build_util.py index 25977a2b7..718e00cd5 100644 --- a/oneflow/python/framework/graph_build_util.py +++ b/oneflow/python/framework/graph_build_util.py @@ -23,10 +23,11 @@ import oneflow.core.job.scope_pb2 as scope_pb2_util import oneflow.python.framework.attr_util as attr_util import oneflow.python.framework.c_api_util as c_api_util import oneflow.python.framework.placement_util as placement_util -import oneflow.python.framework.runtime_mode as runtime_mode import oneflow.python.framework.scope_util as scope_util import oneflow.python.framework.session_context as session_context import oneflow._oneflow_internal +from oneflow._oneflow_internal import Tensor as InternalTensor +from oneflow.python.framework.tensor import Tensor lazy_mode = oneflow._oneflow_internal.lazy_mode @@ -125,3 +126,61 @@ def make_new_block_scope(prev_scope, block): def scope_to_proto(scope): return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto()) + + +def build_graph_input_arg(op_name, arg): + assert isinstance(arg, (Tensor, InternalTensor)) + input_conf = ( + oneflow._oneflow_internal.oneflow.core.operator.op_conf.FeedInputOpConf() + ) + + input_op = oneflow._oneflow_internal.one.FeedInputOpExpr( + op_name, input_conf, ["in_0"], ["out_0"] + ) + attrs = oneflow._oneflow_internal.MutableCfgAttrMap() + + if isinstance(arg, Tensor): + if not arg.is_determined: + arg.determine() + tensor_in_c = arg._local_or_consistent_tensor + else: + tensor_in_c = arg + + lazy_arg = input_op.apply([tensor_in_c], attrs)[0] + return lazy_arg + + +def build_graph_state(op_name, state_tensor): + var_conf = ( + oneflow._oneflow_internal.oneflow.core.operator.op_conf.FeedVariableOpConf() + ) + + var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr( + op_name, var_conf, ["in_0"], ["out_0"] + ) + attrs = oneflow._oneflow_internal.MutableCfgAttrMap() + + assert isinstance(state_tensor, Tensor) + if not state_tensor.is_determined: + state_tensor.determine() + tensor_in_c = state_tensor._local_or_consistent_tensor + + lazy_tensor = var_op.apply([tensor_in_c], attrs)[0] + return lazy_tensor + + +def build_graph_output(op_name, out): + assert isinstance(out, InternalTensor) + assert out.is_lazy + + output_conf = ( + oneflow._oneflow_internal.oneflow.core.operator.op_conf.FetchOutputOpConf() + ) + + output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr( + op_name, output_conf, ["in_0"], ["out_0"] + ) + attrs = oneflow._oneflow_internal.MutableCfgAttrMap() + + eager_out = output_op.apply([out], attrs)[0] + return eager_out diff --git a/oneflow/python/nn/graph.py b/oneflow/python/nn/graph.py index 72f23aa9a..1f34da4e4 100644 --- a/oneflow/python/nn/graph.py +++ b/oneflow/python/nn/graph.py @@ -15,19 +15,21 @@ limitations under the License. """ from __future__ import absolute_import from collections import OrderedDict -from typing import Union, Optional, Iterator, Set +from functools import partial import oneflow._oneflow_internal import oneflow.python.framework.c_api_util as c_api_util import oneflow.python.framework.graph_build_util as graph_build_util import oneflow.python.framework.session_context as session_ctx import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util +from oneflow._oneflow_internal import Tensor as InternalTensor from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.framework.multi_client_session import MultiClientSession -from oneflow.python.framework.tensor import Tensor +from oneflow.python.nn.graph_block import Block +from oneflow.python.nn.graph_optimizer import OptimizerConfig from oneflow.python.nn.module import Module -from oneflow.python.nn.parameter import Parameter from oneflow.python.nn.optimizer.optimizer import Optimizer +from oneflow.python.nn.utils import add_indent from oneflow.python.framework.function_util import FunctionConfig @@ -45,6 +47,7 @@ class Graph(object): self._optimizers = OrderedDict() self._is_compiled = False self._state_tensortuple = None + self._job_proto = None @property def name(self): @@ -56,7 +59,7 @@ class Graph(object): @property def _graph_proto(self): - return c_api_util.GetCurrentJob() + return self._job_proto def build(self, *args): raise NotImplementedError() @@ -69,7 +72,7 @@ class Graph(object): grad_clipping_conf=None, weight_decay_conf=None, ): - self._optimizers[name] = self.OptimizerConfig( + self._optimizers[name] = OptimizerConfig( optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf ) @@ -102,19 +105,57 @@ class Graph(object): session.TryInit() with graph_build_util.graph_build_context(self.config.proto, session): + # Deal with input + lazy_args = [] + lazy_arg_op_names = [] + for idx, arg in enumerate(args): + op_name = "_" + self.name + "-input_" + str(idx) + lazy_args.append(graph_build_util.build_graph_input_arg(op_name, arg)) + lazy_arg_op_names.append(op_name) + # Deal with parameter and buffer - for s in self._state(): + state_op_names = [] + state_tensors = [] + for state_block in self._state(): + op_name = state_block.name_prefix + state_block.name + state_tensor = state_block.origin + state_op_names.append(op_name) + state_tensors.append(state_tensor) + state_block.set_lazy_origin_builder( + partial(graph_build_util.build_graph_state, op_name, state_tensor) + ) - def to_lazy(): - # TODO(): Replace repr(s) with OpExpr(s.origin) - lazy_tensor = repr(s) - return lazy_tensor + # Deal with module in self.build(*args) + outputs = self.build(*lazy_args) + + # Deal with outputs + if not (type(outputs) is tuple or type(outputs) is list): + if outputs is None: + outputs = () + else: + assert type(outputs) is InternalTensor + outputs = (outputs,) + eager_outputs = [] + eager_output_op_names = [] + for idx, out in enumerate(outputs): + op_name = "_" + self.name + "-output_" + str(idx) + eager_outputs.append(graph_build_util.build_graph_output(op_name, out)) + eager_output_op_names.append(op_name) + if len(eager_outputs) == 0: + eager_outputs = None + elif len(eager_outputs) == 1: + eager_outputs = eager_outputs[0] + else: + eager_outputs = tuple(eager_outputs) - s.set_lazy_origin_lambda(to_lazy) - outputs = self.build(*args) + # TODO(): call self._c_nn_graph + # register lazy_arg_op_names/state_op_names/state_tensors/eager_output_op_names + + # Save job proto for debug + self._job_proto = c_api_util.GetCurrentJob() self._is_compiled = True - return outputs + return eager_outputs def _launch(self): # TODO(xuxiaoyu) @@ -122,7 +163,6 @@ class Graph(object): ... def __call__(self, *args): - # TODO(xuxiaoyu) # if not self._is_compiled: # self._compile() # return self._launch() @@ -179,7 +219,7 @@ class Graph(object): child_lines = [] for n, m in self._blocks.items(): mod_str = repr(m) - mod_str = _add_indent(mod_str, 2) + mod_str = add_indent(mod_str, 2) child_lines.append(mod_str) lines = child_lines @@ -190,278 +230,6 @@ class Graph(object): return main_str -class BlockType: - NONE = "NONE" - MODULE = "MODULE" - PARAMETER = "PARAMETER" - BUFFER = "BUFFER" - - -@oneflow_export("nn.graph.Block") -@experimental_api -class Block(object): - def __init__( - self, - prefix: str = "", - name: str = "", - value: Union[Module, Parameter, Tensor] = None, - ): - assert not isinstance(value, Block) - self._name = name - self._name_prefix = prefix - self._type = BlockType.NONE - self._origin = value - self.config = BlockConfig() - self._scope = None - self._prev_scope = None - - if isinstance(value, Module): - self._type = BlockType.MODULE - self._is_executing_forward = False - self._modules = OrderedDict() - self._parameters = OrderedDict() - self._buffers = OrderedDict() - for n, m in list(value.named_children()): - self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m)) - for n, p in list(value.named_parameters("", False)): - self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p)) - for n, b in list(value.named_buffers("", False)): - self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b)) - elif isinstance(value, Parameter): - self._type = BlockType.PARAMETER - self._lazy_origin = None - self._lazy_origin_lambda = None - elif isinstance(value, Tensor): - self._type = BlockType.BUFFER - self._lazy_origin = None - self._lazy_origin_lambda = None - else: - raise NotImplementedError() - - @property - def name(self): - return self._name - - @property - def name_prefix(self): - return self._name_prefix - - @property - def type(self): - return self._type - - @property - def origin(self): - return self._origin - - @property - def lazy_origin(self): - assert ( - self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER - ), "Only Parameter or Buffer Block has lazy_origin" - return self._lazy_origin - - def lazy_origin_lambda(self): - assert ( - self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER - ), "Only Parameter or Buffer Block has lazy_origin_lambda" - return self._lazy_origin_lambda - - def set_lazy_origin_lambda(self, fn=None): - assert ( - self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER - ), "Only Parameter or Buffer Block has lazy_origin_lambda" - self._lazy_origin_lambda = fn - - @property - def prev_scope(self): - if self._prev_scope is None: - self._prev_scope = oneflow._oneflow_internal.GetCurrentScope() - return self._prev_scope - - @property - def scope(self): - if self._scope is None: - self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self) - return self._scope - - def scope_context(self): - return graph_build_util.BlockScopeContext(self.prev_scope, self.scope) - - def __call__(self, *args): - assert self._type == BlockType.MODULE - # nn.Module.__call__ will call self.forward() - # so the scope is set in self.forward() - result = self._origin.__class__.__call__(self, *args) - return result - - def forward(self, *args): - assert self._type == BlockType.MODULE - self._is_executing_forward = True - # TODO(xuxiaoyu): only build scope in lazy mode - with self.scope_context(): - result = self._origin.__class__.forward(self, *args) - self._is_executing_forward = False - return result - - def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]: - assert self._type == BlockType.MODULE - if memo is None: - memo = set() - if self not in memo: - memo.add(self) - yield self - for name, module in self._modules.items(): - if module is None: - continue - for m in module.modules(memo): - yield m - - def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]: - assert self._type == BlockType.MODULE - memo = set() - modules = self.modules() if recurse else [self] - for module in modules: - members = get_members_fn(module) - for k, v in members: - if v is None or v in memo: - continue - memo.add(v) - yield v - - def parameters(self, recurse: bool = True) -> Iterator["Block"]: - assert self._type == BlockType.MODULE - gen = self._members(lambda module: module._parameters.items(), recurse=recurse) - for elem in gen: - yield elem - - def buffers(self, recurse: bool = True) -> Iterator["Block"]: - assert self._type == BlockType.MODULE - gen = self._members(lambda module: module._buffers.items(), recurse=recurse) - for elem in gen: - yield elem - - def __setattr__(self, name: str, value=None) -> None: - if value is None or not isinstance(value, Block): - self.__dict__[name] = value - else: - dicts_or_sets = ( - self.__dict__, - self._modules, - self._parameters, - self._buffers, - ) - for d in dicts_or_sets: - if name in d: - raise AttributeError( - "'{}' object has duplicated attribute named '{}'".format( - self._name, name - ) - ) - if value.type == BlockType.MODULE: - self._modules[name] = value - elif value.type == BlockType.PARAMETER: - self._parameters[name] = value - elif value.type == BlockType.BUFFER: - self._buffers[name] = value - else: - raise AttributeError( - "'{}' object are not allowed to set attribute named '{}'".format( - type(self).__name__, name - ) - ) - - def __getattr__(self, name: str): - if name in self.__dict__: - return self.__dict__[name] - - if self._type == BlockType.MODULE: - if "_modules" in self.__dict__: - modules = self.__dict__["_modules"] - if name in modules: - return modules[name] - if "_parameters" in self.__dict__: - _parameters = self.__dict__["_parameters"] - if name in _parameters: - p_block = _parameters[name] - if self._is_executing_forward: - # Return Tensor for running when getattr inside it's father Block's forward() - if graph_build_util.lazy_mode.is_enabled(): - if p_block._lazy_origin is None: - assert p_block._lazy_origin_lambda is not None, ( - repr(p_block) - + " has no lazy Tensor creation function." - ) - # Create and return lazy tensor - with p_block.scope_context(): - p_block._lazy_origin = p_block._lazy_origin_lambda() - return p_block._lazy_origin - else: - return p_block.origin - else: - # Return Block for config when getattr outside it's father Block's forward() - return p_block - if "_buffers" in self.__dict__: - _buffers = self.__dict__["_buffers"] - if name in _buffers: - b_block = _buffers[name] - if self._is_executing_forward: - # Return Tensor for running when getattr inside it's father Block's forward() - if graph_build_util.lazy_mode.is_enabled(): - if b_block._lazy_origin is None: - assert b_block._lazy_origin_lambda is not None, ( - repr(b_block) - + " has no lazy Tensor creation function." - ) - # Create and return lazy tensor - with b_block.scope_context(): - b_block._lazy_origin = b_block._lazy_origin_lambda() - return b_block._lazy_origin - else: - return b_block.origin - else: - # Return Block for config when getattr outside it's father Block's forward() - return b_block - if name in self._origin.__dict__: - return self._origin.__dict__[name] - - raise AttributeError( - "'{}' object has no attribute '{}'".format(type(self).__name__, name) - ) - - def __repr__(self): - lines = None - if self._type == BlockType.MODULE: - child_lines = [] - - def _append_child(d): - for _, n in d.items(): - n_str = repr(n) - n_str = _add_indent(n_str, 2) - child_lines.append(n_str) - - _append_child(self._modules) - _append_child(self._parameters) - _append_child(self._buffers) - if len(child_lines) > 0: - lines = child_lines - - main_str = ( - "(" - + self._name_prefix - + self._name - + ":" - + self._origin.__class__.__name__ - + ":" - + self._type - + "): (" - ) - if lines is not None: - main_str += "\n " + "\n ".join(lines) + "\n" - main_str += ")" - return main_str - - @oneflow_export("nn.graph.GraphConfig") @experimental_api class GraphConfig(FunctionConfig): @@ -486,56 +254,3 @@ class GraphConfig(FunctionConfig): self.function_desc.job_config_proto.mutable_train_conf() else: self.function_desc.job_config_proto.mutable_predict_conf() - - -@oneflow_export("nn.graph.BlockConfig") -@experimental_api -class BlockConfig(object): - def __init__(self): - self._stage_id = None - self._activation_checkpointing = None - - @property - def stage_id(self): - return self._stage_id - - @stage_id.setter - def stage_id(self, value: int = None): - self._stage_id = value - - @property - def activation_checkpointing(self): - return self._activation_checkpointing - - @activation_checkpointing.setter - def activation_checkpointing(self, value: bool = False): - self._activation_checkpointing = value - - -@oneflow_export("nn.graph.OptimizerConfig") -@experimental_api -class OptimizerConfig(object): - def __init__( - self, - name: str, - optimizer: Optimizer = None, - lr_scheduler=None, - grad_clipping_conf=None, - weight_decay_conf=None, - ): - self.name = name - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.grad_clipping_conf = grad_clipping_conf - self.weight_decay_conf = weight_decay_conf - - -def _add_indent(in_s, num_spaces): - s = in_s.split("\n") - if len(s) == 1: - return in_s - first = s.pop(0) - s = [(num_spaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s diff --git a/oneflow/python/nn/graph_block.py b/oneflow/python/nn/graph_block.py new file mode 100644 index 000000000..925b4c5dd --- /dev/null +++ b/oneflow/python/nn/graph_block.py @@ -0,0 +1,331 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import absolute_import +from collections import OrderedDict +from typing import Union, Optional, Iterator, Set + +import oneflow._oneflow_internal +import oneflow.python.framework.graph_build_util as graph_build_util +from oneflow.python.oneflow_export import oneflow_export, experimental_api +from oneflow.python.framework.tensor import Tensor +from oneflow.python.nn.module import Module +from oneflow.python.nn.parameter import Parameter +from oneflow.python.nn.utils import add_indent + + +class BlockType: + NONE = "NONE" + MODULE = "MODULE" + PARAMETER = "PARAMETER" + BUFFER = "BUFFER" + + +@oneflow_export("nn.graph.Block") +@experimental_api +class Block(object): + def __init__( + self, + prefix: str = "", + name: str = "", + value: Union[Module, Parameter, Tensor] = None, + ): + assert not isinstance(value, Block) + self._name = name + self._name_prefix = prefix + self._type = BlockType.NONE + self._origin = value + self.config = BlockConfig() + self._scope = None + self._prev_scope = None + + if isinstance(value, Module): + self._type = BlockType.MODULE + self._is_executing_forward = False + self._modules = OrderedDict() + self._parameters = OrderedDict() + self._buffers = OrderedDict() + for n, m in list(value.named_children()): + self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m)) + for n, p in list(value.named_parameters("", False)): + self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p)) + for n, b in list(value.named_buffers("", False)): + self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b)) + elif isinstance(value, Parameter): + self._type = BlockType.PARAMETER + self._lazy_origin = None + self._lazy_origin_builder = None + elif isinstance(value, Tensor): + self._type = BlockType.BUFFER + self._lazy_origin = None + self._lazy_origin_builder = None + else: + raise NotImplementedError() + + @property + def name(self): + return self._name + + @property + def name_prefix(self): + return self._name_prefix + + @property + def type(self): + return self._type + + @property + def origin(self): + return self._origin + + @property + def lazy_origin(self): + assert ( + self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER + ), "Only Parameter or Buffer Block has lazy_origin" + return self._lazy_origin + + def lazy_origin_builder(self): + assert ( + self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER + ), "Only Parameter or Buffer Block has lazy_origin_builder" + return self._lazy_origin_builder + + def set_lazy_origin_builder(self, fn=None): + assert ( + self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER + ), "Only Parameter or Buffer Block has lazy_origin_builder" + self._lazy_origin_builder = fn + + @property + def prev_scope(self): + if self._prev_scope is None: + self._prev_scope = oneflow._oneflow_internal.GetCurrentScope() + return self._prev_scope + + @property + def scope(self): + if self._scope is None: + self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self) + return self._scope + + def scope_context(self): + return graph_build_util.BlockScopeContext(self.prev_scope, self.scope) + + def __call__(self, *args): + assert self._type == BlockType.MODULE + # nn.Module.__call__ will call self.forward() + # so the scope is set in self.forward() + result = self._origin.__class__.__call__(self, *args) + return result + + def forward(self, *args): + assert self._type == BlockType.MODULE + self._is_executing_forward = True + # TODO(xuxiaoyu): only build scope in lazy mode + with self.scope_context(): + result = self._origin.__class__.forward(self, *args) + self._is_executing_forward = False + return result + + def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]: + assert self._type == BlockType.MODULE + if memo is None: + memo = set() + if self not in memo: + memo.add(self) + yield self + for name, module in self._modules.items(): + if module is None: + continue + for m in module.modules(memo): + yield m + + def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]: + assert self._type == BlockType.MODULE + memo = set() + modules = self.modules() if recurse else [self] + for module in modules: + members = get_members_fn(module) + for k, v in members: + if v is None or v in memo: + continue + memo.add(v) + yield v + + def parameters(self, recurse: bool = True) -> Iterator["Block"]: + assert self._type == BlockType.MODULE + gen = self._members(lambda module: module._parameters.items(), recurse=recurse) + for elem in gen: + yield elem + + def buffers(self, recurse: bool = True) -> Iterator["Block"]: + assert self._type == BlockType.MODULE + gen = self._members(lambda module: module._buffers.items(), recurse=recurse) + for elem in gen: + yield elem + + def __setattr__(self, name: str, value=None) -> None: + if value is None or not isinstance(value, Block): + self.__dict__[name] = value + else: + dicts_or_sets = ( + self.__dict__, + self._modules, + self._parameters, + self._buffers, + ) + for d in dicts_or_sets: + if name in d: + raise AttributeError( + "'{}' object has duplicated attribute named '{}'".format( + self._name, name + ) + ) + if value.type == BlockType.MODULE: + self._modules[name] = value + elif value.type == BlockType.PARAMETER: + self._parameters[name] = value + elif value.type == BlockType.BUFFER: + self._buffers[name] = value + else: + raise AttributeError( + "'{}' object are not allowed to set attribute named '{}'".format( + type(self).__name__, name + ) + ) + + def __getattr__(self, name: str): + if name in self.__dict__: + return self.__dict__[name] + + if self._type == BlockType.MODULE: + if "_modules" in self.__dict__: + modules = self.__dict__["_modules"] + if name in modules: + return modules[name] + if "_parameters" in self.__dict__: + _parameters = self.__dict__["_parameters"] + if name in _parameters: + p_block = _parameters[name] + if self._is_executing_forward: + # Return Tensor for running when getattr is + # inside it's father Block's forward() + if graph_build_util.lazy_mode.is_enabled(): + if p_block._lazy_origin is None: + assert p_block._lazy_origin_builder is not None, ( + repr(p_block) + + " has no lazy Tensor creation function." + ) + # Create and return lazy tensor + with p_block.scope_context(): + p_block._lazy_origin = ( + p_block._lazy_origin_builder() + ) + return p_block._lazy_origin + else: + return p_block.origin + else: + # Return Block for config when getattr is + # outside it's father Block's forward() + return p_block + if "_buffers" in self.__dict__: + _buffers = self.__dict__["_buffers"] + if name in _buffers: + b_block = _buffers[name] + if self._is_executing_forward: + # Return Tensor for running when getattr is + # inside it's father Block's forward() + if graph_build_util.lazy_mode.is_enabled(): + if b_block._lazy_origin is None: + assert b_block._lazy_origin_builder is not None, ( + repr(b_block) + + " has no lazy Tensor creation function." + ) + # Create and return lazy tensor + with b_block.scope_context(): + b_block._lazy_origin = ( + b_block._lazy_origin_builder() + ) + return b_block._lazy_origin + else: + return b_block.origin + else: + # Return Block for config when getattr is + # outside it's father Block's forward() + return b_block + if name in self._origin.__dict__: + return self._origin.__dict__[name] + + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, name) + ) + + def __repr__(self): + lines = None + if self._type == BlockType.MODULE: + child_lines = [] + + def _append_child(d): + for _, n in d.items(): + n_str = repr(n) + n_str = add_indent(n_str, 2) + child_lines.append(n_str) + + _append_child(self._modules) + _append_child(self._parameters) + _append_child(self._buffers) + if len(child_lines) > 0: + lines = child_lines + + main_str = ( + "(" + + self._name_prefix + + self._name + + ":" + + self._origin.__class__.__name__ + + ":" + + self._type + + "): (" + ) + if lines is not None: + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str + + +@oneflow_export("nn.graph.BlockConfig") +@experimental_api +class BlockConfig(object): + def __init__(self): + self._stage_id = None + self._activation_checkpointing = None + + @property + def stage_id(self): + return self._stage_id + + @stage_id.setter + def stage_id(self, value: int = None): + self._stage_id = value + + @property + def activation_checkpointing(self): + return self._activation_checkpointing + + @activation_checkpointing.setter + def activation_checkpointing(self, value: bool = False): + self._activation_checkpointing = value diff --git a/oneflow/python/nn/graph_optimizer.py b/oneflow/python/nn/graph_optimizer.py new file mode 100644 index 000000000..8502885e2 --- /dev/null +++ b/oneflow/python/nn/graph_optimizer.py @@ -0,0 +1,37 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from __future__ import absolute_import + +from oneflow.python.nn.optimizer.optimizer import Optimizer +from oneflow.python.oneflow_export import oneflow_export, experimental_api + + +@oneflow_export("nn.graph.OptimizerConfig") +@experimental_api +class OptimizerConfig(object): + def __init__( + self, + name: str, + optimizer: Optimizer = None, + lr_scheduler=None, + grad_clipping_conf=None, + weight_decay_conf=None, + ): + self.name = name + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.grad_clipping_conf = grad_clipping_conf + self.weight_decay_conf = weight_decay_conf diff --git a/oneflow/python/nn/utils.py b/oneflow/python/nn/utils.py new file mode 100644 index 000000000..a1f5cddbd --- /dev/null +++ b/oneflow/python/nn/utils.py @@ -0,0 +1,28 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from __future__ import absolute_import + + +def add_indent(in_s, num_spaces): + s = in_s.split("\n") + if len(s) == 1: + return in_s + first = s.pop(0) + s = [(num_spaces * " ") + line for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s diff --git a/oneflow/python/test/graph/test_forward_graph.py b/oneflow/python/test/graph/test_forward_graph.py new file mode 100644 index 000000000..552644e92 --- /dev/null +++ b/oneflow/python/test/graph/test_forward_graph.py @@ -0,0 +1,85 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import os + +import oneflow +import oneflow.experimental as flow + + +class SubModule(flow.nn.Module): + def __init__(self): + super().__init__() + self.weight = flow.nn.Parameter(flow.Tensor(6, 6)) + self.relu = flow.nn.ReLU() + + def forward(self, x, y): + x = oneflow.F.matmul(x, self.weight) + x = self.relu(x) + y = self.relu(y) + return x, y + + +class CustomModule(flow.nn.Module): + def __init__(self): + super().__init__() + self.layer = SubModule() + self.register_buffer( + "dummy_buff", flow.Tensor(6, 8), + ) + + def forward(self, x, y): + x, y = self.layer(x, y) + x = oneflow.F.flatten(x, 1) + x = oneflow.F.matmul(x, self.dummy_buff) + return x, y + + +@flow.unittest.skip_unless_1n1d() +class TestGraph(flow.unittest.TestCase): + def test_forward_graph(test_case): + class CustomGraph(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x, y): + out = self.m(x, y) + return out + + m = CustomModule() + m.to("cuda") + g = CustomGraph(m) + + x = flow.Tensor(6, 6) + flow.nn.init.uniform_(x, a=-1.0, b=1.0) + x = x.to("cuda") + + y = flow.Tensor(10, 10) + flow.nn.init.uniform_(y, a=-1.0, b=1.0) + y = y.to("cuda") + + print(repr(g)) + z, a = g._compile(x, y) + test_case.assertEqual(z.shape, (6, 8)) + test_case.assertEqual(z.is_lazy, False) + test_case.assertEqual(a.shape, (10, 10)) + test_case.assertEqual(a.is_lazy, False) + print("graph proto: ", g._graph_proto) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/graph/test_graph.py b/oneflow/python/test/graph/test_graph.py index 816ab199c..0a07ec04b 100644 --- a/oneflow/python/test/graph/test_graph.py +++ b/oneflow/python/test/graph/test_graph.py @@ -18,13 +18,6 @@ import os import numpy as np -# To enable MultiClient -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = "12139" -os.environ["WORLD_SIZE"] = "1" -os.environ["RANK"] = "0" -os.environ["LOCAL_RANK"] = "0" - import oneflow import oneflow.experimental as flow import oneflow.python.framework.graph_build_util as graph_build_util @@ -189,7 +182,6 @@ class TestGraph(flow.unittest.TestCase): def build(self): # check lazy mode in nn.Graph._compile test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True) - print("graph proto", self._graph_proto) # check session type import oneflow.python.framework.session_context as session_ctx @@ -217,6 +209,7 @@ class TestGraph(flow.unittest.TestCase): g = CustomGraph() test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False) g._compile() + print("graph proto", g._graph_proto) test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False) def test_block_scope(test_case): @@ -241,7 +234,6 @@ class TestGraph(flow.unittest.TestCase): # weight is not get in conv1's forward, so it will return a Block x = self.conv1.weight test_case.assertEqual(type(x), flow.nn.graph.Block) - return x class SubModule1(flow.nn.Module): def __init__(self): @@ -283,7 +275,6 @@ class TestGraph(flow.unittest.TestCase): test_case.assertEqual( dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id ) - return x class CustomModule1(flow.nn.Module): def __init__(self): @@ -294,7 +285,6 @@ class TestGraph(flow.unittest.TestCase): def forward(self): x = self.layer0() y = self.layer1() - return x, y m = CustomModule1() @@ -313,7 +303,7 @@ class TestGraph(flow.unittest.TestCase): g = CustomGraph1() x = flow.Tensor(1, 1, 10, 10) flow.nn.init.uniform_(x, a=-1.0, b=1.0) - z = g._compile() + g._compile() if __name__ == "__main__": diff --git a/oneflow/python/test/graph/test_input_op_expr.py b/oneflow/python/test/graph/test_input_op_expr.py index befbb0624..31ecc4dbf 100644 --- a/oneflow/python/test/graph/test_input_op_expr.py +++ b/oneflow/python/test/graph/test_input_op_expr.py @@ -18,12 +18,6 @@ import unittest import numpy as np import os -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = "12139" -os.environ["WORLD_SIZE"] = "1" -os.environ["RANK"] = "0" -os.environ["LOCAL_RANK"] = "0" - import oneflow import oneflow.experimental as flow import oneflow.python.framework.session_context as session_ctx diff --git a/oneflow/python/test/graph/test_multi_client_session.py b/oneflow/python/test/graph/test_multi_client_session.py index 7b8506358..f4493f1df 100644 --- a/oneflow/python/test/graph/test_multi_client_session.py +++ b/oneflow/python/test/graph/test_multi_client_session.py @@ -16,12 +16,6 @@ limitations under the License. import unittest import os -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = "12139" -os.environ["WORLD_SIZE"] = "1" -os.environ["RANK"] = "0" -os.environ["LOCAL_RANK"] = "0" - import oneflow import oneflow.experimental as flow import oneflow.python.framework.session_context as session_ctx diff --git a/oneflow/python/test/graph/test_output_op_expr.py b/oneflow/python/test/graph/test_output_op_expr.py index 5a20e57ec..0a29c15d5 100644 --- a/oneflow/python/test/graph/test_output_op_expr.py +++ b/oneflow/python/test/graph/test_output_op_expr.py @@ -18,12 +18,6 @@ import unittest import numpy as np import os -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = "12139" -os.environ["WORLD_SIZE"] = "1" -os.environ["RANK"] = "0" -os.environ["LOCAL_RANK"] = "0" - import oneflow import oneflow.experimental as flow import oneflow.python.framework.session_context as session_ctx diff --git a/oneflow/python/test/graph/test_user_op_expr.py b/oneflow/python/test/graph/test_user_op_expr.py index db486ad8f..451e0dcf1 100644 --- a/oneflow/python/test/graph/test_user_op_expr.py +++ b/oneflow/python/test/graph/test_user_op_expr.py @@ -18,12 +18,6 @@ import unittest import numpy as np import os -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = "12139" -os.environ["WORLD_SIZE"] = "1" -os.environ["RANK"] = "0" -os.environ["LOCAL_RANK"] = "0" - import oneflow import oneflow.experimental as flow import oneflow.python.framework.session_context as session_ctx diff --git a/oneflow/python/test/graph/test_variable_op_expr.py b/oneflow/python/test/graph/test_variable_op_expr.py index 8e9fcd9e6..cebd29ed4 100644 --- a/oneflow/python/test/graph/test_variable_op_expr.py +++ b/oneflow/python/test/graph/test_variable_op_expr.py @@ -18,12 +18,6 @@ import unittest import numpy as np import os -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = "12139" -os.environ["WORLD_SIZE"] = "1" -os.environ["RANK"] = "0" -os.environ["LOCAL_RANK"] = "0" - import oneflow import oneflow.experimental as flow import oneflow.python.framework.session_context as session_ctx -- GitLab