diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake index d9285ed06a4396f065978406d4527dba8b6554ad..356d16539479d78170e081ee5fc2e48c1cd1f3b0 100644 --- a/cmake/oneflow.cmake +++ b/cmake/oneflow.cmake @@ -418,6 +418,7 @@ list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/kernel_conte list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/kernel_util.cuh") list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/job/sbp_signature_builder.h") list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/job/parallel_desc.h") +list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/autograd/autograd_meta.h") copy_files("${OF_CORE_HDRS}" "${PROJECT_SOURCE_DIR}" "${ONEFLOW_INCLUDE_DIR}" of_include_copy) add_dependencies(pip_install of_include_copy) diff --git a/oneflow/api/python/framework/tensor.cpp b/oneflow/api/python/framework/tensor.cpp index ee531615f2e08ec743b81e3e630c8d9e8b183a32..4559a7548cf64d62e9a3ba688e955d7c3943f83b 100644 --- a/oneflow/api/python/framework/tensor.cpp +++ b/oneflow/api/python/framework/tensor.cpp @@ -39,10 +39,9 @@ struct TensorExportUtil<MirroredTensor> final { static std::shared_ptr<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype, const std::shared_ptr<const Device>& device, - bool is_lazy, bool requires_grad, bool is_leaf, - bool retain_grad) { - return MirroredTensor::MakeTensor(shape, dtype, device, is_lazy, requires_grad, is_leaf, - retain_grad); + bool is_lazy, bool requires_grad, + bool is_leaf) { + return MirroredTensor::MakeTensor(shape, dtype, device, is_lazy, requires_grad, is_leaf); } }; @@ -52,9 +51,9 @@ struct TensorExportUtil<ConsistentTensor> final { const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype, const std::shared_ptr<const compatible_py::Distribute>& distribute, const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad, - bool is_leaf, bool retain_grad) { + bool is_leaf) { return ConsistentTensor::MakeTensor(shape, dtype, distribute, parallel_desc, is_lazy, - requires_grad, is_leaf, retain_grad); + requires_grad, is_leaf); } }; @@ -73,7 +72,10 @@ void ExportTensor(py::module& m, const char* name) { .def_property_readonly("requires_grad", &T::requires_grad) .def_property_readonly("is_leaf", &T::is_leaf) // Methods of pytorch - .def("retain_grad", [](T& t) { t.set_retain_grad(true); }) + .def("retain_grad", + [](T& t) { + if (!t.is_leaf()) { t.set_retain_grad(true); } + }) .def("detach", [](const T& t) { return t.api_detach().GetPtrOrThrow(); }) // OneFlow tensor properties other than pytorch tensor .def_property_readonly("placement", &T::parallel_desc) diff --git a/oneflow/core/autograd/autograd_engine.cpp b/oneflow/core/autograd/autograd_engine.cpp index daf42b737b3a47a7b9f8c32c5c8071338f1dffa1..bda275dc6724a065d41910071cad197dbced4240 100644 --- a/oneflow/core/autograd/autograd_engine.cpp +++ b/oneflow/core/autograd/autograd_engine.cpp @@ -29,36 +29,25 @@ namespace one { namespace { -bool IsReadyToRun(const std::vector<std::shared_ptr<TensorArg>>& out_grads) { - return std::any_of( - out_grads.begin(), out_grads.end(), - [](const std::shared_ptr<TensorArg>& tensor_arg) { return !tensor_arg->Empty(); }); +bool IsReadyToRun(const std::vector<std::shared_ptr<AutogradMeta>>& out_meta_datas) { + return std::any_of(out_meta_datas.begin(), out_meta_datas.end(), + [](const std::shared_ptr<AutogradMeta>& meta_data) { + return !meta_data->now_grad_arg()->Empty(); + }); } -Maybe<void> InitEmptyTensorArgs2ZerosTensor(const TensorTuple& outputs, - std::vector<std::shared_ptr<TensorArg>>& out_grads) { - const auto& zero_like = JUST(op_expr_helper::ZeroLikeOp()); - for (int i = 0; i < out_grads.size(); ++i) { - if (out_grads.at(i)->Empty()) { - TensorTuple output(1); - JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*zero_like, {outputs.at(i)}, &output)); - JUST(out_grads.at(i)->PushPartialTensor(output.at(0))); - } - } - return Maybe<void>::Ok(); -} - -Maybe<void> CopyOrAccGrad(Tensor* tensor, bool autograd_mode) { +Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) { autograd::AutoGradMode mode(autograd_mode); - const auto& now_grad = JUST(tensor->now_grad_arg()->GetAccTensor()); - if (tensor->acc_grad()) { - TensorTuple input = {tensor->acc_grad(), now_grad}; + const auto& now_grad = JUST(autograd_meta->now_grad_arg()->GetAccTensor()); + if (!now_grad) { return Maybe<void>::Ok(); } + if (autograd_meta->acc_grad()) { + TensorTuple input = {autograd_meta->acc_grad(), now_grad}; TensorTuple output(1); const auto& add = JUST(op_expr_helper::AddOp()); JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*add, input, &output)); - tensor->set_acc_grad(output.at(0)); + autograd_meta->set_acc_grad(output.at(0)); } else { - tensor->set_acc_grad(now_grad); + autograd_meta->set_acc_grad(now_grad); } return Maybe<void>::Ok(); } @@ -69,16 +58,18 @@ StackFunctionNode::StackFunctionNode( const std::shared_ptr<const std::function<Maybe<void>(const TensorTuple&, TensorTuple*, bool)>>& backward_fn, const TensorTuple& inputs, const TensorTuple& outputs) { - inputs_ = std::make_shared<TensorTuple>(inputs.size()); + input_meta_datas_.resize(inputs.size()); + input_tensors_.resize(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { - inputs_->at(i) = inputs.at(i); - in_grads_.emplace_back(inputs.at(i)->now_grad_arg()); + input_meta_datas_.at(i) = inputs.at(i)->mut_autograd_meta(); + if (input_meta_datas_.at(i)->requires_grad()) { input_tensors_.at(i) = inputs.at(i); } } - outputs_ = std::make_shared<TensorTuple>(outputs.size()); + output_meta_datas_.resize(outputs.size()); + output_tensor_infos_.reserve(outputs.size()); for (int i = 0; i < outputs.size(); ++i) { - outputs_->at(i) = outputs.at(i)->detach(); - out_grads_.emplace_back(outputs.at(i)->now_grad_arg()); + output_meta_datas_.at(i) = outputs.at(i)->mut_autograd_meta(); + output_tensor_infos_.emplace_back(TensorInfo(*outputs.at(i))); } backward_fn_ = backward_fn; @@ -86,51 +77,52 @@ StackFunctionNode::StackFunctionNode( } Maybe<void> StackFunctionNode::AccGrad4RetainGradTensor() { - for (int i = 0; i < outputs_->size(); ++i) { - if (outputs_->at(i)->retain_grad() && outputs_->at(i)->requires_grad()) { - JUST(CopyOrAccGrad(outputs_->at(i).get(), /*autograd_mode=*/false)); - } + for (const std::shared_ptr<AutogradMeta>& out : output_meta_datas_) { + if (out->retain_grad()) { JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false)); } } return Maybe<void>::Ok(); } Maybe<void> StackFunctionNode::AccGrad4LeafTensor(bool create_graph) { - for (int i = 0; i < outputs_->size(); ++i) { - if (outputs_->at(i)->is_leaf() && outputs_->at(i)->requires_grad()) { - JUST(CopyOrAccGrad(outputs_->at(i).get(), /*autograd_mode=*/create_graph)); + for (const std::shared_ptr<AutogradMeta>& out : output_meta_datas_) { + if (out->is_leaf() && out->requires_grad()) { + JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false)); } } return Maybe<void>::Ok(); } void StackFunctionNode::ReleaseOutTensorArgs() { - for (const std::shared_ptr<TensorArg>& tensor_arg : out_grads_) { tensor_arg->Release(); } + for (const std::shared_ptr<AutogradMeta>& meta_data : output_meta_datas_) { + meta_data->now_grad_arg()->Release(); + } } void StackFunctionNode::ReleaseData() { - if (!inputs_->empty()) { - inputs_.reset(); - outputs_.reset(); - in_grads_.clear(); - out_grads_.clear(); - backward_fn_.reset(); - } + // Releases backward function and makes useless tensors release as early as possible + if (!input_meta_datas_.empty()) { backward_fn_.reset(); } + input_tensors_.clear(); is_in_stack_ = false; } Maybe<bool> StackFunctionNode::Apply(bool create_graph) { CHECK_NOTNULL_OR_RETURN(backward_fn_.get()) << "This FunctionNode with name `" << GetOpName() << "` has been released."; - if (!IsReadyToRun(out_grads_)) { return false; } - JUST(InitEmptyTensorArgs2ZerosTensor(*outputs_, out_grads_)); - TensorTuple input_grads(in_grads_.size()); - TensorTuple output_grads(out_grads_.size()); - for (int i = 0; i < out_grads_.size(); ++i) { - output_grads.at(i) = JUST(out_grads_.at(i)->GetAccTensor()); + if (!IsReadyToRun(output_meta_datas_)) { return false; } + TensorTuple input_grads(input_meta_datas_.size()); + TensorTuple output_grads(output_meta_datas_.size()); + for (int i = 0; i < output_meta_datas_.size(); ++i) { + if (output_meta_datas_.at(i)->now_grad_arg()->Empty()) { + output_grads.at(i) = JUST(output_tensor_infos_.at(i).zeros()); + } else { + output_grads.at(i) = JUST(output_meta_datas_.at(i)->now_grad_arg()->GetAccTensor()); + } } JUST((*backward_fn_)(output_grads, &input_grads, create_graph)); - for (int i = 0; i < in_grads_.size(); ++i) { - JUST(in_grads_.at(i)->PushPartialTensor(input_grads.at(i))); + for (int i = 0; i < input_meta_datas_.size(); ++i) { + if (input_grads.at(i)) { + JUST(input_meta_datas_.at(i)->now_grad_arg()->PushPartialTensor(input_grads.at(i))); + } } return true; } diff --git a/oneflow/core/autograd/autograd_engine.h b/oneflow/core/autograd/autograd_engine.h index cba0b5f29bb88b31e5f152947f88eeb88987b86d..53f9c3e1361db492b14b0a57dad5636178e913a1 100644 --- a/oneflow/core/autograd/autograd_engine.h +++ b/oneflow/core/autograd/autograd_engine.h @@ -27,9 +27,10 @@ namespace oneflow { namespace one { -class TensorArg; class Tensor; class TensorTuple; +class AutogradMeta; +class TensorInfo; // Calculates one backward op class FunctionNode { @@ -101,12 +102,10 @@ class StackFunctionNode final : public FunctionNode { void set_is_in_stack(bool in_stack) { is_in_stack_ = in_stack; } private: - // FunctionNode shares Tensor with `inputs_`, and only shares TensorImpl with `outputs_`. - // The reference link is `output tensors -> node -> inputs_/input tensors`. - std::shared_ptr<TensorTuple> inputs_; - std::shared_ptr<TensorTuple> outputs_; - std::vector<std::shared_ptr<TensorArg>> in_grads_; - std::vector<std::shared_ptr<TensorArg>> out_grads_; + std::vector<std::shared_ptr<Tensor>> input_tensors_; + std::vector<std::shared_ptr<AutogradMeta>> input_meta_datas_; + std::vector<std::shared_ptr<AutogradMeta>> output_meta_datas_; + std::vector<TensorInfo> output_tensor_infos_; // Actual backward function builds in `AutogradInterpreter` to calculate one backward op std::shared_ptr<const std::function<Maybe<void>(const TensorTuple&, TensorTuple*, bool)>> backward_fn_; diff --git a/oneflow/core/autograd/autograd_meta.cpp b/oneflow/core/autograd/autograd_meta.cpp new file mode 100644 index 0000000000000000000000000000000000000000..873810a5d504fb30653e5b0dc9fab07261e597e9 --- /dev/null +++ b/oneflow/core/autograd/autograd_meta.cpp @@ -0,0 +1,39 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/op_expr_helper.h" +#include "oneflow/core/framework/dtype.h" +#include "oneflow/core/autograd/autograd_meta.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" + +namespace oneflow { + +namespace one { + +TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {} + +Maybe<Tensor> TensorInfo::zeros() const { + const auto& interpreter = JUST(OpInterpUtil::GetInterpreter()); + const auto& zeros_op = JUST(op_expr_helper::ZerosOp(*shape_.get(), dtype_->data_type())); + TensorTuple outputs(1); + JUST(interpreter->Apply(*zeros_op, {}, &outputs)); + return outputs.at(0); +} + +} // namespace one + +} // namespace oneflow diff --git a/oneflow/core/autograd/autograd_meta.h b/oneflow/core/autograd/autograd_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..4fa8d4a4815f37ee8611b9f078457535c792c3db --- /dev/null +++ b/oneflow/core/autograd/autograd_meta.h @@ -0,0 +1,86 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_ +#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_ + +#include <memory> +#include "oneflow/core/framework/tensor_arg.h" +#include "oneflow/core/common/util.h" + +namespace oneflow { + +class Shape; +class DType; + +namespace one { + +class Tensor; +class TensorArg; + +class AutogradMeta final { + public: + AutogradMeta() = delete; + AutogradMeta(bool requires_grad, bool is_leaf) + : is_leaf_(is_leaf), + requires_grad_(requires_grad), + retain_grad_(false), + now_grad_arg_(new TensorArg) {} + + // Getters + const std::shared_ptr<Tensor>& acc_grad() const { return acc_grad_; } + const std::shared_ptr<TensorArg>& now_grad_arg() const { return now_grad_arg_; } + bool requires_grad() const { return requires_grad_; } + bool is_leaf() const { return is_leaf_; } + bool retain_grad() const { return retain_grad_; } + + // Setters + void set_acc_grad(const std::shared_ptr<Tensor>& grad) { acc_grad_ = grad; } + std::shared_ptr<Tensor> mut_acc_grad() { return acc_grad_; } + void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } + void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; } + void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; } + + private: + bool is_leaf_; + + // Only meaningful on leaf Tensors (must be false otherwise) + bool requires_grad_; + + // Oney meaningful on non_leaf Tensors (must be false otherwise) + bool retain_grad_; + + std::shared_ptr<Tensor> acc_grad_; + std::shared_ptr<TensorArg> now_grad_arg_; +}; + +class TensorInfo final { + public: + TensorInfo() = delete; + explicit TensorInfo(const Tensor& tensor); + + Maybe<Tensor> zeros() const; + + private: + std::shared_ptr<const Shape> shape_; + std::shared_ptr<const DType> dtype_; + // TODO: Add device info +}; + +} // namespace one +} // namespace oneflow + +#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_ diff --git a/oneflow/core/autograd/autograd_mode.h b/oneflow/core/autograd/autograd_mode.h index 26c5687cd283f772e8f0b22a7fbe7f009752d6f1..96c095fe22bdba7ad28e0008a94e981c96ec9946 100644 --- a/oneflow/core/autograd/autograd_mode.h +++ b/oneflow/core/autograd/autograd_mode.h @@ -14,8 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -namespace oneflow { +#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_ +#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_ +namespace oneflow { namespace autograd { struct GradMode { @@ -41,5 +43,6 @@ class NoGradGuard : public AutoGradMode { }; } // namespace autograd - } // namespace oneflow + +#endif // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_ diff --git a/oneflow/core/framework/op_expr_grad_function.h b/oneflow/core/framework/op_expr_grad_function.h index ea637d931d389808441823a91a203a73636d9f9c..beb3715ad27eaf6008ce55c1259dec7a0e671e54 100644 --- a/oneflow/core/framework/op_expr_grad_function.h +++ b/oneflow/core/framework/op_expr_grad_function.h @@ -54,7 +54,9 @@ class OpExprGradFunction : public OpExprGradFunctionIf { const TensorTuple& outputs, const AttrMap& attrs) const override { StateT* state = dynamic_cast<StateT*>(ctx); CHECK_NOTNULL_OR_RETURN(state); - return Capture(state, inputs, outputs, attrs); + TensorTuple detach_outputs(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) { detach_outputs.at(i) = outputs.at(i)->detach(); } + return Capture(state, inputs, detach_outputs, attrs); } Maybe<void> ApplyIf(const OpExprInterpState* ctx, const TensorTuple& out_grads, diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp index 1c1460c801f76b0bf200c3f777be82a16c06552f..b0b8ee8c7348bb5e1d18e7d036b68379081b7b69 100644 --- a/oneflow/core/framework/op_expr_helper.cpp +++ b/oneflow/core/framework/op_expr_helper.cpp @@ -80,26 +80,26 @@ OF_PP_FOR_EACH_TUPLE(DEFINE_FLOATING_CONSTATNT_OP, FLOATING_DATA_TYPE_SEQ); OF_PP_FOR_EACH_TUPLE(DEFINE_INTEGER_CONSTATNT_OP, INT_DATA_TYPE_SEQ) #undef DEFINE_INTEGER_CONSTATNT_OP -Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype) { +Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype) { return OnesOp(shape, dtype, UniqueOpName("constant")); } -Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype, const std::string& name) { +Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype, const std::string& name) { switch (dtype) { #define CONSTANT_DATA_TYPE_CASE(cpp_type, data_type) \ - case data_type: return ConstantOp(shape, (cpp_type)1, name); + case data_type: return ConstantOp(shape, (cpp_type)0, name); OF_PP_FOR_EACH_TUPLE(CONSTANT_DATA_TYPE_CASE, FLOATING_DATA_TYPE_SEQ INT_DATA_TYPE_SEQ); #undef CONSTANT_DATA_TYPE_CASE default: UNIMPLEMENTED_THEN_RETURN(); } } -Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype) { - return ZerosOp(shape, dtype, UniqueOpName("constant")); +Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype) { + return OnesOp(shape, dtype, UniqueOpName("constant")); } -Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype, const std::string& name) { +Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype, const std::string& name) { switch (dtype) { #define CONSTANT_DATA_TYPE_CASE(cpp_type, data_type) \ - case data_type: return ConstantOp(shape, (cpp_type)0, name); + case data_type: return ConstantOp(shape, (cpp_type)1, name); OF_PP_FOR_EACH_TUPLE(CONSTANT_DATA_TYPE_CASE, FLOATING_DATA_TYPE_SEQ INT_DATA_TYPE_SEQ); #undef CONSTANT_DATA_TYPE_CASE default: UNIMPLEMENTED_THEN_RETURN(); diff --git a/oneflow/core/framework/op_expr_helper.h b/oneflow/core/framework/op_expr_helper.h index 7eadedc7dcb98bed15114249fa929a58fa8fb6a7..6d2f7268c891bdaebc0c28e20b5605f5c6f5eecd 100644 --- a/oneflow/core/framework/op_expr_helper.h +++ b/oneflow/core/framework/op_expr_helper.h @@ -28,6 +28,9 @@ Maybe<one::UserOpExpr> AddNOp(int32_t n, const std::string& name); Maybe<one::UserOpExpr> AddOp(); Maybe<one::UserOpExpr> AddOp(const std::string& name); +Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype); +Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype, const std::string& name); + Maybe<one::UserOpExpr> ZeroLikeOp(); Maybe<one::UserOpExpr> ZeroLikeOp(const std::string& name); @@ -39,9 +42,6 @@ Maybe<one::UserOpExpr> ConstantOp(const Shape& shape, const T& value, const std: Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype); Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype, const std::string& name); -Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype); -Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype, const std::string& name); - Maybe<one::UserOpExpr> IdentityOp(); Maybe<one::UserOpExpr> IdentityOp(const std::string& name); diff --git a/oneflow/core/framework/op_interpreter.h b/oneflow/core/framework/op_interpreter.h index caf73af34dcc150e92c385ed94c13d567dfeea45..7c4cbc7307b3ca93e859b6922acd31df1e4b2bf3 100644 --- a/oneflow/core/framework/op_interpreter.h +++ b/oneflow/core/framework/op_interpreter.h @@ -33,7 +33,7 @@ class OpExprInterpState { size_t SaveTensorForBackward(const std::shared_ptr<Tensor>& tensor) { size_t offset = saved_tensors_.size(); - saved_tensors_.push_back(tensor->detach()); + saved_tensors_.push_back(tensor); return offset; } diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp index f36fb8a116aa1ba71f6791406e25d0033de70948..dd85a53ff4b51a1731664d5f0e0dd738ddbbc6ae 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp @@ -160,13 +160,12 @@ using Bn2BlobObjectMap = HashMap<std::string, std::shared_ptr<compatible_py::Blo const auto& device = JUST(Device::MakeDeviceByParallelDesc(*parallel_attr->parallel_desc_symbol())); return static_cast<std::shared_ptr<Tensor>>(MirroredTensor::MakeTensor( - blob_attr->shape(), dtype, device, is_lazy, /*requires_grad=*/false, /*is_leaf=*/false, - /*retain_grad=*/false)); + blob_attr->shape(), dtype, device, is_lazy, /*requires_grad=*/false, /*is_leaf=*/false)); } else { const auto& distribute = JUST(compatible_py::MakeDistribute(*(parallel_attr->sbp_parallel()))); return static_cast<std::shared_ptr<Tensor>>(ConsistentTensor::MakeTensor( blob_attr->shape(), dtype, distribute, parallel_attr->parallel_desc_symbol(), is_lazy, - /*requires_grad=*/false, /*is_leaf=*/false, /*retain_grad=*/false)); + /*requires_grad=*/false, /*is_leaf=*/false)); } } @@ -174,8 +173,7 @@ using Bn2BlobObjectMap = HashMap<std::string, std::shared_ptr<compatible_py::Blo const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object, const std::shared_ptr<const Device>& device) { auto tensor = MirroredTensor::MakeEagerTensor(eager_blob_object, device, - /* requires_grad */ false, /* is_leaf */ false, - /* retain_grad */ false); + /* requires_grad */ false, /* is_leaf */ false); return std::static_pointer_cast<Tensor>(tensor); } diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index 6ba8841df5e2d6e34e385f44b39db2f4ceb24c90..7e3522eae4a1b133a2643cd6186653034af9bfcf 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -26,27 +26,24 @@ namespace one { std::shared_ptr<MirroredTensor> MirroredTensor::MakeTensor( const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype, - const std::shared_ptr<const Device>& device, bool is_lazy, bool requires_grad, bool is_leaf, - bool retain_grad) { + const std::shared_ptr<const Device>& device, bool is_lazy, bool requires_grad, bool is_leaf) { std::shared_ptr<MirroredTensorImpl> impl; if (is_lazy) { - impl = std::make_shared<LazyMirroredTensorImpl>(shape, dtype, device, requires_grad, is_leaf, - retain_grad); + impl = std::make_shared<LazyMirroredTensorImpl>(shape, dtype, device, requires_grad, is_leaf); } else { const auto eager_blob_object = CHECK_JUST(GenerateAllocatedEagerBlobObject(dtype->data_type(), *shape)); impl = std::make_shared<EagerMirroredTensorImpl>(eager_blob_object, device, requires_grad, - is_leaf, retain_grad); + is_leaf); } return std::make_shared<MirroredTensor>(impl); } std::shared_ptr<MirroredTensor> MirroredTensor::MakeEagerTensor( const std::shared_ptr<vm::EagerBlobObject> eager_blob_object, - const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf, - bool retain_grad) { - std::shared_ptr<MirroredTensorImpl> impl = std::make_shared<EagerMirroredTensorImpl>( - eager_blob_object, device, requires_grad, is_leaf, retain_grad); + const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf) { + std::shared_ptr<MirroredTensorImpl> impl = + std::make_shared<EagerMirroredTensorImpl>(eager_blob_object, device, requires_grad, is_leaf); return std::make_shared<MirroredTensor>(impl); } @@ -60,7 +57,7 @@ int64_t MirroredTensor::nelement() const { return shape()->elem_cnt(); } std::shared_ptr<MirroredTensor> MirroredTensor::data() const { std::shared_ptr<MirroredTensor> t = - MakeTensor(shape(), dtype(), device(), is_lazy(), false, is_leaf(), false); + MakeTensor(shape(), dtype(), device(), is_lazy(), false, is_leaf()); t->set_blob_object(blob_object()); return t; } @@ -74,14 +71,14 @@ std::shared_ptr<ConsistentTensor> ConsistentTensor::MakeTensor( const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype, const std::shared_ptr<const compatible_py::Distribute>& distribute, const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad, - bool is_leaf, bool retain_grad) { + bool is_leaf) { std::shared_ptr<ConsistentTensorImpl> impl; if (is_lazy) { impl = std::make_shared<LazyConsistentTensorImpl>(shape, dtype, distribute, parallel_desc, - requires_grad, is_leaf, retain_grad); + requires_grad, is_leaf); } else { impl = std::make_shared<EagerConsistentTensorImpl>(shape, dtype, distribute, parallel_desc, - requires_grad, is_leaf, retain_grad); + requires_grad, is_leaf); } return std::make_shared<ConsistentTensor>(impl); } @@ -97,8 +94,8 @@ int64_t ConsistentTensor::nelement() const { return shape()->elem_cnt(); } int64_t ConsistentTensor::ndim() const { return shape()->NumAxes(); } std::shared_ptr<ConsistentTensor> ConsistentTensor::data() const { - std::shared_ptr<ConsistentTensor> t = MakeTensor(shape(), dtype(), distribute(), parallel_desc(), - is_lazy(), false, is_leaf(), false); + std::shared_ptr<ConsistentTensor> t = + MakeTensor(shape(), dtype(), distribute(), parallel_desc(), is_lazy(), false, is_leaf()); t->set_blob_object(blob_object()); return t; } diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index ebf0d47953b0c8d58e584d5322c8dbddaefd1be6..d13c523ccb169d5e71ec3dba54295cec2bf2bd73 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -100,6 +100,7 @@ class Tensor { virtual void set_acc_grad(const std::shared_ptr<Tensor>& grad) = 0; virtual std::shared_ptr<Tensor> mut_acc_grad() = 0; virtual void set_is_leaf(bool is_leaf) = 0; + virtual std::shared_ptr<AutogradMeta> mut_autograd_meta() = 0; protected: Tensor() = default; @@ -218,9 +219,10 @@ class MirroredTensor final : public TensorIf<MirroredTensor> { // Setters for autograd void set_acc_grad(const std::shared_ptr<Tensor>& grad) override { impl_->set_acc_grad(grad); } void set_requires_grad(bool requires_grad) override { impl_->set_requires_grad(requires_grad); } - void set_retain_grad(bool retain_grad) override { impl_->set_requires_grad(retain_grad); } + void set_retain_grad(bool retain_grad) override { impl_->set_retain_grad(retain_grad); } std::shared_ptr<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); } void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); } + std::shared_ptr<AutogradMeta> mut_autograd_meta() override { return impl_->mut_autograd_meta(); } // Operators for tensor std::shared_ptr<Tensor> detach() const override; @@ -239,13 +241,11 @@ class MirroredTensor final : public TensorIf<MirroredTensor> { static std::shared_ptr<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype, const std::shared_ptr<const Device>& device, - bool is_lazy, bool requires_grad, bool is_leaf, - bool retain_grad); + bool is_lazy, bool requires_grad, bool is_leaf); static std::shared_ptr<MirroredTensor> MakeEagerTensor( const std::shared_ptr<vm::EagerBlobObject> eager_blob_object, - const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf, - bool retain_grad); + const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf); private: std::shared_ptr<MirroredTensorImpl> impl_; @@ -308,8 +308,9 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> { void set_acc_grad(const std::shared_ptr<Tensor>& grad) override { impl_->set_acc_grad(grad); } std::shared_ptr<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); } void set_requires_grad(bool requires_grad) override { impl_->set_requires_grad(requires_grad); } - void set_retain_grad(bool retain_grad) override { impl_->set_requires_grad(retain_grad); } + void set_retain_grad(bool retain_grad) override { impl_->set_retain_grad(retain_grad); } void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); } + std::shared_ptr<AutogradMeta> mut_autograd_meta() override { return impl_->mut_autograd_meta(); } // Operators for tensor std::shared_ptr<Tensor> detach() const override; @@ -329,7 +330,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> { const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype, const std::shared_ptr<const compatible_py::Distribute>& distribute, const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad, - bool is_leaf, bool retain_grad); + bool is_leaf); private: std::shared_ptr<ConsistentTensorImpl> impl_; diff --git a/oneflow/core/framework/tensor_impl.cpp b/oneflow/core/framework/tensor_impl.cpp index e9e1c9cb4d072471d3bd4c6e038071b65a0739c6..a9e9e0ab554bb5e6e2e53bc2eccdb05d4d899387 100644 --- a/oneflow/core/framework/tensor_impl.cpp +++ b/oneflow/core/framework/tensor_impl.cpp @@ -72,9 +72,8 @@ Maybe<void> EagerConsistentTensorImpl::set_blob_object( EagerMirroredTensorImpl::EagerMirroredTensorImpl( const std::shared_ptr<vm::EagerBlobObject> eager_blob_object, - const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf, bool retain_grad) - : MirroredTensorImpl(device, requires_grad, is_leaf, retain_grad), - eager_blob_object_(eager_blob_object) { + const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf) + : MirroredTensorImpl(device, requires_grad, is_leaf), eager_blob_object_(eager_blob_object) { dtype_ = CHECK_JUST(DType::GetDTypeByDataType(eager_blob_object->blob_desc().data_type())); tensor_storage_ = std::make_shared<TensorStorage>(eager_blob_object->tensor_buffer()); const auto& parallel_desc = this->parallel_desc(); diff --git a/oneflow/core/framework/tensor_impl.h b/oneflow/core/framework/tensor_impl.h index 2f1e9df1211e24f3b730ba936183334244886dda..b470de2ee4c8e19d1d5bc7bf08037ecbcf25679f 100644 --- a/oneflow/core/framework/tensor_impl.h +++ b/oneflow/core/framework/tensor_impl.h @@ -19,11 +19,10 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" -#include "oneflow/core/common/shape.h" #include "oneflow/core/job/placement.cfg.h" #include "oneflow/core/framework/object.h" -#include "oneflow/core/framework/tensor_arg.h" #include "oneflow/core/framework/tensor_storage.h" +#include "oneflow/core/autograd/autograd_meta.h" namespace oneflow { @@ -35,6 +34,7 @@ namespace compatible_py { class Distribute; } +class Shape; class Device; class DType; @@ -63,11 +63,11 @@ class TensorImpl { virtual Maybe<VmLocalDepObject> compute_local_dep_object() const = 0; // Getters for autograd - const std::shared_ptr<Tensor>& acc_grad() const { return acc_grad_; } - const std::shared_ptr<TensorArg>& now_grad_arg() const { return now_grad_arg_; } - bool requires_grad() const { return requires_grad_; } - bool is_leaf() const { return is_leaf_; } - bool retain_grad() const { return retain_grad_; } + const std::shared_ptr<Tensor>& acc_grad() const { return autograd_meta_->acc_grad(); } + const std::shared_ptr<TensorArg>& now_grad_arg() const { return autograd_meta_->now_grad_arg(); } + bool requires_grad() const { return autograd_meta_->requires_grad(); } + bool is_leaf() const { return autograd_meta_->is_leaf(); } + bool retain_grad() const { return autograd_meta_->retain_grad(); } // Setters virtual void set_shape(const std::shared_ptr<const Shape>& shape) = 0; @@ -76,11 +76,12 @@ class TensorImpl { const std::shared_ptr<const ParallelDesc>& parallel_desc) = 0; // Setters for autograd - void set_acc_grad(const std::shared_ptr<Tensor>& grad) { acc_grad_ = grad; } - std::shared_ptr<Tensor> mut_acc_grad() { return acc_grad_; } - void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } - void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; } - void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; } + void set_acc_grad(const std::shared_ptr<Tensor>& grad) { autograd_meta_->set_acc_grad(grad); } + std::shared_ptr<Tensor> mut_acc_grad() { return autograd_meta_->mut_acc_grad(); } + void set_requires_grad(bool requires_grad) { autograd_meta_->set_requires_grad(requires_grad); } + void set_retain_grad(bool retain_grad) { autograd_meta_->set_retain_grad(retain_grad); } + void set_is_leaf(bool is_leaf) { autograd_meta_->set_is_leaf(is_leaf); } + std::shared_ptr<AutogradMeta> mut_autograd_meta() { return autograd_meta_; } // Getters to be deprecated virtual const std::shared_ptr<compatible_py::BlobObject>& blob_object() const = 0; @@ -90,20 +91,13 @@ class TensorImpl { const std::shared_ptr<compatible_py::BlobObject>& blob_object) = 0; protected: - TensorImpl(bool requires_grad, bool is_leaf, bool retain_grad) - : requires_grad_(requires_grad), - is_leaf_(is_leaf), - retain_grad_(retain_grad), - now_grad_arg_(new TensorArg) {} + TensorImpl(bool requires_grad, bool is_leaf) + : autograd_meta_(new AutogradMeta(requires_grad, is_leaf)) {} Maybe<void> SyncBlobObject2Attributes( const std::shared_ptr<compatible_py::BlobObject>& blob_object); - // For autograd - bool requires_grad_; - bool is_leaf_; - bool retain_grad_; - std::shared_ptr<Tensor> acc_grad_; - std::shared_ptr<TensorArg> now_grad_arg_; + protected: + std::shared_ptr<AutogradMeta> autograd_meta_; }; class MirroredTensorImpl : public TensorImpl { @@ -123,9 +117,8 @@ class MirroredTensorImpl : public TensorImpl { std::shared_ptr<vm::EagerBlobObject> eager_blob_object) = 0; protected: - MirroredTensorImpl(const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf, - bool retain_grad) - : TensorImpl(requires_grad, is_leaf, retain_grad) { + MirroredTensorImpl(const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf) + : TensorImpl(requires_grad, is_leaf) { set_device(device); } @@ -152,8 +145,8 @@ class ConsistentTensorImpl : public TensorImpl { protected: ConsistentTensorImpl(const std::shared_ptr<const ParallelDesc>& parallel_desc, bool requires_grad, - bool is_leaf, bool retain_grad) - : TensorImpl(requires_grad, is_leaf, retain_grad), parallel_desc_(parallel_desc) {} + bool is_leaf) + : TensorImpl(requires_grad, is_leaf), parallel_desc_(parallel_desc) {} const std::shared_ptr<const Device> device_; // always nullptr std::shared_ptr<const ParallelDesc> parallel_desc_; @@ -165,10 +158,8 @@ class LazyMirroredTensorImpl final : public MirroredTensorImpl { LazyMirroredTensorImpl(const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype, const std::shared_ptr<const Device>& device, bool requires_grad, - bool is_leaf, bool retain_grad) - : MirroredTensorImpl(device, requires_grad, is_leaf, retain_grad), - shape_(shape), - dtype_(dtype) {} + bool is_leaf) + : MirroredTensorImpl(device, requires_grad, is_leaf), shape_(shape), dtype_(dtype) {} ~LazyMirroredTensorImpl() override = default; // Getters @@ -211,7 +202,12 @@ class EagerMirroredTensorImpl final : public MirroredTensorImpl { OF_DISALLOW_COPY_AND_MOVE(EagerMirroredTensorImpl); EagerMirroredTensorImpl(const std::shared_ptr<vm::EagerBlobObject> eager_blob_object, const std::shared_ptr<const Device>& device, bool requires_grad, - bool is_leaf, bool retain_grad); + bool is_leaf); + EagerMirroredTensorImpl(const std::shared_ptr<const Shape>& shape, + const std::shared_ptr<const DType>& dtype, + const std::shared_ptr<const Device>& device, + const std::shared_ptr<TensorStorage>& tensor_storage, bool requires_grad, + bool is_leaf); ~EagerMirroredTensorImpl() override; // Getters @@ -257,8 +253,8 @@ class LazyConsistentTensorImpl final : public ConsistentTensorImpl { const std::shared_ptr<const DType>& dtype, const std::shared_ptr<const compatible_py::Distribute>& distribute, const std::shared_ptr<const ParallelDesc>& parallel_desc, - bool requires_grad, bool is_leaf, bool retain_grad) - : ConsistentTensorImpl(parallel_desc, requires_grad, is_leaf, retain_grad), + bool requires_grad, bool is_leaf) + : ConsistentTensorImpl(parallel_desc, requires_grad, is_leaf), shape_(shape), dtype_(dtype), distribute_(distribute) {} @@ -308,8 +304,8 @@ class EagerConsistentTensorImpl final : public ConsistentTensorImpl { const std::shared_ptr<const DType>& dtype, const std::shared_ptr<const compatible_py::Distribute>& distribute, const std::shared_ptr<const ParallelDesc>& parallel_desc, - bool requires_grad, bool is_leaf, bool retain_grad) - : ConsistentTensorImpl(parallel_desc, requires_grad, is_leaf, retain_grad), + bool requires_grad, bool is_leaf) + : ConsistentTensorImpl(parallel_desc, requires_grad, is_leaf), shape_(shape), dtype_(dtype), distribute_(distribute) {} diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index 96f6bc70c41c704a868caf139fd63b6434153b12..cdf4a7ed644b4317a1e37700463353c71ddbd939 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -82,7 +82,6 @@ class Tensor: dtype=None, device=None, requires_grad=False, - retain_grad=False, placement=None, sbp=None, is_consistent=False, @@ -110,7 +109,6 @@ class Tensor: dtype=dtype, device=device, requires_grad=requires_grad, - retain_grad=retain_grad, placement=placement, sbp=sbp, is_consistent=is_consistent, @@ -124,7 +122,6 @@ class Tensor: dtype, device=device, requires_grad=requires_grad, - retain_grad=retain_grad, placement=placement, sbp=sbp, is_consistent=is_consistent, @@ -511,7 +508,6 @@ class Tensor: dtype=None, device=None, requires_grad=False, - retain_grad=False, placement=None, sbp=None, is_consistent=False, @@ -533,7 +529,6 @@ class Tensor: dtype, device=device, requires_grad=requires_grad, - retain_grad=retain_grad, placement=placement, sbp=sbp, is_consistent=is_consistent, @@ -549,7 +544,6 @@ class UndeterminedTensor: dtype, device=None, requires_grad=False, - retain_grad=False, placement=None, sbp=None, is_consistent=False, @@ -573,7 +567,6 @@ class UndeterminedTensor: self.dtype = dtype self.device = device self.requires_grad = requires_grad - self.retain_grad = retain_grad self.placement = placement self.sbp = sbp self.is_consistent = is_consistent @@ -623,7 +616,6 @@ def _default_initializer_for_determining(tensor): undetermined_tensor.is_lazy, undetermined_tensor.requires_grad, True, - undetermined_tensor.retain_grad, ) determined_tensor._set_blob_object( _create_blob_object( @@ -643,7 +635,6 @@ def _default_initializer_for_determining(tensor): undetermined_tensor.is_lazy, undetermined_tensor.requires_grad, True, - undetermined_tensor.retain_grad, ) _init_eager_local_tensor_by_initializer_conf( determined_tensor, undetermined_tensor.data_initializer @@ -682,7 +673,6 @@ def _numpy_initializer_for_determining(tensor): undetermined_tensor.is_lazy, undetermined_tensor.requires_grad, True, - undetermined_tensor.retain_grad, ) determined_tensor._set_blob_object(blob.blob_object) else: @@ -693,7 +683,6 @@ def _numpy_initializer_for_determining(tensor): undetermined_tensor.is_lazy, undetermined_tensor.requires_grad, True, - undetermined_tensor.retain_grad, ) _copy_from_numpy_to_eager_local_tensor(determined_tensor, numpy_data)