diff --git a/oneflow/api/python/framework/tensor.cpp b/oneflow/api/python/framework/tensor.cpp index 97a83ae9f59a9beaa3814206e4a9a55d40d41169..a9c191af02b1ddab8e8d0b52244c14fc72b6a459 100644 --- a/oneflow/api/python/framework/tensor.cpp +++ b/oneflow/api/python/framework/tensor.cpp @@ -38,42 +38,31 @@ namespace one { namespace { -template<typename T> -const DType* GetTensorDType(const T& tensor) { +const DType* GetTensorDType(const Tensor& tensor) { return DType::Get(tensor.dtype()).GetOrThrow().get(); } -template<typename T> -struct TensorExportUtil final {}; - -template<> -struct TensorExportUtil<MirroredTensor> final { - static std::shared_ptr<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape, - const DType* dtype, - const Symbol<Device>& device, bool is_lazy, - bool requires_grad, bool is_leaf) { - return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad, - is_leaf) - .GetPtrOrThrow(); - } -}; - -template<> -struct TensorExportUtil<ConsistentTensor> final { - static std::shared_ptr<ConsistentTensor> MakeTensor( - const std::shared_ptr<const Shape>& shape, const DType* dtype, - const std::shared_ptr<const cfg::ParallelDistribution>& parallel_distribution, - const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad, - bool is_leaf) { - return ConsistentTensor::MakeTensor(shape, dtype->data_type(), SymbolOf(*parallel_distribution), - SymbolOf(*parallel_desc), is_lazy, requires_grad, is_leaf) - .GetPtrOrThrow(); - } -}; +std::shared_ptr<Tensor> MakeLocalTensor(const std::shared_ptr<const Shape>& shape, + const DType* dtype, const Symbol<Device>& device, + bool is_lazy, bool requires_grad, bool is_leaf) { + return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad, + is_leaf) + .GetPtrOrThrow(); +} -namespace { +std::shared_ptr<Tensor> MakeConsistentTensor( + const std::shared_ptr<const Shape>& shape, const DType* dtype, + Symbol<cfg::ParallelDistribution>& parallel_distribution, Symbol<ParallelDesc> parallel_desc, + bool is_lazy, bool requires_grad, bool is_leaf) { + return ConsistentTensor::MakeTensor(shape, dtype->data_type(), parallel_distribution, + parallel_desc, is_lazy, requires_grad, is_leaf) + .GetPtrOrThrow(); +} -Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tensor) { +Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) { + const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t); + CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only"; + CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only"; JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { JUST(builder->AccessBlobByCallback( tensor, @@ -84,19 +73,21 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tens "mut")); return Maybe<void>::Ok(); })); - return Maybe<void>::Ok(); } -void ApiEagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tensor) { +void ApiEagerMirroredTensorZeros(const std::shared_ptr<Tensor>& tensor) { return EagerMirroredTensorZeros(tensor).GetOrThrow(); } template<typename T> -Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<MirroredTensor>& tensor, +Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<Tensor>& t, py::array_t<T> array, void (*Copy)(uint64_t, py::array_t<T>), const std::string& modifier) { + const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t); + CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only"; + CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only"; std::atomic<bool> synced(false); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { @@ -118,15 +109,13 @@ Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<MirroredTens } template<typename T> -void ApiCopyMirroredTensorToNumpy(const std::shared_ptr<MirroredTensor>& tensor, - py::array_t<T> array) { +void ApiCopyMirroredTensorToNumpy(const std::shared_ptr<Tensor>& tensor, py::array_t<T> array) { return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyToBuffer, "const") .GetOrThrow(); } template<typename T> -void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr<MirroredTensor>& tensor, - py::array_t<T> array) { +void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr<Tensor>& tensor, py::array_t<T> array) { return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyFromBuffer, "mut") .GetOrThrow(); } @@ -161,21 +150,22 @@ const std::string& ApiGetCopyMirroredTensorFromNumpyFuncName(const Tensor& tenso return *GetCopyMirroredTensorFromNumpyFuncName(tensor.dtype()).GetPtrOrThrow(); } -Symbol<Device> TensorGetDevice(const MirroredTensor& tensor) { - return tensor.device().GetOrThrow(); -} +Symbol<Device> TensorGetDevice(const Tensor& tensor) { return tensor.device().GetOrThrow(); } -std::shared_ptr<const ParallelDesc> TensorGetParallelDesc(const ConsistentTensor& tensor) { - return tensor.parallel_desc().GetOrThrow().shared_from_symbol(); +Symbol<ParallelDesc> TensorGetParallelDesc(const Tensor& tensor) { + return tensor.parallel_desc().GetOrThrow(); } -std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesAndDTypes( - const std::shared_ptr<MirroredTensor>& tensor) { +Maybe<std::tuple<std::vector<Shape>, std::vector<const DType*>>> +MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) { + const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t); + CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only"; + CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only"; std::vector<Shape> shapes; std::vector<const DType*> dtypes; std::atomic<bool> synced(false); - CHECK_JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { + JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> { JUST(builder->AccessBlobByCallback( tensor, [&synced](uint64_t of_blob_ptr) { synced = true; }, "const")); return Maybe<void>::Ok(); @@ -185,7 +175,7 @@ std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesA while (!synced) {} }); - const Blob& blob = CHECK_JUST(tensor->eager_blob_object())->blob(); + const Blob& blob = JUST(tensor->eager_blob_object())->blob(); const Shape& blob_shape = blob.static_shape(); const auto* tensor_buffer_ptr = blob.dptr<TensorBuffer>(); for (int64_t i = 0; i < blob_shape.elem_cnt(); ++i) { @@ -193,63 +183,56 @@ std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesA shapes.push_back(tensor_buffer->shape()); dtypes.push_back(DType::Get(tensor_buffer->data_type()).GetOrThrow().get()); } - return std::make_tuple(shapes, dtypes); } -} // namespace - -void SpecializedDef(py::class_<MirroredTensor, Tensor, std::shared_ptr<MirroredTensor>>* api) { - using T = MirroredTensor; - api->def_property_readonly("device", &TensorGetDevice); - api->def_property_readonly("data", &T::data); - api->def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes); -#define DEFINE_TENSOR_METHOD(T, type_proto) \ - api->def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>); \ - api->def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>); - OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ); - -#undef DEFINE_TENSOR_METHOD - api->def("_get_copy_mirrored_tensor_to_numpy_func_name", - &ApiGetCopyMirroredTensorToNumpyFuncName); - api->def("_get_copy_mirrored_tensor_from_numpy_func_name", - &ApiGetCopyMirroredTensorFromNumpyFuncName); - api->def("zeros_", &ApiEagerMirroredTensorZeros); - api->def("_register_hook", - [](const std::shared_ptr<MirroredTensor>& self, const AutogradMeta::Hook& hook) -> void { - if (!self->grad_fn_node()) { CHECK_JUST(AddAccumulateFunctionNode(self)); } - self->mut_autograd_meta()->add_hook(hook); - }); +std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesAndDTypes( + const std::shared_ptr<Tensor>& tensor) { + return MaybeGetTensorBufferShapesAndDTypes(tensor).GetOrThrow(); } -void SpecializedDef(py::class_<ConsistentTensor, Tensor, std::shared_ptr<ConsistentTensor>>* api) { - api->def_property_readonly("placement", &TensorGetParallelDesc); +Maybe<void> RegisterTensorHook(const std::shared_ptr<Tensor>& self, + const AutogradMeta::Hook& hook) { + if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); } + self->mut_autograd_meta()->add_hook(hook); + return Maybe<void>::Ok(); +} +void ApiRegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMeta::Hook& hook) { + return RegisterTensorHook(self, hook).GetOrThrow(); } -template<typename T> -void ExportTensor(py::module& m, const char* name) { - py::class_<T, Tensor, std::shared_ptr<T>> tensor_api(m, name); - tensor_api - .def(py::init(&TensorExportUtil<T>::MakeTensor)) +} // namespace + +ONEFLOW_API_PYBIND11_MODULE("", m) { + py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor") + .def(py::init(&MakeLocalTensor)) + .def(py::init(&MakeConsistentTensor)) // Properties of pytorch - .def_property_readonly("shape", &T::shape) - .def_property_readonly("dtype", &GetTensorDType<T>) - .def_property_readonly("is_cuda", &T::is_cuda) - .def_property_readonly("grad", [](const T& t) { return t.api_acc_grad().GetPtrOrThrow(); }) + .def_property_readonly("shape", &Tensor::shape) + .def_property_readonly("dtype", &GetTensorDType) + .def_property_readonly("is_cuda", &Tensor::is_cuda) + .def_property_readonly("grad", + [](const Tensor& t) -> std::shared_ptr<Tensor> { + if (t.has_autograd_meta()) { + return t.acc_grad().GetPtrOrThrow(); + } else { + return std::shared_ptr<Tensor>(); + } + }) // setter of grad .def("set_grad", - [](T& t, const std::shared_ptr<T>& grad) { + [](Tensor& t, const std::shared_ptr<Tensor>& grad) { if (t.is_leaf()) { - t.set_acc_grad(grad); + t.set_acc_grad(grad).GetOrThrow(); } else { throw std::runtime_error("You can only change gradient of leaf tensors."); } }) - .def_property_readonly("grad_fn", &T::grad_fn_node) - .def_property_readonly("is_leaf", &T::is_leaf) + .def_property_readonly("grad_fn", &Tensor::grad_fn_node) + .def_property_readonly("is_leaf", &Tensor::is_leaf) .def_property( - "requires_grad", &T::requires_grad, - [](T& t, bool requires_grad) { + "requires_grad", &Tensor::requires_grad, + [](Tensor& t, bool requires_grad) { if (t.is_leaf()) { t.set_requires_grad(requires_grad); } else { @@ -258,23 +241,32 @@ void ExportTensor(py::module& m, const char* name) { }) // Methods of pytorch .def("retain_grad", - [](T& t) { + [](Tensor& t) { if (!t.is_leaf()) { t.set_retain_grad(true).GetOrThrow(); } }) - .def("detach", [](const T& t) { return t.api_detach().GetPtrOrThrow(); }) - .def("clone", [](const T& t) { return t.api_clone().GetPtrOrThrow(); }) + .def("detach", [](const Tensor& t) { return t.detach().GetPtrOrThrow(); }) + .def("clone", [](const Tensor& t) { return t.clone().GetPtrOrThrow(); }) // OneFlow tensor properties other than pytorch tensor - .def_property_readonly("is_lazy", &T::is_lazy) - .def_property_readonly("is_consistent", &T::is_consistent); - SpecializedDef(&tensor_api); -} - -} // namespace - -ONEFLOW_API_PYBIND11_MODULE("", m) { - py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor"); - ExportTensor<MirroredTensor>(m, "LocalTensor"); - ExportTensor<ConsistentTensor>(m, "ConsistentTensor"); + .def_property_readonly("is_lazy", &Tensor::is_lazy) + .def_property_readonly("is_eager", &Tensor::is_eager) + .def_property_readonly("is_consistent", &Tensor::is_consistent) + .def_property_readonly("is_local", &Tensor::is_local) + .def("zeros_", &ApiEagerMirroredTensorZeros) + .def("_register_hook", &ApiRegisterTensorHook) + // local tensor only + .def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes) + .def_property_readonly("device", &TensorGetDevice) + .def_property_readonly("data", &Tensor::data) +#define DEFINE_TENSOR_METHOD(T, type_proto) \ + .def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>) \ + .def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>) + OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ) +#undef DEFINE_TENSOR_METHOD + .def("_get_copy_mirrored_tensor_to_numpy_func_name", &ApiGetCopyMirroredTensorToNumpyFuncName) + .def("_get_copy_mirrored_tensor_from_numpy_func_name", + &ApiGetCopyMirroredTensorFromNumpyFuncName) + // consistent tensor only + .def_property_readonly("placement", &TensorGetParallelDesc); } } // namespace one diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index 4fccd92de1c87fe0947c2d89d9b5e49fb2dd7f7d..920486f8313f699a8ea659d29a44268098d842e8 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -67,13 +67,14 @@ int64_t MirroredTensor::dim(int64_t index) const { return shape()->At(index); } int64_t MirroredTensor::nelement() const { return shape()->elem_cnt(); } -std::shared_ptr<MirroredTensor> MirroredTensor::data() const { +std::shared_ptr<Tensor> MirroredTensor::data() const { std::shared_ptr<MirroredTensor> t = std::make_shared<MirroredTensor>(impl_); return t; } -Maybe<MirroredTensor> MirroredTensor::api_detach() const { - return std::make_shared<MirroredTensor>(JUST(impl_->detach())); +Maybe<Tensor> MirroredTensor::detach() const { + std::shared_ptr<Tensor> tensor = std::make_shared<MirroredTensor>(JUST(impl_->detach())); + return tensor; } Maybe<Tensor> MirroredTensor::clone() const { @@ -117,13 +118,13 @@ int64_t ConsistentTensor::nelement() const { return shape()->elem_cnt(); } int64_t ConsistentTensor::ndim() const { return shape()->NumAxes(); } -std::shared_ptr<ConsistentTensor> ConsistentTensor::data() const { +std::shared_ptr<Tensor> ConsistentTensor::data() const { std::shared_ptr<ConsistentTensor> t = std::make_shared<ConsistentTensor>(impl_); return t; } -Maybe<ConsistentTensor> ConsistentTensor::api_detach() const { - std::shared_ptr<ConsistentTensor> t = std::make_shared<ConsistentTensor>(impl_); +Maybe<Tensor> ConsistentTensor::detach() const { + std::shared_ptr<Tensor> t = std::make_shared<ConsistentTensor>(impl_); return t; } diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index fa3f744ee73143621440379361db3e40c6954891..b81a32107cb7434a650bc28eadfe558fdf3e688e 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -48,9 +48,13 @@ class Tensor { virtual Maybe<Symbol<cfg::ParallelDistribution>> parallel_distribution() const = 0; virtual Maybe<Symbol<ParallelDesc>> parallel_desc() const = 0; virtual Maybe<Symbol<Device>> device() const = 0; - virtual Maybe<Symbol<Device>*> mut_device() { OF_UNIMPLEMENTED(); } + virtual Maybe<Symbol<Device>*> mut_device() = 0; + virtual int64_t ndim() const = 0; + virtual bool is_cuda() const = 0; virtual bool is_consistent() const = 0; + virtual bool is_local() const { return !is_consistent(); } virtual bool is_lazy() const = 0; + virtual bool is_eager() const { return !is_lazy(); } virtual const TensorMeta& tensor_meta() const = 0; virtual Maybe<Symbol<ConsistentTensorMeta>> consistent_tensor_meta() const { OF_UNIMPLEMENTED(); } @@ -81,6 +85,7 @@ class Tensor { virtual Maybe<TensorArg> now_grad_arg() const = 0; virtual Maybe<Tensor> detach() const = 0; virtual Maybe<Tensor> clone() const = 0; + virtual std::shared_ptr<Tensor> data() const = 0; // Setters for autograd virtual void set_requires_grad(bool requires_grad) = 0; @@ -106,8 +111,6 @@ class TensorIf : public Tensor { virtual ~TensorIf() = default; // Getters - virtual int64_t ndim() const = 0; - virtual bool is_cuda() const = 0; virtual int64_t nelement() const = 0; virtual int64_t dim(int64_t index) const = 0; @@ -115,15 +118,6 @@ class TensorIf : public Tensor { // acc_grad is tensor's accumulated grad in more than once backward operation, // and now_grad_arg is temporary grad to shared data with different FunctionNode std::shared_ptr<const FunctionNode> grad_fn_node() const override { return grad_fn_node_; } - // used by pybind11 only - Maybe<DerivedT> api_acc_grad() const { - if (has_autograd_meta()) { - const std::shared_ptr<Tensor>& tensor = JUST(acc_grad()); - return cast_for_api(tensor); - } else { - return std::shared_ptr<DerivedT>(); - } - } // Setters for autograd void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) override { @@ -131,29 +125,9 @@ class TensorIf : public Tensor { } const std::shared_ptr<FunctionNode>& mut_grad_fn_node() override { return grad_fn_node_; } - Maybe<Tensor> detach() const override { - return std::static_pointer_cast<Tensor>(JUST(api_detach())); - } - - // Operators for tensor - // used by pybind11 only - virtual Maybe<DerivedT> api_detach() const = 0; - Maybe<DerivedT> api_clone() const { - const std::shared_ptr<Tensor>& tensor = JUST(clone()); - return cast_for_api(tensor); - } - protected: TensorIf() = default; std::shared_ptr<FunctionNode> grad_fn_node_; - - private: - Maybe<DerivedT> cast_for_api(const std::shared_ptr<Tensor>& tensor) const { - if (!tensor) { return std::shared_ptr<DerivedT>(); } - const auto& ptr = std::dynamic_pointer_cast<DerivedT>(tensor); - CHECK_OR_RETURN(ptr) << Error::ValueError("Tensor Cast Error"); - return ptr; - } }; class MirroredTensor final : public TensorIf<MirroredTensor>, @@ -179,7 +153,7 @@ class MirroredTensor final : public TensorIf<MirroredTensor>, bool is_cuda() const override; int64_t dim(int64_t index) const override; int64_t nelement() const override; - std::shared_ptr<MirroredTensor> data() const; + std::shared_ptr<Tensor> data() const override; const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); } // Getters valid only for EagerMirroredTensor @@ -216,7 +190,7 @@ class MirroredTensor final : public TensorIf<MirroredTensor>, } // Operators for tensor - Maybe<MirroredTensor> api_detach() const override; + Maybe<Tensor> detach() const override; Maybe<Tensor> clone() const override; static Maybe<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape, DataType dtype, @@ -251,6 +225,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> { } Maybe<Symbol<ParallelDesc>> parallel_desc() const override { return impl_->parallel_desc(); } Maybe<Symbol<Device>> device() const override { OF_UNIMPLEMENTED(); } + Maybe<Symbol<Device>*> mut_device() override { OF_UNIMPLEMENTED(); } bool is_lazy() const override { return impl_->is_lazy(); } bool is_consistent() const override { return true; } Maybe<Symbol<cfg::ParallelDistribution>> consumer_parallel_distribution_constraint() @@ -264,7 +239,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> { bool is_cuda() const override; int64_t dim(int64_t index) const override; int64_t nelement() const override; - std::shared_ptr<ConsistentTensor> data() const; + std::shared_ptr<Tensor> data() const override; // Getters valid only for EagerMirroredTensor Maybe<vm::EagerBlobObject> eager_blob_object() const override { @@ -308,7 +283,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> { } // Operators for tensor - virtual Maybe<ConsistentTensor> api_detach() const override; + Maybe<Tensor> detach() const override; Maybe<Tensor> clone() const override { return Error::Unimplemented(); } static Maybe<ConsistentTensor> MakeTensor(const std::shared_ptr<const Shape>& shape, diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index 604809852fbe522e8e09d2398d165c942a9dc9a5..12307e7c03f4682d40ce0f1a8e9c6111b129efe3 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -39,7 +39,7 @@ def register_local_tensor_method(name=None): op_name = method.__name__ else: op_name = name - setattr(oneflow._oneflow_internal.LocalTensor, op_name, method) + setattr(oneflow._oneflow_internal.Tensor, op_name, method) return method return decorator @@ -868,7 +868,7 @@ def _default_initializer_for_determining(tensor): else: shape = undetermined_tensor.shape dtype = undetermined_tensor.dtype - determined_tensor = oneflow._oneflow_internal.LocalTensor( + determined_tensor = oneflow._oneflow_internal.Tensor( shape, dtype, undetermined_tensor.device, @@ -891,7 +891,7 @@ def _numpy_initializer_for_determining(tensor): if undetermined_tensor.is_consistent: raise NotImplementedError() else: - determined_tensor = oneflow._oneflow_internal.LocalTensor( + determined_tensor = oneflow._oneflow_internal.Tensor( undetermined_tensor.shape, undetermined_tensor.dtype, undetermined_tensor.device, @@ -913,13 +913,7 @@ def _input_args_is_numpy(*args): def _input_args_is_consistent_or_local(*args): - return len(args) == 1 and isinstance( - args[0], - ( - oneflow._oneflow_internal.ConsistentTensor, - oneflow._oneflow_internal.LocalTensor, - ), - ) + return len(args) == 1 and isinstance(args[0], oneflow._oneflow_internal.Tensor) def _input_args_is_tensor(*args): @@ -937,7 +931,7 @@ def _input_args_is_shape(*args): def register_tensor_op(op_name): def set_tensor_op(method): setattr(Tensor, op_name, method) - setattr(oneflow._oneflow_internal.LocalTensor, op_name, method) + setattr(oneflow._oneflow_internal.Tensor, op_name, method) return method return set_tensor_op diff --git a/oneflow/python/framework/tensor_tuple_util.py b/oneflow/python/framework/tensor_tuple_util.py index 75408c04c863e6daff9cbcf70d4cb090a363d96f..a1cee4d66eb91c556e126a49c2f08ada41ae9df3 100644 --- a/oneflow/python/framework/tensor_tuple_util.py +++ b/oneflow/python/framework/tensor_tuple_util.py @@ -17,17 +17,17 @@ limitations under the License. import collections from typing import Union, Sequence, Tuple, Optional -from oneflow.python.framework.tensor import Tensor -from oneflow._oneflow_internal import TensorTuple, LocalTensor +from oneflow.python.framework.tensor import Tensor as PyTensor +from oneflow._oneflow_internal import TensorTuple, Tensor def convert_to_tensor_tuple( - args: Optional[Union[Tensor, Sequence[Tensor], LocalTensor, Sequence[LocalTensor]]] + args: Optional[Union[PyTensor, Sequence[PyTensor], Tensor, Sequence[Tensor]]] ): if args is None: return TensorTuple() elif isinstance(args, collections.abc.Sequence): - if isinstance(args[0], Tensor): + if isinstance(args[0], PyTensor): for tensor in args: if not tensor.is_determined: tensor.determine() @@ -35,7 +35,7 @@ def convert_to_tensor_tuple( return TensorTuple(args) else: tensor_tuple = TensorTuple() - if isinstance(args, Tensor): + if isinstance(args, PyTensor): if not args.is_determined: args.determine() tensor_tuple.append(args._local_or_consistent_tensor) diff --git a/oneflow/python/nn/modules/eq.py b/oneflow/python/nn/modules/eq.py index 7faefb3b672ae4ea63ef80df341b97c7cdff2916..458a9fb913db379f1dc2f37ffaefc5bfb27f90a2 100644 --- a/oneflow/python/nn/modules/eq.py +++ b/oneflow/python/nn/modules/eq.py @@ -25,7 +25,7 @@ class Eq(Module): def forward(self, input, other): if isinstance(other, flow.Tensor) or isinstance( - other, flow._oneflow_internal.LocalTensor + other, flow._oneflow_internal.Tensor ): for i in range(len(input.size())): assert ( diff --git a/oneflow/python/nn/modules/ne.py b/oneflow/python/nn/modules/ne.py index 08345281602ac46168e0ade4eefb982e3b1e1dd8..74ad5552a661bd73949b3cb2cda24d2b854e31ab 100644 --- a/oneflow/python/nn/modules/ne.py +++ b/oneflow/python/nn/modules/ne.py @@ -25,7 +25,7 @@ class Ne(Module): def forward(self, input, other): if isinstance(other, flow.Tensor) or isinstance( - other, flow._oneflow_internal.LocalTensor + other, flow._oneflow_internal.Tensor ): for i in range(len(input.size())): assert (