Skip to content
Snippets Groups Projects
Unverified Commit ab8aab8b authored by Xiaoyu Xu's avatar Xiaoyu Xu Committed by GitHub
Browse files

Fea/nn graph/forward graph (#5516)


* add test on add input to graph

* add var into graph

* LazyInterpreter for FetchOutputOpExpr and set op parallel_distribution

* refine input var build

* split file

* rename

* mini refine

* Add note

* LazyInterpret::ApplyImpl for UserOpExpr

* refine test scripts

* add output to graph

* format

* Fix bug of LazyInterpret UserOpExpr for change output lbns

* Add test user op expr test

* fix note mistake

* add userop and test

* address review

* address review

* save i/o/s op_name and tensor for c_nn_graph

* address review

* adjust test

* refine build_graph_state

Co-authored-by: default avatarchengtbf <472491134@qq.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 77755951
No related branches found
No related tags found
No related merge requests found
Showing with 598 additions and 383 deletions
......@@ -23,10 +23,11 @@ import oneflow.core.job.scope_pb2 as scope_pb2_util
import oneflow.python.framework.attr_util as attr_util
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.placement_util as placement_util
import oneflow.python.framework.runtime_mode as runtime_mode
import oneflow.python.framework.scope_util as scope_util
import oneflow.python.framework.session_context as session_context
import oneflow._oneflow_internal
from oneflow._oneflow_internal import Tensor as InternalTensor
from oneflow.python.framework.tensor import Tensor
lazy_mode = oneflow._oneflow_internal.lazy_mode
......@@ -125,3 +126,61 @@ def make_new_block_scope(prev_scope, block):
def scope_to_proto(scope):
return text_format.Parse(scope._proto_str, scope_pb2_util.ScopeProto())
def build_graph_input_arg(op_name, arg):
assert isinstance(arg, (Tensor, InternalTensor))
input_conf = (
oneflow._oneflow_internal.oneflow.core.operator.op_conf.FeedInputOpConf()
)
input_op = oneflow._oneflow_internal.one.FeedInputOpExpr(
op_name, input_conf, ["in_0"], ["out_0"]
)
attrs = oneflow._oneflow_internal.MutableCfgAttrMap()
if isinstance(arg, Tensor):
if not arg.is_determined:
arg.determine()
tensor_in_c = arg._local_or_consistent_tensor
else:
tensor_in_c = arg
lazy_arg = input_op.apply([tensor_in_c], attrs)[0]
return lazy_arg
def build_graph_state(op_name, state_tensor):
var_conf = (
oneflow._oneflow_internal.oneflow.core.operator.op_conf.FeedVariableOpConf()
)
var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr(
op_name, var_conf, ["in_0"], ["out_0"]
)
attrs = oneflow._oneflow_internal.MutableCfgAttrMap()
assert isinstance(state_tensor, Tensor)
if not state_tensor.is_determined:
state_tensor.determine()
tensor_in_c = state_tensor._local_or_consistent_tensor
lazy_tensor = var_op.apply([tensor_in_c], attrs)[0]
return lazy_tensor
def build_graph_output(op_name, out):
assert isinstance(out, InternalTensor)
assert out.is_lazy
output_conf = (
oneflow._oneflow_internal.oneflow.core.operator.op_conf.FetchOutputOpConf()
)
output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr(
op_name, output_conf, ["in_0"], ["out_0"]
)
attrs = oneflow._oneflow_internal.MutableCfgAttrMap()
eager_out = output_op.apply([out], attrs)[0]
return eager_out
......@@ -15,19 +15,21 @@ limitations under the License.
"""
from __future__ import absolute_import
from collections import OrderedDict
from typing import Union, Optional, Iterator, Set
from functools import partial
import oneflow._oneflow_internal
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.graph_build_util as graph_build_util
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util
from oneflow._oneflow_internal import Tensor as InternalTensor
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.framework.multi_client_session import MultiClientSession
from oneflow.python.framework.tensor import Tensor
from oneflow.python.nn.graph_block import Block
from oneflow.python.nn.graph_optimizer import OptimizerConfig
from oneflow.python.nn.module import Module
from oneflow.python.nn.parameter import Parameter
from oneflow.python.nn.optimizer.optimizer import Optimizer
from oneflow.python.nn.utils import add_indent
from oneflow.python.framework.function_util import FunctionConfig
......@@ -45,6 +47,7 @@ class Graph(object):
self._optimizers = OrderedDict()
self._is_compiled = False
self._state_tensortuple = None
self._job_proto = None
@property
def name(self):
......@@ -56,7 +59,7 @@ class Graph(object):
@property
def _graph_proto(self):
return c_api_util.GetCurrentJob()
return self._job_proto
def build(self, *args):
raise NotImplementedError()
......@@ -69,7 +72,7 @@ class Graph(object):
grad_clipping_conf=None,
weight_decay_conf=None,
):
self._optimizers[name] = self.OptimizerConfig(
self._optimizers[name] = OptimizerConfig(
optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf
)
......@@ -102,19 +105,57 @@ class Graph(object):
session.TryInit()
with graph_build_util.graph_build_context(self.config.proto, session):
# Deal with input
lazy_args = []
lazy_arg_op_names = []
for idx, arg in enumerate(args):
op_name = "_" + self.name + "-input_" + str(idx)
lazy_args.append(graph_build_util.build_graph_input_arg(op_name, arg))
lazy_arg_op_names.append(op_name)
# Deal with parameter and buffer
for s in self._state():
state_op_names = []
state_tensors = []
for state_block in self._state():
op_name = state_block.name_prefix + state_block.name
state_tensor = state_block.origin
state_op_names.append(op_name)
state_tensors.append(state_tensor)
state_block.set_lazy_origin_builder(
partial(graph_build_util.build_graph_state, op_name, state_tensor)
)
def to_lazy():
# TODO(): Replace repr(s) with OpExpr(s.origin)
lazy_tensor = repr(s)
return lazy_tensor
# Deal with module in self.build(*args)
outputs = self.build(*lazy_args)
# Deal with outputs
if not (type(outputs) is tuple or type(outputs) is list):
if outputs is None:
outputs = ()
else:
assert type(outputs) is InternalTensor
outputs = (outputs,)
eager_outputs = []
eager_output_op_names = []
for idx, out in enumerate(outputs):
op_name = "_" + self.name + "-output_" + str(idx)
eager_outputs.append(graph_build_util.build_graph_output(op_name, out))
eager_output_op_names.append(op_name)
if len(eager_outputs) == 0:
eager_outputs = None
elif len(eager_outputs) == 1:
eager_outputs = eager_outputs[0]
else:
eager_outputs = tuple(eager_outputs)
s.set_lazy_origin_lambda(to_lazy)
outputs = self.build(*args)
# TODO(): call self._c_nn_graph
# register lazy_arg_op_names/state_op_names/state_tensors/eager_output_op_names
# Save job proto for debug
self._job_proto = c_api_util.GetCurrentJob()
self._is_compiled = True
return outputs
return eager_outputs
def _launch(self):
# TODO(xuxiaoyu)
......@@ -122,7 +163,6 @@ class Graph(object):
...
def __call__(self, *args):
# TODO(xuxiaoyu)
# if not self._is_compiled:
# self._compile()
# return self._launch()
......@@ -179,7 +219,7 @@ class Graph(object):
child_lines = []
for n, m in self._blocks.items():
mod_str = repr(m)
mod_str = _add_indent(mod_str, 2)
mod_str = add_indent(mod_str, 2)
child_lines.append(mod_str)
lines = child_lines
......@@ -190,278 +230,6 @@ class Graph(object):
return main_str
class BlockType:
NONE = "NONE"
MODULE = "MODULE"
PARAMETER = "PARAMETER"
BUFFER = "BUFFER"
@oneflow_export("nn.graph.Block")
@experimental_api
class Block(object):
def __init__(
self,
prefix: str = "",
name: str = "",
value: Union[Module, Parameter, Tensor] = None,
):
assert not isinstance(value, Block)
self._name = name
self._name_prefix = prefix
self._type = BlockType.NONE
self._origin = value
self.config = BlockConfig()
self._scope = None
self._prev_scope = None
if isinstance(value, Module):
self._type = BlockType.MODULE
self._is_executing_forward = False
self._modules = OrderedDict()
self._parameters = OrderedDict()
self._buffers = OrderedDict()
for n, m in list(value.named_children()):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m))
for n, p in list(value.named_parameters("", False)):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p))
for n, b in list(value.named_buffers("", False)):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
elif isinstance(value, Parameter):
self._type = BlockType.PARAMETER
self._lazy_origin = None
self._lazy_origin_lambda = None
elif isinstance(value, Tensor):
self._type = BlockType.BUFFER
self._lazy_origin = None
self._lazy_origin_lambda = None
else:
raise NotImplementedError()
@property
def name(self):
return self._name
@property
def name_prefix(self):
return self._name_prefix
@property
def type(self):
return self._type
@property
def origin(self):
return self._origin
@property
def lazy_origin(self):
assert (
self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
), "Only Parameter or Buffer Block has lazy_origin"
return self._lazy_origin
def lazy_origin_lambda(self):
assert (
self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
), "Only Parameter or Buffer Block has lazy_origin_lambda"
return self._lazy_origin_lambda
def set_lazy_origin_lambda(self, fn=None):
assert (
self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
), "Only Parameter or Buffer Block has lazy_origin_lambda"
self._lazy_origin_lambda = fn
@property
def prev_scope(self):
if self._prev_scope is None:
self._prev_scope = oneflow._oneflow_internal.GetCurrentScope()
return self._prev_scope
@property
def scope(self):
if self._scope is None:
self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self)
return self._scope
def scope_context(self):
return graph_build_util.BlockScopeContext(self.prev_scope, self.scope)
def __call__(self, *args):
assert self._type == BlockType.MODULE
# nn.Module.__call__ will call self.forward()
# so the scope is set in self.forward()
result = self._origin.__class__.__call__(self, *args)
return result
def forward(self, *args):
assert self._type == BlockType.MODULE
self._is_executing_forward = True
# TODO(xuxiaoyu): only build scope in lazy mode
with self.scope_context():
result = self._origin.__class__.forward(self, *args)
self._is_executing_forward = False
return result
def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield self
for name, module in self._modules.items():
if module is None:
continue
for m in module.modules(memo):
yield m
def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
memo = set()
modules = self.modules() if recurse else [self]
for module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
yield v
def parameters(self, recurse: bool = True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
gen = self._members(lambda module: module._parameters.items(), recurse=recurse)
for elem in gen:
yield elem
def buffers(self, recurse: bool = True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
gen = self._members(lambda module: module._buffers.items(), recurse=recurse)
for elem in gen:
yield elem
def __setattr__(self, name: str, value=None) -> None:
if value is None or not isinstance(value, Block):
self.__dict__[name] = value
else:
dicts_or_sets = (
self.__dict__,
self._modules,
self._parameters,
self._buffers,
)
for d in dicts_or_sets:
if name in d:
raise AttributeError(
"'{}' object has duplicated attribute named '{}'".format(
self._name, name
)
)
if value.type == BlockType.MODULE:
self._modules[name] = value
elif value.type == BlockType.PARAMETER:
self._parameters[name] = value
elif value.type == BlockType.BUFFER:
self._buffers[name] = value
else:
raise AttributeError(
"'{}' object are not allowed to set attribute named '{}'".format(
type(self).__name__, name
)
)
def __getattr__(self, name: str):
if name in self.__dict__:
return self.__dict__[name]
if self._type == BlockType.MODULE:
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"]
if name in _parameters:
p_block = _parameters[name]
if self._is_executing_forward:
# Return Tensor for running when getattr inside it's father Block's forward()
if graph_build_util.lazy_mode.is_enabled():
if p_block._lazy_origin is None:
assert p_block._lazy_origin_lambda is not None, (
repr(p_block)
+ " has no lazy Tensor creation function."
)
# Create and return lazy tensor
with p_block.scope_context():
p_block._lazy_origin = p_block._lazy_origin_lambda()
return p_block._lazy_origin
else:
return p_block.origin
else:
# Return Block for config when getattr outside it's father Block's forward()
return p_block
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
b_block = _buffers[name]
if self._is_executing_forward:
# Return Tensor for running when getattr inside it's father Block's forward()
if graph_build_util.lazy_mode.is_enabled():
if b_block._lazy_origin is None:
assert b_block._lazy_origin_lambda is not None, (
repr(b_block)
+ " has no lazy Tensor creation function."
)
# Create and return lazy tensor
with b_block.scope_context():
b_block._lazy_origin = b_block._lazy_origin_lambda()
return b_block._lazy_origin
else:
return b_block.origin
else:
# Return Block for config when getattr outside it's father Block's forward()
return b_block
if name in self._origin.__dict__:
return self._origin.__dict__[name]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)
def __repr__(self):
lines = None
if self._type == BlockType.MODULE:
child_lines = []
def _append_child(d):
for _, n in d.items():
n_str = repr(n)
n_str = _add_indent(n_str, 2)
child_lines.append(n_str)
_append_child(self._modules)
_append_child(self._parameters)
_append_child(self._buffers)
if len(child_lines) > 0:
lines = child_lines
main_str = (
"("
+ self._name_prefix
+ self._name
+ ":"
+ self._origin.__class__.__name__
+ ":"
+ self._type
+ "): ("
)
if lines is not None:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
@oneflow_export("nn.graph.GraphConfig")
@experimental_api
class GraphConfig(FunctionConfig):
......@@ -486,56 +254,3 @@ class GraphConfig(FunctionConfig):
self.function_desc.job_config_proto.mutable_train_conf()
else:
self.function_desc.job_config_proto.mutable_predict_conf()
@oneflow_export("nn.graph.BlockConfig")
@experimental_api
class BlockConfig(object):
def __init__(self):
self._stage_id = None
self._activation_checkpointing = None
@property
def stage_id(self):
return self._stage_id
@stage_id.setter
def stage_id(self, value: int = None):
self._stage_id = value
@property
def activation_checkpointing(self):
return self._activation_checkpointing
@activation_checkpointing.setter
def activation_checkpointing(self, value: bool = False):
self._activation_checkpointing = value
@oneflow_export("nn.graph.OptimizerConfig")
@experimental_api
class OptimizerConfig(object):
def __init__(
self,
name: str,
optimizer: Optimizer = None,
lr_scheduler=None,
grad_clipping_conf=None,
weight_decay_conf=None,
):
self.name = name
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.grad_clipping_conf = grad_clipping_conf
self.weight_decay_conf = weight_decay_conf
def _add_indent(in_s, num_spaces):
s = in_s.split("\n")
if len(s) == 1:
return in_s
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
from collections import OrderedDict
from typing import Union, Optional, Iterator, Set
import oneflow._oneflow_internal
import oneflow.python.framework.graph_build_util as graph_build_util
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.framework.tensor import Tensor
from oneflow.python.nn.module import Module
from oneflow.python.nn.parameter import Parameter
from oneflow.python.nn.utils import add_indent
class BlockType:
NONE = "NONE"
MODULE = "MODULE"
PARAMETER = "PARAMETER"
BUFFER = "BUFFER"
@oneflow_export("nn.graph.Block")
@experimental_api
class Block(object):
def __init__(
self,
prefix: str = "",
name: str = "",
value: Union[Module, Parameter, Tensor] = None,
):
assert not isinstance(value, Block)
self._name = name
self._name_prefix = prefix
self._type = BlockType.NONE
self._origin = value
self.config = BlockConfig()
self._scope = None
self._prev_scope = None
if isinstance(value, Module):
self._type = BlockType.MODULE
self._is_executing_forward = False
self._modules = OrderedDict()
self._parameters = OrderedDict()
self._buffers = OrderedDict()
for n, m in list(value.named_children()):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m))
for n, p in list(value.named_parameters("", False)):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p))
for n, b in list(value.named_buffers("", False)):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
elif isinstance(value, Parameter):
self._type = BlockType.PARAMETER
self._lazy_origin = None
self._lazy_origin_builder = None
elif isinstance(value, Tensor):
self._type = BlockType.BUFFER
self._lazy_origin = None
self._lazy_origin_builder = None
else:
raise NotImplementedError()
@property
def name(self):
return self._name
@property
def name_prefix(self):
return self._name_prefix
@property
def type(self):
return self._type
@property
def origin(self):
return self._origin
@property
def lazy_origin(self):
assert (
self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
), "Only Parameter or Buffer Block has lazy_origin"
return self._lazy_origin
def lazy_origin_builder(self):
assert (
self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
), "Only Parameter or Buffer Block has lazy_origin_builder"
return self._lazy_origin_builder
def set_lazy_origin_builder(self, fn=None):
assert (
self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER
), "Only Parameter or Buffer Block has lazy_origin_builder"
self._lazy_origin_builder = fn
@property
def prev_scope(self):
if self._prev_scope is None:
self._prev_scope = oneflow._oneflow_internal.GetCurrentScope()
return self._prev_scope
@property
def scope(self):
if self._scope is None:
self._scope = graph_build_util.make_new_block_scope(self.prev_scope, self)
return self._scope
def scope_context(self):
return graph_build_util.BlockScopeContext(self.prev_scope, self.scope)
def __call__(self, *args):
assert self._type == BlockType.MODULE
# nn.Module.__call__ will call self.forward()
# so the scope is set in self.forward()
result = self._origin.__class__.__call__(self, *args)
return result
def forward(self, *args):
assert self._type == BlockType.MODULE
self._is_executing_forward = True
# TODO(xuxiaoyu): only build scope in lazy mode
with self.scope_context():
result = self._origin.__class__.forward(self, *args)
self._is_executing_forward = False
return result
def modules(self, memo: Optional[Set["Block"]] = None) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield self
for name, module in self._modules.items():
if module is None:
continue
for m in module.modules(memo):
yield m
def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
memo = set()
modules = self.modules() if recurse else [self]
for module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
yield v
def parameters(self, recurse: bool = True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
gen = self._members(lambda module: module._parameters.items(), recurse=recurse)
for elem in gen:
yield elem
def buffers(self, recurse: bool = True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
gen = self._members(lambda module: module._buffers.items(), recurse=recurse)
for elem in gen:
yield elem
def __setattr__(self, name: str, value=None) -> None:
if value is None or not isinstance(value, Block):
self.__dict__[name] = value
else:
dicts_or_sets = (
self.__dict__,
self._modules,
self._parameters,
self._buffers,
)
for d in dicts_or_sets:
if name in d:
raise AttributeError(
"'{}' object has duplicated attribute named '{}'".format(
self._name, name
)
)
if value.type == BlockType.MODULE:
self._modules[name] = value
elif value.type == BlockType.PARAMETER:
self._parameters[name] = value
elif value.type == BlockType.BUFFER:
self._buffers[name] = value
else:
raise AttributeError(
"'{}' object are not allowed to set attribute named '{}'".format(
type(self).__name__, name
)
)
def __getattr__(self, name: str):
if name in self.__dict__:
return self.__dict__[name]
if self._type == BlockType.MODULE:
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"]
if name in _parameters:
p_block = _parameters[name]
if self._is_executing_forward:
# Return Tensor for running when getattr is
# inside it's father Block's forward()
if graph_build_util.lazy_mode.is_enabled():
if p_block._lazy_origin is None:
assert p_block._lazy_origin_builder is not None, (
repr(p_block)
+ " has no lazy Tensor creation function."
)
# Create and return lazy tensor
with p_block.scope_context():
p_block._lazy_origin = (
p_block._lazy_origin_builder()
)
return p_block._lazy_origin
else:
return p_block.origin
else:
# Return Block for config when getattr is
# outside it's father Block's forward()
return p_block
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
b_block = _buffers[name]
if self._is_executing_forward:
# Return Tensor for running when getattr is
# inside it's father Block's forward()
if graph_build_util.lazy_mode.is_enabled():
if b_block._lazy_origin is None:
assert b_block._lazy_origin_builder is not None, (
repr(b_block)
+ " has no lazy Tensor creation function."
)
# Create and return lazy tensor
with b_block.scope_context():
b_block._lazy_origin = (
b_block._lazy_origin_builder()
)
return b_block._lazy_origin
else:
return b_block.origin
else:
# Return Block for config when getattr is
# outside it's father Block's forward()
return b_block
if name in self._origin.__dict__:
return self._origin.__dict__[name]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)
def __repr__(self):
lines = None
if self._type == BlockType.MODULE:
child_lines = []
def _append_child(d):
for _, n in d.items():
n_str = repr(n)
n_str = add_indent(n_str, 2)
child_lines.append(n_str)
_append_child(self._modules)
_append_child(self._parameters)
_append_child(self._buffers)
if len(child_lines) > 0:
lines = child_lines
main_str = (
"("
+ self._name_prefix
+ self._name
+ ":"
+ self._origin.__class__.__name__
+ ":"
+ self._type
+ "): ("
)
if lines is not None:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
@oneflow_export("nn.graph.BlockConfig")
@experimental_api
class BlockConfig(object):
def __init__(self):
self._stage_id = None
self._activation_checkpointing = None
@property
def stage_id(self):
return self._stage_id
@stage_id.setter
def stage_id(self, value: int = None):
self._stage_id = value
@property
def activation_checkpointing(self):
return self._activation_checkpointing
@activation_checkpointing.setter
def activation_checkpointing(self, value: bool = False):
self._activation_checkpointing = value
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
from oneflow.python.nn.optimizer.optimizer import Optimizer
from oneflow.python.oneflow_export import oneflow_export, experimental_api
@oneflow_export("nn.graph.OptimizerConfig")
@experimental_api
class OptimizerConfig(object):
def __init__(
self,
name: str,
optimizer: Optimizer = None,
lr_scheduler=None,
grad_clipping_conf=None,
weight_decay_conf=None,
):
self.name = name
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.grad_clipping_conf = grad_clipping_conf
self.weight_decay_conf = weight_decay_conf
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
def add_indent(in_s, num_spaces):
s = in_s.split("\n")
if len(s) == 1:
return in_s
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import os
import oneflow
import oneflow.experimental as flow
class SubModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.weight = flow.nn.Parameter(flow.Tensor(6, 6))
self.relu = flow.nn.ReLU()
def forward(self, x, y):
x = oneflow.F.matmul(x, self.weight)
x = self.relu(x)
y = self.relu(y)
return x, y
class CustomModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.layer = SubModule()
self.register_buffer(
"dummy_buff", flow.Tensor(6, 8),
)
def forward(self, x, y):
x, y = self.layer(x, y)
x = oneflow.F.flatten(x, 1)
x = oneflow.F.matmul(x, self.dummy_buff)
return x, y
@flow.unittest.skip_unless_1n1d()
class TestGraph(flow.unittest.TestCase):
def test_forward_graph(test_case):
class CustomGraph(flow.nn.Graph):
def __init__(self, module):
super().__init__()
self.m = module
def build(self, x, y):
out = self.m(x, y)
return out
m = CustomModule()
m.to("cuda")
g = CustomGraph(m)
x = flow.Tensor(6, 6)
flow.nn.init.uniform_(x, a=-1.0, b=1.0)
x = x.to("cuda")
y = flow.Tensor(10, 10)
flow.nn.init.uniform_(y, a=-1.0, b=1.0)
y = y.to("cuda")
print(repr(g))
z, a = g._compile(x, y)
test_case.assertEqual(z.shape, (6, 8))
test_case.assertEqual(z.is_lazy, False)
test_case.assertEqual(a.shape, (10, 10))
test_case.assertEqual(a.is_lazy, False)
print("graph proto: ", g._graph_proto)
if __name__ == "__main__":
unittest.main()
......@@ -18,13 +18,6 @@ import os
import numpy as np
# To enable MultiClient
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12139"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
import oneflow
import oneflow.experimental as flow
import oneflow.python.framework.graph_build_util as graph_build_util
......@@ -189,7 +182,6 @@ class TestGraph(flow.unittest.TestCase):
def build(self):
# check lazy mode in nn.Graph._compile
test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), True)
print("graph proto", self._graph_proto)
# check session type
import oneflow.python.framework.session_context as session_ctx
......@@ -217,6 +209,7 @@ class TestGraph(flow.unittest.TestCase):
g = CustomGraph()
test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)
g._compile()
print("graph proto", g._graph_proto)
test_case.assertEqual(graph_build_util.lazy_mode.is_enabled(), False)
def test_block_scope(test_case):
......@@ -241,7 +234,6 @@ class TestGraph(flow.unittest.TestCase):
# weight is not get in conv1's forward, so it will return a Block
x = self.conv1.weight
test_case.assertEqual(type(x), flow.nn.graph.Block)
return x
class SubModule1(flow.nn.Module):
def __init__(self):
......@@ -283,7 +275,6 @@ class TestGraph(flow.unittest.TestCase):
test_case.assertEqual(
dummy_buff_scope_proto.parent_scope_symbol_id, scope.symbol_id
)
return x
class CustomModule1(flow.nn.Module):
def __init__(self):
......@@ -294,7 +285,6 @@ class TestGraph(flow.unittest.TestCase):
def forward(self):
x = self.layer0()
y = self.layer1()
return x, y
m = CustomModule1()
......@@ -313,7 +303,7 @@ class TestGraph(flow.unittest.TestCase):
g = CustomGraph1()
x = flow.Tensor(1, 1, 10, 10)
flow.nn.init.uniform_(x, a=-1.0, b=1.0)
z = g._compile()
g._compile()
if __name__ == "__main__":
......
......@@ -18,12 +18,6 @@ import unittest
import numpy as np
import os
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12139"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
import oneflow
import oneflow.experimental as flow
import oneflow.python.framework.session_context as session_ctx
......
......@@ -16,12 +16,6 @@ limitations under the License.
import unittest
import os
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12139"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
import oneflow
import oneflow.experimental as flow
import oneflow.python.framework.session_context as session_ctx
......
......@@ -18,12 +18,6 @@ import unittest
import numpy as np
import os
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12139"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
import oneflow
import oneflow.experimental as flow
import oneflow.python.framework.session_context as session_ctx
......
......@@ -18,12 +18,6 @@ import unittest
import numpy as np
import os
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12139"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
import oneflow
import oneflow.experimental as flow
import oneflow.python.framework.session_context as session_ctx
......
......@@ -18,12 +18,6 @@ import unittest
import numpy as np
import os
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12139"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
import oneflow
import oneflow.experimental as flow
import oneflow.python.framework.session_context as session_ctx
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment