Skip to content
Snippets Groups Projects
Unverified Commit 0a821433 authored by Xiaoyu Xu's avatar Xiaoyu Xu Committed by GitHub
Browse files

nn.Graph python (#5309)


* graph api

* add graph dummy test

* add test

* add recursive module mode

* graph.build test pass

* add detail check on graph inner node

* support config and train

* add repr for debug

* test buffer

* test buffer add

* refine test

* add comment

* refine test

* refactor Node to Block

* add named_state

* refine Graph.named_state()

* add state_tensortuple

* graph._compile()

* add mc session 0

* nn.graph: state tuple to private var; add BlockType; add simple multi client session

* NNGraphIf

* rm old graph.cpp

* nn.graph: add cpp NNGraph; export and call NNGraph

* add comment

* nn.Graph: rm prototype MultiClientSession

* nn.Graph: rm prototype MultiClientSession test

* nn.Graph: add TODO

* nn.Graph: format for review

* nn.Graph: format

* nn.Graph: format

* nn.Graph: pass flake8 check

Co-authored-by: default avatarXinqi Li <lixinqi0703106@163.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: default avatarcheng cheng <472491134@qq.com>
parent a28eadca
No related branches found
No related tags found
No related merge requests found
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <string>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/nn_graph_if.h"
namespace py = pybind11;
namespace oneflow {
ONEFLOW_API_PYBIND11_MODULE("", m) {
using namespace oneflow;
py::class_<NNGraph, std::shared_ptr<NNGraph>>(m, "NNGraph")
.def(py::init<const std::string&>())
.def_property_readonly("name", &NNGraph::job_name);
}
} // namespace oneflow
......@@ -14,5 +14,11 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/nn_graph_if.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
namespace oneflow {}
const std::vector<std::string>& NNGraph::inputs_op_names() const { UNIMPLEMENTED(); }
const std::vector<std::string>& NNGraph::outputs_op_names() const { UNIMPLEMENTED(); }
} // namespace oneflow
......@@ -33,6 +33,19 @@ class NNGraphIf {
NNGraphIf() = default;
};
class NNGraph final : public NNGraphIf {
public:
NNGraph() = delete;
explicit NNGraph(const std::string& name) : name_(name) {}
const std::string& job_name() const { return name_; }
const std::vector<std::string>& inputs_op_names() const;
const std::vector<std::string>& outputs_op_names() const;
private:
std::string name_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_FRAMEWORK_NN_GRAPH_IF_H_
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
from collections import OrderedDict
from typing import Union
import oneflow._oneflow_internal
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.framework.tensor import Tensor
from oneflow.python.nn.parameter import Parameter
from oneflow.python.nn.optimizer.optimizer import Optimizer
from oneflow.python.framework.function_util import FunctionConfig
@oneflow_export("nn.Graph", "nn.graph.Graph")
@experimental_api
class Graph(object):
def __init__(self):
self.config = GraphConfig()
self._name = id_util.UniqueStr(self.__class__.__name__ + "_")
self._c_nn_graph = oneflow._oneflow_internal.NNGraph(self._name)
self._blocks = OrderedDict()
self._optimizers = OrderedDict()
self._is_compiled = False
self._state_tensortuple = None
self.train(True)
@property
def name(self):
return self._name
@property
def training(self):
return self.config.training
def build(self, *args):
raise NotImplementedError()
def add_optimizer(
self,
name: str,
optimizer: Optimizer = None,
lr_scheduler=None,
grad_clipping_conf=None,
weight_decay_conf=None,
):
self._optimizers[name] = self.OptimizerConfig(
optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf
)
def train(self, mode: bool = True):
self.config._train(mode)
for name, block in self._blocks.items():
assert block.type == BlockType.MODULE
block.origin.train(mode)
def _named_state(self):
for _, b in self._blocks.items():
prefix = b.name + "."
p_gen = b.origin.named_parameters()
for n, p in p_gen:
yield prefix + n, p
b_gen = b.origin.named_buffers()
for n, b in b_gen:
yield prefix + n, b
def _compile(self):
assert not self._is_compiled, (
"nn.Graph " + self._name + " has already been compiled."
)
self._state_tensortuple = tensor_tuple_util.convert_to_tensor_tuple(
tuple(t for _, t in self._named_state())
)
# TODO(xuxiaoyu)
# sess = session_ctx.GetDefaultSession()
# sess.TryInit()
# do job compile
self._is_compiled = True
def _launch(self):
# TODO(xuxiaoyu)
# return self._c_nn_graph.run()
...
def __call__(self, *args):
# TODO(xuxiaoyu)
# if not self._is_compiled:
# self._compile()
# return self._launch()
...
def _add_block(self, name: str, module: Module = None) -> None:
r"""Adds a module to the current graph as a block.
The block can be accessed as an attribute using the given name.
Args:
name (string): name of the child block. The child block can be
accessed from this graph using the given name
module (Module): child module to be added to the graph.
"""
if not isinstance(module, Module) and module is not None:
raise TypeError("{} is not a Module subclass".format(type(module)))
elif not isinstance(name, str):
raise TypeError("module name should be a string. Got {}".format(type(name)))
elif hasattr(self, name) and name not in self._blocks:
raise KeyError("attribute '{}' already exists".format(name))
elif "." in name:
raise KeyError('module name can\'t contain ".", got: {}'.format(name))
elif name == "":
raise KeyError('module name can\'t be empty string ""')
self._blocks[name] = Block(self._name + ".", name, module)
def __setattr__(self, name: str, value=None):
if isinstance(value, Module):
self._add_block(name, value)
elif isinstance(value, Optimizer):
raise AttributeError(
"'{}' object are not allowed to set Optimizer attribute named '{}', \
please use add_optimizer(...) instead.".format(
type(self).__name__, name
)
)
else:
object.__setattr__(self, name, value)
def __getattr__(self, name: str):
if "_blocks" in self.__dict__:
if name in self._blocks:
return self._blocks[name]
if name in self.__dict__:
return self.__dict__[name]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)
def __repr__(self):
lines = None
if len(self._blocks) > 0:
child_lines = []
for n, m in self._blocks.items():
mod_str = repr(m)
mod_str = _add_indent(mod_str, 2)
child_lines.append(mod_str)
lines = child_lines
main_str = "(" + self._name + ":" + self.__class__.__name__ + ":GRAPH): ("
if lines is not None:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
class BlockType:
NONE = "NONE"
MODULE = "MODULE"
PARAMETER = "PARAMETER"
BUFFER = "BUFFER"
@oneflow_export("nn.graph.Block")
@experimental_api
class Block(object):
def __init__(
self,
prefix: str = "",
name: str = "",
value: Union[Module, Parameter, Tensor] = None,
):
assert not isinstance(value, Block)
self._name = name
self._name_prefix = prefix
self._type = BlockType.NONE
self._origin = value
self._config = BlockConfig()
if isinstance(value, Module):
self._type = BlockType.MODULE
self._modules = OrderedDict()
self._parameters = OrderedDict()
self._buffers = OrderedDict()
for n, m in list(value.named_children()):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, m))
for n, p in list(value.named_parameters("", False)):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, p))
for n, b in list(value.named_buffers("", False)):
self.__setattr__(n, Block(self._name_prefix + self._name + ".", n, b))
elif isinstance(value, Parameter):
self._type = BlockType.PARAMETER
elif isinstance(value, Tensor):
self._type = BlockType.BUFFER
else:
raise NotImplementedError()
@property
def name(self):
return self._name
@property
def name_prefix(self):
return self._name_prefix
@property
def type(self):
return self._type
@property
def origin(self):
return self._origin
def __call__(self, *args):
assert self._type == BlockType.MODULE
# TODO(): with oneflow_c_api.set_scope(self.config_):
return self._origin.__class__.__call__(self, *args)
def forward(self, *args):
assert self._type == BlockType.MODULE
return self._origin.__class__.forward(self, *args)
def __setattr__(self, name: str, value=None) -> None:
if value is None or not isinstance(value, Block):
self.__dict__[name] = value
else:
dicts_or_sets = (
self.__dict__,
self._modules,
self._parameters,
self._buffers,
)
for d in dicts_or_sets:
if name in d:
raise AttributeError(
"'{}' object has duplicated attribute named '{}'".format(
self._name, name
)
)
if value.type == BlockType.MODULE:
self._modules[name] = value
elif value.type == BlockType.PARAMETER:
self._parameters[name] = value
elif value.type == BlockType.BUFFER:
self._buffers[name] = value
else:
raise AttributeError(
"'{}' object are not allowed to set attribute named '{}'".format(
type(self).__name__, name
)
)
def __getattr__(self, name: str):
if name in self.__dict__:
return self.__dict__[name]
if self._type == BlockType.MODULE:
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if name in modules:
return modules[name]
if "_parameters" in self.__dict__:
_parameters = self.__dict__["_parameters"]
if name in _parameters:
# TODO(): return block when need config
# return _parameters[name]
return _parameters[name].origin
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
# TODO(): return block when need config
# return _buffers[name]
return _buffers[name].origin
if name in self._origin.__dict__:
return self._origin.__dict__[name]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name)
)
def __repr__(self):
lines = None
if self._type == BlockType.MODULE:
child_lines = []
def _append_child(d):
for _, n in d.items():
n_str = repr(n)
n_str = _add_indent(n_str, 2)
child_lines.append(n_str)
_append_child(self._modules)
_append_child(self._parameters)
_append_child(self._buffers)
if len(child_lines) > 0:
lines = child_lines
main_str = (
"("
+ self._name
+ ":"
+ self._origin.__class__.__name__
+ ":"
+ self._type
+ "): ("
)
if lines is not None:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
@oneflow_export("nn.graph.GraphConfig")
@experimental_api
class GraphConfig(FunctionConfig):
def __init__(self):
super().__init__()
@property
def proto(self):
return self.function_desc.job_config_proto
@property
def training(self):
if self.function_desc.job_config_proto.has_train_conf():
return True
if self.function_desc.job_config_proto.has_predict_conf():
return False
raise NotImplementedError
def _train(self, mode: bool = True):
if mode:
self.function_desc.job_config_proto.mutable_train_conf()
else:
self.function_desc.job_config_proto.mutable_predict_conf()
@oneflow_export("nn.graph.BlockConfig")
@experimental_api
class BlockConfig(object):
def __init__(self):
# TODO(xuxiaoyu): implement config for block
# support generating Scope Object
pass
@oneflow_export("nn.graph.OptimizerConfig")
@experimental_api
class OptimizerConfig(object):
def __init__(
self,
name: str,
optimizer: Optimizer = None,
lr_scheduler=None,
grad_clipping_conf=None,
weight_decay_conf=None,
):
self.name = name
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.grad_clipping_conf = grad_clipping_conf
self.weight_decay_conf = weight_decay_conf
def _add_indent(in_s, num_spaces):
s = in_s.split("\n")
if len(s) == 1:
return in_s
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow.experimental as flow
import oneflow
class SubModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = flow.nn.Conv2d(1, 1, 5)
self.relu = flow.nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
return x
class CustomModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.layer = SubModule()
self.fc1 = flow.nn.Linear(36, 4)
self.register_buffer(
"dummy_buff", flow.Tensor(1, 4),
)
def forward(self, x):
x = self.layer(x)
x = oneflow.F.flatten(x, 1)
x = self.fc1(x) + self.dummy_buff
return x
@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestGraph(flow.unittest.TestCase):
def test_add_nested_module(test_case):
x = flow.Tensor(1, 1, 10, 10)
flow.nn.init.uniform_(x, a=-1.0, b=1.0)
# Module init and call
m = CustomModule()
y = m(x)
class CustomGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = m
def build(self, x):
return self.m(x)
# Graph init
g = CustomGraph()
# g.m is Block
test_case.assertTrue(isinstance(g.m, flow.nn.graph.Block))
# g.m.name is "m"
test_case.assertEqual(g.m.name, "m")
# g.m.dummy_buff is Tensor, Graph.build(...) need buffer to be Tensor
test_case.assertTrue(isinstance(g.m.dummy_buff, flow.Tensor))
# g.m._buffers["dummy_buff"] is Block
test_case.assertTrue(
isinstance(g.m._buffers["dummy_buff"], flow.nn.graph.Block)
)
# conv1 is Block
test_case.assertTrue(isinstance(g.m.layer.conv1, flow.nn.graph.Block))
# conv1.name is "conv1"
test_case.assertEqual(g.m.layer.conv1.name, "conv1")
# conv1.weight is Tensor, Graph.build(...) need weight to be Tensor
test_case.assertTrue(isinstance(g.m.layer.conv1.weight, flow.Tensor))
# conv1._parameters["weight"] is Block
test_case.assertTrue(
isinstance(g.m.layer.conv1._parameters["weight"], flow.nn.graph.Block)
)
# conv1.kernel_size is original data in original module
test_case.assertEqual(g.m.layer.conv1.kernel_size, (5, 5))
# Graph build
z = g.build(x)
# g got the same result as m
test_case.assertTrue(np.array_equal(y.numpy(), z.numpy()))
def test_graph_config(test_case):
class CustomGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = CustomModule()
self.config.enable_auto_mixed_precision(True)
def build(self, x):
x = self.m(x)
return x
g = CustomGraph()
# check set train to True
g.train(True)
test_case.assertEqual(g.training, True)
test_case.assertEqual(g.m.training, True)
test_case.assertEqual(g.m.layer.conv1.training, True)
# set graph config
g.config.enable_fuse_add_to_output(True)
# check set train to False
g.train(False)
test_case.assertEqual(g.training, False)
test_case.assertEqual(g.training, False)
test_case.assertEqual(g.m.training, False)
test_case.assertEqual(g.m.layer.conv1.training, False)
# set graph config
g.config.enable_fuse_add_to_output(False)
# check _named_state get the right tensor
for n, t in g._named_state():
test_case.assertEqual(id(eval("g." + n)), id(t))
# print repr of nn.Graph
print(repr(g))
def test_graph_compile(test_case):
class CustomGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = CustomModule()
self.config.enable_auto_mixed_precision(True)
def build(self, x):
x = self.m(x)
return x
g = CustomGraph()
test_case.assertEqual(g.name, g._c_nn_graph.name)
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment