Skip to content
Snippets Groups Projects
Unverified Commit a4075797 authored by Xiaoyu Xu's avatar Xiaoyu Xu Committed by GitHub
Browse files

Fea/nn graph/block proxy func (#5727)


* pass test on linear with training

* Refactor RuntimeCtx for multi-runtime

* refactor inplace to support nn graph

* block support iterator

* block iter add check

* fix scalar_mul op conf build

* deal with inplace after merge master

* add alexnet graph test

* add cpu test and format

* cout to glog

* deal with Job run finish bug

* refactor lazy deal with inplace

* deal with 0D tensor

* update data path

* address review

* deal with lazy default attr

* mv according to ci

* merge master

* fix for ci

* fix for ci limit

* block proxy func

* support module custom func and refacotr get attr of block

* auto format by CI

Co-authored-by: default avatarchengtbf <472491134@qq.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
parent a4c87a26
No related branches found
No related tags found
No related merge requests found
...@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and ...@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from collections import OrderedDict from collections import OrderedDict
from functools import partial
from typing import Iterator, Optional, Set, Union from typing import Iterator, Optional, Set, Union
import oneflow._oneflow_internal import oneflow._oneflow_internal
...@@ -207,56 +208,60 @@ class Block(object): ...@@ -207,56 +208,60 @@ class Block(object):
if name in self.__dict__: if name in self.__dict__:
return self.__dict__[name] return self.__dict__[name]
if self._type == BlockType.MODULE: if self._type == BlockType.MODULE:
# support get module
if "_modules" in self.__dict__: if "_modules" in self.__dict__:
modules = self.__dict__["_modules"] modules = self.__dict__["_modules"]
if name in modules: if name in modules:
return modules[name] return modules[name]
if "_parameters" in self.__dict__: # support get parameter
_parameters = self.__dict__["_parameters"] p_state = self._get_in_states(name, "_parameters")
if name in _parameters: if p_state is not None:
p_block = _parameters[name] return p_state
if self._is_executing_forward: # support get buffer
if graph_build_util.lazy_mode.is_enabled(): b_state = self._get_in_states(name, "_buffers")
if p_block._lazy_origin is None: if b_state is not None:
assert p_block._lazy_origin_builder is not None, ( return b_state
repr(p_block) # support get normal attr
+ " has no lazy Tensor creation function."
)
with p_block.scope_context():
p_block._lazy_origin = (
p_block._lazy_origin_builder()
)
return p_block._lazy_origin
else:
return p_block.origin
else:
return p_block
if "_buffers" in self.__dict__:
_buffers = self.__dict__["_buffers"]
if name in _buffers:
b_block = _buffers[name]
if self._is_executing_forward:
if graph_build_util.lazy_mode.is_enabled():
if b_block._lazy_origin is None:
assert b_block._lazy_origin_builder is not None, (
repr(b_block)
+ " has no lazy Tensor creation function."
)
with b_block.scope_context():
b_block._lazy_origin = (
b_block._lazy_origin_builder()
)
return b_block._lazy_origin
else:
return b_block.origin
else:
return b_block
if name in self._origin.__dict__: if name in self._origin.__dict__:
return self._origin.__dict__[name] return self._origin.__dict__[name]
# support get function
if hasattr(self._origin, name):
return partial(getattr(self._origin.__class__, name), self)
raise AttributeError( raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, name) "'{}' object has no attribute '{}'".format(type(self).__name__, name)
) )
def _get_in_states(self, name, states_name):
if states_name not in self.__dict__:
return None
_states = self.__dict__[states_name]
if name not in _states:
return None
_s_block = _states[name]
if graph_build_util.lazy_mode.is_enabled():
# lazy
if _s_block._lazy_origin is None:
assert _s_block._lazy_origin_builder is not None, (
repr(_s_block) + " has no lazy Tensor creation function."
)
assert self._is_executing_forward, (
repr(_s_block)
+ "'s first get must happened in it's nn.Module.forward() to generate the right scope."
)
with _s_block.scope_context():
_s_block._lazy_origin = _s_block._lazy_origin_builder()
return _s_block._lazy_origin
elif (
not graph_build_util.lazy_mode.is_enabled()
) and self._is_executing_forward:
# eager and inside nn.Graph.build()
return _s_block.origin
else:
# outside nn.Graph.build()
return _s_block
def __repr__(self): def __repr__(self):
lines = None lines = None
if self._type == BlockType.MODULE: if self._type == BlockType.MODULE:
......
...@@ -51,6 +51,7 @@ class CustomModule(flow.nn.Module): ...@@ -51,6 +51,7 @@ class CustomModule(flow.nn.Module):
return x return x
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n1d() @flow.unittest.skip_unless_1n1d()
class TestGraph(flow.unittest.TestCase): class TestGraph(flow.unittest.TestCase):
def test_add_nested_module(test_case): def test_add_nested_module(test_case):
...@@ -201,9 +202,10 @@ class TestGraph(flow.unittest.TestCase): ...@@ -201,9 +202,10 @@ class TestGraph(flow.unittest.TestCase):
"pipeline_stage_id_hint" "pipeline_stage_id_hint"
].at_int64 ].at_int64
test_case.assertEqual(stage_int, 0) test_case.assertEqual(stage_int, 0)
out = self.conv1(x)
weight = self.conv1.weight weight = self.conv1.weight
test_case.assertEqual(type(weight), flow.nn.graph.Block) test_case.assertTrue(weight.is_lazy)
return self.conv1(x) return out
class SubModule1(flow.nn.Module): class SubModule1(flow.nn.Module):
def __init__(self): def __init__(self):
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
import numpy as np
import oneflow as flow
import oneflow.unittest
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n1d()
class TestGraphBlock(flow.unittest.TestCase):
def test_module_has_custom_func(test_case):
class CustomModuleHasFunc(flow.nn.Module):
def __init__(self):
super().__init__()
self.data_mem = 10
def forward(self, x):
return self._custom_func(x)
def _custom_func(self, x):
test_case.assertEqual(self.data_mem, 10)
return x
class CustomGraphHasFunc(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = CustomModuleHasFunc()
def build(self, x):
return self.m(x)
g = CustomGraphHasFunc()
x = np.ones((10, 10))
x = flow.tensor(x, dtype=flow.float32)
out = g(x)
test_case.assertTrue(np.array_equal(x.numpy(), out.numpy()))
def test_block_with_parameter(test_case):
device = "cuda"
linear = flow.nn.Linear(3, 8)
linear = linear.to(device)
flow.nn.init.constant_(linear.weight, 2.068758)
flow.nn.init.constant_(linear.bias, 0.23)
of_sgd = flow.optim.SGD(linear.parameters(), lr=0.001, momentum=0.9)
x = flow.Tensor(
[
[-0.94630778, -0.83378579, -0.87060891],
[2.0289922, -0.28708987, -2.18369248],
[0.35217619, -0.67095644, -1.58943879],
[0.08086036, -1.81075924, 1.20752494],
[0.8901075, -0.49976737, -1.07153746],
[-0.44872912, -1.07275683, 0.06256855],
[-0.22556897, 0.74798368, 0.90416439],
[0.48339456, -2.32742195, -0.59321527],
],
device=device,
requires_grad=False,
)
class CustomModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.linear = linear
def forward(self, x):
return self._forward_impl(x)
def _forward_impl(self, x):
test_case.assertTrue(isinstance(self.linear, flow.nn.graph.Block))
return self.linear(x)
class LinearTrainGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = CustomModule()
self.add_optimizer("sgd", of_sgd)
def build(self, x):
out = self.m(x)
out = out.sum()
out.backward()
test_case.assertTrue(self.m.linear.weight.is_lazy)
return out
linear_t_g = LinearTrainGraph()
linear_t_g(x)
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment