diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake
index d9285ed06a4396f065978406d4527dba8b6554ad..356d16539479d78170e081ee5fc2e48c1cd1f3b0 100644
--- a/cmake/oneflow.cmake
+++ b/cmake/oneflow.cmake
@@ -418,6 +418,7 @@ list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/kernel_conte
 list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/kernel_util.cuh")
 list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/job/sbp_signature_builder.h")
 list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/job/parallel_desc.h")
+list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/autograd/autograd_meta.h")
 copy_files("${OF_CORE_HDRS}" "${PROJECT_SOURCE_DIR}" "${ONEFLOW_INCLUDE_DIR}" of_include_copy)
 
 add_dependencies(pip_install of_include_copy)
diff --git a/oneflow/api/python/framework/tensor.cpp b/oneflow/api/python/framework/tensor.cpp
index ee531615f2e08ec743b81e3e630c8d9e8b183a32..4559a7548cf64d62e9a3ba688e955d7c3943f83b 100644
--- a/oneflow/api/python/framework/tensor.cpp
+++ b/oneflow/api/python/framework/tensor.cpp
@@ -39,10 +39,9 @@ struct TensorExportUtil<MirroredTensor> final {
   static std::shared_ptr<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,
                                                     const std::shared_ptr<const DType>& dtype,
                                                     const std::shared_ptr<const Device>& device,
-                                                    bool is_lazy, bool requires_grad, bool is_leaf,
-                                                    bool retain_grad) {
-    return MirroredTensor::MakeTensor(shape, dtype, device, is_lazy, requires_grad, is_leaf,
-                                      retain_grad);
+                                                    bool is_lazy, bool requires_grad,
+                                                    bool is_leaf) {
+    return MirroredTensor::MakeTensor(shape, dtype, device, is_lazy, requires_grad, is_leaf);
   }
 };
 
@@ -52,9 +51,9 @@ struct TensorExportUtil<ConsistentTensor> final {
       const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype,
       const std::shared_ptr<const compatible_py::Distribute>& distribute,
       const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad,
-      bool is_leaf, bool retain_grad) {
+      bool is_leaf) {
     return ConsistentTensor::MakeTensor(shape, dtype, distribute, parallel_desc, is_lazy,
-                                        requires_grad, is_leaf, retain_grad);
+                                        requires_grad, is_leaf);
   }
 };
 
@@ -73,7 +72,10 @@ void ExportTensor(py::module& m, const char* name) {
       .def_property_readonly("requires_grad", &T::requires_grad)
       .def_property_readonly("is_leaf", &T::is_leaf)
       // Methods of pytorch
-      .def("retain_grad", [](T& t) { t.set_retain_grad(true); })
+      .def("retain_grad",
+           [](T& t) {
+             if (!t.is_leaf()) { t.set_retain_grad(true); }
+           })
       .def("detach", [](const T& t) { return t.api_detach().GetPtrOrThrow(); })
       // OneFlow tensor properties other than pytorch tensor
       .def_property_readonly("placement", &T::parallel_desc)
diff --git a/oneflow/core/autograd/autograd_engine.cpp b/oneflow/core/autograd/autograd_engine.cpp
index daf42b737b3a47a7b9f8c32c5c8071338f1dffa1..bda275dc6724a065d41910071cad197dbced4240 100644
--- a/oneflow/core/autograd/autograd_engine.cpp
+++ b/oneflow/core/autograd/autograd_engine.cpp
@@ -29,36 +29,25 @@ namespace one {
 
 namespace {
 
-bool IsReadyToRun(const std::vector<std::shared_ptr<TensorArg>>& out_grads) {
-  return std::any_of(
-      out_grads.begin(), out_grads.end(),
-      [](const std::shared_ptr<TensorArg>& tensor_arg) { return !tensor_arg->Empty(); });
+bool IsReadyToRun(const std::vector<std::shared_ptr<AutogradMeta>>& out_meta_datas) {
+  return std::any_of(out_meta_datas.begin(), out_meta_datas.end(),
+                     [](const std::shared_ptr<AutogradMeta>& meta_data) {
+                       return !meta_data->now_grad_arg()->Empty();
+                     });
 }
 
-Maybe<void> InitEmptyTensorArgs2ZerosTensor(const TensorTuple& outputs,
-                                            std::vector<std::shared_ptr<TensorArg>>& out_grads) {
-  const auto& zero_like = JUST(op_expr_helper::ZeroLikeOp());
-  for (int i = 0; i < out_grads.size(); ++i) {
-    if (out_grads.at(i)->Empty()) {
-      TensorTuple output(1);
-      JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*zero_like, {outputs.at(i)}, &output));
-      JUST(out_grads.at(i)->PushPartialTensor(output.at(0)));
-    }
-  }
-  return Maybe<void>::Ok();
-}
-
-Maybe<void> CopyOrAccGrad(Tensor* tensor, bool autograd_mode) {
+Maybe<void> CopyOrAccGrad(AutogradMeta* autograd_meta, bool autograd_mode) {
   autograd::AutoGradMode mode(autograd_mode);
-  const auto& now_grad = JUST(tensor->now_grad_arg()->GetAccTensor());
-  if (tensor->acc_grad()) {
-    TensorTuple input = {tensor->acc_grad(), now_grad};
+  const auto& now_grad = JUST(autograd_meta->now_grad_arg()->GetAccTensor());
+  if (!now_grad) { return Maybe<void>::Ok(); }
+  if (autograd_meta->acc_grad()) {
+    TensorTuple input = {autograd_meta->acc_grad(), now_grad};
     TensorTuple output(1);
     const auto& add = JUST(op_expr_helper::AddOp());
     JUST(JUST(OpInterpUtil::GetInterpreter())->Apply(*add, input, &output));
-    tensor->set_acc_grad(output.at(0));
+    autograd_meta->set_acc_grad(output.at(0));
   } else {
-    tensor->set_acc_grad(now_grad);
+    autograd_meta->set_acc_grad(now_grad);
   }
   return Maybe<void>::Ok();
 }
@@ -69,16 +58,18 @@ StackFunctionNode::StackFunctionNode(
     const std::shared_ptr<const std::function<Maybe<void>(const TensorTuple&, TensorTuple*, bool)>>&
         backward_fn,
     const TensorTuple& inputs, const TensorTuple& outputs) {
-  inputs_ = std::make_shared<TensorTuple>(inputs.size());
+  input_meta_datas_.resize(inputs.size());
+  input_tensors_.resize(inputs.size());
   for (int i = 0; i < inputs.size(); ++i) {
-    inputs_->at(i) = inputs.at(i);
-    in_grads_.emplace_back(inputs.at(i)->now_grad_arg());
+    input_meta_datas_.at(i) = inputs.at(i)->mut_autograd_meta();
+    if (input_meta_datas_.at(i)->requires_grad()) { input_tensors_.at(i) = inputs.at(i); }
   }
 
-  outputs_ = std::make_shared<TensorTuple>(outputs.size());
+  output_meta_datas_.resize(outputs.size());
+  output_tensor_infos_.reserve(outputs.size());
   for (int i = 0; i < outputs.size(); ++i) {
-    outputs_->at(i) = outputs.at(i)->detach();
-    out_grads_.emplace_back(outputs.at(i)->now_grad_arg());
+    output_meta_datas_.at(i) = outputs.at(i)->mut_autograd_meta();
+    output_tensor_infos_.emplace_back(TensorInfo(*outputs.at(i)));
   }
 
   backward_fn_ = backward_fn;
@@ -86,51 +77,52 @@ StackFunctionNode::StackFunctionNode(
 }
 
 Maybe<void> StackFunctionNode::AccGrad4RetainGradTensor() {
-  for (int i = 0; i < outputs_->size(); ++i) {
-    if (outputs_->at(i)->retain_grad() && outputs_->at(i)->requires_grad()) {
-      JUST(CopyOrAccGrad(outputs_->at(i).get(), /*autograd_mode=*/false));
-    }
+  for (const std::shared_ptr<AutogradMeta>& out : output_meta_datas_) {
+    if (out->retain_grad()) { JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false)); }
   }
   return Maybe<void>::Ok();
 }
 
 Maybe<void> StackFunctionNode::AccGrad4LeafTensor(bool create_graph) {
-  for (int i = 0; i < outputs_->size(); ++i) {
-    if (outputs_->at(i)->is_leaf() && outputs_->at(i)->requires_grad()) {
-      JUST(CopyOrAccGrad(outputs_->at(i).get(), /*autograd_mode=*/create_graph));
+  for (const std::shared_ptr<AutogradMeta>& out : output_meta_datas_) {
+    if (out->is_leaf() && out->requires_grad()) {
+      JUST(CopyOrAccGrad(out.get(), /*autograd_mode=*/false));
     }
   }
   return Maybe<void>::Ok();
 }
 
 void StackFunctionNode::ReleaseOutTensorArgs() {
-  for (const std::shared_ptr<TensorArg>& tensor_arg : out_grads_) { tensor_arg->Release(); }
+  for (const std::shared_ptr<AutogradMeta>& meta_data : output_meta_datas_) {
+    meta_data->now_grad_arg()->Release();
+  }
 }
 
 void StackFunctionNode::ReleaseData() {
-  if (!inputs_->empty()) {
-    inputs_.reset();
-    outputs_.reset();
-    in_grads_.clear();
-    out_grads_.clear();
-    backward_fn_.reset();
-  }
+  // Releases backward function and makes useless tensors release as early as possible
+  if (!input_meta_datas_.empty()) { backward_fn_.reset(); }
+  input_tensors_.clear();
   is_in_stack_ = false;
 }
 
 Maybe<bool> StackFunctionNode::Apply(bool create_graph) {
   CHECK_NOTNULL_OR_RETURN(backward_fn_.get())
       << "This FunctionNode with name `" << GetOpName() << "` has been released.";
-  if (!IsReadyToRun(out_grads_)) { return false; }
-  JUST(InitEmptyTensorArgs2ZerosTensor(*outputs_, out_grads_));
-  TensorTuple input_grads(in_grads_.size());
-  TensorTuple output_grads(out_grads_.size());
-  for (int i = 0; i < out_grads_.size(); ++i) {
-    output_grads.at(i) = JUST(out_grads_.at(i)->GetAccTensor());
+  if (!IsReadyToRun(output_meta_datas_)) { return false; }
+  TensorTuple input_grads(input_meta_datas_.size());
+  TensorTuple output_grads(output_meta_datas_.size());
+  for (int i = 0; i < output_meta_datas_.size(); ++i) {
+    if (output_meta_datas_.at(i)->now_grad_arg()->Empty()) {
+      output_grads.at(i) = JUST(output_tensor_infos_.at(i).zeros());
+    } else {
+      output_grads.at(i) = JUST(output_meta_datas_.at(i)->now_grad_arg()->GetAccTensor());
+    }
   }
   JUST((*backward_fn_)(output_grads, &input_grads, create_graph));
-  for (int i = 0; i < in_grads_.size(); ++i) {
-    JUST(in_grads_.at(i)->PushPartialTensor(input_grads.at(i)));
+  for (int i = 0; i < input_meta_datas_.size(); ++i) {
+    if (input_grads.at(i)) {
+      JUST(input_meta_datas_.at(i)->now_grad_arg()->PushPartialTensor(input_grads.at(i)));
+    }
   }
   return true;
 }
diff --git a/oneflow/core/autograd/autograd_engine.h b/oneflow/core/autograd/autograd_engine.h
index cba0b5f29bb88b31e5f152947f88eeb88987b86d..53f9c3e1361db492b14b0a57dad5636178e913a1 100644
--- a/oneflow/core/autograd/autograd_engine.h
+++ b/oneflow/core/autograd/autograd_engine.h
@@ -27,9 +27,10 @@ namespace oneflow {
 
 namespace one {
 
-class TensorArg;
 class Tensor;
 class TensorTuple;
+class AutogradMeta;
+class TensorInfo;
 
 // Calculates one backward op
 class FunctionNode {
@@ -101,12 +102,10 @@ class StackFunctionNode final : public FunctionNode {
   void set_is_in_stack(bool in_stack) { is_in_stack_ = in_stack; }
 
  private:
-  // FunctionNode shares Tensor with `inputs_`, and only shares TensorImpl with `outputs_`.
-  // The reference link is `output tensors -> node -> inputs_/input tensors`.
-  std::shared_ptr<TensorTuple> inputs_;
-  std::shared_ptr<TensorTuple> outputs_;
-  std::vector<std::shared_ptr<TensorArg>> in_grads_;
-  std::vector<std::shared_ptr<TensorArg>> out_grads_;
+  std::vector<std::shared_ptr<Tensor>> input_tensors_;
+  std::vector<std::shared_ptr<AutogradMeta>> input_meta_datas_;
+  std::vector<std::shared_ptr<AutogradMeta>> output_meta_datas_;
+  std::vector<TensorInfo> output_tensor_infos_;
   // Actual backward function builds in `AutogradInterpreter` to calculate one backward op
   std::shared_ptr<const std::function<Maybe<void>(const TensorTuple&, TensorTuple*, bool)>>
       backward_fn_;
diff --git a/oneflow/core/autograd/autograd_meta.cpp b/oneflow/core/autograd/autograd_meta.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..873810a5d504fb30653e5b0dc9fab07261e597e9
--- /dev/null
+++ b/oneflow/core/autograd/autograd_meta.cpp
@@ -0,0 +1,39 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#include "oneflow/core/framework/tensor.h"
+#include "oneflow/core/framework/op_expr_helper.h"
+#include "oneflow/core/framework/dtype.h"
+#include "oneflow/core/autograd/autograd_meta.h"
+#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
+
+namespace oneflow {
+
+namespace one {
+
+TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {}
+
+Maybe<Tensor> TensorInfo::zeros() const {
+  const auto& interpreter = JUST(OpInterpUtil::GetInterpreter());
+  const auto& zeros_op = JUST(op_expr_helper::ZerosOp(*shape_.get(), dtype_->data_type()));
+  TensorTuple outputs(1);
+  JUST(interpreter->Apply(*zeros_op, {}, &outputs));
+  return outputs.at(0);
+}
+
+}  // namespace one
+
+}  // namespace oneflow
diff --git a/oneflow/core/autograd/autograd_meta.h b/oneflow/core/autograd/autograd_meta.h
new file mode 100644
index 0000000000000000000000000000000000000000..4fa8d4a4815f37ee8611b9f078457535c792c3db
--- /dev/null
+++ b/oneflow/core/autograd/autograd_meta.h
@@ -0,0 +1,86 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_
+#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_
+
+#include <memory>
+#include "oneflow/core/framework/tensor_arg.h"
+#include "oneflow/core/common/util.h"
+
+namespace oneflow {
+
+class Shape;
+class DType;
+
+namespace one {
+
+class Tensor;
+class TensorArg;
+
+class AutogradMeta final {
+ public:
+  AutogradMeta() = delete;
+  AutogradMeta(bool requires_grad, bool is_leaf)
+      : is_leaf_(is_leaf),
+        requires_grad_(requires_grad),
+        retain_grad_(false),
+        now_grad_arg_(new TensorArg) {}
+
+  // Getters
+  const std::shared_ptr<Tensor>& acc_grad() const { return acc_grad_; }
+  const std::shared_ptr<TensorArg>& now_grad_arg() const { return now_grad_arg_; }
+  bool requires_grad() const { return requires_grad_; }
+  bool is_leaf() const { return is_leaf_; }
+  bool retain_grad() const { return retain_grad_; }
+
+  // Setters
+  void set_acc_grad(const std::shared_ptr<Tensor>& grad) { acc_grad_ = grad; }
+  std::shared_ptr<Tensor> mut_acc_grad() { return acc_grad_; }
+  void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
+  void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; }
+  void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; }
+
+ private:
+  bool is_leaf_;
+
+  // Only meaningful on leaf Tensors (must be false otherwise)
+  bool requires_grad_;
+
+  // Oney meaningful on non_leaf Tensors (must be false otherwise)
+  bool retain_grad_;
+
+  std::shared_ptr<Tensor> acc_grad_;
+  std::shared_ptr<TensorArg> now_grad_arg_;
+};
+
+class TensorInfo final {
+ public:
+  TensorInfo() = delete;
+  explicit TensorInfo(const Tensor& tensor);
+
+  Maybe<Tensor> zeros() const;
+
+ private:
+  std::shared_ptr<const Shape> shape_;
+  std::shared_ptr<const DType> dtype_;
+  // TODO: Add device info
+};
+
+}  // namespace one
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_META_H_
diff --git a/oneflow/core/autograd/autograd_mode.h b/oneflow/core/autograd/autograd_mode.h
index 26c5687cd283f772e8f0b22a7fbe7f009752d6f1..96c095fe22bdba7ad28e0008a94e981c96ec9946 100644
--- a/oneflow/core/autograd/autograd_mode.h
+++ b/oneflow/core/autograd/autograd_mode.h
@@ -14,8 +14,10 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 
-namespace oneflow {
+#ifndef ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_
+#define ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_
 
+namespace oneflow {
 namespace autograd {
 
 struct GradMode {
@@ -41,5 +43,6 @@ class NoGradGuard : public AutoGradMode {
 };
 
 }  // namespace autograd
-
 }  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_AUTOGRAD_AUTOGRAD_MODE_H_
diff --git a/oneflow/core/framework/op_expr_grad_function.h b/oneflow/core/framework/op_expr_grad_function.h
index ea637d931d389808441823a91a203a73636d9f9c..beb3715ad27eaf6008ce55c1259dec7a0e671e54 100644
--- a/oneflow/core/framework/op_expr_grad_function.h
+++ b/oneflow/core/framework/op_expr_grad_function.h
@@ -54,7 +54,9 @@ class OpExprGradFunction : public OpExprGradFunctionIf {
                         const TensorTuple& outputs, const AttrMap& attrs) const override {
     StateT* state = dynamic_cast<StateT*>(ctx);
     CHECK_NOTNULL_OR_RETURN(state);
-    return Capture(state, inputs, outputs, attrs);
+    TensorTuple detach_outputs(outputs.size());
+    for (int i = 0; i < outputs.size(); ++i) { detach_outputs.at(i) = outputs.at(i)->detach(); }
+    return Capture(state, inputs, detach_outputs, attrs);
   }
 
   Maybe<void> ApplyIf(const OpExprInterpState* ctx, const TensorTuple& out_grads,
diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp
index 1c1460c801f76b0bf200c3f777be82a16c06552f..b0b8ee8c7348bb5e1d18e7d036b68379081b7b69 100644
--- a/oneflow/core/framework/op_expr_helper.cpp
+++ b/oneflow/core/framework/op_expr_helper.cpp
@@ -80,26 +80,26 @@ OF_PP_FOR_EACH_TUPLE(DEFINE_FLOATING_CONSTATNT_OP, FLOATING_DATA_TYPE_SEQ);
 OF_PP_FOR_EACH_TUPLE(DEFINE_INTEGER_CONSTATNT_OP, INT_DATA_TYPE_SEQ)
 #undef DEFINE_INTEGER_CONSTATNT_OP
 
-Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype) {
+Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype) {
   return OnesOp(shape, dtype, UniqueOpName("constant"));
 }
-Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype, const std::string& name) {
+Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype, const std::string& name) {
   switch (dtype) {
 #define CONSTANT_DATA_TYPE_CASE(cpp_type, data_type) \
-  case data_type: return ConstantOp(shape, (cpp_type)1, name);
+  case data_type: return ConstantOp(shape, (cpp_type)0, name);
     OF_PP_FOR_EACH_TUPLE(CONSTANT_DATA_TYPE_CASE, FLOATING_DATA_TYPE_SEQ INT_DATA_TYPE_SEQ);
 #undef CONSTANT_DATA_TYPE_CASE
     default: UNIMPLEMENTED_THEN_RETURN();
   }
 }
 
-Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype) {
-  return ZerosOp(shape, dtype, UniqueOpName("constant"));
+Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype) {
+  return OnesOp(shape, dtype, UniqueOpName("constant"));
 }
-Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype, const std::string& name) {
+Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype, const std::string& name) {
   switch (dtype) {
 #define CONSTANT_DATA_TYPE_CASE(cpp_type, data_type) \
-  case data_type: return ConstantOp(shape, (cpp_type)0, name);
+  case data_type: return ConstantOp(shape, (cpp_type)1, name);
     OF_PP_FOR_EACH_TUPLE(CONSTANT_DATA_TYPE_CASE, FLOATING_DATA_TYPE_SEQ INT_DATA_TYPE_SEQ);
 #undef CONSTANT_DATA_TYPE_CASE
     default: UNIMPLEMENTED_THEN_RETURN();
diff --git a/oneflow/core/framework/op_expr_helper.h b/oneflow/core/framework/op_expr_helper.h
index 7eadedc7dcb98bed15114249fa929a58fa8fb6a7..6d2f7268c891bdaebc0c28e20b5605f5c6f5eecd 100644
--- a/oneflow/core/framework/op_expr_helper.h
+++ b/oneflow/core/framework/op_expr_helper.h
@@ -28,6 +28,9 @@ Maybe<one::UserOpExpr> AddNOp(int32_t n, const std::string& name);
 Maybe<one::UserOpExpr> AddOp();
 Maybe<one::UserOpExpr> AddOp(const std::string& name);
 
+Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype);
+Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype, const std::string& name);
+
 Maybe<one::UserOpExpr> ZeroLikeOp();
 Maybe<one::UserOpExpr> ZeroLikeOp(const std::string& name);
 
@@ -39,9 +42,6 @@ Maybe<one::UserOpExpr> ConstantOp(const Shape& shape, const T& value, const std:
 Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype);
 Maybe<one::UserOpExpr> OnesOp(const Shape& shape, const DataType& dtype, const std::string& name);
 
-Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype);
-Maybe<one::UserOpExpr> ZerosOp(const Shape& shape, const DataType& dtype, const std::string& name);
-
 Maybe<one::UserOpExpr> IdentityOp();
 Maybe<one::UserOpExpr> IdentityOp(const std::string& name);
 
diff --git a/oneflow/core/framework/op_interpreter.h b/oneflow/core/framework/op_interpreter.h
index caf73af34dcc150e92c385ed94c13d567dfeea45..7c4cbc7307b3ca93e859b6922acd31df1e4b2bf3 100644
--- a/oneflow/core/framework/op_interpreter.h
+++ b/oneflow/core/framework/op_interpreter.h
@@ -33,7 +33,7 @@ class OpExprInterpState {
 
   size_t SaveTensorForBackward(const std::shared_ptr<Tensor>& tensor) {
     size_t offset = saved_tensors_.size();
-    saved_tensors_.push_back(tensor->detach());
+    saved_tensors_.push_back(tensor);
     return offset;
   }
 
diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp
index f36fb8a116aa1ba71f6791406e25d0033de70948..dd85a53ff4b51a1731664d5f0e0dd738ddbbc6ae 100644
--- a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp
+++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp
@@ -160,13 +160,12 @@ using Bn2BlobObjectMap = HashMap<std::string, std::shared_ptr<compatible_py::Blo
     const auto& device =
         JUST(Device::MakeDeviceByParallelDesc(*parallel_attr->parallel_desc_symbol()));
     return static_cast<std::shared_ptr<Tensor>>(MirroredTensor::MakeTensor(
-        blob_attr->shape(), dtype, device, is_lazy, /*requires_grad=*/false, /*is_leaf=*/false,
-        /*retain_grad=*/false));
+        blob_attr->shape(), dtype, device, is_lazy, /*requires_grad=*/false, /*is_leaf=*/false));
   } else {
     const auto& distribute = JUST(compatible_py::MakeDistribute(*(parallel_attr->sbp_parallel())));
     return static_cast<std::shared_ptr<Tensor>>(ConsistentTensor::MakeTensor(
         blob_attr->shape(), dtype, distribute, parallel_attr->parallel_desc_symbol(), is_lazy,
-        /*requires_grad=*/false, /*is_leaf=*/false, /*retain_grad=*/false));
+        /*requires_grad=*/false, /*is_leaf=*/false));
   }
 }
 
@@ -174,8 +173,7 @@ using Bn2BlobObjectMap = HashMap<std::string, std::shared_ptr<compatible_py::Blo
     const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,
     const std::shared_ptr<const Device>& device) {
   auto tensor = MirroredTensor::MakeEagerTensor(eager_blob_object, device,
-                                                /* requires_grad */ false, /* is_leaf */ false,
-                                                /* retain_grad */ false);
+                                                /* requires_grad */ false, /* is_leaf */ false);
   return std::static_pointer_cast<Tensor>(tensor);
 }
 
diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp
index 6ba8841df5e2d6e34e385f44b39db2f4ceb24c90..7e3522eae4a1b133a2643cd6186653034af9bfcf 100644
--- a/oneflow/core/framework/tensor.cpp
+++ b/oneflow/core/framework/tensor.cpp
@@ -26,27 +26,24 @@ namespace one {
 
 std::shared_ptr<MirroredTensor> MirroredTensor::MakeTensor(
     const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype,
-    const std::shared_ptr<const Device>& device, bool is_lazy, bool requires_grad, bool is_leaf,
-    bool retain_grad) {
+    const std::shared_ptr<const Device>& device, bool is_lazy, bool requires_grad, bool is_leaf) {
   std::shared_ptr<MirroredTensorImpl> impl;
   if (is_lazy) {
-    impl = std::make_shared<LazyMirroredTensorImpl>(shape, dtype, device, requires_grad, is_leaf,
-                                                    retain_grad);
+    impl = std::make_shared<LazyMirroredTensorImpl>(shape, dtype, device, requires_grad, is_leaf);
   } else {
     const auto eager_blob_object =
         CHECK_JUST(GenerateAllocatedEagerBlobObject(dtype->data_type(), *shape));
     impl = std::make_shared<EagerMirroredTensorImpl>(eager_blob_object, device, requires_grad,
-                                                     is_leaf, retain_grad);
+                                                     is_leaf);
   }
   return std::make_shared<MirroredTensor>(impl);
 }
 
 std::shared_ptr<MirroredTensor> MirroredTensor::MakeEagerTensor(
     const std::shared_ptr<vm::EagerBlobObject> eager_blob_object,
-    const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf,
-    bool retain_grad) {
-  std::shared_ptr<MirroredTensorImpl> impl = std::make_shared<EagerMirroredTensorImpl>(
-      eager_blob_object, device, requires_grad, is_leaf, retain_grad);
+    const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf) {
+  std::shared_ptr<MirroredTensorImpl> impl =
+      std::make_shared<EagerMirroredTensorImpl>(eager_blob_object, device, requires_grad, is_leaf);
   return std::make_shared<MirroredTensor>(impl);
 }
 
@@ -60,7 +57,7 @@ int64_t MirroredTensor::nelement() const { return shape()->elem_cnt(); }
 
 std::shared_ptr<MirroredTensor> MirroredTensor::data() const {
   std::shared_ptr<MirroredTensor> t =
-      MakeTensor(shape(), dtype(), device(), is_lazy(), false, is_leaf(), false);
+      MakeTensor(shape(), dtype(), device(), is_lazy(), false, is_leaf());
   t->set_blob_object(blob_object());
   return t;
 }
@@ -74,14 +71,14 @@ std::shared_ptr<ConsistentTensor> ConsistentTensor::MakeTensor(
     const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype,
     const std::shared_ptr<const compatible_py::Distribute>& distribute,
     const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad,
-    bool is_leaf, bool retain_grad) {
+    bool is_leaf) {
   std::shared_ptr<ConsistentTensorImpl> impl;
   if (is_lazy) {
     impl = std::make_shared<LazyConsistentTensorImpl>(shape, dtype, distribute, parallel_desc,
-                                                      requires_grad, is_leaf, retain_grad);
+                                                      requires_grad, is_leaf);
   } else {
     impl = std::make_shared<EagerConsistentTensorImpl>(shape, dtype, distribute, parallel_desc,
-                                                       requires_grad, is_leaf, retain_grad);
+                                                       requires_grad, is_leaf);
   }
   return std::make_shared<ConsistentTensor>(impl);
 }
@@ -97,8 +94,8 @@ int64_t ConsistentTensor::nelement() const { return shape()->elem_cnt(); }
 int64_t ConsistentTensor::ndim() const { return shape()->NumAxes(); }
 
 std::shared_ptr<ConsistentTensor> ConsistentTensor::data() const {
-  std::shared_ptr<ConsistentTensor> t = MakeTensor(shape(), dtype(), distribute(), parallel_desc(),
-                                                   is_lazy(), false, is_leaf(), false);
+  std::shared_ptr<ConsistentTensor> t =
+      MakeTensor(shape(), dtype(), distribute(), parallel_desc(), is_lazy(), false, is_leaf());
   t->set_blob_object(blob_object());
   return t;
 }
diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h
index ebf0d47953b0c8d58e584d5322c8dbddaefd1be6..d13c523ccb169d5e71ec3dba54295cec2bf2bd73 100644
--- a/oneflow/core/framework/tensor.h
+++ b/oneflow/core/framework/tensor.h
@@ -100,6 +100,7 @@ class Tensor {
   virtual void set_acc_grad(const std::shared_ptr<Tensor>& grad) = 0;
   virtual std::shared_ptr<Tensor> mut_acc_grad() = 0;
   virtual void set_is_leaf(bool is_leaf) = 0;
+  virtual std::shared_ptr<AutogradMeta> mut_autograd_meta() = 0;
 
  protected:
   Tensor() = default;
@@ -218,9 +219,10 @@ class MirroredTensor final : public TensorIf<MirroredTensor> {
   // Setters for autograd
   void set_acc_grad(const std::shared_ptr<Tensor>& grad) override { impl_->set_acc_grad(grad); }
   void set_requires_grad(bool requires_grad) override { impl_->set_requires_grad(requires_grad); }
-  void set_retain_grad(bool retain_grad) override { impl_->set_requires_grad(retain_grad); }
+  void set_retain_grad(bool retain_grad) override { impl_->set_retain_grad(retain_grad); }
   std::shared_ptr<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); }
   void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); }
+  std::shared_ptr<AutogradMeta> mut_autograd_meta() override { return impl_->mut_autograd_meta(); }
 
   // Operators for tensor
   std::shared_ptr<Tensor> detach() const override;
@@ -239,13 +241,11 @@ class MirroredTensor final : public TensorIf<MirroredTensor> {
   static std::shared_ptr<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,
                                                     const std::shared_ptr<const DType>& dtype,
                                                     const std::shared_ptr<const Device>& device,
-                                                    bool is_lazy, bool requires_grad, bool is_leaf,
-                                                    bool retain_grad);
+                                                    bool is_lazy, bool requires_grad, bool is_leaf);
 
   static std::shared_ptr<MirroredTensor> MakeEagerTensor(
       const std::shared_ptr<vm::EagerBlobObject> eager_blob_object,
-      const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf,
-      bool retain_grad);
+      const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf);
 
  private:
   std::shared_ptr<MirroredTensorImpl> impl_;
@@ -308,8 +308,9 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
   void set_acc_grad(const std::shared_ptr<Tensor>& grad) override { impl_->set_acc_grad(grad); }
   std::shared_ptr<Tensor> mut_acc_grad() override { return impl_->mut_acc_grad(); }
   void set_requires_grad(bool requires_grad) override { impl_->set_requires_grad(requires_grad); }
-  void set_retain_grad(bool retain_grad) override { impl_->set_requires_grad(retain_grad); }
+  void set_retain_grad(bool retain_grad) override { impl_->set_retain_grad(retain_grad); }
   void set_is_leaf(bool is_leaf) override { impl_->set_is_leaf(is_leaf); }
+  std::shared_ptr<AutogradMeta> mut_autograd_meta() override { return impl_->mut_autograd_meta(); }
 
   // Operators for tensor
   std::shared_ptr<Tensor> detach() const override;
@@ -329,7 +330,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
       const std::shared_ptr<const Shape>& shape, const std::shared_ptr<const DType>& dtype,
       const std::shared_ptr<const compatible_py::Distribute>& distribute,
       const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad,
-      bool is_leaf, bool retain_grad);
+      bool is_leaf);
 
  private:
   std::shared_ptr<ConsistentTensorImpl> impl_;
diff --git a/oneflow/core/framework/tensor_impl.cpp b/oneflow/core/framework/tensor_impl.cpp
index e9e1c9cb4d072471d3bd4c6e038071b65a0739c6..a9e9e0ab554bb5e6e2e53bc2eccdb05d4d899387 100644
--- a/oneflow/core/framework/tensor_impl.cpp
+++ b/oneflow/core/framework/tensor_impl.cpp
@@ -72,9 +72,8 @@ Maybe<void> EagerConsistentTensorImpl::set_blob_object(
 
 EagerMirroredTensorImpl::EagerMirroredTensorImpl(
     const std::shared_ptr<vm::EagerBlobObject> eager_blob_object,
-    const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf, bool retain_grad)
-    : MirroredTensorImpl(device, requires_grad, is_leaf, retain_grad),
-      eager_blob_object_(eager_blob_object) {
+    const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf)
+    : MirroredTensorImpl(device, requires_grad, is_leaf), eager_blob_object_(eager_blob_object) {
   dtype_ = CHECK_JUST(DType::GetDTypeByDataType(eager_blob_object->blob_desc().data_type()));
   tensor_storage_ = std::make_shared<TensorStorage>(eager_blob_object->tensor_buffer());
   const auto& parallel_desc = this->parallel_desc();
diff --git a/oneflow/core/framework/tensor_impl.h b/oneflow/core/framework/tensor_impl.h
index 2f1e9df1211e24f3b730ba936183334244886dda..b470de2ee4c8e19d1d5bc7bf08037ecbcf25679f 100644
--- a/oneflow/core/framework/tensor_impl.h
+++ b/oneflow/core/framework/tensor_impl.h
@@ -19,11 +19,10 @@ limitations under the License.
 
 #include "oneflow/core/common/util.h"
 #include "oneflow/core/common/data_type.h"
-#include "oneflow/core/common/shape.h"
 #include "oneflow/core/job/placement.cfg.h"
 #include "oneflow/core/framework/object.h"
-#include "oneflow/core/framework/tensor_arg.h"
 #include "oneflow/core/framework/tensor_storage.h"
+#include "oneflow/core/autograd/autograd_meta.h"
 
 namespace oneflow {
 
@@ -35,6 +34,7 @@ namespace compatible_py {
 class Distribute;
 }
 
+class Shape;
 class Device;
 class DType;
 
@@ -63,11 +63,11 @@ class TensorImpl {
   virtual Maybe<VmLocalDepObject> compute_local_dep_object() const = 0;
 
   // Getters for autograd
-  const std::shared_ptr<Tensor>& acc_grad() const { return acc_grad_; }
-  const std::shared_ptr<TensorArg>& now_grad_arg() const { return now_grad_arg_; }
-  bool requires_grad() const { return requires_grad_; }
-  bool is_leaf() const { return is_leaf_; }
-  bool retain_grad() const { return retain_grad_; }
+  const std::shared_ptr<Tensor>& acc_grad() const { return autograd_meta_->acc_grad(); }
+  const std::shared_ptr<TensorArg>& now_grad_arg() const { return autograd_meta_->now_grad_arg(); }
+  bool requires_grad() const { return autograd_meta_->requires_grad(); }
+  bool is_leaf() const { return autograd_meta_->is_leaf(); }
+  bool retain_grad() const { return autograd_meta_->retain_grad(); }
 
   // Setters
   virtual void set_shape(const std::shared_ptr<const Shape>& shape) = 0;
@@ -76,11 +76,12 @@ class TensorImpl {
       const std::shared_ptr<const ParallelDesc>& parallel_desc) = 0;
 
   // Setters for autograd
-  void set_acc_grad(const std::shared_ptr<Tensor>& grad) { acc_grad_ = grad; }
-  std::shared_ptr<Tensor> mut_acc_grad() { return acc_grad_; }
-  void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
-  void set_retain_grad(bool retain_grad) { retain_grad_ = retain_grad; }
-  void set_is_leaf(bool is_leaf) { is_leaf_ = is_leaf; }
+  void set_acc_grad(const std::shared_ptr<Tensor>& grad) { autograd_meta_->set_acc_grad(grad); }
+  std::shared_ptr<Tensor> mut_acc_grad() { return autograd_meta_->mut_acc_grad(); }
+  void set_requires_grad(bool requires_grad) { autograd_meta_->set_requires_grad(requires_grad); }
+  void set_retain_grad(bool retain_grad) { autograd_meta_->set_retain_grad(retain_grad); }
+  void set_is_leaf(bool is_leaf) { autograd_meta_->set_is_leaf(is_leaf); }
+  std::shared_ptr<AutogradMeta> mut_autograd_meta() { return autograd_meta_; }
 
   // Getters to be deprecated
   virtual const std::shared_ptr<compatible_py::BlobObject>& blob_object() const = 0;
@@ -90,20 +91,13 @@ class TensorImpl {
       const std::shared_ptr<compatible_py::BlobObject>& blob_object) = 0;
 
  protected:
-  TensorImpl(bool requires_grad, bool is_leaf, bool retain_grad)
-      : requires_grad_(requires_grad),
-        is_leaf_(is_leaf),
-        retain_grad_(retain_grad),
-        now_grad_arg_(new TensorArg) {}
+  TensorImpl(bool requires_grad, bool is_leaf)
+      : autograd_meta_(new AutogradMeta(requires_grad, is_leaf)) {}
   Maybe<void> SyncBlobObject2Attributes(
       const std::shared_ptr<compatible_py::BlobObject>& blob_object);
 
-  // For autograd
-  bool requires_grad_;
-  bool is_leaf_;
-  bool retain_grad_;
-  std::shared_ptr<Tensor> acc_grad_;
-  std::shared_ptr<TensorArg> now_grad_arg_;
+ protected:
+  std::shared_ptr<AutogradMeta> autograd_meta_;
 };
 
 class MirroredTensorImpl : public TensorImpl {
@@ -123,9 +117,8 @@ class MirroredTensorImpl : public TensorImpl {
       std::shared_ptr<vm::EagerBlobObject> eager_blob_object) = 0;
 
  protected:
-  MirroredTensorImpl(const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf,
-                     bool retain_grad)
-      : TensorImpl(requires_grad, is_leaf, retain_grad) {
+  MirroredTensorImpl(const std::shared_ptr<const Device>& device, bool requires_grad, bool is_leaf)
+      : TensorImpl(requires_grad, is_leaf) {
     set_device(device);
   }
 
@@ -152,8 +145,8 @@ class ConsistentTensorImpl : public TensorImpl {
 
  protected:
   ConsistentTensorImpl(const std::shared_ptr<const ParallelDesc>& parallel_desc, bool requires_grad,
-                       bool is_leaf, bool retain_grad)
-      : TensorImpl(requires_grad, is_leaf, retain_grad), parallel_desc_(parallel_desc) {}
+                       bool is_leaf)
+      : TensorImpl(requires_grad, is_leaf), parallel_desc_(parallel_desc) {}
 
   const std::shared_ptr<const Device> device_;  // always nullptr
   std::shared_ptr<const ParallelDesc> parallel_desc_;
@@ -165,10 +158,8 @@ class LazyMirroredTensorImpl final : public MirroredTensorImpl {
   LazyMirroredTensorImpl(const std::shared_ptr<const Shape>& shape,
                          const std::shared_ptr<const DType>& dtype,
                          const std::shared_ptr<const Device>& device, bool requires_grad,
-                         bool is_leaf, bool retain_grad)
-      : MirroredTensorImpl(device, requires_grad, is_leaf, retain_grad),
-        shape_(shape),
-        dtype_(dtype) {}
+                         bool is_leaf)
+      : MirroredTensorImpl(device, requires_grad, is_leaf), shape_(shape), dtype_(dtype) {}
   ~LazyMirroredTensorImpl() override = default;
 
   // Getters
@@ -211,7 +202,12 @@ class EagerMirroredTensorImpl final : public MirroredTensorImpl {
   OF_DISALLOW_COPY_AND_MOVE(EagerMirroredTensorImpl);
   EagerMirroredTensorImpl(const std::shared_ptr<vm::EagerBlobObject> eager_blob_object,
                           const std::shared_ptr<const Device>& device, bool requires_grad,
-                          bool is_leaf, bool retain_grad);
+                          bool is_leaf);
+  EagerMirroredTensorImpl(const std::shared_ptr<const Shape>& shape,
+                          const std::shared_ptr<const DType>& dtype,
+                          const std::shared_ptr<const Device>& device,
+                          const std::shared_ptr<TensorStorage>& tensor_storage, bool requires_grad,
+                          bool is_leaf);
   ~EagerMirroredTensorImpl() override;
 
   // Getters
@@ -257,8 +253,8 @@ class LazyConsistentTensorImpl final : public ConsistentTensorImpl {
                            const std::shared_ptr<const DType>& dtype,
                            const std::shared_ptr<const compatible_py::Distribute>& distribute,
                            const std::shared_ptr<const ParallelDesc>& parallel_desc,
-                           bool requires_grad, bool is_leaf, bool retain_grad)
-      : ConsistentTensorImpl(parallel_desc, requires_grad, is_leaf, retain_grad),
+                           bool requires_grad, bool is_leaf)
+      : ConsistentTensorImpl(parallel_desc, requires_grad, is_leaf),
         shape_(shape),
         dtype_(dtype),
         distribute_(distribute) {}
@@ -308,8 +304,8 @@ class EagerConsistentTensorImpl final : public ConsistentTensorImpl {
                             const std::shared_ptr<const DType>& dtype,
                             const std::shared_ptr<const compatible_py::Distribute>& distribute,
                             const std::shared_ptr<const ParallelDesc>& parallel_desc,
-                            bool requires_grad, bool is_leaf, bool retain_grad)
-      : ConsistentTensorImpl(parallel_desc, requires_grad, is_leaf, retain_grad),
+                            bool requires_grad, bool is_leaf)
+      : ConsistentTensorImpl(parallel_desc, requires_grad, is_leaf),
         shape_(shape),
         dtype_(dtype),
         distribute_(distribute) {}
diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py
index 96f6bc70c41c704a868caf139fd63b6434153b12..cdf4a7ed644b4317a1e37700463353c71ddbd939 100644
--- a/oneflow/python/framework/tensor.py
+++ b/oneflow/python/framework/tensor.py
@@ -82,7 +82,6 @@ class Tensor:
         dtype=None,
         device=None,
         requires_grad=False,
-        retain_grad=False,
         placement=None,
         sbp=None,
         is_consistent=False,
@@ -110,7 +109,6 @@ class Tensor:
                 dtype=dtype,
                 device=device,
                 requires_grad=requires_grad,
-                retain_grad=retain_grad,
                 placement=placement,
                 sbp=sbp,
                 is_consistent=is_consistent,
@@ -124,7 +122,6 @@ class Tensor:
                 dtype,
                 device=device,
                 requires_grad=requires_grad,
-                retain_grad=retain_grad,
                 placement=placement,
                 sbp=sbp,
                 is_consistent=is_consistent,
@@ -511,7 +508,6 @@ class Tensor:
         dtype=None,
         device=None,
         requires_grad=False,
-        retain_grad=False,
         placement=None,
         sbp=None,
         is_consistent=False,
@@ -533,7 +529,6 @@ class Tensor:
             dtype,
             device=device,
             requires_grad=requires_grad,
-            retain_grad=retain_grad,
             placement=placement,
             sbp=sbp,
             is_consistent=is_consistent,
@@ -549,7 +544,6 @@ class UndeterminedTensor:
         dtype,
         device=None,
         requires_grad=False,
-        retain_grad=False,
         placement=None,
         sbp=None,
         is_consistent=False,
@@ -573,7 +567,6 @@ class UndeterminedTensor:
         self.dtype = dtype
         self.device = device
         self.requires_grad = requires_grad
-        self.retain_grad = retain_grad
         self.placement = placement
         self.sbp = sbp
         self.is_consistent = is_consistent
@@ -623,7 +616,6 @@ def _default_initializer_for_determining(tensor):
             undetermined_tensor.is_lazy,
             undetermined_tensor.requires_grad,
             True,
-            undetermined_tensor.retain_grad,
         )
         determined_tensor._set_blob_object(
             _create_blob_object(
@@ -643,7 +635,6 @@ def _default_initializer_for_determining(tensor):
             undetermined_tensor.is_lazy,
             undetermined_tensor.requires_grad,
             True,
-            undetermined_tensor.retain_grad,
         )
         _init_eager_local_tensor_by_initializer_conf(
             determined_tensor, undetermined_tensor.data_initializer
@@ -682,7 +673,6 @@ def _numpy_initializer_for_determining(tensor):
             undetermined_tensor.is_lazy,
             undetermined_tensor.requires_grad,
             True,
-            undetermined_tensor.retain_grad,
         )
         determined_tensor._set_blob_object(blob.blob_object)
     else:
@@ -693,7 +683,6 @@ def _numpy_initializer_for_determining(tensor):
             undetermined_tensor.is_lazy,
             undetermined_tensor.requires_grad,
             True,
-            undetermined_tensor.retain_grad,
         )
         _copy_from_numpy_to_eager_local_tensor(determined_tensor, numpy_data)