Skip to content
Snippets Groups Projects
Unverified Commit 041e441f authored by Xiaoyu Xu's avatar Xiaoyu Xu Committed by GitHub
Browse files

Fea/nn graph/graph name (#5413)


* graph api

* add graph dummy test

* add test

* add recursive module mode

* graph.build test pass

* add detail check on graph inner node

* support config and train

* add repr for debug

* test buffer

* test buffer add

* refine test

* add comment

* refine test

* refactor Node to Block

* add named_state

* refine Graph.named_state()

* add state_tensortuple

* graph._compile()

* add mc session 0

* nn.graph: state tuple to private var; add BlockType; add simple multi client session

* NNGraphIf

* rm old graph.cpp

* nn.graph: add cpp NNGraph; export and call NNGraph

* add comment

* nn.Graph: rm prototype MultiClientSession

* nn.Graph: rm prototype MultiClientSession test

* nn.Graph: add TODO

* nn.Graph: hack to get Graph object name

* nn.Graph: get obj name

* nn.Graph: get obj name 2

* nn.Graph: format for review

* nn.Graph: format

* nn.Graph: format

* nn.Graph: pass flake8 check

* Update graph.py

* name with init count

* name with init count 2

Co-authored-by: default avatarXinqi Li <lixinqi0703106@163.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 1b8bca09
No related branches found
No related tags found
No related merge requests found
...@@ -18,7 +18,6 @@ from collections import OrderedDict ...@@ -18,7 +18,6 @@ from collections import OrderedDict
from typing import Union from typing import Union
import oneflow._oneflow_internal import oneflow._oneflow_internal
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util
from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module from oneflow.python.nn.module import Module
...@@ -31,9 +30,11 @@ from oneflow.python.framework.function_util import FunctionConfig ...@@ -31,9 +30,11 @@ from oneflow.python.framework.function_util import FunctionConfig
@oneflow_export("nn.Graph", "nn.graph.Graph") @oneflow_export("nn.Graph", "nn.graph.Graph")
@experimental_api @experimental_api
class Graph(object): class Graph(object):
_child_init_cnt = dict()
def __init__(self): def __init__(self):
self.config = GraphConfig() self.config = GraphConfig()
self._name = id_util.UniqueStr(self.__class__.__name__ + "_") self._generate_name()
self._c_nn_graph = oneflow._oneflow_internal.NNGraph(self._name) self._c_nn_graph = oneflow._oneflow_internal.NNGraph(self._name)
self._blocks = OrderedDict() self._blocks = OrderedDict()
self._optimizers = OrderedDict() self._optimizers = OrderedDict()
...@@ -63,6 +64,13 @@ class Graph(object): ...@@ -63,6 +64,13 @@ class Graph(object):
optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf
) )
def _generate_name(self):
child_name = self.__class__.__name__
if Graph._child_init_cnt.get(child_name) is None:
Graph._child_init_cnt[child_name] = 0
self._name = child_name + "_" + str(Graph._child_init_cnt[child_name])
Graph._child_init_cnt[child_name] += 1
def _named_state(self): def _named_state(self):
for _, b in self._blocks.items(): for _, b in self._blocks.items():
prefix = b.name + "." prefix = b.name + "."
......
...@@ -73,6 +73,8 @@ class TestGraph(flow.unittest.TestCase): ...@@ -73,6 +73,8 @@ class TestGraph(flow.unittest.TestCase):
# Graph init # Graph init
g = CustomGraph() g = CustomGraph()
# check _c_nn_graph init
test_case.assertEqual(g.name, g._c_nn_graph.name)
# g.m is Block # g.m is Block
test_case.assertTrue(isinstance(g.m, flow.nn.graph.Block)) test_case.assertTrue(isinstance(g.m, flow.nn.graph.Block))
# g.m.name is "m" # g.m.name is "m"
...@@ -128,19 +130,39 @@ class TestGraph(flow.unittest.TestCase): ...@@ -128,19 +130,39 @@ class TestGraph(flow.unittest.TestCase):
# print repr of nn.Graph # print repr of nn.Graph
print(repr(g)) print(repr(g))
def test_graph_compile(test_case): def test_graph_name(test_case):
class CustomGraph(flow.nn.Graph): class ACustomGraph(flow.nn.Graph):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.m = CustomModule()
self.config.enable_auto_mixed_precision(True)
def build(self, x): def build(self, x):
x = self.m(x)
return x return x
g = CustomGraph() class BCustomGraph(flow.nn.Graph):
test_case.assertEqual(g.name, g._c_nn_graph.name) def __init__(self):
super().__init__()
def build(self, x):
return x
class CBCustomGraph(BCustomGraph):
def __init__(self):
super().__init__()
def create_graph(cnt):
a = ACustomGraph()
test_case.assertEqual(a.name, "ACustomGraph_" + str(cnt))
b = BCustomGraph()
test_case.assertEqual(b.name, "BCustomGraph_" + str(cnt))
cb = CBCustomGraph()
test_case.assertEqual(cb.name, "CBCustomGraph_" + str(cnt))
flow.nn.Graph._child_init_cnt.clear()
for i in range(0, 3):
create_graph(i)
flow.nn.Graph._child_init_cnt.clear()
for i in range(0, 3):
create_graph(i)
if __name__ == "__main__": if __name__ == "__main__":
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment