From ab8aab8ba330e499b1f4ea409b8bd95245a07fab Mon Sep 17 00:00:00 2001
From: Xiaoyu Xu <xiaoyulink@gmail.com>
Date: Wed, 21 Jul 2021 23:12:25 +0800
Subject: [PATCH] Fea/nn graph/forward graph (#5516)

* add test on add input to graph

* add var into graph

* LazyInterpreter for FetchOutputOpExpr and set op parallel_distribution

* refine input var build

* split file

* rename

* mini refine

* Add note

* LazyInterpret::ApplyImpl for UserOpExpr

* refine test scripts

* add output to graph

* format

* Fix bug of LazyInterpret UserOpExpr for change output lbns

* Add test user op expr test

* fix note mistake

* add userop and test

* address review

* address review

* save i/o/s op_name and tensor for c_nn_graph

* address review

* adjust test

* refine build_graph_state

Co-authored-by: chengtbf <472491134@qq.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/python/framework/graph_build_util.py  |  61 ++-
 oneflow/python/nn/graph.py                    | 395 +++---------------
 oneflow/python/nn/graph_block.py              | 331 +++++++++++++++
 oneflow/python/nn/graph_optimizer.py          |  37 ++
 oneflow/python/nn/utils.py                    |  28 ++
 .../python/test/graph/test_forward_graph.py   |  85 ++++
 oneflow/python/test/graph/test_graph.py       |  14 +-
 .../python/test/graph/test_input_op_expr.py   |   6 -
 .../test/graph/test_multi_client_session.py   |   6 -
 .../python/test/graph/test_output_op_expr.py  |   6 -
 .../python/test/graph/test_user_op_expr.py    |   6 -
 .../test/graph/test_variable_op_expr.py       |   6 -
 12 files changed, 598 insertions(+), 383 deletions(-)
 create mode 100644 oneflow/python/nn/graph_block.py
 create mode 100644 oneflow/python/nn/graph_optimizer.py
 create mode 100644 oneflow/python/nn/utils.py
 create mode 100644 oneflow/python/test/graph/test_forward_graph.py

diff --git a/oneflow/python/framework/graph_build_util.py b/oneflow/python/framework/graph_build_util.py
index 25977a2b7..718e00cd5 100644
--- a/oneflow/python/framework/graph_build_util.py
+++ b/oneflow/python/framework/graph_build_util.py
@@ -23,10 +23,11 @@ import oneflow.core.job.scope_pb2 as scope_pb2_util
 import oneflow.python.framework.attr_util as attr_util
 import oneflow.python.framework.c_api_util as c_api_util
 import oneflow.python.framework.placement_util as placement_util
-import oneflow.python.framework.runtime_mode as runtime_mode
 import oneflow.python.framework.scope_util as scope_util
 import oneflow.python.framework.session_context as session_context
 import oneflow._oneflow_internal
+from oneflow._oneflow_internal import Tensor as InternalTensor
+from oneflow.python.framework.tensor import Tensor
 
 
 lazy_mode = oneflow._oneflow_internal.lazy_mode
@@ -125,3 +126,61 @@ def make_new_block_scope(prev_scope, block):
 
 def scope_to_proto(scope):
     return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto())
+
+
+def build_graph_input_arg(op_name, arg):
+    assert isinstance(arg, (Tensor, InternalTensor))
+    input_conf = (
+        oneflow._oneflow_internal.oneflow.core.operator.op_conf.FeedInputOpConf()
+    )
+
+    input_op = oneflow._oneflow_internal.one.FeedInputOpExpr(
+        op_name, input_conf, ["in_0"], ["out_0"]
+    )
+    attrs = oneflow._oneflow_internal.MutableCfgAttrMap()
+
+    if isinstance(arg, Tensor):
+        if not arg.is_determined:
+            arg.determine()
+        tensor_in_c = arg._local_or_consistent_tensor
+    else:
+        tensor_in_c = arg
+
+    lazy_arg = input_op.apply([tensor_in_c], attrs)[0]
+    return lazy_arg
+
+
+def build_graph_state(op_name, state_tensor):
+    var_conf = (
+        oneflow._oneflow_internal.oneflow.core.operator.op_conf.FeedVariableOpConf()
+    )
+
+    var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr(
+        op_name, var_conf, ["in_0"], ["out_0"]
+    )
+    attrs = oneflow._oneflow_internal.MutableCfgAttrMap()
+
+    assert isinstance(state_tensor, Tensor)
+    if not state_tensor.is_determined:
+        state_tensor.determine()
+    tensor_in_c = state_tensor._local_or_consistent_tensor
+
+    lazy_tensor = var_op.apply([tensor_in_c], attrs)[0]
+    return lazy_tensor
+
+
+def build_graph_output(op_name, out):
+    assert isinstance(out, InternalTensor)
+    assert out.is_lazy
+
+    output_conf = (
+        oneflow._oneflow_internal.oneflow.core.operator.op_conf.FetchOutputOpConf()
+    )
+
+    output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr(
+        op_name, output_conf, ["in_0"], ["out_0"]
+    )
+    attrs = oneflow._oneflow_internal.MutableCfgAttrMap()
+
+    eager_out = output_op.apply([out], attrs)[0]
+    return eager_out
diff --git a/oneflow/python/nn/graph.py b/oneflow/python/nn/graph.py
index 72f23aa9a..1f34da4e4 100644
--- a/oneflow/python/nn/graph.py
+++ b/oneflow/python/nn/graph.py
@@ -15,19 +15,21 @@ limitations under the License.
 """
 from __future__ import absolute_import
 from collections import OrderedDict
-from typing import Union, Optional, Iterator, Set
+from functools import partial
 
 import oneflow._oneflow_internal
 import oneflow.python.framework.c_api_util as c_api_util
 import oneflow.python.framework.graph_build_util as graph_build_util
 import oneflow.python.framework.session_context as session_ctx
 import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util
+from oneflow._oneflow_internal import Tensor as InternalTensor
 from oneflow.python.oneflow_export import oneflow_export, experimental_api
 from oneflow.python.framework.multi_client_session import MultiClientSession
-from oneflow.python.framework.tensor import Tensor
+from oneflow.python.nn.graph_block import Block
+from oneflow.python.nn.graph_optimizer import OptimizerConfig
 from oneflow.python.nn.module import Module
-from oneflow.python.nn.parameter import Parameter
 from oneflow.python.nn.optimizer.optimizer import Optimizer
+from oneflow.python.nn.utils import add_indent
 from oneflow.python.framework.function_util import FunctionConfig
 
 
@@ -45,6 +47,7 @@ class Graph(object):
         self._optimizers = OrderedDict()
         self._is_compiled = False
         self._state_tensortuple = None
+        self._job_proto = None
 
     @property
     def name(self):
@@ -56,7 +59,7 @@ class Graph(object):
 
     @property
     def _graph_proto(self):
-        return c_api_util.GetCurrentJob()
+        return self._job_proto
 
     def build(self, *args):
         raise NotImplementedError()
@@ -69,7 +72,7 @@ class Graph(object):
         grad_clipping_conf=None,
         weight_decay_conf=None,
     ):
-        self._optimizers[name] = self.OptimizerConfig(
+        self._optimizers[name] = OptimizerConfig(
             optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf
         )
 
@@ -102,19 +105,57 @@ class Graph(object):
         session.TryInit()
 
         with graph_build_util.graph_build_context(self.config.proto, session):
+            # Deal with input
+            lazy_args = []
+            lazy_arg_op_names = []
+            for idx, arg in enumerate(args):
+                op_name = "_" + self.name + "-input_" + str(idx)
+                lazy_args.append(graph_build_util.build_graph_input_arg(op_name, arg))
+                lazy_arg_op_names.append(op_name)
+
             # Deal with parameter and buffer
-            for s in self._state():
+            state_op_names = []
+            state_tensors = []
+            for state_block in self._state():
+                op_name = state_block.name_prefix + state_block.name
+                state_tensor = state_block.origin
+                state_op_names.append(op_name)
+                state_tensors.append(state_tensor)
+                state_block.set_lazy_origin_builder(
+                    partial(graph_build_util.build_graph_state, op_name, state_tensor)
+                )
 
-                def to_lazy():
-                    # TODO(): Replace repr(s) with OpExpr(s.origin)
-                    lazy_tensor = repr(s)
-                    return lazy_tensor
+            # Deal with module in self.build(*args)
+            outputs = self.build(*lazy_args)
+
+            # Deal with outputs
+            if not (type(outputs) is tuple or type(outputs) is list):
+                if outputs is None:
+                    outputs = ()
+                else:
+                    assert type(outputs) is InternalTensor
+                    outputs = (outputs,)
+            eager_outputs = []
+            eager_output_op_names = []
+            for idx, out in enumerate(outputs):
+                op_name = "_" + self.name + "-output_" + str(idx)
+                eager_outputs.append(graph_build_util.build_graph_output(op_name, out))
+                eager_output_op_names.append(op_name)
+            if len(eager_outputs) == 0:
+                eager_outputs = None
+            elif len(eager_outputs) == 1:
+                eager_outputs = eager_outputs[0]
+            else:
+                eager_outputs = tuple(eager_outputs)
 
-                s.set_lazy_origin_lambda(to_lazy)
-            outputs = self.build(*args)
+            # TODO(): call self._c_nn_graph
+            #     register lazy_arg_op_names/state_op_names/state_tensors/eager_output_op_names
+
+            # Save job proto for debug
+            self._job_proto = c_api_util.GetCurrentJob()
 
         self._is_compiled = True
-        return outputs
+        return eager_outputs
 
     def _launch(self):
         # TODO(xuxiaoyu)
@@ -122,7 +163,6 @@ class Graph(object):
         ...
 
     def __call__(self, *args):
-        # TODO(xuxiaoyu)
         # if not self._is_compiled:
         #     self._compile()
         # return self._launch()
@@ -179,7 +219,7 @@ class Graph(object):
             child_lines = []
             for n, m in self._blocks.items():
                 mod_str = repr(m)
-                mod_str = _add_indent(mod_str, 2)
+                mod_str = add_indent(mod_str, 2)
                 child_lines.append(mod_str)
             lines = child_lines
 
@@ -190,278 +230,6 @@ class Graph(object):
         return main_str
 
 
-class BlockType:
-    NONE = "NONE"
-    MODULE = "MODULE"
-    PARAMETER = "PARAMETER"
-    BUFFER = "BUFFER"
-
-
-@oneflow_export("nn.graph.Block")
-@experimental_api
-class Block(object):
-    def __init__(
-        self,
-        prefix: str = "",
-        name: str = "",
-        value: Union[Module, Parameter, Tensor] = None,
-    ):
-        assert not isinstance(value, Block)
-        self._name = name
-        self._name_prefix = prefix
-        self._type = BlockType.NONE
-        self._origin = value
-        self.config = BlockConfig()
-        self._scope = None
-        self._prev_scope = None
-
-        if isinstance(value, Module):
-            self._type = BlockType.MODULE
-            self._is_executing_forward = False
-            self._modules = OrderedDict()
-            self._parameters = OrderedDict()
-            self._buffers = OrderedDict()
-            for n, m in list(value.named_children()):
-                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m))
-            for n, p in list(value.named_parameters("", False)):
-                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p))
-            for n, b in list(value.named_buffers("", False)):
-                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
-        elif isinstance(value, Parameter):
-            self._type = BlockType.PARAMETER
-            self._lazy_origin = None
-            self._lazy_origin_lambda = None
-        elif isinstance(value, Tensor):
-            self._type = BlockType.BUFFER
-            self._lazy_origin = None
-            self._lazy_origin_lambda = None
-        else:
-            raise NotImplementedError()
-
-    @property
-    def name(self):
-        return self._name
-
-    @property
-    def name_prefix(self):
-        return self._name_prefix
-
-    @property
-    def type(self):
-        return self._type
-
-    @property
-    def origin(self):
-        return self._origin
-
-    @property
-    def lazy_origin(self):
-        assert (
-            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
-        ), "Only Parameter or Buffer Block has lazy_origin"
-        return self._lazy_origin
-
-    def lazy_origin_lambda(self):
-        assert (
-            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
-        ), "Only Parameter or Buffer Block has lazy_origin_lambda"
-        return self._lazy_origin_lambda
-
-    def set_lazy_origin_lambda(self, fn=None):
-        assert (
-            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
-        ), "Only Parameter or Buffer Block has lazy_origin_lambda"
-        self._lazy_origin_lambda = fn
-
-    @property
-    def prev_scope(self):
-        if self._prev_scope is None:
-            self._prev_scope = oneflow._oneflow_internal.GetCurrentScope()
-        return self._prev_scope
-
-    @property
-    def scope(self):
-        if self._scope is None:
-            self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self)
-        return self._scope
-
-    def scope_context(self):
-        return graph_build_util.BlockScopeContext(self.prev_scope, self.scope)
-
-    def __call__(self, *args):
-        assert self._type == BlockType.MODULE
-        # nn.Module.__call__ will call self.forward()
-        # so the scope is set in self.forward()
-        result = self._origin.__class__.__call__(self, *args)
-        return result
-
-    def forward(self, *args):
-        assert self._type == BlockType.MODULE
-        self._is_executing_forward = True
-        # TODO(xuxiaoyu): only build scope in lazy mode
-        with self.scope_context():
-            result = self._origin.__class__.forward(self, *args)
-        self._is_executing_forward = False
-        return result
-
-    def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]:
-        assert self._type == BlockType.MODULE
-        if memo is None:
-            memo = set()
-        if self not in memo:
-            memo.add(self)
-            yield self
-            for name, module in self._modules.items():
-                if module is None:
-                    continue
-                for m in module.modules(memo):
-                    yield m
-
-    def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
-        assert self._type == BlockType.MODULE
-        memo = set()
-        modules = self.modules() if recurse else [self]
-        for module in modules:
-            members = get_members_fn(module)
-            for k, v in members:
-                if v is None or v in memo:
-                    continue
-                memo.add(v)
-                yield v
-
-    def parameters(self, recurse: bool = True) -> Iterator["Block"]:
-        assert self._type == BlockType.MODULE
-        gen = self._members(lambda module: module._parameters.items(), recurse=recurse)
-        for elem in gen:
-            yield elem
-
-    def buffers(self, recurse: bool = True) -> Iterator["Block"]:
-        assert self._type == BlockType.MODULE
-        gen = self._members(lambda module: module._buffers.items(), recurse=recurse)
-        for elem in gen:
-            yield elem
-
-    def __setattr__(self, name: str, value=None) -> None:
-        if value is None or not isinstance(value, Block):
-            self.__dict__[name] = value
-        else:
-            dicts_or_sets = (
-                self.__dict__,
-                self._modules,
-                self._parameters,
-                self._buffers,
-            )
-            for d in dicts_or_sets:
-                if name in d:
-                    raise AttributeError(
-                        "'{}' object has duplicated attribute named '{}'".format(
-                            self._name, name
-                        )
-                    )
-            if value.type == BlockType.MODULE:
-                self._modules[name] = value
-            elif value.type == BlockType.PARAMETER:
-                self._parameters[name] = value
-            elif value.type == BlockType.BUFFER:
-                self._buffers[name] = value
-            else:
-                raise AttributeError(
-                    "'{}' object are not allowed to set attribute named '{}'".format(
-                        type(self).__name__, name
-                    )
-                )
-
-    def __getattr__(self, name: str):
-        if name in self.__dict__:
-            return self.__dict__[name]
-
-        if self._type == BlockType.MODULE:
-            if "_modules" in self.__dict__:
-                modules = self.__dict__["_modules"]
-                if name in modules:
-                    return modules[name]
-            if "_parameters" in self.__dict__:
-                _parameters = self.__dict__["_parameters"]
-                if name in _parameters:
-                    p_block = _parameters[name]
-                    if self._is_executing_forward:
-                        # Return Tensor for running when getattr inside it's father Block's forward()
-                        if graph_build_util.lazy_mode.is_enabled():
-                            if p_block._lazy_origin is None:
-                                assert p_block._lazy_origin_lambda is not None, (
-                                    repr(p_block)
-                                    + " has no lazy Tensor creation function."
-                                )
-                                # Create and return lazy tensor
-                                with p_block.scope_context():
-                                    p_block._lazy_origin = p_block._lazy_origin_lambda()
-                            return p_block._lazy_origin
-                        else:
-                            return p_block.origin
-                    else:
-                        # Return Block for config when getattr outside it's father Block's forward()
-                        return p_block
-            if "_buffers" in self.__dict__:
-                _buffers = self.__dict__["_buffers"]
-                if name in _buffers:
-                    b_block = _buffers[name]
-                    if self._is_executing_forward:
-                        # Return Tensor for running when getattr inside it's father Block's forward()
-                        if graph_build_util.lazy_mode.is_enabled():
-                            if b_block._lazy_origin is None:
-                                assert b_block._lazy_origin_lambda is not None, (
-                                    repr(b_block)
-                                    + " has no lazy Tensor creation function."
-                                )
-                                # Create and return lazy tensor
-                                with b_block.scope_context():
-                                    b_block._lazy_origin = b_block._lazy_origin_lambda()
-                            return b_block._lazy_origin
-                        else:
-                            return b_block.origin
-                    else:
-                        # Return Block for config when getattr outside it's father Block's forward()
-                        return b_block
-            if name in self._origin.__dict__:
-                return self._origin.__dict__[name]
-
-        raise AttributeError(
-            "'{}' object has no attribute '{}'".format(type(self).__name__, name)
-        )
-
-    def __repr__(self):
-        lines = None
-        if self._type == BlockType.MODULE:
-            child_lines = []
-
-            def _append_child(d):
-                for _, n in d.items():
-                    n_str = repr(n)
-                    n_str = _add_indent(n_str, 2)
-                    child_lines.append(n_str)
-
-            _append_child(self._modules)
-            _append_child(self._parameters)
-            _append_child(self._buffers)
-            if len(child_lines) > 0:
-                lines = child_lines
-
-        main_str = (
-            "("
-            + self._name_prefix
-            + self._name
-            + ":"
-            + self._origin.__class__.__name__
-            + ":"
-            + self._type
-            + "): ("
-        )
-        if lines is not None:
-            main_str += "\n  " + "\n  ".join(lines) + "\n"
-        main_str += ")"
-        return main_str
-
-
 @oneflow_export("nn.graph.GraphConfig")
 @experimental_api
 class GraphConfig(FunctionConfig):
@@ -486,56 +254,3 @@ class GraphConfig(FunctionConfig):
             self.function_desc.job_config_proto.mutable_train_conf()
         else:
             self.function_desc.job_config_proto.mutable_predict_conf()
-
-
-@oneflow_export("nn.graph.BlockConfig")
-@experimental_api
-class BlockConfig(object):
-    def __init__(self):
-        self._stage_id = None
-        self._activation_checkpointing = None
-
-    @property
-    def stage_id(self):
-        return self._stage_id
-
-    @stage_id.setter
-    def stage_id(self, value: int = None):
-        self._stage_id = value
-
-    @property
-    def activation_checkpointing(self):
-        return self._activation_checkpointing
-
-    @activation_checkpointing.setter
-    def activation_checkpointing(self, value: bool = False):
-        self._activation_checkpointing = value
-
-
-@oneflow_export("nn.graph.OptimizerConfig")
-@experimental_api
-class OptimizerConfig(object):
-    def __init__(
-        self,
-        name: str,
-        optimizer: Optimizer = None,
-        lr_scheduler=None,
-        grad_clipping_conf=None,
-        weight_decay_conf=None,
-    ):
-        self.name = name
-        self.optimizer = optimizer
-        self.lr_scheduler = lr_scheduler
-        self.grad_clipping_conf = grad_clipping_conf
-        self.weight_decay_conf = weight_decay_conf
-
-
-def _add_indent(in_s, num_spaces):
-    s = in_s.split("\n")
-    if len(s) == 1:
-        return in_s
-    first = s.pop(0)
-    s = [(num_spaces * " ") + line for line in s]
-    s = "\n".join(s)
-    s = first + "\n" + s
-    return s
diff --git a/oneflow/python/nn/graph_block.py b/oneflow/python/nn/graph_block.py
new file mode 100644
index 000000000..925b4c5dd
--- /dev/null
+++ b/oneflow/python/nn/graph_block.py
@@ -0,0 +1,331 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+from __future__ import absolute_import
+from collections import OrderedDict
+from typing import Union, Optional, Iterator, Set
+
+import oneflow._oneflow_internal
+import oneflow.python.framework.graph_build_util as graph_build_util
+from oneflow.python.oneflow_export import oneflow_export, experimental_api
+from oneflow.python.framework.tensor import Tensor
+from oneflow.python.nn.module import Module
+from oneflow.python.nn.parameter import Parameter
+from oneflow.python.nn.utils import add_indent
+
+
+class BlockType:
+    NONE = "NONE"
+    MODULE = "MODULE"
+    PARAMETER = "PARAMETER"
+    BUFFER = "BUFFER"
+
+
+@oneflow_export("nn.graph.Block")
+@experimental_api
+class Block(object):
+    def __init__(
+        self,
+        prefix: str = "",
+        name: str = "",
+        value: Union[Module, Parameter, Tensor] = None,
+    ):
+        assert not isinstance(value, Block)
+        self._name = name
+        self._name_prefix = prefix
+        self._type = BlockType.NONE
+        self._origin = value
+        self.config = BlockConfig()
+        self._scope = None
+        self._prev_scope = None
+
+        if isinstance(value, Module):
+            self._type = BlockType.MODULE
+            self._is_executing_forward = False
+            self._modules = OrderedDict()
+            self._parameters = OrderedDict()
+            self._buffers = OrderedDict()
+            for n, m in list(value.named_children()):
+                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m))
+            for n, p in list(value.named_parameters("", False)):
+                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p))
+            for n, b in list(value.named_buffers("", False)):
+                self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
+        elif isinstance(value, Parameter):
+            self._type = BlockType.PARAMETER
+            self._lazy_origin = None
+            self._lazy_origin_builder = None
+        elif isinstance(value, Tensor):
+            self._type = BlockType.BUFFER
+            self._lazy_origin = None
+            self._lazy_origin_builder = None
+        else:
+            raise NotImplementedError()
+
+    @property
+    def name(self):
+        return self._name
+
+    @property
+    def name_prefix(self):
+        return self._name_prefix
+
+    @property
+    def type(self):
+        return self._type
+
+    @property
+    def origin(self):
+        return self._origin
+
+    @property
+    def lazy_origin(self):
+        assert (
+            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
+        ), "Only Parameter or Buffer Block has lazy_origin"
+        return self._lazy_origin
+
+    def lazy_origin_builder(self):
+        assert (
+            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
+        ), "Only Parameter or Buffer Block has lazy_origin_builder"
+        return self._lazy_origin_builder
+
+    def set_lazy_origin_builder(self, fn=None):
+        assert (
+            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
+        ), "Only Parameter or Buffer Block has lazy_origin_builder"
+        self._lazy_origin_builder = fn
+
+    @property
+    def prev_scope(self):
+        if self._prev_scope is None:
+            self._prev_scope = oneflow._oneflow_internal.GetCurrentScope()
+        return self._prev_scope
+
+    @property
+    def scope(self):
+        if self._scope is None:
+            self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self)
+        return self._scope
+
+    def scope_context(self):
+        return graph_build_util.BlockScopeContext(self.prev_scope, self.scope)
+
+    def __call__(self, *args):
+        assert self._type == BlockType.MODULE
+        # nn.Module.__call__ will call self.forward()
+        # so the scope is set in self.forward()
+        result = self._origin.__class__.__call__(self, *args)
+        return result
+
+    def forward(self, *args):
+        assert self._type == BlockType.MODULE
+        self._is_executing_forward = True
+        # TODO(xuxiaoyu): only build scope in lazy mode
+        with self.scope_context():
+            result = self._origin.__class__.forward(self, *args)
+        self._is_executing_forward = False
+        return result
+
+    def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]:
+        assert self._type == BlockType.MODULE
+        if memo is None:
+            memo = set()
+        if self not in memo:
+            memo.add(self)
+            yield self
+            for name, module in self._modules.items():
+                if module is None:
+                    continue
+                for m in module.modules(memo):
+                    yield m
+
+    def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
+        assert self._type == BlockType.MODULE
+        memo = set()
+        modules = self.modules() if recurse else [self]
+        for module in modules:
+            members = get_members_fn(module)
+            for k, v in members:
+                if v is None or v in memo:
+                    continue
+                memo.add(v)
+                yield v
+
+    def parameters(self, recurse: bool = True) -> Iterator["Block"]:
+        assert self._type == BlockType.MODULE
+        gen = self._members(lambda module: module._parameters.items(), recurse=recurse)
+        for elem in gen:
+            yield elem
+
+    def buffers(self, recurse: bool = True) -> Iterator["Block"]:
+        assert self._type == BlockType.MODULE
+        gen = self._members(lambda module: module._buffers.items(), recurse=recurse)
+        for elem in gen:
+            yield elem
+
+    def __setattr__(self, name: str, value=None) -> None:
+        if value is None or not isinstance(value, Block):
+            self.__dict__[name] = value
+        else:
+            dicts_or_sets = (
+                self.__dict__,
+                self._modules,
+                self._parameters,
+                self._buffers,
+            )
+            for d in dicts_or_sets:
+                if name in d:
+                    raise AttributeError(
+                        "'{}' object has duplicated attribute named '{}'".format(
+                            self._name, name
+                        )
+                    )
+            if value.type == BlockType.MODULE:
+                self._modules[name] = value
+            elif value.type == BlockType.PARAMETER:
+                self._parameters[name] = value
+            elif value.type == BlockType.BUFFER:
+                self._buffers[name] = value
+            else:
+                raise AttributeError(
+                    "'{}' object are not allowed to set attribute named '{}'".format(
+                        type(self).__name__, name
+                    )
+                )
+
+    def __getattr__(self, name: str):
+        if name in self.__dict__:
+            return self.__dict__[name]
+
+        if self._type == BlockType.MODULE:
+            if "_modules" in self.__dict__:
+                modules = self.__dict__["_modules"]
+                if name in modules:
+                    return modules[name]
+            if "_parameters" in self.__dict__:
+                _parameters = self.__dict__["_parameters"]
+                if name in _parameters:
+                    p_block = _parameters[name]
+                    if self._is_executing_forward:
+                        # Return Tensor for running when getattr is
+                        # inside it's father Block's forward()
+                        if graph_build_util.lazy_mode.is_enabled():
+                            if p_block._lazy_origin is None:
+                                assert p_block._lazy_origin_builder is not None, (
+                                    repr(p_block)
+                                    + " has no lazy Tensor creation function."
+                                )
+                                # Create and return lazy tensor
+                                with p_block.scope_context():
+                                    p_block._lazy_origin = (
+                                        p_block._lazy_origin_builder()
+                                    )
+                            return p_block._lazy_origin
+                        else:
+                            return p_block.origin
+                    else:
+                        # Return Block for config when getattr is
+                        # outside it's father Block's forward()
+                        return p_block
+            if "_buffers" in self.__dict__:
+                _buffers = self.__dict__["_buffers"]
+                if name in _buffers:
+                    b_block = _buffers[name]
+                    if self._is_executing_forward:
+                        # Return Tensor for running when getattr is
+                        # inside it's father Block's forward()
+                        if graph_build_util.lazy_mode.is_enabled():
+                            if b_block._lazy_origin is None:
+                                assert b_block._lazy_origin_builder is not None, (
+                                    repr(b_block)
+                                    + " has no lazy Tensor creation function."
+                                )
+                                # Create and return lazy tensor
+                                with b_block.scope_context():
+                                    b_block._lazy_origin = (
+                                        b_block._lazy_origin_builder()
+                                    )
+                            return b_block._lazy_origin
+                        else:
+                            return b_block.origin
+                    else:
+                        # Return Block for config when getattr is
+                        # outside it's father Block's forward()
+                        return b_block
+            if name in self._origin.__dict__:
+                return self._origin.__dict__[name]
+
+        raise AttributeError(
+            "'{}' object has no attribute '{}'".format(type(self).__name__, name)
+        )
+
+    def __repr__(self):
+        lines = None
+        if self._type == BlockType.MODULE:
+            child_lines = []
+
+            def _append_child(d):
+                for _, n in d.items():
+                    n_str = repr(n)
+                    n_str = add_indent(n_str, 2)
+                    child_lines.append(n_str)
+
+            _append_child(self._modules)
+            _append_child(self._parameters)
+            _append_child(self._buffers)
+            if len(child_lines) > 0:
+                lines = child_lines
+
+        main_str = (
+            "("
+            + self._name_prefix
+            + self._name
+            + ":"
+            + self._origin.__class__.__name__
+            + ":"
+            + self._type
+            + "): ("
+        )
+        if lines is not None:
+            main_str += "\n  " + "\n  ".join(lines) + "\n"
+        main_str += ")"
+        return main_str
+
+
+@oneflow_export("nn.graph.BlockConfig")
+@experimental_api
+class BlockConfig(object):
+    def __init__(self):
+        self._stage_id = None
+        self._activation_checkpointing = None
+
+    @property
+    def stage_id(self):
+        return self._stage_id
+
+    @stage_id.setter
+    def stage_id(self, value: int = None):
+        self._stage_id = value
+
+    @property
+    def activation_checkpointing(self):
+        return self._activation_checkpointing
+
+    @activation_checkpointing.setter
+    def activation_checkpointing(self, value: bool = False):
+        self._activation_checkpointing = value
diff --git a/oneflow/python/nn/graph_optimizer.py b/oneflow/python/nn/graph_optimizer.py
new file mode 100644
index 000000000..8502885e2
--- /dev/null
+++ b/oneflow/python/nn/graph_optimizer.py
@@ -0,0 +1,37 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from __future__ import absolute_import
+
+from oneflow.python.nn.optimizer.optimizer import Optimizer
+from oneflow.python.oneflow_export import oneflow_export, experimental_api
+
+
+@oneflow_export("nn.graph.OptimizerConfig")
+@experimental_api
+class OptimizerConfig(object):
+    def __init__(
+        self,
+        name: str,
+        optimizer: Optimizer = None,
+        lr_scheduler=None,
+        grad_clipping_conf=None,
+        weight_decay_conf=None,
+    ):
+        self.name = name
+        self.optimizer = optimizer
+        self.lr_scheduler = lr_scheduler
+        self.grad_clipping_conf = grad_clipping_conf
+        self.weight_decay_conf = weight_decay_conf
diff --git a/oneflow/python/nn/utils.py b/oneflow/python/nn/utils.py
new file mode 100644
index 000000000..a1f5cddbd
--- /dev/null
+++ b/oneflow/python/nn/utils.py
@@ -0,0 +1,28 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+from __future__ import absolute_import
+
+
+def add_indent(in_s, num_spaces):
+    s = in_s.split("\n")
+    if len(s) == 1:
+        return in_s
+    first = s.pop(0)
+    s = [(num_spaces * " ") + line for line in s]
+    s = "\n".join(s)
+    s = first + "\n" + s
+    return s
diff --git a/oneflow/python/test/graph/test_forward_graph.py b/oneflow/python/test/graph/test_forward_graph.py
new file mode 100644
index 000000000..552644e92
--- /dev/null
+++ b/oneflow/python/test/graph/test_forward_graph.py
@@ -0,0 +1,85 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+import os
+
+import oneflow
+import oneflow.experimental as flow
+
+
+class SubModule(flow.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.weight = flow.nn.Parameter(flow.Tensor(6, 6))
+        self.relu = flow.nn.ReLU()
+
+    def forward(self, x, y):
+        x = oneflow.F.matmul(x, self.weight)
+        x = self.relu(x)
+        y = self.relu(y)
+        return x, y
+
+
+class CustomModule(flow.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.layer = SubModule()
+        self.register_buffer(
+            "dummy_buff", flow.Tensor(6, 8),
+        )
+
+    def forward(self, x, y):
+        x, y = self.layer(x, y)
+        x = oneflow.F.flatten(x, 1)
+        x = oneflow.F.matmul(x, self.dummy_buff)
+        return x, y
+
+
+@flow.unittest.skip_unless_1n1d()
+class TestGraph(flow.unittest.TestCase):
+    def test_forward_graph(test_case):
+        class CustomGraph(flow.nn.Graph):
+            def __init__(self, module):
+                super().__init__()
+                self.m = module
+
+            def build(self, x, y):
+                out = self.m(x, y)
+                return out
+
+        m = CustomModule()
+        m.to("cuda")
+        g = CustomGraph(m)
+
+        x = flow.Tensor(6, 6)
+        flow.nn.init.uniform_(x, a=-1.0, b=1.0)
+        x = x.to("cuda")
+
+        y = flow.Tensor(10, 10)
+        flow.nn.init.uniform_(y, a=-1.0, b=1.0)
+        y = y.to("cuda")
+
+        print(repr(g))
+        z, a = g._compile(x, y)
+        test_case.assertEqual(z.shape, (6, 8))
+        test_case.assertEqual(z.is_lazy, False)
+        test_case.assertEqual(a.shape, (10, 10))
+        test_case.assertEqual(a.is_lazy, False)
+        print("graph proto: ", g._graph_proto)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/python/test/graph/test_graph.py b/oneflow/python/test/graph/test_graph.py
index 816ab199c..0a07ec04b 100644
--- a/oneflow/python/test/graph/test_graph.py
+++ b/oneflow/python/test/graph/test_graph.py
@@ -18,13 +18,6 @@ import os
 
 import numpy as np
 
-# To enable MultiClient
-os.environ["MASTER_ADDR"] = "127.0.0.1"
-os.environ["MASTER_PORT"] = "12139"
-os.environ["WORLD_SIZE"] = "1"
-os.environ["RANK"] = "0"
-os.environ["LOCAL_RANK"] = "0"
-
 import oneflow
 import oneflow.experimental as flow
 import oneflow.python.framework.graph_build_util as graph_build_util
@@ -189,7 +182,6 @@ class TestGraph(flow.unittest.TestCase):
             def build(self):
                 # check lazy mode in nn.Graph._compile
                 test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True)
-                print("graph proto", self._graph_proto)
 
                 # check session type
                 import oneflow.python.framework.session_context as session_ctx
@@ -217,6 +209,7 @@ class TestGraph(flow.unittest.TestCase):
         g = CustomGraph()
         test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)
         g._compile()
+        print("graph proto", g._graph_proto)
         test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)
 
     def test_block_scope(test_case):
@@ -241,7 +234,6 @@ class TestGraph(flow.unittest.TestCase):
                 # weight is not get in conv1's forward, so it will return a Block
                 x = self.conv1.weight
                 test_case.assertEqual(type(x), flow.nn.graph.Block)
-                return x
 
         class SubModule1(flow.nn.Module):
             def __init__(self):
@@ -283,7 +275,6 @@ class TestGraph(flow.unittest.TestCase):
                 test_case.assertEqual(
                     dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id
                 )
-                return x
 
         class CustomModule1(flow.nn.Module):
             def __init__(self):
@@ -294,7 +285,6 @@ class TestGraph(flow.unittest.TestCase):
             def forward(self):
                 x = self.layer0()
                 y = self.layer1()
-                return x, y
 
         m = CustomModule1()
 
@@ -313,7 +303,7 @@ class TestGraph(flow.unittest.TestCase):
         g = CustomGraph1()
         x = flow.Tensor(1, 1, 10, 10)
         flow.nn.init.uniform_(x, a=-1.0, b=1.0)
-        z = g._compile()
+        g._compile()
 
 
 if __name__ == "__main__":
diff --git a/oneflow/python/test/graph/test_input_op_expr.py b/oneflow/python/test/graph/test_input_op_expr.py
index befbb0624..31ecc4dbf 100644
--- a/oneflow/python/test/graph/test_input_op_expr.py
+++ b/oneflow/python/test/graph/test_input_op_expr.py
@@ -18,12 +18,6 @@ import unittest
 import numpy as np
 import os
 
-os.environ["MASTER_ADDR"] = "127.0.0.1"
-os.environ["MASTER_PORT"] = "12139"
-os.environ["WORLD_SIZE"] = "1"
-os.environ["RANK"] = "0"
-os.environ["LOCAL_RANK"] = "0"
-
 import oneflow
 import oneflow.experimental as flow
 import oneflow.python.framework.session_context as session_ctx
diff --git a/oneflow/python/test/graph/test_multi_client_session.py b/oneflow/python/test/graph/test_multi_client_session.py
index 7b8506358..f4493f1df 100644
--- a/oneflow/python/test/graph/test_multi_client_session.py
+++ b/oneflow/python/test/graph/test_multi_client_session.py
@@ -16,12 +16,6 @@ limitations under the License.
 import unittest
 import os
 
-os.environ["MASTER_ADDR"] = "127.0.0.1"
-os.environ["MASTER_PORT"] = "12139"
-os.environ["WORLD_SIZE"] = "1"
-os.environ["RANK"] = "0"
-os.environ["LOCAL_RANK"] = "0"
-
 import oneflow
 import oneflow.experimental as flow
 import oneflow.python.framework.session_context as session_ctx
diff --git a/oneflow/python/test/graph/test_output_op_expr.py b/oneflow/python/test/graph/test_output_op_expr.py
index 5a20e57ec..0a29c15d5 100644
--- a/oneflow/python/test/graph/test_output_op_expr.py
+++ b/oneflow/python/test/graph/test_output_op_expr.py
@@ -18,12 +18,6 @@ import unittest
 import numpy as np
 import os
 
-os.environ["MASTER_ADDR"] = "127.0.0.1"
-os.environ["MASTER_PORT"] = "12139"
-os.environ["WORLD_SIZE"] = "1"
-os.environ["RANK"] = "0"
-os.environ["LOCAL_RANK"] = "0"
-
 import oneflow
 import oneflow.experimental as flow
 import oneflow.python.framework.session_context as session_ctx
diff --git a/oneflow/python/test/graph/test_user_op_expr.py b/oneflow/python/test/graph/test_user_op_expr.py
index db486ad8f..451e0dcf1 100644
--- a/oneflow/python/test/graph/test_user_op_expr.py
+++ b/oneflow/python/test/graph/test_user_op_expr.py
@@ -18,12 +18,6 @@ import unittest
 import numpy as np
 import os
 
-os.environ["MASTER_ADDR"] = "127.0.0.1"
-os.environ["MASTER_PORT"] = "12139"
-os.environ["WORLD_SIZE"] = "1"
-os.environ["RANK"] = "0"
-os.environ["LOCAL_RANK"] = "0"
-
 import oneflow
 import oneflow.experimental as flow
 import oneflow.python.framework.session_context as session_ctx
diff --git a/oneflow/python/test/graph/test_variable_op_expr.py b/oneflow/python/test/graph/test_variable_op_expr.py
index 8e9fcd9e6..cebd29ed4 100644
--- a/oneflow/python/test/graph/test_variable_op_expr.py
+++ b/oneflow/python/test/graph/test_variable_op_expr.py
@@ -18,12 +18,6 @@ import unittest
 import numpy as np
 import os
 
-os.environ["MASTER_ADDR"] = "127.0.0.1"
-os.environ["MASTER_PORT"] = "12139"
-os.environ["WORLD_SIZE"] = "1"
-os.environ["RANK"] = "0"
-os.environ["LOCAL_RANK"] = "0"
-
 import oneflow
 import oneflow.experimental as flow
 import oneflow.python.framework.session_context as session_ctx
-- 
GitLab