Skip to content
Snippets Groups Projects
Unverified Commit 2e75393e authored by poohRui's avatar poohRui Committed by GitHub
Browse files

Support inplace operations (#5204)


* support inplace forward

* support inplace backward

* add test case

* add test case for clone

* inplace is not support for leaf nodes

* refine clone

* add checks

* refine

* forbid clone with no grad

* Separate autograd meta to tensor (#5267)

* separate autograd meta

* minor fix

* fix acc_grad interface

* fix acc_grad with null

* minor fix

* inplace without clone

* refine

* minor fix

* remove maybe from constructor

* change from create to set

* fix merge bugs

* fix merge bug

* remove inplace flag in local_call_opkernel_phy_instr_operand

* remove out-date codes

* refine code

* add JUST

* fix merge master bug

* revert autograd engine input_grad check

* fix bug in tensor_hook

Co-authored-by: default avatarwyg1997 <wyg19970408@gmail.com>
Co-authored-by: default avatarHoujiang Chen <chenhoujiangcug@gmail.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 0a821433
No related branches found
No related tags found
No related merge requests found
Showing
with 245 additions and 90 deletions
......@@ -215,6 +215,7 @@ void SpecializedDef(py::class_<MirroredTensor, Tensor, std::shared_ptr<MirroredT
api->def("zeros_", &ApiEagerMirroredTensorZeros);
api->def("_register_hook",
[](const std::shared_ptr<MirroredTensor>& self, const AutogradMeta::Hook& hook) -> void {
if (!self->grad_fn_node()) { CHECK_JUST(AddAccumulateFunctionNode(self)); }
self->mut_autograd_meta()->add_hook(hook);
});
}
......@@ -256,9 +257,10 @@ void ExportTensor(py::module& m, const char* name) {
// Methods of pytorch
.def("retain_grad",
[](T& t) {
if (!t.is_leaf()) { t.set_retain_grad(true); }
if (!t.is_leaf()) { t.set_retain_grad(true).GetOrThrow(); }
})
.def("detach", [](const T& t) { return t.api_detach().GetPtrOrThrow(); })
.def("clone", [](const T& t) { return t.api_clone().GetPtrOrThrow(); })
// OneFlow tensor properties other than pytorch tensor
.def_property_readonly("is_lazy", &T::is_lazy)
.def_property_readonly("is_consistent", &T::is_consistent);
......
......@@ -64,8 +64,8 @@ StackFunctionNode::StackFunctionNode(
input_meta_datas_.resize(inputs.size());
next_functions_->reserve(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
input_meta_datas_.at(i) = inputs.at(i)->mut_autograd_meta();
if (input_meta_datas_.at(i)->requires_grad()) {
if (inputs.at(i)->requires_grad()) {
input_meta_datas_.at(i) = inputs.at(i)->mut_autograd_meta();
next_functions_->emplace_back(inputs.at(i)->mut_grad_fn_node());
}
}
......@@ -73,6 +73,9 @@ StackFunctionNode::StackFunctionNode(
output_meta_datas_.resize(outputs.size());
output_tensor_infos_.reserve(outputs.size());
for (int i = 0; i < outputs.size(); ++i) {
const auto& autograd_meta =
NewAutogradMeta(outputs.at(i)->requires_grad(), outputs.at(i)->is_leaf());
outputs.at(i)->set_autograd_meta(autograd_meta);
output_meta_datas_.at(i) = outputs.at(i)->mut_autograd_meta();
output_tensor_infos_.emplace_back(TensorInfo(*outputs.at(i)));
}
......@@ -126,6 +129,7 @@ Maybe<bool> FunctionNode::Apply(bool create_graph) {
JUST((*backward_fn_)(output_grads, &input_grads, create_graph));
for (int i = 0; i < input_meta_datas_.size(); ++i) {
if (input_grads.at(i)) {
CHECK_NOTNULL_OR_RETURN(input_meta_datas_.at(i));
JUST(input_meta_datas_.at(i)->now_grad_arg()->PushPartialTensor(input_grads.at(i)));
}
}
......@@ -148,7 +152,7 @@ Maybe<void> StackAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor
bool create_graph) {
ClearReleasedFunctionNodes();
for (int i = 0; i < outputs.size(); ++i) {
JUST(outputs.at(i)->now_grad_arg()->PushPartialTensor(out_grads.at(i)));
JUST(JUST(outputs.at(i)->now_grad_arg())->PushPartialTensor(out_grads.at(i)));
}
// Runs each FunctionNode
for (const auto& weak_func_node : node_list_) {
......@@ -173,10 +177,10 @@ Maybe<TensorTuple> StackAutogradEngine::RunBackwardAndReturnInputsTensorGrad(
std::vector<bool> ori_retain_grad(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
ori_retain_grad.at(i) = inputs.at(i)->retain_grad();
inputs.at(i)->set_retain_grad(true);
JUST(inputs.at(i)->set_retain_grad(true));
}
for (int i = 0; i < outputs.size(); ++i) {
JUST(outputs.at(i)->now_grad_arg()->PushPartialTensor(out_grads.at(i)));
JUST(JUST(outputs.at(i)->now_grad_arg())->PushPartialTensor(out_grads.at(i)));
}
// Runs each FunctionNode
for (const auto& weak_func_node : node_list_) {
......@@ -190,10 +194,10 @@ Maybe<TensorTuple> StackAutogradEngine::RunBackwardAndReturnInputsTensorGrad(
}
// Gets input grads and resume retain_grad
for (int i = 0; i < inputs.size(); ++i) {
input_now_grads->at(i) = inputs.at(i)->acc_grad();
input_now_grads->at(i) = JUST(inputs.at(i)->acc_grad());
if (!ori_retain_grad.at(i)) {
inputs.at(i)->set_acc_grad(nullptr);
inputs.at(i)->set_retain_grad(false);
JUST(inputs.at(i)->set_acc_grad(nullptr));
JUST(inputs.at(i)->set_retain_grad(false));
}
}
if (!retain_graph) { ClearEngine(); }
......@@ -241,8 +245,8 @@ GraphFunctionNode::GraphFunctionNode(
input_meta_datas_.resize(inputs.size());
next_functions_->reserve(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
input_meta_datas_.at(i) = inputs.at(i)->mut_autograd_meta();
if (input_meta_datas_.at(i)->requires_grad()) {
if (inputs.at(i)->requires_grad()) {
input_meta_datas_.at(i) = inputs.at(i)->mut_autograd_meta();
next_functions_->emplace_back(inputs.at(i)->mut_grad_fn_node());
}
}
......@@ -250,6 +254,9 @@ GraphFunctionNode::GraphFunctionNode(
output_meta_datas_.resize(outputs.size());
output_tensor_infos_.reserve(outputs.size());
for (int i = 0; i < outputs.size(); ++i) {
const auto& autograd_meta =
NewAutogradMeta(outputs.at(i)->requires_grad(), outputs.at(i)->is_leaf());
outputs.at(i)->set_autograd_meta(autograd_meta);
output_meta_datas_.at(i) = outputs.at(i)->mut_autograd_meta();
output_tensor_infos_.emplace_back(TensorInfo(*outputs.at(i)));
}
......@@ -373,7 +380,7 @@ Maybe<void> GraphAutogradEngine::RunBackwardAndSaveGrads4LeafTensor(const Tensor
bool retain_graph,
bool create_graph) {
for (int i = 0; i < outputs.size(); ++i) {
JUST(outputs.at(i)->now_grad_arg()->PushPartialTensor(out_grads.at(i)));
JUST(JUST(outputs.at(i)->now_grad_arg())->PushPartialTensor(out_grads.at(i)));
}
GraphTask graph_task(outputs, retain_graph, create_graph);
JUST(graph_task.ComputeDependencies());
......@@ -389,10 +396,10 @@ Maybe<TensorTuple> GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad(
std::vector<bool> ori_retain_grad(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
ori_retain_grad.at(i) = inputs.at(i)->retain_grad();
inputs.at(i)->set_retain_grad(true);
JUST(inputs.at(i)->set_retain_grad(true));
}
for (int i = 0; i < outputs.size(); ++i) {
JUST(outputs.at(i)->now_grad_arg()->PushPartialTensor(out_grads.at(i)));
JUST(JUST(outputs.at(i)->now_grad_arg())->PushPartialTensor(out_grads.at(i)));
}
JUST(graph_task.ComputeDependenciesAndPruneNode(inputs));
......@@ -400,10 +407,10 @@ Maybe<TensorTuple> GraphAutogradEngine::RunBackwardAndReturnInputsTensorGrad(
// Gets input grads and resume retain_grad
for (int i = 0; i < inputs.size(); ++i) {
input_now_grads->at(i) = inputs.at(i)->acc_grad();
input_now_grads->at(i) = JUST(inputs.at(i)->acc_grad());
if (!ori_retain_grad.at(i)) {
inputs.at(i)->set_acc_grad(nullptr);
inputs.at(i)->set_retain_grad(false);
JUST(inputs.at(i)->set_acc_grad(nullptr));
JUST(inputs.at(i)->set_retain_grad(false));
}
}
return input_now_grads;
......
......@@ -512,7 +512,7 @@ struct LocalCallOpKernelUtil final {
static inline Maybe<void> InitOutputBlobs(LocalCallOpKernelPhyInstrOperand* operand) {
JUST(operand->ForEachOutputTensor([&](vm::EagerBlobObject* blob_object) -> Maybe<void> {
CHECK_OR_RETURN(static_cast<bool>(blob_object));
JUST(blob_object->InitBlob());
JUST(blob_object->TryInitBlob());
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
......
......@@ -59,14 +59,17 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
}
input_eager_blob_objects->at(i) = JUST(inputs.at(i)->eager_blob_object());
}
std::shared_ptr<EagerBlobObjectList> output_eager_blob_objects =
std::make_shared<EagerBlobObjectList>(outputs->size());
for (int i = 0; i < outputs->size(); i++) {
if (!outputs->at(i)) {
outputs->at(i) =
std::make_shared<MirroredTensor>(std::make_shared<EagerMirroredTensorImpl>());
}
if (JUST(outputs->at(i)->has_eager_blob_object())) {
output_eager_blob_objects->at(i) = JUST(outputs->at(i)->eager_blob_object());
}
}
std::shared_ptr<EagerBlobObjectList> output_eager_blob_objects =
std::make_shared<EagerBlobObjectList>(outputs->size());
Symbol<Device> op_device;
std::shared_ptr<const ParallelDesc> op_parallel_desc;
bool need_check_mem_case = true;
......@@ -102,9 +105,11 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
}));
for (int i = 0; i < output_eager_blob_objects->size(); i++) {
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i)));
JUST(tensor_impl->InitEagerBlobObject(JUST(outputs->at(i)->device())->mem_case()));
output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object());
if (!output_eager_blob_objects->at(i)) {
auto* tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i)));
JUST(tensor_impl->InitEagerBlobObject(JUST(outputs->at(i)->device())->mem_case()));
output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object());
}
}
const auto& kernel = JUST(user_op_expr.MutKernel4Device(*op_device));
......
......@@ -20,6 +20,9 @@ limitations under the License.
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
namespace oneflow {
......@@ -51,8 +54,7 @@ namespace one {
const auto& blob_desc = eager_blob_object->blob_desc();
const auto& tensor_meta =
std::make_shared<MirroredTensorMeta>(blob_desc.shape_ptr(), blob_desc.data_type(), device);
const auto& autograd_meta = std::make_shared<AutogradMeta>(requires_grad, is_leaf);
auto* tensor_impl = new EagerMirroredTensorImpl(tensor_meta, autograd_meta);
auto* tensor_impl = new EagerMirroredTensorImpl(tensor_meta, requires_grad, is_leaf);
JUST(tensor_impl->InitEagerBlobObjectAndTensorStorage(eager_blob_object, tensor_storage));
return std::make_shared<MirroredTensor>(std::shared_ptr<MirroredTensorImpl>(tensor_impl));
}
......@@ -74,6 +76,21 @@ Maybe<MirroredTensor> MirroredTensor::api_detach() const {
return std::make_shared<MirroredTensor>(JUST(impl_->detach()));
}
Maybe<Tensor> MirroredTensor::clone() const {
const auto& device_type = JUST(this->device())->type();
int64_t device_id = JUST(this->device())->device_id();
std::shared_ptr<OpExpr> copy_op_ = JUST(one::OpBuilder("copy")
.Input("in", 1)
.Attr("device_type", device_type)
.Attr("device_id", device_id)
.Output("out", 1)
.Build());
std::shared_ptr<MirroredTensor> input =
std::const_pointer_cast<MirroredTensor>(shared_from_this());
const auto& output = JUST(OpInterpUtil::Dispatch<Tensor>(*copy_op_, {input}));
return output;
}
Maybe<ConsistentTensor> ConsistentTensor::MakeTensor(
const std::shared_ptr<const Shape>& shape, DataType dtype,
Symbol<cfg::ParallelDistribution> parallel_distribution, Symbol<ParallelDesc> parallel_desc,
......
......@@ -58,6 +58,7 @@ class Tensor {
virtual Maybe<EagerMirroredTensorImpl*> mut_eager_mirrored_tensor_impl() { OF_UNIMPLEMENTED(); }
virtual Maybe<vm::EagerBlobObject> eager_blob_object() const = 0;
virtual Maybe<VmLocalDepObject> compute_local_dep_object() const = 0;
virtual Maybe<bool> has_eager_blob_object() const = 0;
virtual Maybe<TensorStorage> tensor_storage() const { OF_UNIMPLEMENTED(); }
// Getters/Setters valid only for EagerConsistentTensor
......@@ -76,19 +77,22 @@ class Tensor {
virtual bool is_leaf() const = 0;
virtual bool retain_grad() const = 0;
virtual std::shared_ptr<const FunctionNode> grad_fn_node() const = 0;
virtual const std::shared_ptr<Tensor>& acc_grad() const = 0;
virtual const std::shared_ptr<TensorArg>& now_grad_arg() const = 0;
virtual Maybe<Tensor> acc_grad() const = 0;
virtual Maybe<TensorArg> now_grad_arg() const = 0;
virtual Maybe<Tensor> detach() const = 0;
virtual Maybe<Tensor> clone() const = 0;
// Setters for autograd
virtual void set_requires_grad(bool requires_grad) = 0;
virtual void set_retain_grad(bool retain_grad) = 0;
virtual Maybe<void> set_retain_grad(bool retain_grad) = 0;
virtual void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) = 0;
virtual const std::shared_ptr<FunctionNode>& mut_grad_fn_node() = 0;
virtual void set_acc_grad(const std::shared_ptr<Tensor>& grad) = 0;
virtual std::shared_ptr<Tensor> mut_acc_grad() = 0;
virtual Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) = 0;
virtual Maybe<Tensor> mut_acc_grad() = 0;
virtual void set_is_leaf(bool is_leaf) = 0;
virtual std::shared_ptr<AutogradMeta> mut_autograd_meta() = 0;
virtual bool has_autograd_meta() const = 0;
virtual void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) = 0;
virtual user_op::TensorDesc* mut_tensor_meta() = 0;
......@@ -97,7 +101,7 @@ class Tensor {
};
template<typename DerivedT>
class TensorIf : public Tensor, public std::enable_shared_from_this<TensorIf<DerivedT>> {
class TensorIf : public Tensor {
public:
virtual ~TensorIf() = default;
......@@ -113,8 +117,12 @@ class TensorIf : public Tensor, public std::enable_shared_from_this<TensorIf<Der
std::shared_ptr<const FunctionNode> grad_fn_node() const override { return grad_fn_node_; }
// used by pybind11 only
Maybe<DerivedT> api_acc_grad() const {
const std::shared_ptr<Tensor>& tensor = acc_grad();
return cast_for_api(tensor);
if (has_autograd_meta()) {
const std::shared_ptr<Tensor>& tensor = JUST(acc_grad());
return cast_for_api(tensor);
} else {
return std::shared_ptr<DerivedT>();
}
}
// Setters for autograd
......@@ -130,6 +138,10 @@ class TensorIf : public Tensor, public std::enable_shared_from_this<TensorIf<Der
// Operators for tensor
// used by pybind11 only
virtual Maybe<DerivedT> api_detach() const = 0;
Maybe<DerivedT> api_clone() const {
const std::shared_ptr<Tensor>& tensor = JUST(clone());
return cast_for_api(tensor);
}
protected:
TensorIf() = default;
......@@ -144,7 +156,8 @@ class TensorIf : public Tensor, public std::enable_shared_from_this<TensorIf<Der
}
};
class MirroredTensor final : public TensorIf<MirroredTensor> {
class MirroredTensor final : public TensorIf<MirroredTensor>,
public std::enable_shared_from_this<MirroredTensor> {
public:
OF_DISALLOW_COPY_AND_MOVE(MirroredTensor);
MirroredTensor() = default;
......@@ -177,24 +190,34 @@ class MirroredTensor final : public TensorIf<MirroredTensor> {
return impl_->compute_local_dep_object();
}
Maybe<TensorStorage> tensor_storage() const override { return impl_->tensor_storage(); }
Maybe<bool> has_eager_blob_object() const override { return impl_->has_eager_blob_object(); }
// Getters for autograd
const std::shared_ptr<Tensor>& acc_grad() const override { return impl_->acc_grad(); }
const std::shared_ptr<TensorArg>& now_grad_arg() const override { return impl_->now_grad_arg(); }
Maybe<Tensor> acc_grad() const override { return impl_->acc_grad(); }
Maybe<TensorArg> now_grad_arg() const override { return impl_->now_grad_arg(); }
bool requires_grad() const override { return impl_->requires_grad(); }
bool is_leaf() const override { return impl_->is_leaf(); }
bool retain_grad() const override { return impl_->retain_grad(); }
bool has_autograd_meta() const override { return impl_->has_autograd_meta(); }
// Setters for autograd
void set_acc_grad(const std::shared_ptr<Tensor>& grad) override { impl_->set_acc_grad(grad); }
Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) override {
return impl_->set_acc_grad(grad);
}
void set_requires_grad(bool requires_grad) override { impl_->set_requires_grad(requires_grad); }
void set_retain_grad(bool retain_grad) override { impl_->set_retain_grad(retain_grad); }
std::shared_ptr<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); }
Maybe<void> set_retain_grad(bool retain_grad) override {
return impl_->set_retain_grad(retain_grad);
}
Maybe<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); }
void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); }
std::shared_ptr<AutogradMeta> mut_autograd_meta() override { return impl_->mut_autograd_meta(); }
void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) override {
impl_->set_autograd_meta(autograd_meta);
}
// Operators for tensor
Maybe<MirroredTensor> api_detach() const override;
Maybe<Tensor> clone() const override;
static Maybe<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape, DataType dtype,
const Symbol<Device>& device, bool is_lazy,
......@@ -234,7 +257,9 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
const override {
return impl_->consumer_parallel_distribution_constraint();
}
Maybe<MirroredTensor> cur_rank_phy_tensor() const { return impl_->cur_rank_phy_tensor(); }
Maybe<MirroredTensor> cur_rank_phy_tensor() const override {
return impl_->cur_rank_phy_tensor();
}
int64_t ndim() const override;
bool is_cuda() const override;
int64_t dim(int64_t index) const override;
......@@ -249,6 +274,8 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
return impl_->compute_local_dep_object();
}
const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); }
Maybe<TensorStorage> tensor_storage() const override { return impl_->tensor_storage(); }
Maybe<bool> has_eager_blob_object() const override { return impl_->has_eager_blob_object(); }
// Setters
Maybe<void> set_consumer_parallel_distribution_constraint(
......@@ -258,22 +285,31 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
}
// Getters for autograd
const std::shared_ptr<Tensor>& acc_grad() const override { return impl_->acc_grad(); }
const std::shared_ptr<TensorArg>& now_grad_arg() const override { return impl_->now_grad_arg(); }
Maybe<Tensor> acc_grad() const override { return impl_->acc_grad(); }
Maybe<TensorArg> now_grad_arg() const override { return impl_->now_grad_arg(); }
bool requires_grad() const override { return impl_->requires_grad(); }
bool is_leaf() const override { return impl_->is_leaf(); }
bool retain_grad() const override { return impl_->retain_grad(); }
bool has_autograd_meta() const override { return impl_->has_autograd_meta(); }
// Setters for autograd
void set_acc_grad(const std::shared_ptr<Tensor>& grad) override { impl_->set_acc_grad(grad); }
std::shared_ptr<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); }
Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) override {
return impl_->set_acc_grad(grad);
}
Maybe<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); }
void set_requires_grad(bool requires_grad) override { impl_->set_requires_grad(requires_grad); }
void set_retain_grad(bool retain_grad) override { impl_->set_retain_grad(retain_grad); }
Maybe<void> set_retain_grad(bool retain_grad) override {
return impl_->set_retain_grad(retain_grad);
}
void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); }
std::shared_ptr<AutogradMeta> mut_autograd_meta() override { return impl_->mut_autograd_meta(); }
void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) override {
impl_->set_autograd_meta(autograd_meta);
}
// Operators for tensor
virtual Maybe<ConsistentTensor> api_detach() const override;
Maybe<Tensor> clone() const override { return Error::Unimplemented(); }
static Maybe<ConsistentTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,
DataType dtype,
......
......@@ -32,6 +32,33 @@ limitations under the License.
namespace oneflow {
namespace one {
Maybe<Tensor> TensorImpl::acc_grad() const {
CHECK_NOTNULL_OR_RETURN(autograd_meta_);
return autograd_meta_->acc_grad();
}
Maybe<TensorArg> TensorImpl::now_grad_arg() const {
CHECK_NOTNULL_OR_RETURN(autograd_meta_);
return autograd_meta_->now_grad_arg();
}
Maybe<void> TensorImpl::set_acc_grad(const std::shared_ptr<Tensor>& grad) {
CHECK_NOTNULL_OR_RETURN(autograd_meta_);
autograd_meta_->set_acc_grad(grad);
return Maybe<void>::Ok();
}
Maybe<Tensor> TensorImpl::mut_acc_grad() {
CHECK_NOTNULL_OR_RETURN(autograd_meta_);
return autograd_meta_->mut_acc_grad();
}
Maybe<void> TensorImpl::set_retain_grad(bool retain_grad) {
CHECK_NOTNULL_OR_RETURN(autograd_meta_);
autograd_meta_->set_retain_grad(retain_grad);
return Maybe<void>::Ok();
}
namespace {
std::shared_ptr<const MirroredTensorMeta> NewDefaultMirroredTensorMeta() {
......@@ -48,24 +75,19 @@ Maybe<MirroredTensorImpl> LazyMirroredTensorImpl::detach() const {
}
EagerMirroredTensorImpl::EagerMirroredTensorImpl()
: MirroredTensorImpl(NewDefaultMirroredTensorMeta(), NewAutogradMeta(false, false)) {}
EagerMirroredTensorImpl::EagerMirroredTensorImpl(
const std::shared_ptr<const MirroredTensorMeta>& tensor_meta,
const std::shared_ptr<AutogradMeta>& autograd_meta)
: MirroredTensorImpl(tensor_meta, autograd_meta) {}
: MirroredTensorImpl(NewDefaultMirroredTensorMeta(), false, false) {}
EagerMirroredTensorImpl::EagerMirroredTensorImpl(
const std::shared_ptr<const MirroredTensorMeta>& tensor_meta, bool requires_grad, bool is_leaf)
: MirroredTensorImpl(tensor_meta, NewAutogradMeta(requires_grad, is_leaf)) {}
: MirroredTensorImpl(tensor_meta, requires_grad, is_leaf) {}
EagerMirroredTensorImpl::~EagerMirroredTensorImpl() {}
EagerMirroredTensorImpl::EagerMirroredTensorImpl(
const std::shared_ptr<const MirroredTensorMeta>& tensor_meta,
std::shared_ptr<TensorStorage> tensor_storage, bool requires_grad, bool is_leaf)
: MirroredTensorImpl(tensor_meta, NewAutogradMeta(requires_grad, is_leaf)),
tensor_storage_(tensor_storage) {}
: MirroredTensorImpl(tensor_meta, requires_grad, is_leaf), tensor_storage_(tensor_storage) {}
Maybe<void> EagerMirroredTensorImpl::UpdateTensorStorage() {
const auto& eager_blob_object = eager_blob_object_;
tensor_storage_ = std::make_shared<TensorStorage>(eager_blob_object->tensor_buffer());
......@@ -169,10 +191,10 @@ size_t ConsistentTensorMeta::CalcHashValue() const {
}
EagerConsistentTensorImpl::EagerConsistentTensorImpl(
Symbol<ConsistentTensorMeta> consistent_tensor_meta,
const std::shared_ptr<AutogradMeta>& autograd_meta,
Symbol<ConsistentTensorMeta> consistent_tensor_meta, bool requires_grad, bool is_leaf,
const std::shared_ptr<MirroredTensor>& cur_rank_phy_tensor)
: ConsistentTensorImpl(consistent_tensor_meta, autograd_meta),
: ConsistentTensorImpl(consistent_tensor_meta, cur_rank_phy_tensor->requires_grad(),
cur_rank_phy_tensor->is_leaf()),
cur_rank_phy_tensor_(cur_rank_phy_tensor) {}
/*static*/ Maybe<EagerConsistentTensorImpl> EagerConsistentTensorImpl::New(
......@@ -193,8 +215,9 @@ EagerConsistentTensorImpl::EagerConsistentTensorImpl(
const auto& dtype = cur_rank_phy_tensor->dtype();
Symbol<ConsistentTensorMeta> consistent_tensor_meta(
ConsistentTensorMeta(shape, dtype, parallel_distribution, parallel_desc));
return std::shared_ptr<EagerConsistentTensorImpl>(new EagerConsistentTensorImpl(
consistent_tensor_meta, cur_rank_phy_tensor->mut_autograd_meta(), cur_rank_phy_tensor));
return std::shared_ptr<EagerConsistentTensorImpl>(
new EagerConsistentTensorImpl(consistent_tensor_meta, cur_rank_phy_tensor->requires_grad(),
cur_rank_phy_tensor->is_leaf(), cur_rank_phy_tensor));
}
/*static*/ Maybe<EagerConsistentTensorImpl> EagerConsistentTensorImpl::New(
......@@ -219,21 +242,20 @@ EagerConsistentTensorImpl::EagerConsistentTensorImpl(
JUST(GetPhysicalShape(*shape, *parallel_distribution, *parallel_desc, parallel_id));
const auto& cur_rank_phy_tensor_meta =
std::make_shared<MirroredTensorMeta>(cur_rank_phy_shape, dtype, device);
const auto& autograd_meta = NewAutogradMeta(requires_grad, is_leaf);
auto cur_rank_phy_tensor_impl =
std::make_shared<EagerMirroredTensorImpl>(cur_rank_phy_tensor_meta, autograd_meta);
std::make_shared<EagerMirroredTensorImpl>(cur_rank_phy_tensor_meta, requires_grad, is_leaf);
JUST(cur_rank_phy_tensor_impl->InitEagerBlobObject(device->mem_case()));
const auto& cur_rank_phy_tensor = std::make_shared<MirroredTensor>(cur_rank_phy_tensor_impl);
auto* tensor_impl = new EagerConsistentTensorImpl(
consistent_tensor_meta, cur_rank_phy_tensor->mut_autograd_meta(), cur_rank_phy_tensor);
auto* tensor_impl =
new EagerConsistentTensorImpl(consistent_tensor_meta, cur_rank_phy_tensor->requires_grad(),
cur_rank_phy_tensor->is_leaf(), cur_rank_phy_tensor);
return std::shared_ptr<EagerConsistentTensorImpl>(tensor_impl);
}
/*static*/ Maybe<EagerConsistentTensorImpl> EagerConsistentTensorImpl::NewWithoutPhyTensor(
Symbol<ConsistentTensorMeta> consistent_tensor_meta, Symbol<Device> device, int64_t parallel_id,
bool requires_grad, bool is_leaf) {
const auto& autograd_meta = NewAutogradMeta(requires_grad, is_leaf);
auto* tensor_impl = new EagerConsistentTensorImpl(consistent_tensor_meta, autograd_meta,
auto* tensor_impl = new EagerConsistentTensorImpl(consistent_tensor_meta, requires_grad, is_leaf,
std::shared_ptr<MirroredTensor>());
return std::shared_ptr<EagerConsistentTensorImpl>(tensor_impl);
}
......
......@@ -62,26 +62,33 @@ class TensorImpl {
virtual Maybe<vm::EagerBlobObject> eager_blob_object() const = 0;
virtual Maybe<VmLocalDepObject> compute_local_dep_object() const = 0;
virtual Maybe<TensorStorage> tensor_storage() const { OF_UNIMPLEMENTED(); }
virtual Maybe<bool> has_eager_blob_object() const = 0;
// Getters for autograd
const std::shared_ptr<Tensor>& acc_grad() const { return autograd_meta_->acc_grad(); }
const std::shared_ptr<TensorArg>& now_grad_arg() const { return autograd_meta_->now_grad_arg(); }
bool requires_grad() const { return autograd_meta_->requires_grad(); }
bool is_leaf() const { return autograd_meta_->is_leaf(); }
Maybe<Tensor> acc_grad() const;
Maybe<TensorArg> now_grad_arg() const;
bool requires_grad() const { return requires_grad_; }
bool is_leaf() const { return is_leaf_; }
bool retain_grad() const { return autograd_meta_->retain_grad(); }
// Setters for autograd
void set_acc_grad(const std::shared_ptr<Tensor>& grad) { autograd_meta_->set_acc_grad(grad); }
std::shared_ptr<Tensor> mut_acc_grad() { return autograd_meta_->mut_acc_grad(); }
void set_requires_grad(bool requires_grad) { autograd_meta_->set_requires_grad(requires_grad); }
void set_retain_grad(bool retain_grad) { autograd_meta_->set_retain_grad(retain_grad); }
void set_is_leaf(bool is_leaf) { autograd_meta_->set_is_leaf(is_leaf); }
Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad);
Maybe<Tensor> mut_acc_grad();
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
Maybe<void> set_retain_grad(bool retain_grad);
void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; }
std::shared_ptr<AutogradMeta> mut_autograd_meta() { return autograd_meta_; }
void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) {
autograd_meta_ = autograd_meta;
}
bool has_autograd_meta() const { return autograd_meta_.get(); }
protected:
TensorImpl(const std::shared_ptr<AutogradMeta>& autograd_meta) : autograd_meta_(autograd_meta) {}
TensorImpl(bool requires_grad, bool is_leaf) : requires_grad_(requires_grad), is_leaf_(is_leaf) {}
protected:
bool requires_grad_;
bool is_leaf_;
std::shared_ptr<AutogradMeta> autograd_meta_;
};
......@@ -106,8 +113,8 @@ class MirroredTensorImpl : public TensorImpl {
protected:
MirroredTensorImpl(const std::shared_ptr<const MirroredTensorMeta>& tensor_meta,
const std::shared_ptr<AutogradMeta>& autograd_meta)
: TensorImpl(autograd_meta), tensor_meta_(tensor_meta) {}
bool requires_grad, bool is_leaf)
: TensorImpl(requires_grad, is_leaf), tensor_meta_(tensor_meta) {}
std::shared_ptr<const MirroredTensorMeta> tensor_meta_;
};
......@@ -134,6 +141,7 @@ class ConsistentTensorImpl : public TensorImpl {
// Getters valid only for EagerMirroredTensorImpl
Maybe<vm::EagerBlobObject> eager_blob_object() const override { OF_UNIMPLEMENTED(); }
Maybe<VmLocalDepObject> compute_local_dep_object() const override { OF_UNIMPLEMENTED(); }
Maybe<bool> has_eager_blob_object() const override { OF_UNIMPLEMENTED(); }
// Setters
void set_consumer_parallel_distribution_constraint(Symbol<cfg::ParallelDistribution> val) {
......@@ -146,9 +154,8 @@ class ConsistentTensorImpl : public TensorImpl {
}
protected:
ConsistentTensorImpl(Symbol<ConsistentTensorMeta> tensor_meta,
const std::shared_ptr<AutogradMeta>& autograd_meta)
: TensorImpl(autograd_meta),
ConsistentTensorImpl(Symbol<ConsistentTensorMeta> tensor_meta, bool requires_grad, bool is_leaf)
: TensorImpl(requires_grad, is_leaf),
tensor_meta_(tensor_meta),
consumer_parallel_distribution_constraint_() {}
......@@ -161,17 +168,18 @@ class LazyMirroredTensorImpl final : public MirroredTensorImpl {
OF_DISALLOW_COPY_AND_MOVE(LazyMirroredTensorImpl);
LazyMirroredTensorImpl(const std::shared_ptr<const MirroredTensorMeta>& tensor_meta,
bool requires_grad, bool is_leaf)
: MirroredTensorImpl(tensor_meta, NewAutogradMeta(requires_grad, is_leaf)) {}
: MirroredTensorImpl(tensor_meta, requires_grad, is_leaf) {}
~LazyMirroredTensorImpl() override = default;
// Getters
const std::shared_ptr<const Shape>& shape() const { return tensor_meta()->shape_ptr(); }
const std::shared_ptr<const Shape>& shape() const override { return tensor_meta()->shape_ptr(); }
bool is_lazy() const override { return true; }
// Getters valid only for EagerMirroredTensorImpl
Maybe<vm::EagerBlobObject> eager_blob_object() const override { OF_UNIMPLEMENTED(); }
Maybe<VmLocalDepObject> compute_local_dep_object() const override { OF_UNIMPLEMENTED(); }
Maybe<TensorStorage> tensor_storage() const override { OF_UNIMPLEMENTED(); }
Maybe<bool> has_eager_blob_object() const override { OF_UNIMPLEMENTED(); }
Maybe<MirroredTensorImpl> detach() const override;
};
......@@ -179,8 +187,6 @@ class EagerMirroredTensorImpl final : public MirroredTensorImpl {
public:
OF_DISALLOW_COPY_AND_MOVE(EagerMirroredTensorImpl);
EagerMirroredTensorImpl();
EagerMirroredTensorImpl(const std::shared_ptr<const MirroredTensorMeta>& tensor_meta,
const std::shared_ptr<AutogradMeta>& autograd_meta);
EagerMirroredTensorImpl(const std::shared_ptr<const MirroredTensorMeta>& tensor_meta,
bool requires_grad, bool is_leaf);
EagerMirroredTensorImpl(const std::shared_ptr<const MirroredTensorMeta>& tensor_meta,
......@@ -203,6 +209,7 @@ class EagerMirroredTensorImpl final : public MirroredTensorImpl {
CHECK_OR_RETURN(eager_blob_object_);
return tensor_storage_;
}
Maybe<bool> has_eager_blob_object() const override { return eager_blob_object_.get(); }
// Setters
TensorStorage* mut_tensor_storage() { return tensor_storage_.get(); }
......@@ -226,7 +233,7 @@ class LazyConsistentTensorImpl final : public ConsistentTensorImpl {
OF_DISALLOW_COPY_AND_MOVE(LazyConsistentTensorImpl);
LazyConsistentTensorImpl(Symbol<ConsistentTensorMeta> consistent_tensor_meta, bool requires_grad,
bool is_leaf)
: ConsistentTensorImpl(consistent_tensor_meta, NewAutogradMeta(requires_grad, is_leaf)) {}
: ConsistentTensorImpl(consistent_tensor_meta, requires_grad, is_leaf) {}
~LazyConsistentTensorImpl() override = default;
// Getters
......@@ -262,8 +269,8 @@ class EagerConsistentTensorImpl final : public ConsistentTensorImpl {
Symbol<Device>, int64_t, bool, bool);
private:
EagerConsistentTensorImpl(Symbol<ConsistentTensorMeta> consistent_tensor_meta,
const std::shared_ptr<AutogradMeta>& autograd_meta,
EagerConsistentTensorImpl(Symbol<ConsistentTensorMeta> consistent_tensor_meta, bool requires_grad,
bool is_leaf,
const std::shared_ptr<MirroredTensor>& cur_rank_phy_tensor);
std::shared_ptr<MirroredTensor> cur_rank_phy_tensor_;
......
......@@ -153,7 +153,7 @@
bind_python: True
- name: "relu"
signature: "Tensor Relu(Tensor x)"
signature: "Tensor Relu(Tensor x, *, Bool inplace=False)"
bind_python: True
- name: "relu_grad"
......
......@@ -25,6 +25,7 @@ limitations under the License.
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/scalar.h"
#include "oneflow/core/autograd/autograd_mode.h"
namespace oneflow {
namespace one {
......@@ -32,9 +33,24 @@ namespace functional {
namespace impl {
class ReluFunctor : public UnaryFunctor {
class ReluFunctor {
public:
ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("in").Output("out").Build()); }
ReluFunctor() {
op_ = CHECK_JUST(one::OpBuilder("relu").Input("in", 1).Output("out", 1).Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, bool inplace) const {
if (inplace) {
std::shared_ptr<TensorTuple> outputs = std::make_shared<TensorTuple>(1);
outputs->at(0) = x;
JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*op_, {x}, outputs.get(), AttrMap{}));
return outputs->at(0);
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {x});
}
}
private:
std::shared_ptr<OpExpr> op_;
};
class ReluGradFunctor : public BinaryFunctor {
......
......@@ -322,6 +322,13 @@ class Tensor:
else:
return None
@_auto_determine
def clone(self):
if self._local_or_consistent_tensor is not None:
return flow.Tensor(self._local_or_consistent_tensor.clone())
else:
return None
def requires_grad_(self, requires_grad=True):
self.requires_grad = requires_grad
......
......@@ -133,8 +133,15 @@ class ReLU(Module):
def __init__(self, inplace: bool = False):
super().__init__()
self._inplace = inplace
def forward(self, x):
if self._inplace:
if x.requires_grad and x.is_leaf:
raise RuntimeError(
"a leaf Variable that requires grad is being used in an in-place operation."
)
return flow.F.relu(x, inplace=True)
return flow.F.relu(x)
......
......@@ -38,6 +38,19 @@ def _test_relu_impl(test_case, shape, device):
of_out.backward()
test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_out > 0, 1e-5, 1e-5))
inplace_m = flow.nn.ReLU(inplace=True)
of_input = flow.Tensor(
np_input, dtype=flow.float32, device=flow.device(device), requires_grad=True
)
of_input_inplace = of_input + 1
inplace_m(of_input_inplace)
np_out = np.maximum(0, np_input + 1)
test_case.assertTrue(np.allclose(of_input_inplace.numpy(), np_out, 1e-5, 1e-5))
of_out_inplace = of_input_inplace.sum()
of_out_inplace.backward()
test_case.assertTrue(np.allclose(of_input.grad.numpy(), np_out > 0, 1e-5, 1e-5))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
......
......@@ -763,6 +763,22 @@ class TestTensor(flow.unittest.TestCase):
test_case.assertEqual(z.is_leaf, True)
test_case.assertEqual(z.grad_fn, None)
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
"numpy doesn't work in lazy mode",
)
def test_tensor_clone(test_case):
shape = (2, 3, 4, 5)
x = flow.Tensor(
np.random.randn(*shape), dtype=flow.float32, requires_grad=True,
)
y = x.clone()
test_case.assertTrue(np.allclose(y.numpy(), x.numpy(), 1e-4, 1e-4))
test_case.assertEqual(y.requires_grad, True)
test_case.assertEqual(y.is_leaf, False)
# Cannot print Copy grad function
test_case.assertTrue(y.grad_fn != None)
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
"numpy doesn't work in lazy mode",
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment