From 2a597c2cd4411b421242614f06cec3f1f1ac337a Mon Sep 17 00:00:00 2001
From: Xiaoyu Xu <xiaoyulink@gmail.com>
Date: Thu, 15 Jul 2021 21:28:39 +0800
Subject: [PATCH] Fea/nn graph/block scope (#5498)

* trigger ci test on graph

* add module scope

* Block.paramter or buffer return different value based on context

* add test on scope name prefix

* refine scope build

* refine scope build

* lambda get parameter to create right scope

* check para get

* fix vm:PhysicalRun

* lazy Tensor only create once

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 .github/workflows/test.yml                    |   8 +
 .../core/framework/instructions_builder.cpp   |   2 +-
 oneflow/python/framework/graph_build_util.py  |  79 +++++++-
 oneflow/python/framework/scope_util.py        |   4 -
 oneflow/python/nn/graph.py                    | 176 +++++++++++++++---
 oneflow/python/test/graph/test_graph.py       | 126 +++++++++++--
 6 files changed, 338 insertions(+), 57 deletions(-)

diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 5c9b1e515..f7110ace7 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -523,6 +523,14 @@ jobs:
             -e ONEFLOW_TEST_DIR=$PWD/oneflow/python/test/tensor \
             ${{ env.image_tag }} \
             bash -c "python3 -m pip config set global.index-url ${{ env.pip_index_mirror }} && bash ci/test/try_install.sh && bash ci/test/generic_test.sh"
+      - name: Graph API test
+        if: matrix.test_suite == 'cuda_new_interface'
+        run: |
+          docker run \
+            ${{ env.extra_docker_args }} ${{ env.pip_cache_docker_args }} \
+            -e ONEFLOW_TEST_DIR=$PWD/oneflow/python/test/graph \
+            ${{ env.image_tag }} \
+            bash -c "python3 -m pip config set global.index-url ${{ env.pip_index_mirror }} && bash ci/test/try_install.sh && bash ci/test/generic_test.sh"
       - name: Op test
         timeout-minutes: 45
         if: matrix.test_suite == 'cpu' || matrix.test_suite == 'cuda_op'
diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp
index 6e6510b74..81a0ebd5d 100644
--- a/oneflow/core/framework/instructions_builder.cpp
+++ b/oneflow/core/framework/instructions_builder.cpp
@@ -1600,7 +1600,7 @@ Maybe<void> LogicalRun(const std::function<Maybe<void>(InstructionsBuilder*)>& B
 Maybe<void> PhysicalRun(const std::function<Maybe<void>(InstructionsBuilder*)>& Build) {
   vm::InstructionMsgList instruction_list;
   vm::cfg::EagerSymbolList eager_symbol_list;
-  InstructionsBuilder instructions_builder(std::shared_ptr<vm::PhysicalIdGenerator>(),
+  InstructionsBuilder instructions_builder(std::make_shared<vm::PhysicalIdGenerator>(),
                                            &instruction_list, &eager_symbol_list,
                                            _ReleasePhysicalObject);
   JUST(Build(&instructions_builder));
diff --git a/oneflow/python/framework/graph_build_util.py b/oneflow/python/framework/graph_build_util.py
index ad305f85a..25977a2b7 100644
--- a/oneflow/python/framework/graph_build_util.py
+++ b/oneflow/python/framework/graph_build_util.py
@@ -15,23 +15,28 @@ limitations under the License.
 """
 
 from __future__ import absolute_import
-
 from contextlib import contextmanager
 
-import inspect
+from google.protobuf import text_format
+
+import oneflow.core.job.scope_pb2 as scope_pb2_util
+import oneflow.python.framework.attr_util as attr_util
 import oneflow.python.framework.c_api_util as c_api_util
 import oneflow.python.framework.placement_util as placement_util
 import oneflow.python.framework.runtime_mode as runtime_mode
 import oneflow.python.framework.scope_util as scope_util
+import oneflow.python.framework.session_context as session_context
 import oneflow._oneflow_internal
 
+
 lazy_mode = oneflow._oneflow_internal.lazy_mode
 
 
 @contextmanager
 def graph_build_context(config_proto, session):
+    prev_scope = oneflow._oneflow_internal.GetCurrentScope()
     device_tag_and_ids = placement_util.GetDefaultMachineDeviceIds(session.resource)
-    scope = scope_util.MakeInitialScope(
+    new_scope = scope_util.MakeInitialScope(
         config_proto,
         *device_tag_and_ids,
         None,  # TODO(): set hierarchy from user graph config
@@ -40,7 +45,7 @@ def graph_build_context(config_proto, session):
 
     with lazy_mode.gard(True):
         with JobBuildAndInferCtx(config_proto):
-            with scope_util.ScopeContext(scope):
+            with BlockScopeContext(prev_scope, new_scope):
                 yield
 
 
@@ -53,10 +58,70 @@ class JobBuildAndInferCtx(object):
         c_api_util.CurJobBuildAndInferCtx_SetJobConf(self._job_conf)
 
     def __exit__(self, exc_type, exc_val, exc_tb):
+        # TODO(xuxiaoyu): open job optimization pass
+        # oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()
+        oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
         if exc_type is None:
-            # TODO(xuxiaoyu): open job optimization pass
-            # oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()
-            oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
             return True
         else:
             return False
+
+
+class BlockScopeContext(object):
+    def __init__(self, prev_scope, new_scope):
+        assert prev_scope is not None
+        assert new_scope is not None
+        self._prev_scope = prev_scope
+        self._new_scope = new_scope
+
+    def __enter__(self):
+        oneflow._oneflow_internal.GlobalScopeStackPush(self._new_scope)
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        assert oneflow._oneflow_internal.GetCurrentScope() is self._new_scope
+        oneflow._oneflow_internal.GlobalScopeStackPop()
+        assert oneflow._oneflow_internal.GetCurrentScope() is self._prev_scope
+        if exc_type is None:
+            return True
+        else:
+            return False
+
+
+def make_new_block_scope(prev_scope, block):
+    assert prev_scope is not None
+    assert block is not None
+
+    attr_dict = dict()
+    if block.config.stage_id is not None:
+        attr_dict["pipeline_stage_id_hint"] = block.config.stage_id
+    if block.config.activation_checkpointing is not None:
+        attr_dict["checkpointing"] = block.config.activation_checkpointing
+
+    name2default = session_context.GetDefaultSession().scope_attr_name2default_val
+
+    def scope_proto_setter(scope_proto):
+        # set attr
+        for attr_name, py_value in attr_dict.items():
+            assert attr_name in name2default
+            attr_util.SetAttrValue(
+                scope_proto.mutable_attr_name2attr_value()[attr_name],
+                py_value,
+                name2default[attr_name],
+            )
+        # append name prefix
+        scope_proto.clear_scope_op_name_prefixes()
+        scope_proto.add_scope_op_name_prefixes(block.name_prefix + block.name)
+
+    new_scope = None
+
+    def build_scope(builder):
+        nonlocal new_scope
+        new_scope = builder.BuildScopeByProtoSetter(prev_scope, scope_proto_setter)
+        assert new_scope is not None
+
+    oneflow._oneflow_internal.deprecated.LogicalRun(build_scope)
+    return new_scope
+
+
+def scope_to_proto(scope):
+    return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto())
diff --git a/oneflow/python/framework/scope_util.py b/oneflow/python/framework/scope_util.py
index 6db92f187..55e348d25 100644
--- a/oneflow/python/framework/scope_util.py
+++ b/oneflow/python/framework/scope_util.py
@@ -116,7 +116,3 @@ def ScopeContext(scope):
         assert oneflow._oneflow_internal.GetCurrentScope() is scope
         oneflow._oneflow_internal.GlobalScopeStackPop()
         assert oneflow._oneflow_internal.GetCurrentScope() is old_scope
-
-
-def to_proto(scope):
-    return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto())
diff --git a/oneflow/python/nn/graph.py b/oneflow/python/nn/graph.py
index 9df3213bf..72f23aa9a 100644
--- a/oneflow/python/nn/graph.py
+++ b/oneflow/python/nn/graph.py
@@ -15,7 +15,7 @@ limitations under the License.
 """
 from __future__ import absolute_import
 from collections import OrderedDict
-from typing import Union
+from typing import Union, Optional, Iterator, Set
 
 import oneflow._oneflow_internal
 import oneflow.python.framework.c_api_util as c_api_util
@@ -80,21 +80,20 @@ class Graph(object):
         self._name = child_name + "_" + str(Graph._child_init_cnt[child_name])
         Graph._child_init_cnt[child_name] += 1
 
-    def _named_state(self):
+    def _state(self):
         for _, b in self._blocks.items():
-            prefix = b.name + "."
-            p_gen = b.origin.named_parameters()
-            for n, p in p_gen:
-                yield prefix + n, p
-            b_gen = b.origin.named_buffers()
-            for n, b in b_gen:
-                yield prefix + n, b
+            pa_gen = b.parameters(recurse=True)
+            for pa in pa_gen:
+                yield pa
+            bu_gen = b.buffers(recurse=True)
+            for bu in bu_gen:
+                yield bu
 
     def _compile(self, *args):
         assert not self._is_compiled, (
             "nn.Graph " + self._name + " has already been compiled."
         )
-        state = tuple(t for _, t in self._named_state())
+        state = tuple(s.origin for s in self._state())
         if len(state) > 0:
             self._state_tensortuple = tensor_tuple_util.convert_to_tensor_tuple(state)
 
@@ -103,9 +102,19 @@ class Graph(object):
         session.TryInit()
 
         with graph_build_util.graph_build_context(self.config.proto, session):
+            # Deal with parameter and buffer
+            for s in self._state():
+
+                def to_lazy():
+                    # TODO(): Replace repr(s) with OpExpr(s.origin)
+                    lazy_tensor = repr(s)
+                    return lazy_tensor
+
+                s.set_lazy_origin_lambda(to_lazy)
             outputs = self.build(*args)
 
         self._is_compiled = True
+        return outputs
 
     def _launch(self):
         # TODO(xuxiaoyu)
@@ -202,10 +211,13 @@ class Block(object):
         self._name_prefix = prefix
         self._type = BlockType.NONE
         self._origin = value
-        self._config = BlockConfig()
+        self.config = BlockConfig()
+        self._scope = None
+        self._prev_scope = None
 
         if isinstance(value, Module):
             self._type = BlockType.MODULE
+            self._is_executing_forward = False
             self._modules = OrderedDict()
             self._parameters = OrderedDict()
             self._buffers = OrderedDict()
@@ -217,8 +229,12 @@ class Block(object):
                 self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
         elif isinstance(value, Parameter):
             self._type = BlockType.PARAMETER
+            self._lazy_origin = None
+            self._lazy_origin_lambda = None
         elif isinstance(value, Tensor):
             self._type = BlockType.BUFFER
+            self._lazy_origin = None
+            self._lazy_origin_lambda = None
         else:
             raise NotImplementedError()
 
@@ -238,14 +254,92 @@ class Block(object):
     def origin(self):
         return self._origin
 
+    @property
+    def lazy_origin(self):
+        assert (
+            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
+        ), "Only Parameter or Buffer Block has lazy_origin"
+        return self._lazy_origin
+
+    def lazy_origin_lambda(self):
+        assert (
+            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
+        ), "Only Parameter or Buffer Block has lazy_origin_lambda"
+        return self._lazy_origin_lambda
+
+    def set_lazy_origin_lambda(self, fn=None):
+        assert (
+            self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
+        ), "Only Parameter or Buffer Block has lazy_origin_lambda"
+        self._lazy_origin_lambda = fn
+
+    @property
+    def prev_scope(self):
+        if self._prev_scope is None:
+            self._prev_scope = oneflow._oneflow_internal.GetCurrentScope()
+        return self._prev_scope
+
+    @property
+    def scope(self):
+        if self._scope is None:
+            self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self)
+        return self._scope
+
+    def scope_context(self):
+        return graph_build_util.BlockScopeContext(self.prev_scope, self.scope)
+
     def __call__(self, *args):
         assert self._type == BlockType.MODULE
-        # TODO(): with oneflow_c_api.set_scope(self.config_):
-        return self._origin.__class__.__call__(self, *args)
+        # nn.Module.__call__ will call self.forward()
+        # so the scope is set in self.forward()
+        result = self._origin.__class__.__call__(self, *args)
+        return result
 
     def forward(self, *args):
         assert self._type == BlockType.MODULE
-        return self._origin.__class__.forward(self, *args)
+        self._is_executing_forward = True
+        # TODO(xuxiaoyu): only build scope in lazy mode
+        with self.scope_context():
+            result = self._origin.__class__.forward(self, *args)
+        self._is_executing_forward = False
+        return result
+
+    def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]:
+        assert self._type == BlockType.MODULE
+        if memo is None:
+            memo = set()
+        if self not in memo:
+            memo.add(self)
+            yield self
+            for name, module in self._modules.items():
+                if module is None:
+                    continue
+                for m in module.modules(memo):
+                    yield m
+
+    def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
+        assert self._type == BlockType.MODULE
+        memo = set()
+        modules = self.modules() if recurse else [self]
+        for module in modules:
+            members = get_members_fn(module)
+            for k, v in members:
+                if v is None or v in memo:
+                    continue
+                memo.add(v)
+                yield v
+
+    def parameters(self, recurse: bool = True) -> Iterator["Block"]:
+        assert self._type == BlockType.MODULE
+        gen = self._members(lambda module: module._parameters.items(), recurse=recurse)
+        for elem in gen:
+            yield elem
+
+    def buffers(self, recurse: bool = True) -> Iterator["Block"]:
+        assert self._type == BlockType.MODULE
+        gen = self._members(lambda module: module._buffers.items(), recurse=recurse)
+        for elem in gen:
+            yield elem
 
     def __setattr__(self, name: str, value=None) -> None:
         if value is None or not isinstance(value, Block):
@@ -289,15 +383,45 @@ class Block(object):
             if "_parameters" in self.__dict__:
                 _parameters = self.__dict__["_parameters"]
                 if name in _parameters:
-                    # TODO(): return block when need config
-                    # return _parameters[name]
-                    return _parameters[name].origin
+                    p_block = _parameters[name]
+                    if self._is_executing_forward:
+                        # Return Tensor for running when getattr inside it's father Block's forward()
+                        if graph_build_util.lazy_mode.is_enabled():
+                            if p_block._lazy_origin is None:
+                                assert p_block._lazy_origin_lambda is not None, (
+                                    repr(p_block)
+                                    + " has no lazy Tensor creation function."
+                                )
+                                # Create and return lazy tensor
+                                with p_block.scope_context():
+                                    p_block._lazy_origin = p_block._lazy_origin_lambda()
+                            return p_block._lazy_origin
+                        else:
+                            return p_block.origin
+                    else:
+                        # Return Block for config when getattr outside it's father Block's forward()
+                        return p_block
             if "_buffers" in self.__dict__:
                 _buffers = self.__dict__["_buffers"]
                 if name in _buffers:
-                    # TODO(): return block when need config
-                    # return _buffers[name]
-                    return _buffers[name].origin
+                    b_block = _buffers[name]
+                    if self._is_executing_forward:
+                        # Return Tensor for running when getattr inside it's father Block's forward()
+                        if graph_build_util.lazy_mode.is_enabled():
+                            if b_block._lazy_origin is None:
+                                assert b_block._lazy_origin_lambda is not None, (
+                                    repr(b_block)
+                                    + " has no lazy Tensor creation function."
+                                )
+                                # Create and return lazy tensor
+                                with b_block.scope_context():
+                                    b_block._lazy_origin = b_block._lazy_origin_lambda()
+                            return b_block._lazy_origin
+                        else:
+                            return b_block.origin
+                    else:
+                        # Return Block for config when getattr outside it's father Block's forward()
+                        return b_block
             if name in self._origin.__dict__:
                 return self._origin.__dict__[name]
 
@@ -324,6 +448,7 @@ class Block(object):
 
         main_str = (
             "("
+            + self._name_prefix
             + self._name
             + ":"
             + self._origin.__class__.__name__
@@ -336,10 +461,6 @@ class Block(object):
         main_str += ")"
         return main_str
 
-    @property
-    def scope(self):
-        return self._config.scope
-
 
 @oneflow_export("nn.graph.GraphConfig")
 @experimental_api
@@ -372,12 +493,7 @@ class GraphConfig(FunctionConfig):
 class BlockConfig(object):
     def __init__(self):
         self._stage_id = None
-        self._activation_checkpointing = False
-
-    @property
-    def scope(self):
-        # TODO(xuxiaoyu): support generating Scope Object
-        print("BlockConfig.scope todo")
+        self._activation_checkpointing = None
 
     @property
     def stage_id(self):
diff --git a/oneflow/python/test/graph/test_graph.py b/oneflow/python/test/graph/test_graph.py
index 8d21ce204..e64633037 100644
--- a/oneflow/python/test/graph/test_graph.py
+++ b/oneflow/python/test/graph/test_graph.py
@@ -86,24 +86,26 @@ class TestGraph(flow.unittest.TestCase):
         test_case.assertEqual(g.name, g._c_nn_graph.name)
         # g.m is Block
         test_case.assertTrue(isinstance(g.m, flow.nn.graph.Block))
+        test_case.assertEqual(g.m.type, "MODULE")
         # g.m.name is "m"
         test_case.assertEqual(g.m.name, "m")
-        # g.m.dummy_buff is Tensor, Graph.build(...) need buffer to be Tensor
-        test_case.assertTrue(isinstance(g.m.dummy_buff, flow.Tensor))
-        # g.m._buffers["dummy_buff"] is Block
-        test_case.assertTrue(
-            isinstance(g.m._buffers["dummy_buff"], flow.nn.graph.Block)
-        )
+        # g.m.dummy_buff is Block
+        test_case.assertTrue(isinstance(g.m.dummy_buff, flow.nn.graph.Block))
+        test_case.assertEqual(g.m.dummy_buff.type, "BUFFER")
+
         # conv1 is Block
         test_case.assertTrue(isinstance(g.m.layer.conv1, flow.nn.graph.Block))
         # conv1.name is "conv1"
         test_case.assertEqual(g.m.layer.conv1.name, "conv1")
+        # conv1.name_prefix is "m.layer."
+        test_case.assertEqual(g.m.layer.conv1.name_prefix, "m.layer.")
+        # conv1.weight is Block
+        test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.nn.graph.Block))
+        test_case.assertEqual(g.m.layer.conv1.weight.type, "PARAMETER")
         # conv1.weight is Tensor, Graph.build(...) need weight to be Tensor
+        g.m.layer.conv1._is_executing_forward = True
         test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.Tensor))
-        # conv1._parameters["weight"] is Block
-        test_case.assertTrue(
-            isinstance(g.m.layer.conv1._parameters["weight"], flow.nn.graph.Block)
-        )
+        g.m.layer.conv1._is_executing_forward = False
         # conv1.kernel_size is original data in original module
         test_case.assertEqual(g.m.layer.conv1.kernel_size, (5, 5))
 
@@ -132,9 +134,8 @@ class TestGraph(flow.unittest.TestCase):
         g.config.enable_fuse_add_to_output(True)
         g.config.enable_fuse_add_to_output(False)
 
-        # check _named_state get the right tensor
-        for n, t in g._named_state():
-            test_case.assertEqual(id(eval("g." + n)), id(t))
+        for s in g._state():
+            print("g state: ", repr(s))
 
         # print repr of nn.Graph
         print(repr(g))
@@ -207,8 +208,7 @@ class TestGraph(flow.unittest.TestCase):
                 import oneflow.python.framework.scope_util as scope_util
 
                 scope = oneflow.current_scope()
-                scope_proto = scope_util.to_proto(scope)
-                print("cur scope in build ", scope_proto)
+                scope_proto = graph_build_util.scope_to_proto(scope)
                 test_case.assertEqual(session.id, scope_proto.session_id)
 
                 # check job_build_and_infer_ctx
@@ -223,6 +223,102 @@ class TestGraph(flow.unittest.TestCase):
         g._compile()
         test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)
 
+    def test_block_scope(test_case):
+        class SubModule0(flow.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.conv1 = flow.nn.Conv2d(1, 1, 5)
+
+            def forward(self):
+                scope = oneflow.current_scope()
+                scope_proto = graph_build_util.scope_to_proto(scope)
+
+                # check scope activation checkpointing
+                ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool
+                test_case.assertEqual(ck_bool, True)
+                # check scope stage id
+                stage_int = scope_proto.attr_name2attr_value[
+                    "pipeline_stage_id_hint"
+                ].at_int64
+                test_case.assertEqual(stage_int, 0)
+
+                # weight is not get in conv1's forward, so it will return a Block
+                x = self.conv1.weight
+                test_case.assertEqual(type(x), flow.nn.graph.Block)
+                return x
+
+        class SubModule1(flow.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.fc1 = flow.nn.Linear(36, 4)
+                self.register_buffer(
+                    "dummy_buff", flow.Tensor(1, 4),
+                )
+
+            def forward(self):
+                scope = oneflow.current_scope()
+                scope_proto = graph_build_util.scope_to_proto(scope)
+
+                # check scope symbol id
+                test_case.assertEqual(
+                    scope_proto.parent_scope_symbol_id, self.prev_scope.symbol_id
+                )
+
+                # check scope activation checkpointing
+                ck_bool = scope_proto.attr_name2attr_value["checkpointing"]
+                test_case.assertEqual(ck_bool.WhichOneof("value"), None)
+                # check scope stage id
+                stage_int = scope_proto.attr_name2attr_value[
+                    "pipeline_stage_id_hint"
+                ].at_int64
+                test_case.assertEqual(stage_int, 1)
+
+                name = self.name_prefix + self.name
+                prefixes = []
+                for prefix in scope_proto.scope_op_name_prefixes:
+                    prefixes.append(prefix)
+                name_in_scope = ".".join(prefixes)
+                test_case.assertEqual(name, name_in_scope)
+
+                x = self.dummy_buff
+                dummy_buff_scope_proto = graph_build_util.scope_to_proto(
+                    self._buffers["dummy_buff"].scope
+                )
+                test_case.assertEqual(
+                    dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id
+                )
+                return x
+
+        class CustomModule1(flow.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.layer0 = SubModule0()
+                self.layer1 = SubModule1()
+
+            def forward(self):
+                x = self.layer0()
+                y = self.layer1()
+                return x, y
+
+        m = CustomModule1()
+
+        class CustomGraph1(flow.nn.Graph):
+            def __init__(self):
+                super().__init__()
+                self.m = m
+                # config scope
+                self.m.layer0.config.stage_id = 0
+                self.m.layer0.config.activation_checkpointing = True
+                self.m.layer1.config.stage_id = 1
+
+            def build(self):
+                return self.m()
+
+        g = CustomGraph1()
+        x = flow.Tensor(1, 1, 10, 10)
+        flow.nn.init.uniform_(x, a=-1.0, b=1.0)
+        z = g._compile()
+
 
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab