diff --git a/oneflow/api/python/framework/nn_graph.cpp b/oneflow/api/python/framework/nn_graph.cpp index f6fd51b1cace79e57ca06a84aea29369f9583b68..6e63bdd4f1f318b5640ad26358d92394d0e4d3cc 100644 --- a/oneflow/api/python/framework/nn_graph.cpp +++ b/oneflow/api/python/framework/nn_graph.cpp @@ -15,7 +15,9 @@ limitations under the License. */ #include <pybind11/pybind11.h> #include <string> +#include "oneflow/api/python/job_build/job_build_and_infer.h" #include "oneflow/api/python/of_api_registry.h" +#include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/nn_graph.h" #include "oneflow/core/job/runtime.h" #include "oneflow/core/register/blob.h" @@ -23,9 +25,9 @@ limitations under the License. namespace py = pybind11; namespace oneflow { -ONEFLOW_API_PYBIND11_MODULE("", m) { +ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) { using namespace oneflow; - py::class_<NNGraph, std::shared_ptr<NNGraph>>(m, "NNGraph") + py::class_<NNGraph, std::shared_ptr<NNGraph>>(m, "CNNGraph") .def(py::init<const std::string&>()) .def_property_readonly("name", &NNGraph::job_name) .def("register_input_op_names", @@ -51,5 +53,7 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { const std::shared_ptr<NNGraph>& nn_graph) { return RunLazyNNGraph(inputs, outputs, parameters, nn_graph).GetOrThrow(); }); + m.def("AddTensorAsGraphLoss", + [](const std::shared_ptr<one::Tensor>& t) { return AddTensorAsGraphLoss(t).GetOrThrow(); }); } } // namespace oneflow diff --git a/oneflow/api/python/job_build/job_build_and_infer.h b/oneflow/api/python/job_build/job_build_and_infer.h index 05896ac6ea31d89420e1d9f33cc5b7333f61fde6..3d88c27d10980c0fe7a84052f53eed13dc2de409 100644 --- a/oneflow/api/python/job_build/job_build_and_infer.h +++ b/oneflow/api/python/job_build/job_build_and_infer.h @@ -18,11 +18,14 @@ limitations under the License. #include "oneflow/core/job/global_for.h" #include "oneflow/core/common/protobuf.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/job/job_build_and_infer_ctx.h" -#include "oneflow/core/record/record.pb.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/job.pb.h" #include "oneflow/core/job/job_conf.cfg.h" +#include "oneflow/core/job/lazy_mode.h" +#include "oneflow/core/record/record.pb.h" namespace oneflow { @@ -166,6 +169,14 @@ inline Maybe<std::string> JobBuildAndInferCtx_GetOpBlobLbn(const std::string& jo return job_ctx->GetOpBlobLbn(op_name, bn_in_op); } +inline Maybe<void> AddTensorAsGraphLoss(const std::shared_ptr<one::Tensor>& t) { + CHECK_OR_RETURN(t->is_lazy()); + CHECK_OR_RETURN(LazyMode::is_enabled()); + const std::string& loss_lbn = one::TensorNameScope::Global()->Lookup(t); + CHECK_OR_RETURN("" != loss_lbn); + return JUST(GetCurInferCtx())->AddLossLogicalBlobName(loss_lbn); +} + } // namespace oneflow #endif // ONEFLOW_API_PYTHON_JOB_BUILD_JOB_BUILD_AND_INFER_H_ diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index 2f980ebd9e0d0e2839f4126984790a9ba410d122..5024dbde7c8a05902412e8e3006b6f29bde4831c 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -14,7 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor.h" +#include "oneflow/core/common/maybe.h" #include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" +#include "oneflow/core/job/job_build_and_infer_ctx.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/tensor_tuple.h" diff --git a/oneflow/core/framework/tensor_name_scope.cpp b/oneflow/core/framework/tensor_name_scope.cpp index ecd84b1140588cf2b831d6f2cbb38ec93ddbb2f3..ff0c484cb814f532fefb82a9fc95c0d3f57f4ec8 100644 --- a/oneflow/core/framework/tensor_name_scope.cpp +++ b/oneflow/core/framework/tensor_name_scope.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/tensor_name_scope.h" +#include <cstdint> namespace oneflow { namespace one { @@ -24,8 +25,8 @@ namespace one { } const std::string& TensorNameScope::Lookup(const std::shared_ptr<Tensor>& tensor) const { - std::lock_guard<std::mutex> lock(mutex_); uint64_t key = reinterpret_cast<uint64_t>(tensor.get()); + std::lock_guard<std::mutex> lock(mutex_); const auto& it = tensor_names_.find(key); if (it != tensor_names_.end()) { return it->second; diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index 58230af9a6c02e3ad42d672155cebc19fabde738..9e2cf2466e991dcd0c428f14faecf0e05740af18 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -19,6 +19,7 @@ from oneflow._oneflow_internal.exception import IndexException from oneflow.python.oneflow_export import oneflow_export import oneflow.python.framework.remote_blob as remote_blob_util import oneflow._oneflow_internal +import oneflow._oneflow_internal.lazy_mode as lazy_mode import numpy as np import inspect from typing import Union @@ -397,7 +398,13 @@ class Tensor: @_auto_determine @register_local_tensor_method() def backward(self, gradient=None, retain_graph=False, create_graph=False): - flow.autograd.backward(self, gradient, retain_graph, create_graph) + if not lazy_mode.is_enabled(): + flow.autograd.backward(self, gradient, retain_graph, create_graph) + else: + assert ( + self.is_lazy + ), "nn.Graph only accept lazy tensor to call backward() in lazy mode." + flow._oneflow_internal.nn.graph.AddTensorAsGraphLoss(self) @register_local_tensor_method() def _transform_ellipsis_type(self, key): diff --git a/oneflow/python/nn/graph.py b/oneflow/python/nn/graph.py index 1f34da4e49f60c98174c4a514aaafdda27be98a1..b3e46e415ee2d089b37c43815a9f29d02308aeba 100644 --- a/oneflow/python/nn/graph.py +++ b/oneflow/python/nn/graph.py @@ -15,17 +15,17 @@ limitations under the License. """ from __future__ import absolute_import from collections import OrderedDict +from typing import Dict from functools import partial import oneflow._oneflow_internal import oneflow.python.framework.c_api_util as c_api_util import oneflow.python.framework.graph_build_util as graph_build_util import oneflow.python.framework.session_context as session_ctx -import oneflow.python.framework.tensor_tuple_util as tensor_tuple_util from oneflow._oneflow_internal import Tensor as InternalTensor from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.framework.multi_client_session import MultiClientSession -from oneflow.python.nn.graph_block import Block +from oneflow.python.nn.graph_block import Block, BlockType from oneflow.python.nn.graph_optimizer import OptimizerConfig from oneflow.python.nn.module import Module from oneflow.python.nn.optimizer.optimizer import Optimizer @@ -42,11 +42,11 @@ class Graph(object): self.config = GraphConfig() self._generate_name() self.config.proto.set_job_name(self._name) - self._c_nn_graph = oneflow._oneflow_internal.NNGraph(self._name) + self._c_nn_graph = oneflow._oneflow_internal.nn.graph.CNNGraph(self._name) self._blocks = OrderedDict() self._optimizers = OrderedDict() self._is_compiled = False - self._state_tensortuple = None + self._var2var_op_name = dict() self._job_proto = None @property @@ -72,8 +72,14 @@ class Graph(object): grad_clipping_conf=None, weight_decay_conf=None, ): + assert name is not None, "name cannot be None" + assert type(name) is str, "name must be an instance of str" + assert optimizer is not None, "optimizer cannot be None" + assert isinstance( + optimizer, Optimizer + ), "optimizer must be an instance of Optimizer" self._optimizers[name] = OptimizerConfig( - optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf + name, optimizer, lr_scheduler, grad_clipping_conf, weight_decay_conf ) def _generate_name(self): @@ -92,18 +98,34 @@ class Graph(object): for bu in bu_gen: yield bu + def _preprocess_state(self): + state_list = list() + for state_block in self._state(): + state_list.append(state_block.origin) + if state_block.type == BlockType.PARAMETER: + self._var2var_op_name[state_block.origin] = ( + state_block.name_prefix + state_block.name + ) + + def _complete_graph_config(self): + if len(self._optimizers): + self.config._train(True) + # TODO(xuxiaoyu): save variable name and it's l2 if optimizer has weight decay + # which means to used as l2. + for name, opt_config in self._optimizers.items(): + self.config.add_optimizer_config(opt_config, self._var2var_op_name) + def _compile(self, *args): assert not self._is_compiled, ( "nn.Graph " + self._name + " has already been compiled." ) - state = tuple(s.origin for s in self._state()) - if len(state) > 0: - self._state_tensortuple = tensor_tuple_util.convert_to_tensor_tuple(state) + + self._preprocess_state() + self._complete_graph_config() session = session_ctx.GetDefaultSession() assert type(session) is MultiClientSession session.TryInit() - with graph_build_util.graph_build_context(self.config.proto, session): # Deal with input lazy_args = [] @@ -188,6 +210,8 @@ class Graph(object): raise KeyError('module name can\'t contain ".", got: {}'.format(name)) elif name == "": raise KeyError('module name can\'t be empty string ""') + # TODO(xuxiaoyu): Add dict of Parameter id to Parameter Block, for using id + # to query Parameter Block. self._blocks[name] = Block("", name, module) def __setattr__(self, name: str, value=None): @@ -195,8 +219,8 @@ class Graph(object): self._add_block(name, value) elif isinstance(value, Optimizer): raise AttributeError( - "'{}' object are not allowed to set Optimizer attribute named '{}', \ - please use add_optimizer(...) instead.".format( + "'{}' object are not allowed to set Optimizer attribute named '{}', " + "please use add_optimizer(...) instead.".format( type(self).__name__, name ) ) @@ -243,14 +267,22 @@ class GraphConfig(FunctionConfig): @property def training(self): - if self.function_desc.job_config_proto.has_train_conf(): + if self.proto.has_train_conf(): return True - if self.function_desc.job_config_proto.has_predict_conf(): + if self.proto.has_predict_conf(): return False raise NotImplementedError def _train(self, mode: bool = True): if mode: - self.function_desc.job_config_proto.mutable_train_conf() + self.proto.mutable_train_conf() + self.proto.mutable_train_conf().set_loss_scale_factor(1.0) else: - self.function_desc.job_config_proto.mutable_predict_conf() + self.proto.mutable_predict_conf() + + def add_optimizer_config( + self, optimizer_config: OptimizerConfig = None, var2var_op_name: Dict = None + ): + optimizer_config.optimizer.add_to_graph_train_config( + self.proto.mutable_train_conf(), var2var_op_name + ) diff --git a/oneflow/python/nn/optimizer/sgd.py b/oneflow/python/nn/optimizer/sgd.py index 08f7d2ef5fd736e689cc627a55d28c3607fd3197..35082cf5f038d7d733ab38bc678c87da17be8881 100644 --- a/oneflow/python/nn/optimizer/sgd.py +++ b/oneflow/python/nn/optimizer/sgd.py @@ -16,6 +16,7 @@ limitations under the License. from typing import List, Dict, Callable, Union, Iterator import collections +import math import oneflow as flow @@ -135,3 +136,29 @@ class SGD(Optimizer): self._state["step"] = self._state["step"] + 1 return loss + + def add_to_graph_train_config(self, train_conf, var2var_op_name_dict): + for param_group in self.param_groups: + optimizer_conf = train_conf.mutable_optimizer_conf().Add() + lr = param_group["lr"] + beta = param_group["momentum"] + scale = param_group["scale"] + # TODO(): optimizer_conf need to have loss_scale_factor field to support multi scale factor + base_scale = train_conf.loss_scale_factor() + assert math.isclose(base_scale, 1, rel_tol=1e-4) or math.isclose( + scale, base_scale, rel_tol=1e-4 + ), "nn.Graph only support one scale factor at the moment, base_scale {} vs scale {}".format( + base_scale, scale + ) + + train_conf.set_loss_scale_factor(scale) + optimizer_conf.set_base_learning_rate(lr) + if beta == 0: + optimizer_conf.mutable_naive_conf() + else: + optimizer_conf.mutable_momentum_conf().set_beta(beta) + + for param in param_group.parameters: + if not param.requires_grad: + continue + optimizer_conf.add_variable_op_names(var2var_op_name_dict[param]) diff --git a/oneflow/python/test/graph/test_forward_graph.py b/oneflow/python/test/graph/test_forward_graph.py index 552644e921f091112ef47fa42b8e55815b0a1c65..02177e77bb52136b0869c8de34c9db0d989c2241 100644 --- a/oneflow/python/test/graph/test_forward_graph.py +++ b/oneflow/python/test/graph/test_forward_graph.py @@ -49,7 +49,7 @@ class CustomModule(flow.nn.Module): @flow.unittest.skip_unless_1n1d() -class TestGraph(flow.unittest.TestCase): +class TestForwardGraph(flow.unittest.TestCase): def test_forward_graph(test_case): class CustomGraph(flow.nn.Graph): def __init__(self, module): diff --git a/oneflow/python/test/graph/test_graph_optimizer.py b/oneflow/python/test/graph/test_graph_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc78f94d1c7c9b28fca8e0dd7343b651d240a63 --- /dev/null +++ b/oneflow/python/test/graph/test_graph_optimizer.py @@ -0,0 +1,95 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +import os + +import numpy as np + +import oneflow +import oneflow.experimental as flow + + +@flow.unittest.skip_unless_1n1d() +class TestGraphOptimizer(flow.unittest.TestCase): + def test_optimizer(test_case): + class CustomModule(flow.nn.Module): + def __init__(self): + super().__init__() + self.para0 = flow.nn.Parameter(flow.Tensor(1, 4)) + self.para1 = flow.nn.Parameter(flow.Tensor(1, 4)) + self.para2 = flow.nn.Parameter(flow.Tensor(1, 4)) + self.para2.requires_grad_(False) + self.para3 = flow.nn.Parameter(flow.Tensor(1, 4)) + self.para4 = flow.nn.Parameter(flow.Tensor(1, 4)) + + def forward(self, x): + return x + + m = CustomModule() + learning_rate = 0.1 + momentum = 0.2 + scale = 0.3 + sgd0 = flow.optim.SGD( + [ + { + "params": [m.para0, m.para1, m.para2], + "lr": learning_rate, + "momentum": momentum, + "scale": scale, + } + ] + ) + sgd1 = flow.optim.SGD( + [ + { + "params": [m.para3], + "lr": learning_rate, + "momentum": momentum, + "scale": scale, + }, + { + "params": [m.para4], + "lr": learning_rate, + "momentum": momentum, + "scale": scale, + }, + ] + ) + + class CustomGraph0(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = m + self.add_optimizer("sgd0", sgd0) + self.add_optimizer("sgd1", sgd1) + + def build(self, x): + out = self.m(x) + out.backward() + return out + + g = CustomGraph0() + x = flow.Tensor(1, 1, 10, 10) + flow.nn.init.uniform_(x, a=-1.0, b=1.0) + z = g._compile(x) + print(repr(g)) + print("g.config.proto: \n", g.config.proto) + print("graph proto: \n", g._graph_proto) + + +if __name__ == "__main__": + unittest.main()