diff --git a/oneflow/python/nn/graph.py b/oneflow/python/nn/graph.py index 09aefecc6d58cf60a10ed8fad8d6d3aac80a3c2d..940d726b89cb9f29f57a511595b7c9141342cc1b 100644 --- a/oneflow/python/nn/graph.py +++ b/oneflow/python/nn/graph.py @@ -18,7 +18,6 @@ from collections import OrderedDict from typing import Union import oneflow._oneflow_internal -import oneflow.python.framework.id_util as id_util import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.nn.module import Module @@ -31,9 +30,11 @@ from oneflow.python.framework.function_util import FunctionConfig @oneflow_export("nn.Graph", "nn.graph.Graph") @experimental_api class Graph(object): + _child_init_cnt = dict() + def __init__(self): self.config = GraphConfig() - self._name = id_util.UniqueStr(self.__class__.__name__ + "_") + self._generate_name() self._c_nn_graph = oneflow._oneflow_internal.NNGraph(self._name) self._blocks = OrderedDict() self._optimizers = OrderedDict() @@ -63,6 +64,13 @@ class Graph(object): optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf ) + def _generate_name(self): + child_name = self.__class__.__name__ + if Graph._child_init_cnt.get(child_name) is None: + Graph._child_init_cnt[child_name] = 0 + self._name = child_name + "_" + str(Graph._child_init_cnt[child_name]) + Graph._child_init_cnt[child_name] += 1 + def _named_state(self): for _, b in self._blocks.items(): prefix = b.name + "." diff --git a/oneflow/python/test/graph/test_graph.py b/oneflow/python/test/graph/test_graph.py index b954640de6ed191617f7f01f6feaede74a033804..a762b2d6aab1366d3a7e799608cc6e18a96e6eee 100644 --- a/oneflow/python/test/graph/test_graph.py +++ b/oneflow/python/test/graph/test_graph.py @@ -73,6 +73,8 @@ class TestGraph(flow.unittest.TestCase): # Graph init g = CustomGraph() + # check _c_nn_graph init + test_case.assertEqual(g.name, g._c_nn_graph.name) # g.m is Block test_case.assertTrue(isinstance(g.m, flow.nn.graph.Block)) # g.m.name is "m" @@ -128,19 +130,39 @@ class TestGraph(flow.unittest.TestCase): # print repr of nn.Graph print(repr(g)) - def test_graph_compile(test_case): - class CustomGraph(flow.nn.Graph): + def test_graph_name(test_case): + class ACustomGraph(flow.nn.Graph): def __init__(self): super().__init__() - self.m = CustomModule() - self.config.enable_auto_mixed_precision(True) def build(self, x): - x = self.m(x) return x - g = CustomGraph() - test_case.assertEqual(g.name, g._c_nn_graph.name) + class BCustomGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, x): + return x + + class CBCustomGraph(BCustomGraph): + def __init__(self): + super().__init__() + + def create_graph(cnt): + a = ACustomGraph() + test_case.assertEqual(a.name, "ACustomGraph_" + str(cnt)) + b = BCustomGraph() + test_case.assertEqual(b.name, "BCustomGraph_" + str(cnt)) + cb = CBCustomGraph() + test_case.assertEqual(cb.name, "CBCustomGraph_" + str(cnt)) + + flow.nn.Graph._child_init_cnt.clear() + for i in range(0, 3): + create_graph(i) + flow.nn.Graph._child_init_cnt.clear() + for i in range(0, 3): + create_graph(i) if __name__ == "__main__":