Skip to content
Snippets Groups Projects
Unverified Commit 2a597c2c authored by Xiaoyu Xu's avatar Xiaoyu Xu Committed by GitHub
Browse files

Fea/nn graph/block scope (#5498)


* trigger ci test on graph

* add module scope

* Block.paramter or buffer return different value based on context

* add test on scope name prefix

* refine scope build

* refine scope build

* lambda get parameter to create right scope

* check para get

* fix vm:PhysicalRun

* lazy Tensor only create once

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent d68ffeed
No related branches found
No related tags found
No related merge requests found
...@@ -523,6 +523,14 @@ jobs: ...@@ -523,6 +523,14 @@ jobs:
-e ONEFLOW_TEST_DIR=$PWD/oneflow/python/test/tensor \ -e ONEFLOW_TEST_DIR=$PWD/oneflow/python/test/tensor \
${{ env.image_tag }} \ ${{ env.image_tag }} \
bash -c "python3 -m pip config set global.index-url ${{ env.pip_index_mirror }} && bash ci/test/try_install.sh && bash ci/test/generic_test.sh" bash -c "python3 -m pip config set global.index-url ${{ env.pip_index_mirror }} && bash ci/test/try_install.sh && bash ci/test/generic_test.sh"
- name: Graph API test
if: matrix.test_suite == 'cuda_new_interface'
run: |
docker run \
${{ env.extra_docker_args }} ${{ env.pip_cache_docker_args }} \
-e ONEFLOW_TEST_DIR=$PWD/oneflow/python/test/graph \
${{ env.image_tag }} \
bash -c "python3 -m pip config set global.index-url ${{ env.pip_index_mirror }} && bash ci/test/try_install.sh && bash ci/test/generic_test.sh"
- name: Op test - name: Op test
timeout-minutes: 45 timeout-minutes: 45
if: matrix.test_suite == 'cpu' || matrix.test_suite == 'cuda_op' if: matrix.test_suite == 'cpu' || matrix.test_suite == 'cuda_op'
......
...@@ -1600,7 +1600,7 @@ Maybe<void> LogicalRun(const std::function<Maybe<void>(InstructionsBuilder*)>& B ...@@ -1600,7 +1600,7 @@ Maybe<void> LogicalRun(const std::function<Maybe<void>(InstructionsBuilder*)>& B
Maybe<void> PhysicalRun(const std::function<Maybe<void>(InstructionsBuilder*)>& Build) { Maybe<void> PhysicalRun(const std::function<Maybe<void>(InstructionsBuilder*)>& Build) {
vm::InstructionMsgList instruction_list; vm::InstructionMsgList instruction_list;
vm::cfg::EagerSymbolList eager_symbol_list; vm::cfg::EagerSymbolList eager_symbol_list;
InstructionsBuilder instructions_builder(std::shared_ptr<vm::PhysicalIdGenerator>(), InstructionsBuilder instructions_builder(std::make_shared<vm::PhysicalIdGenerator>(),
&instruction_list, &eager_symbol_list, &instruction_list, &eager_symbol_list,
_ReleasePhysicalObject); _ReleasePhysicalObject);
JUST(Build(&instructions_builder)); JUST(Build(&instructions_builder));
......
...@@ -15,23 +15,28 @@ limitations under the License. ...@@ -15,23 +15,28 @@ limitations under the License.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from contextlib import contextmanager from contextlib import contextmanager
import inspect from google.protobuf import text_format
import oneflow.core.job.scope_pb2 as scope_pb2_util
import oneflow.python.framework.attr_util as attr_util
import oneflow.python.framework.c_api_util as c_api_util import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.placement_util as placement_util import oneflow.python.framework.placement_util as placement_util
import oneflow.python.framework.runtime_mode as runtime_mode import oneflow.python.framework.runtime_mode as runtime_mode
import oneflow.python.framework.scope_util as scope_util import oneflow.python.framework.scope_util as scope_util
import oneflow.python.framework.session_context as session_context
import oneflow._oneflow_internal import oneflow._oneflow_internal
lazy_mode = oneflow._oneflow_internal.lazy_mode lazy_mode = oneflow._oneflow_internal.lazy_mode
@contextmanager @contextmanager
def graph_build_context(config_proto, session): def graph_build_context(config_proto, session):
prev_scope = oneflow._oneflow_internal.GetCurrentScope()
device_tag_and_ids = placement_util.GetDefaultMachineDeviceIds(session.resource) device_tag_and_ids = placement_util.GetDefaultMachineDeviceIds(session.resource)
scope = scope_util.MakeInitialScope( new_scope = scope_util.MakeInitialScope(
config_proto, config_proto,
*device_tag_and_ids, *device_tag_and_ids,
None, # TODO(): set hierarchy from user graph config None, # TODO(): set hierarchy from user graph config
...@@ -40,7 +45,7 @@ def graph_build_context(config_proto, session): ...@@ -40,7 +45,7 @@ def graph_build_context(config_proto, session):
with lazy_mode.gard(True): with lazy_mode.gard(True):
with JobBuildAndInferCtx(config_proto): with JobBuildAndInferCtx(config_proto):
with scope_util.ScopeContext(scope): with BlockScopeContext(prev_scope, new_scope):
yield yield
...@@ -53,10 +58,70 @@ class JobBuildAndInferCtx(object): ...@@ -53,10 +58,70 @@ class JobBuildAndInferCtx(object):
c_api_util.CurJobBuildAndInferCtx_SetJobConf(self._job_conf) c_api_util.CurJobBuildAndInferCtx_SetJobConf(self._job_conf)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
# TODO(xuxiaoyu): open job optimization pass
# oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()
oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
if exc_type is None: if exc_type is None:
# TODO(xuxiaoyu): open job optimization pass
# oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()
oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
return True return True
else: else:
return False return False
class BlockScopeContext(object):
def __init__(self, prev_scope, new_scope):
assert prev_scope is not None
assert new_scope is not None
self._prev_scope = prev_scope
self._new_scope = new_scope
def __enter__(self):
oneflow._oneflow_internal.GlobalScopeStackPush(self._new_scope)
def __exit__(self, exc_type, exc_val, exc_tb):
assert oneflow._oneflow_internal.GetCurrentScope() is self._new_scope
oneflow._oneflow_internal.GlobalScopeStackPop()
assert oneflow._oneflow_internal.GetCurrentScope() is self._prev_scope
if exc_type is None:
return True
else:
return False
def make_new_block_scope(prev_scope, block):
assert prev_scope is not None
assert block is not None
attr_dict = dict()
if block.config.stage_id is not None:
attr_dict["pipeline_stage_id_hint"] = block.config.stage_id
if block.config.activation_checkpointing is not None:
attr_dict["checkpointing"] = block.config.activation_checkpointing
name2default = session_context.GetDefaultSession().scope_attr_name2default_val
def scope_proto_setter(scope_proto):
# set attr
for attr_name, py_value in attr_dict.items():
assert attr_name in name2default
attr_util.SetAttrValue(
scope_proto.mutable_attr_name2attr_value()[attr_name],
py_value,
name2default[attr_name],
)
# append name prefix
scope_proto.clear_scope_op_name_prefixes()
scope_proto.add_scope_op_name_prefixes(block.name_prefix + block.name)
new_scope = None
def build_scope(builder):
nonlocal new_scope
new_scope = builder.BuildScopeByProtoSetter(prev_scope, scope_proto_setter)
assert new_scope is not None
oneflow._oneflow_internal.deprecated.LogicalRun(build_scope)
return new_scope
def scope_to_proto(scope):
return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto())
...@@ -116,7 +116,3 @@ def ScopeContext(scope): ...@@ -116,7 +116,3 @@ def ScopeContext(scope):
assert oneflow._oneflow_internal.GetCurrentScope() is scope assert oneflow._oneflow_internal.GetCurrentScope() is scope
oneflow._oneflow_internal.GlobalScopeStackPop() oneflow._oneflow_internal.GlobalScopeStackPop()
assert oneflow._oneflow_internal.GetCurrentScope() is old_scope assert oneflow._oneflow_internal.GetCurrentScope() is old_scope
def to_proto(scope):
return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto())
...@@ -15,7 +15,7 @@ limitations under the License. ...@@ -15,7 +15,7 @@ limitations under the License.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from collections import OrderedDict from collections import OrderedDict
from typing import Union from typing import Union, Optional, Iterator, Set
import oneflow._oneflow_internal import oneflow._oneflow_internal
import oneflow.python.framework.c_api_util as c_api_util import oneflow.python.framework.c_api_util as c_api_util
...@@ -80,21 +80,20 @@ class Graph(object): ...@@ -80,21 +80,20 @@ class Graph(object):
self._name = child_name + "_" + str(Graph._child_init_cnt[child_name]) self._name = child_name + "_" + str(Graph._child_init_cnt[child_name])
Graph._child_init_cnt[child_name] += 1 Graph._child_init_cnt[child_name] += 1
def _named_state(self): def _state(self):
for _, b in self._blocks.items(): for _, b in self._blocks.items():
prefix = b.name + "." pa_gen = b.parameters(recurse=True)
p_gen = b.origin.named_parameters() for pa in pa_gen:
for n, p in p_gen: yield pa
yield prefix + n, p bu_gen = b.buffers(recurse=True)
b_gen = b.origin.named_buffers() for bu in bu_gen:
for n, b in b_gen: yield bu
yield prefix + n, b
def _compile(self, *args): def _compile(self, *args):
assert not self._is_compiled, ( assert not self._is_compiled, (
"nn.Graph " + self._name + " has already been compiled." "nn.Graph " + self._name + " has already been compiled."
) )
state = tuple(t for _, t in self._named_state()) state = tuple(s.origin for s in self._state())
if len(state) > 0: if len(state) > 0:
self._state_tensortuple = tensor_tuple_util.convert_to_tensor_tuple(state) self._state_tensortuple = tensor_tuple_util.convert_to_tensor_tuple(state)
...@@ -103,9 +102,19 @@ class Graph(object): ...@@ -103,9 +102,19 @@ class Graph(object):
session.TryInit() session.TryInit()
with graph_build_util.graph_build_context(self.config.proto, session): with graph_build_util.graph_build_context(self.config.proto, session):
# Deal with parameter and buffer
for s in self._state():
def to_lazy():
# TODO(): Replace repr(s) with OpExpr(s.origin)
lazy_tensor = repr(s)
return lazy_tensor
s.set_lazy_origin_lambda(to_lazy)
outputs = self.build(*args) outputs = self.build(*args)
self._is_compiled = True self._is_compiled = True
return outputs
def _launch(self): def _launch(self):
# TODO(xuxiaoyu) # TODO(xuxiaoyu)
...@@ -202,10 +211,13 @@ class Block(object): ...@@ -202,10 +211,13 @@ class Block(object):
self._name_prefix = prefix self._name_prefix = prefix
self._type = BlockType.NONE self._type = BlockType.NONE
self._origin = value self._origin = value
self._config = BlockConfig() self.config = BlockConfig()
self._scope = None
self._prev_scope = None
if isinstance(value, Module): if isinstance(value, Module):
self._type = BlockType.MODULE self._type = BlockType.MODULE
self._is_executing_forward = False
self._modules = OrderedDict() self._modules = OrderedDict()
self._parameters = OrderedDict() self._parameters = OrderedDict()
self._buffers = OrderedDict() self._buffers = OrderedDict()
...@@ -217,8 +229,12 @@ class Block(object): ...@@ -217,8 +229,12 @@ class Block(object):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b)) self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
elif isinstance(value, Parameter): elif isinstance(value, Parameter):
self._type = BlockType.PARAMETER self._type = BlockType.PARAMETER
self._lazy_origin = None
self._lazy_origin_lambda = None
elif isinstance(value, Tensor): elif isinstance(value, Tensor):
self._type = BlockType.BUFFER self._type = BlockType.BUFFER
self._lazy_origin = None
self._lazy_origin_lambda = None
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -238,14 +254,92 @@ class Block(object): ...@@ -238,14 +254,92 @@ class Block(object):
def origin(self): def origin(self):
return self._origin return self._origin
@property
def lazy_origin(self):
assert (
self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
), "Only Parameter or Buffer Block has lazy_origin"
return self._lazy_origin
def lazy_origin_lambda(self):
assert (
self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
), "Only Parameter or Buffer Block has lazy_origin_lambda"
return self._lazy_origin_lambda
def set_lazy_origin_lambda(self, fn=None):
assert (
self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
), "Only Parameter or Buffer Block has lazy_origin_lambda"
self._lazy_origin_lambda = fn
@property
def prev_scope(self):
if self._prev_scope is None:
self._prev_scope = oneflow._oneflow_internal.GetCurrentScope()
return self._prev_scope
@property
def scope(self):
if self._scope is None:
self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self)
return self._scope
def scope_context(self):
return graph_build_util.BlockScopeContext(self.prev_scope, self.scope)
def __call__(self, *args): def __call__(self, *args):
assert self._type == BlockType.MODULE assert self._type == BlockType.MODULE
# TODO(): with oneflow_c_api.set_scope(self.config_): # nn.Module.__call__ will call self.forward()
return self._origin.__class__.__call__(self, *args) # so the scope is set in self.forward()
result = self._origin.__class__.__call__(self, *args)
return result
def forward(self, *args): def forward(self, *args):
assert self._type == BlockType.MODULE assert self._type == BlockType.MODULE
return self._origin.__class__.forward(self, *args) self._is_executing_forward = True
# TODO(xuxiaoyu): only build scope in lazy mode
with self.scope_context():
result = self._origin.__class__.forward(self, *args)
self._is_executing_forward = False
return result
def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield self
for name, module in self._modules.items():
if module is None:
continue
for m in module.modules(memo):
yield m
def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
memo = set()
modules = self.modules() if recurse else [self]
for module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
yield v
def parameters(self, recurse: bool = True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
gen = self._members(lambda module: module._parameters.items(), recurse=recurse)
for elem in gen:
yield elem
def buffers(self, recurse: bool = True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
gen = self._members(lambda module: module._buffers.items(), recurse=recurse)
for elem in gen:
yield elem
def __setattr__(self, name: str, value=None) -> None: def __setattr__(self, name: str, value=None) -> None:
if value is None or not isinstance(value, Block): if value is None or not isinstance(value, Block):
...@@ -289,15 +383,45 @@ class Block(object): ...@@ -289,15 +383,45 @@ class Block(object):
if "_parameters" in self.__dict__: if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"] _parameters = self.__dict__["_parameters"]
if name in _parameters: if name in _parameters:
# TODO(): return block when need config p_block = _parameters[name]
# return _parameters[name] if self._is_executing_forward:
return _parameters[name].origin # Return Tensor for running when getattr inside it's father Block's forward()
if graph_build_util.lazy_mode.is_enabled():
if p_block._lazy_origin is None:
assert p_block._lazy_origin_lambda is not None, (
repr(p_block)
+ " has no lazy Tensor creation function."
)
# Create and return lazy tensor
with p_block.scope_context():
p_block._lazy_origin = p_block._lazy_origin_lambda()
return p_block._lazy_origin
else:
return p_block.origin
else:
# Return Block for config when getattr outside it's father Block's forward()
return p_block
if "_buffers" in self.__dict__: if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"] _buffers = self.__dict__["_buffers"]
if name in _buffers: if name in _buffers:
# TODO(): return block when need config b_block = _buffers[name]
# return _buffers[name] if self._is_executing_forward:
return _buffers[name].origin # Return Tensor for running when getattr inside it's father Block's forward()
if graph_build_util.lazy_mode.is_enabled():
if b_block._lazy_origin is None:
assert b_block._lazy_origin_lambda is not None, (
repr(b_block)
+ " has no lazy Tensor creation function."
)
# Create and return lazy tensor
with b_block.scope_context():
b_block._lazy_origin = b_block._lazy_origin_lambda()
return b_block._lazy_origin
else:
return b_block.origin
else:
# Return Block for config when getattr outside it's father Block's forward()
return b_block
if name in self._origin.__dict__: if name in self._origin.__dict__:
return self._origin.__dict__[name] return self._origin.__dict__[name]
...@@ -324,6 +448,7 @@ class Block(object): ...@@ -324,6 +448,7 @@ class Block(object):
main_str = ( main_str = (
"(" "("
+ self._name_prefix
+ self._name + self._name
+ ":" + ":"
+ self._origin.__class__.__name__ + self._origin.__class__.__name__
...@@ -336,10 +461,6 @@ class Block(object): ...@@ -336,10 +461,6 @@ class Block(object):
main_str += ")" main_str += ")"
return main_str return main_str
@property
def scope(self):
return self._config.scope
@oneflow_export("nn.graph.GraphConfig") @oneflow_export("nn.graph.GraphConfig")
@experimental_api @experimental_api
...@@ -372,12 +493,7 @@ class GraphConfig(FunctionConfig): ...@@ -372,12 +493,7 @@ class GraphConfig(FunctionConfig):
class BlockConfig(object): class BlockConfig(object):
def __init__(self): def __init__(self):
self._stage_id = None self._stage_id = None
self._activation_checkpointing = False self._activation_checkpointing = None
@property
def scope(self):
# TODO(xuxiaoyu): support generating Scope Object
print("BlockConfig.scope todo")
@property @property
def stage_id(self): def stage_id(self):
......
...@@ -86,24 +86,26 @@ class TestGraph(flow.unittest.TestCase): ...@@ -86,24 +86,26 @@ class TestGraph(flow.unittest.TestCase):
test_case.assertEqual(g.name, g._c_nn_graph.name) test_case.assertEqual(g.name, g._c_nn_graph.name)
# g.m is Block # g.m is Block
test_case.assertTrue(isinstance(g.m, flow.nn.graph.Block)) test_case.assertTrue(isinstance(g.m, flow.nn.graph.Block))
test_case.assertEqual(g.m.type, "MODULE")
# g.m.name is "m" # g.m.name is "m"
test_case.assertEqual(g.m.name, "m") test_case.assertEqual(g.m.name, "m")
# g.m.dummy_buff is Tensor, Graph.build(...) need buffer to be Tensor # g.m.dummy_buff is Block
test_case.assertTrue(isinstance(g.m.dummy_buff, flow.Tensor)) test_case.assertTrue(isinstance(g.m.dummy_buff, flow.nn.graph.Block))
# g.m._buffers["dummy_buff"] is Block test_case.assertEqual(g.m.dummy_buff.type, "BUFFER")
test_case.assertTrue(
isinstance(g.m._buffers["dummy_buff"], flow.nn.graph.Block)
)
# conv1 is Block # conv1 is Block
test_case.assertTrue(isinstance(g.m.layer.conv1, flow.nn.graph.Block)) test_case.assertTrue(isinstance(g.m.layer.conv1, flow.nn.graph.Block))
# conv1.name is "conv1" # conv1.name is "conv1"
test_case.assertEqual(g.m.layer.conv1.name, "conv1") test_case.assertEqual(g.m.layer.conv1.name, "conv1")
# conv1.name_prefix is "m.layer."
test_case.assertEqual(g.m.layer.conv1.name_prefix, "m.layer.")
# conv1.weight is Block
test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.nn.graph.Block))
test_case.assertEqual(g.m.layer.conv1.weight.type, "PARAMETER")
# conv1.weight is Tensor, Graph.build(...) need weight to be Tensor # conv1.weight is Tensor, Graph.build(...) need weight to be Tensor
g.m.layer.conv1._is_executing_forward = True
test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.Tensor)) test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.Tensor))
# conv1._parameters["weight"] is Block g.m.layer.conv1._is_executing_forward = False
test_case.assertTrue(
isinstance(g.m.layer.conv1._parameters["weight"], flow.nn.graph.Block)
)
# conv1.kernel_size is original data in original module # conv1.kernel_size is original data in original module
test_case.assertEqual(g.m.layer.conv1.kernel_size, (5, 5)) test_case.assertEqual(g.m.layer.conv1.kernel_size, (5, 5))
...@@ -132,9 +134,8 @@ class TestGraph(flow.unittest.TestCase): ...@@ -132,9 +134,8 @@ class TestGraph(flow.unittest.TestCase):
g.config.enable_fuse_add_to_output(True) g.config.enable_fuse_add_to_output(True)
g.config.enable_fuse_add_to_output(False) g.config.enable_fuse_add_to_output(False)
# check _named_state get the right tensor for s in g._state():
for n, t in g._named_state(): print("g state: ", repr(s))
test_case.assertEqual(id(eval("g." + n)), id(t))
# print repr of nn.Graph # print repr of nn.Graph
print(repr(g)) print(repr(g))
...@@ -207,8 +208,7 @@ class TestGraph(flow.unittest.TestCase): ...@@ -207,8 +208,7 @@ class TestGraph(flow.unittest.TestCase):
import oneflow.python.framework.scope_util as scope_util import oneflow.python.framework.scope_util as scope_util
scope = oneflow.current_scope() scope = oneflow.current_scope()
scope_proto = scope_util.to_proto(scope) scope_proto = graph_build_util.scope_to_proto(scope)
print("cur scope in build ", scope_proto)
test_case.assertEqual(session.id, scope_proto.session_id) test_case.assertEqual(session.id, scope_proto.session_id)
# check job_build_and_infer_ctx # check job_build_and_infer_ctx
...@@ -223,6 +223,102 @@ class TestGraph(flow.unittest.TestCase): ...@@ -223,6 +223,102 @@ class TestGraph(flow.unittest.TestCase):
g._compile() g._compile()
test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False) test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)
def test_block_scope(test_case):
class SubModule0(flow.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = flow.nn.Conv2d(1, 1, 5)
def forward(self):
scope = oneflow.current_scope()
scope_proto = graph_build_util.scope_to_proto(scope)
# check scope activation checkpointing
ck_bool = scope_proto.attr_name2attr_value["checkpointing"].at_bool
test_case.assertEqual(ck_bool, True)
# check scope stage id
stage_int = scope_proto.attr_name2attr_value[
"pipeline_stage_id_hint"
].at_int64
test_case.assertEqual(stage_int, 0)
# weight is not get in conv1's forward, so it will return a Block
x = self.conv1.weight
test_case.assertEqual(type(x), flow.nn.graph.Block)
return x
class SubModule1(flow.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = flow.nn.Linear(36, 4)
self.register_buffer(
"dummy_buff", flow.Tensor(1, 4),
)
def forward(self):
scope = oneflow.current_scope()
scope_proto = graph_build_util.scope_to_proto(scope)
# check scope symbol id
test_case.assertEqual(
scope_proto.parent_scope_symbol_id, self.prev_scope.symbol_id
)
# check scope activation checkpointing
ck_bool = scope_proto.attr_name2attr_value["checkpointing"]
test_case.assertEqual(ck_bool.WhichOneof("value"), None)
# check scope stage id
stage_int = scope_proto.attr_name2attr_value[
"pipeline_stage_id_hint"
].at_int64
test_case.assertEqual(stage_int, 1)
name = self.name_prefix + self.name
prefixes = []
for prefix in scope_proto.scope_op_name_prefixes:
prefixes.append(prefix)
name_in_scope = ".".join(prefixes)
test_case.assertEqual(name, name_in_scope)
x = self.dummy_buff
dummy_buff_scope_proto = graph_build_util.scope_to_proto(
self._buffers["dummy_buff"].scope
)
test_case.assertEqual(
dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id
)
return x
class CustomModule1(flow.nn.Module):
def __init__(self):
super().__init__()
self.layer0 = SubModule0()
self.layer1 = SubModule1()
def forward(self):
x = self.layer0()
y = self.layer1()
return x, y
m = CustomModule1()
class CustomGraph1(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = m
# config scope
self.m.layer0.config.stage_id = 0
self.m.layer0.config.activation_checkpointing = True
self.m.layer1.config.stage_id = 1
def build(self):
return self.m()
g = CustomGraph1()
x = flow.Tensor(1, 1, 10, 10)
flow.nn.init.uniform_(x, a=-1.0, b=1.0)
z = g._compile()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment