Skip to content
Snippets Groups Projects
Unverified Commit a94748f1 authored by Xiaoyu Xu's avatar Xiaoyu Xu Committed by GitHub
Browse files

rm nn.Graph.train (#5424)

* rm nn.Graph.train

* Update graph.py

GraphCofig default predict

* Update test_graph.py
parent 9356c9d7
No related branches found
No related tags found
No related merge requests found
......@@ -39,7 +39,6 @@ class Graph(object):
self._optimizers = OrderedDict()
self._is_compiled = False
self._state_tensortuple = None
self.train(True)
@property
def name(self):
......@@ -64,12 +63,6 @@ class Graph(object):
optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf
)
def train(self, mode: bool = True):
self.config._train(mode)
for name, block in self._blocks.items():
assert block.type == BlockType.MODULE
block.origin.train(mode)
def _named_state(self):
for _, b in self._blocks.items():
prefix = b.name + "."
......@@ -329,6 +322,7 @@ class Block(object):
class GraphConfig(FunctionConfig):
def __init__(self):
super().__init__()
self._train(False)
@property
def proto(self):
......
......@@ -114,23 +114,11 @@ class TestGraph(flow.unittest.TestCase):
g = CustomGraph()
# check set train to True
g.train(True)
test_case.assertEqual(g.training, True)
test_case.assertEqual(g.m.training, True)
test_case.assertEqual(g.m.layer.conv1.training, True)
# check default training is True
test_case.assertEqual(g.config.training, False)
# set graph config
g.config.enable_fuse_add_to_output(True)
# check set train to False
g.train(False)
test_case.assertEqual(g.training, False)
test_case.assertEqual(g.training, False)
test_case.assertEqual(g.m.training, False)
test_case.assertEqual(g.m.layer.conv1.training, False)
# set graph config
g.config.enable_fuse_add_to_output(False)
# check _named_state get the right tensor
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment