diff --git a/oneflow/api/python/framework/tensor.cpp b/oneflow/api/python/framework/tensor.cpp
index 97a83ae9f59a9beaa3814206e4a9a55d40d41169..a9c191af02b1ddab8e8d0b52244c14fc72b6a459 100644
--- a/oneflow/api/python/framework/tensor.cpp
+++ b/oneflow/api/python/framework/tensor.cpp
@@ -38,42 +38,31 @@ namespace one {
 
 namespace {
 
-template<typename T>
-const DType* GetTensorDType(const T& tensor) {
+const DType* GetTensorDType(const Tensor& tensor) {
   return DType::Get(tensor.dtype()).GetOrThrow().get();
 }
 
-template<typename T>
-struct TensorExportUtil final {};
-
-template<>
-struct TensorExportUtil<MirroredTensor> final {
-  static std::shared_ptr<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,
-                                                    const DType* dtype,
-                                                    const Symbol<Device>& device, bool is_lazy,
-                                                    bool requires_grad, bool is_leaf) {
-    return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad,
-                                      is_leaf)
-        .GetPtrOrThrow();
-  }
-};
-
-template<>
-struct TensorExportUtil<ConsistentTensor> final {
-  static std::shared_ptr<ConsistentTensor> MakeTensor(
-      const std::shared_ptr<const Shape>& shape, const DType* dtype,
-      const std::shared_ptr<const cfg::ParallelDistribution>& parallel_distribution,
-      const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad,
-      bool is_leaf) {
-    return ConsistentTensor::MakeTensor(shape, dtype->data_type(), SymbolOf(*parallel_distribution),
-                                        SymbolOf(*parallel_desc), is_lazy, requires_grad, is_leaf)
-        .GetPtrOrThrow();
-  }
-};
+std::shared_ptr<Tensor> MakeLocalTensor(const std::shared_ptr<const Shape>& shape,
+                                        const DType* dtype, const Symbol<Device>& device,
+                                        bool is_lazy, bool requires_grad, bool is_leaf) {
+  return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad,
+                                    is_leaf)
+      .GetPtrOrThrow();
+}
 
-namespace {
+std::shared_ptr<Tensor> MakeConsistentTensor(
+    const std::shared_ptr<const Shape>& shape, const DType* dtype,
+    Symbol<cfg::ParallelDistribution>& parallel_distribution, Symbol<ParallelDesc> parallel_desc,
+    bool is_lazy, bool requires_grad, bool is_leaf) {
+  return ConsistentTensor::MakeTensor(shape, dtype->data_type(), parallel_distribution,
+                                      parallel_desc, is_lazy, requires_grad, is_leaf)
+      .GetPtrOrThrow();
+}
 
-Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tensor) {
+Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
+  const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t);
+  CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only";
+  CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
   JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
     JUST(builder->AccessBlobByCallback(
         tensor,
@@ -84,19 +73,21 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tens
         "mut"));
     return Maybe<void>::Ok();
   }));
-
   return Maybe<void>::Ok();
 }
 
-void ApiEagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tensor) {
+void ApiEagerMirroredTensorZeros(const std::shared_ptr<Tensor>& tensor) {
   return EagerMirroredTensorZeros(tensor).GetOrThrow();
 }
 
 template<typename T>
-Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<MirroredTensor>& tensor,
+Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<Tensor>& t,
                                               py::array_t<T> array,
                                               void (*Copy)(uint64_t, py::array_t<T>),
                                               const std::string& modifier) {
+  const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t);
+  CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only";
+  CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
   std::atomic<bool> synced(false);
 
   JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
@@ -118,15 +109,13 @@ Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<MirroredTens
 }
 
 template<typename T>
-void ApiCopyMirroredTensorToNumpy(const std::shared_ptr<MirroredTensor>& tensor,
-                                  py::array_t<T> array) {
+void ApiCopyMirroredTensorToNumpy(const std::shared_ptr<Tensor>& tensor, py::array_t<T> array) {
   return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyToBuffer, "const")
       .GetOrThrow();
 }
 
 template<typename T>
-void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr<MirroredTensor>& tensor,
-                                    py::array_t<T> array) {
+void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr<Tensor>& tensor, py::array_t<T> array) {
   return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyFromBuffer, "mut")
       .GetOrThrow();
 }
@@ -161,21 +150,22 @@ const std::string& ApiGetCopyMirroredTensorFromNumpyFuncName(const Tensor& tenso
   return *GetCopyMirroredTensorFromNumpyFuncName(tensor.dtype()).GetPtrOrThrow();
 }
 
-Symbol<Device> TensorGetDevice(const MirroredTensor& tensor) {
-  return tensor.device().GetOrThrow();
-}
+Symbol<Device> TensorGetDevice(const Tensor& tensor) { return tensor.device().GetOrThrow(); }
 
-std::shared_ptr<const ParallelDesc> TensorGetParallelDesc(const ConsistentTensor& tensor) {
-  return tensor.parallel_desc().GetOrThrow().shared_from_symbol();
+Symbol<ParallelDesc> TensorGetParallelDesc(const Tensor& tensor) {
+  return tensor.parallel_desc().GetOrThrow();
 }
 
-std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesAndDTypes(
-    const std::shared_ptr<MirroredTensor>& tensor) {
+Maybe<std::tuple<std::vector<Shape>, std::vector<const DType*>>>
+MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) {
+  const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t);
+  CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only";
+  CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
   std::vector<Shape> shapes;
   std::vector<const DType*> dtypes;
   std::atomic<bool> synced(false);
 
-  CHECK_JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
+  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
     JUST(builder->AccessBlobByCallback(
         tensor, [&synced](uint64_t of_blob_ptr) { synced = true; }, "const"));
     return Maybe<void>::Ok();
@@ -185,7 +175,7 @@ std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesA
     while (!synced) {}
   });
 
-  const Blob& blob = CHECK_JUST(tensor->eager_blob_object())->blob();
+  const Blob& blob = JUST(tensor->eager_blob_object())->blob();
   const Shape& blob_shape = blob.static_shape();
   const auto* tensor_buffer_ptr = blob.dptr<TensorBuffer>();
   for (int64_t i = 0; i < blob_shape.elem_cnt(); ++i) {
@@ -193,63 +183,56 @@ std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesA
     shapes.push_back(tensor_buffer->shape());
     dtypes.push_back(DType::Get(tensor_buffer->data_type()).GetOrThrow().get());
   }
-
   return std::make_tuple(shapes, dtypes);
 }
 
-}  // namespace
-
-void SpecializedDef(py::class_<MirroredTensor, Tensor, std::shared_ptr<MirroredTensor>>* api) {
-  using T = MirroredTensor;
-  api->def_property_readonly("device", &TensorGetDevice);
-  api->def_property_readonly("data", &T::data);
-  api->def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes);
-#define DEFINE_TENSOR_METHOD(T, type_proto)                         \
-  api->def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>); \
-  api->def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>);
-  OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ);
-
-#undef DEFINE_TENSOR_METHOD
-  api->def("_get_copy_mirrored_tensor_to_numpy_func_name",
-           &ApiGetCopyMirroredTensorToNumpyFuncName);
-  api->def("_get_copy_mirrored_tensor_from_numpy_func_name",
-           &ApiGetCopyMirroredTensorFromNumpyFuncName);
-  api->def("zeros_", &ApiEagerMirroredTensorZeros);
-  api->def("_register_hook",
-           [](const std::shared_ptr<MirroredTensor>& self, const AutogradMeta::Hook& hook) -> void {
-             if (!self->grad_fn_node()) { CHECK_JUST(AddAccumulateFunctionNode(self)); }
-             self->mut_autograd_meta()->add_hook(hook);
-           });
+std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesAndDTypes(
+    const std::shared_ptr<Tensor>& tensor) {
+  return MaybeGetTensorBufferShapesAndDTypes(tensor).GetOrThrow();
 }
 
-void SpecializedDef(py::class_<ConsistentTensor, Tensor, std::shared_ptr<ConsistentTensor>>* api) {
-  api->def_property_readonly("placement", &TensorGetParallelDesc);
+Maybe<void> RegisterTensorHook(const std::shared_ptr<Tensor>& self,
+                               const AutogradMeta::Hook& hook) {
+  if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); }
+  self->mut_autograd_meta()->add_hook(hook);
+  return Maybe<void>::Ok();
+}
+void ApiRegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMeta::Hook& hook) {
+  return RegisterTensorHook(self, hook).GetOrThrow();
 }
 
-template<typename T>
-void ExportTensor(py::module& m, const char* name) {
-  py::class_<T, Tensor, std::shared_ptr<T>> tensor_api(m, name);
-  tensor_api
-      .def(py::init(&TensorExportUtil<T>::MakeTensor))
+}  // namespace
+
+ONEFLOW_API_PYBIND11_MODULE("", m) {
+  py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor")
+      .def(py::init(&MakeLocalTensor))
+      .def(py::init(&MakeConsistentTensor))
       // Properties of pytorch
-      .def_property_readonly("shape", &T::shape)
-      .def_property_readonly("dtype", &GetTensorDType<T>)
-      .def_property_readonly("is_cuda", &T::is_cuda)
-      .def_property_readonly("grad", [](const T& t) { return t.api_acc_grad().GetPtrOrThrow(); })
+      .def_property_readonly("shape", &Tensor::shape)
+      .def_property_readonly("dtype", &GetTensorDType)
+      .def_property_readonly("is_cuda", &Tensor::is_cuda)
+      .def_property_readonly("grad",
+                             [](const Tensor& t) -> std::shared_ptr<Tensor> {
+                               if (t.has_autograd_meta()) {
+                                 return t.acc_grad().GetPtrOrThrow();
+                               } else {
+                                 return std::shared_ptr<Tensor>();
+                               }
+                             })
       // setter of grad
       .def("set_grad",
-           [](T& t, const std::shared_ptr<T>& grad) {
+           [](Tensor& t, const std::shared_ptr<Tensor>& grad) {
              if (t.is_leaf()) {
-               t.set_acc_grad(grad);
+               t.set_acc_grad(grad).GetOrThrow();
              } else {
                throw std::runtime_error("You can only change gradient of leaf tensors.");
              }
            })
-      .def_property_readonly("grad_fn", &T::grad_fn_node)
-      .def_property_readonly("is_leaf", &T::is_leaf)
+      .def_property_readonly("grad_fn", &Tensor::grad_fn_node)
+      .def_property_readonly("is_leaf", &Tensor::is_leaf)
       .def_property(
-          "requires_grad", &T::requires_grad,
-          [](T& t, bool requires_grad) {
+          "requires_grad", &Tensor::requires_grad,
+          [](Tensor& t, bool requires_grad) {
             if (t.is_leaf()) {
               t.set_requires_grad(requires_grad);
             } else {
@@ -258,23 +241,32 @@ void ExportTensor(py::module& m, const char* name) {
           })
       // Methods of pytorch
       .def("retain_grad",
-           [](T& t) {
+           [](Tensor& t) {
              if (!t.is_leaf()) { t.set_retain_grad(true).GetOrThrow(); }
            })
-      .def("detach", [](const T& t) { return t.api_detach().GetPtrOrThrow(); })
-      .def("clone", [](const T& t) { return t.api_clone().GetPtrOrThrow(); })
+      .def("detach", [](const Tensor& t) { return t.detach().GetPtrOrThrow(); })
+      .def("clone", [](const Tensor& t) { return t.clone().GetPtrOrThrow(); })
       // OneFlow tensor properties other than pytorch tensor
-      .def_property_readonly("is_lazy", &T::is_lazy)
-      .def_property_readonly("is_consistent", &T::is_consistent);
-  SpecializedDef(&tensor_api);
-}
-
-}  // namespace
-
-ONEFLOW_API_PYBIND11_MODULE("", m) {
-  py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor");
-  ExportTensor<MirroredTensor>(m, "LocalTensor");
-  ExportTensor<ConsistentTensor>(m, "ConsistentTensor");
+      .def_property_readonly("is_lazy", &Tensor::is_lazy)
+      .def_property_readonly("is_eager", &Tensor::is_eager)
+      .def_property_readonly("is_consistent", &Tensor::is_consistent)
+      .def_property_readonly("is_local", &Tensor::is_local)
+      .def("zeros_", &ApiEagerMirroredTensorZeros)
+      .def("_register_hook", &ApiRegisterTensorHook)
+      // local tensor only
+      .def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes)
+      .def_property_readonly("device", &TensorGetDevice)
+      .def_property_readonly("data", &Tensor::data)
+#define DEFINE_TENSOR_METHOD(T, type_proto)                    \
+  .def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>) \
+      .def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>)
+          OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ)
+#undef DEFINE_TENSOR_METHOD
+      .def("_get_copy_mirrored_tensor_to_numpy_func_name", &ApiGetCopyMirroredTensorToNumpyFuncName)
+      .def("_get_copy_mirrored_tensor_from_numpy_func_name",
+           &ApiGetCopyMirroredTensorFromNumpyFuncName)
+      // consistent tensor only
+      .def_property_readonly("placement", &TensorGetParallelDesc);
 }
 
 }  // namespace one
diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp
index 4fccd92de1c87fe0947c2d89d9b5e49fb2dd7f7d..920486f8313f699a8ea659d29a44268098d842e8 100644
--- a/oneflow/core/framework/tensor.cpp
+++ b/oneflow/core/framework/tensor.cpp
@@ -67,13 +67,14 @@ int64_t MirroredTensor::dim(int64_t index) const { return shape()->At(index); }
 
 int64_t MirroredTensor::nelement() const { return shape()->elem_cnt(); }
 
-std::shared_ptr<MirroredTensor> MirroredTensor::data() const {
+std::shared_ptr<Tensor> MirroredTensor::data() const {
   std::shared_ptr<MirroredTensor> t = std::make_shared<MirroredTensor>(impl_);
   return t;
 }
 
-Maybe<MirroredTensor> MirroredTensor::api_detach() const {
-  return std::make_shared<MirroredTensor>(JUST(impl_->detach()));
+Maybe<Tensor> MirroredTensor::detach() const {
+  std::shared_ptr<Tensor> tensor = std::make_shared<MirroredTensor>(JUST(impl_->detach()));
+  return tensor;
 }
 
 Maybe<Tensor> MirroredTensor::clone() const {
@@ -117,13 +118,13 @@ int64_t ConsistentTensor::nelement() const { return shape()->elem_cnt(); }
 
 int64_t ConsistentTensor::ndim() const { return shape()->NumAxes(); }
 
-std::shared_ptr<ConsistentTensor> ConsistentTensor::data() const {
+std::shared_ptr<Tensor> ConsistentTensor::data() const {
   std::shared_ptr<ConsistentTensor> t = std::make_shared<ConsistentTensor>(impl_);
   return t;
 }
 
-Maybe<ConsistentTensor> ConsistentTensor::api_detach() const {
-  std::shared_ptr<ConsistentTensor> t = std::make_shared<ConsistentTensor>(impl_);
+Maybe<Tensor> ConsistentTensor::detach() const {
+  std::shared_ptr<Tensor> t = std::make_shared<ConsistentTensor>(impl_);
   return t;
 }
 
diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h
index fa3f744ee73143621440379361db3e40c6954891..b81a32107cb7434a650bc28eadfe558fdf3e688e 100644
--- a/oneflow/core/framework/tensor.h
+++ b/oneflow/core/framework/tensor.h
@@ -48,9 +48,13 @@ class Tensor {
   virtual Maybe<Symbol<cfg::ParallelDistribution>> parallel_distribution() const = 0;
   virtual Maybe<Symbol<ParallelDesc>> parallel_desc() const = 0;
   virtual Maybe<Symbol<Device>> device() const = 0;
-  virtual Maybe<Symbol<Device>*> mut_device() { OF_UNIMPLEMENTED(); }
+  virtual Maybe<Symbol<Device>*> mut_device() = 0;
+  virtual int64_t ndim() const = 0;
+  virtual bool is_cuda() const = 0;
   virtual bool is_consistent() const = 0;
+  virtual bool is_local() const { return !is_consistent(); }
   virtual bool is_lazy() const = 0;
+  virtual bool is_eager() const { return !is_lazy(); }
   virtual const TensorMeta& tensor_meta() const = 0;
   virtual Maybe<Symbol<ConsistentTensorMeta>> consistent_tensor_meta() const { OF_UNIMPLEMENTED(); }
 
@@ -81,6 +85,7 @@ class Tensor {
   virtual Maybe<TensorArg> now_grad_arg() const = 0;
   virtual Maybe<Tensor> detach() const = 0;
   virtual Maybe<Tensor> clone() const = 0;
+  virtual std::shared_ptr<Tensor> data() const = 0;
 
   // Setters for autograd
   virtual void set_requires_grad(bool requires_grad) = 0;
@@ -106,8 +111,6 @@ class TensorIf : public Tensor {
   virtual ~TensorIf() = default;
 
   // Getters
-  virtual int64_t ndim() const = 0;
-  virtual bool is_cuda() const = 0;
   virtual int64_t nelement() const = 0;
   virtual int64_t dim(int64_t index) const = 0;
 
@@ -115,15 +118,6 @@ class TensorIf : public Tensor {
   // acc_grad is tensor's accumulated grad in more than once backward operation,
   // and now_grad_arg is temporary grad to shared data with different FunctionNode
   std::shared_ptr<const FunctionNode> grad_fn_node() const override { return grad_fn_node_; }
-  // used by pybind11 only
-  Maybe<DerivedT> api_acc_grad() const {
-    if (has_autograd_meta()) {
-      const std::shared_ptr<Tensor>& tensor = JUST(acc_grad());
-      return cast_for_api(tensor);
-    } else {
-      return std::shared_ptr<DerivedT>();
-    }
-  }
 
   // Setters for autograd
   void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) override {
@@ -131,29 +125,9 @@ class TensorIf : public Tensor {
   }
   const std::shared_ptr<FunctionNode>& mut_grad_fn_node() override { return grad_fn_node_; }
 
-  Maybe<Tensor> detach() const override {
-    return std::static_pointer_cast<Tensor>(JUST(api_detach()));
-  }
-
-  // Operators for tensor
-  // used by pybind11 only
-  virtual Maybe<DerivedT> api_detach() const = 0;
-  Maybe<DerivedT> api_clone() const {
-    const std::shared_ptr<Tensor>& tensor = JUST(clone());
-    return cast_for_api(tensor);
-  }
-
  protected:
   TensorIf() = default;
   std::shared_ptr<FunctionNode> grad_fn_node_;
-
- private:
-  Maybe<DerivedT> cast_for_api(const std::shared_ptr<Tensor>& tensor) const {
-    if (!tensor) { return std::shared_ptr<DerivedT>(); }
-    const auto& ptr = std::dynamic_pointer_cast<DerivedT>(tensor);
-    CHECK_OR_RETURN(ptr) << Error::ValueError("Tensor Cast Error");
-    return ptr;
-  }
 };
 
 class MirroredTensor final : public TensorIf<MirroredTensor>,
@@ -179,7 +153,7 @@ class MirroredTensor final : public TensorIf<MirroredTensor>,
   bool is_cuda() const override;
   int64_t dim(int64_t index) const override;
   int64_t nelement() const override;
-  std::shared_ptr<MirroredTensor> data() const;
+  std::shared_ptr<Tensor> data() const override;
   const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); }
 
   // Getters valid only for EagerMirroredTensor
@@ -216,7 +190,7 @@ class MirroredTensor final : public TensorIf<MirroredTensor>,
   }
 
   // Operators for tensor
-  Maybe<MirroredTensor> api_detach() const override;
+  Maybe<Tensor> detach() const override;
   Maybe<Tensor> clone() const override;
 
   static Maybe<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape, DataType dtype,
@@ -251,6 +225,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
   }
   Maybe<Symbol<ParallelDesc>> parallel_desc() const override { return impl_->parallel_desc(); }
   Maybe<Symbol<Device>> device() const override { OF_UNIMPLEMENTED(); }
+  Maybe<Symbol<Device>*> mut_device() override { OF_UNIMPLEMENTED(); }
   bool is_lazy() const override { return impl_->is_lazy(); }
   bool is_consistent() const override { return true; }
   Maybe<Symbol<cfg::ParallelDistribution>> consumer_parallel_distribution_constraint()
@@ -264,7 +239,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
   bool is_cuda() const override;
   int64_t dim(int64_t index) const override;
   int64_t nelement() const override;
-  std::shared_ptr<ConsistentTensor> data() const;
+  std::shared_ptr<Tensor> data() const override;
 
   // Getters valid only for EagerMirroredTensor
   Maybe<vm::EagerBlobObject> eager_blob_object() const override {
@@ -308,7 +283,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
   }
 
   // Operators for tensor
-  virtual Maybe<ConsistentTensor> api_detach() const override;
+  Maybe<Tensor> detach() const override;
   Maybe<Tensor> clone() const override { return Error::Unimplemented(); }
 
   static Maybe<ConsistentTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,
diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py
index 604809852fbe522e8e09d2398d165c942a9dc9a5..12307e7c03f4682d40ce0f1a8e9c6111b129efe3 100644
--- a/oneflow/python/framework/tensor.py
+++ b/oneflow/python/framework/tensor.py
@@ -39,7 +39,7 @@ def register_local_tensor_method(name=None):
             op_name = method.__name__
         else:
             op_name = name
-        setattr(oneflow._oneflow_internal.LocalTensor, op_name, method)
+        setattr(oneflow._oneflow_internal.Tensor, op_name, method)
         return method
 
     return decorator
@@ -868,7 +868,7 @@ def _default_initializer_for_determining(tensor):
     else:
         shape = undetermined_tensor.shape
         dtype = undetermined_tensor.dtype
-        determined_tensor = oneflow._oneflow_internal.LocalTensor(
+        determined_tensor = oneflow._oneflow_internal.Tensor(
             shape,
             dtype,
             undetermined_tensor.device,
@@ -891,7 +891,7 @@ def _numpy_initializer_for_determining(tensor):
     if undetermined_tensor.is_consistent:
         raise NotImplementedError()
     else:
-        determined_tensor = oneflow._oneflow_internal.LocalTensor(
+        determined_tensor = oneflow._oneflow_internal.Tensor(
             undetermined_tensor.shape,
             undetermined_tensor.dtype,
             undetermined_tensor.device,
@@ -913,13 +913,7 @@ def _input_args_is_numpy(*args):
 
 
 def _input_args_is_consistent_or_local(*args):
-    return len(args) == 1 and isinstance(
-        args[0],
-        (
-            oneflow._oneflow_internal.ConsistentTensor,
-            oneflow._oneflow_internal.LocalTensor,
-        ),
-    )
+    return len(args) == 1 and isinstance(args[0], oneflow._oneflow_internal.Tensor)
 
 
 def _input_args_is_tensor(*args):
@@ -937,7 +931,7 @@ def _input_args_is_shape(*args):
 def register_tensor_op(op_name):
     def set_tensor_op(method):
         setattr(Tensor, op_name, method)
-        setattr(oneflow._oneflow_internal.LocalTensor, op_name, method)
+        setattr(oneflow._oneflow_internal.Tensor, op_name, method)
         return method
 
     return set_tensor_op
diff --git a/oneflow/python/framework/tensor_tuple_util.py b/oneflow/python/framework/tensor_tuple_util.py
index 75408c04c863e6daff9cbcf70d4cb090a363d96f..a1cee4d66eb91c556e126a49c2f08ada41ae9df3 100644
--- a/oneflow/python/framework/tensor_tuple_util.py
+++ b/oneflow/python/framework/tensor_tuple_util.py
@@ -17,17 +17,17 @@ limitations under the License.
 import collections
 from typing import Union, Sequence, Tuple, Optional
 
-from oneflow.python.framework.tensor import Tensor
-from oneflow._oneflow_internal import TensorTuple, LocalTensor
+from oneflow.python.framework.tensor import Tensor as PyTensor
+from oneflow._oneflow_internal import TensorTuple, Tensor
 
 
 def convert_to_tensor_tuple(
-    args: Optional[Union[Tensor, Sequence[Tensor], LocalTensor, Sequence[LocalTensor]]]
+    args: Optional[Union[PyTensor, Sequence[PyTensor], Tensor, Sequence[Tensor]]]
 ):
     if args is None:
         return TensorTuple()
     elif isinstance(args, collections.abc.Sequence):
-        if isinstance(args[0], Tensor):
+        if isinstance(args[0], PyTensor):
             for tensor in args:
                 if not tensor.is_determined:
                     tensor.determine()
@@ -35,7 +35,7 @@ def convert_to_tensor_tuple(
         return TensorTuple(args)
     else:
         tensor_tuple = TensorTuple()
-        if isinstance(args, Tensor):
+        if isinstance(args, PyTensor):
             if not args.is_determined:
                 args.determine()
             tensor_tuple.append(args._local_or_consistent_tensor)
diff --git a/oneflow/python/nn/modules/eq.py b/oneflow/python/nn/modules/eq.py
index 7faefb3b672ae4ea63ef80df341b97c7cdff2916..458a9fb913db379f1dc2f37ffaefc5bfb27f90a2 100644
--- a/oneflow/python/nn/modules/eq.py
+++ b/oneflow/python/nn/modules/eq.py
@@ -25,7 +25,7 @@ class Eq(Module):
 
     def forward(self, input, other):
         if isinstance(other, flow.Tensor) or isinstance(
-            other, flow._oneflow_internal.LocalTensor
+            other, flow._oneflow_internal.Tensor
         ):
             for i in range(len(input.size())):
                 assert (
diff --git a/oneflow/python/nn/modules/ne.py b/oneflow/python/nn/modules/ne.py
index 08345281602ac46168e0ade4eefb982e3b1e1dd8..74ad5552a661bd73949b3cb2cda24d2b854e31ab 100644
--- a/oneflow/python/nn/modules/ne.py
+++ b/oneflow/python/nn/modules/ne.py
@@ -25,7 +25,7 @@ class Ne(Module):
 
     def forward(self, input, other):
         if isinstance(other, flow.Tensor) or isinstance(
-            other, flow._oneflow_internal.LocalTensor
+            other, flow._oneflow_internal.Tensor
         ):
             for i in range(len(input.size())):
                 assert (