Skip to content
Snippets Groups Projects
Unverified Commit da9e9de3 authored by Li Xinqi's avatar Li Xinqi Committed by GitHub
Browse files

export Tensor only to python (#5440)


* export Tensor only to python

* address review comments

* address review comments

* refine

* Update tensor_tuple_util.py

fix bug

* Update tensor.cpp

fix bug

* Update tensor.cpp

fix bug

* auto format by CI

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: default avatarclackhan <han_binbin@163.com>
Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
parent 5e29d7b8
No related branches found
No related tags found
No related merge requests found
......@@ -38,42 +38,31 @@ namespace one {
namespace {
template<typename T>
const DType* GetTensorDType(const T& tensor) {
const DType* GetTensorDType(const Tensor& tensor) {
return DType::Get(tensor.dtype()).GetOrThrow().get();
}
template<typename T>
struct TensorExportUtil final {};
template<>
struct TensorExportUtil<MirroredTensor> final {
static std::shared_ptr<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,
const DType* dtype,
const Symbol<Device>& device, bool is_lazy,
bool requires_grad, bool is_leaf) {
return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad,
is_leaf)
.GetPtrOrThrow();
}
};
template<>
struct TensorExportUtil<ConsistentTensor> final {
static std::shared_ptr<ConsistentTensor> MakeTensor(
const std::shared_ptr<const Shape>& shape, const DType* dtype,
const std::shared_ptr<const cfg::ParallelDistribution>& parallel_distribution,
const std::shared_ptr<const ParallelDesc>& parallel_desc, bool is_lazy, bool requires_grad,
bool is_leaf) {
return ConsistentTensor::MakeTensor(shape, dtype->data_type(), SymbolOf(*parallel_distribution),
SymbolOf(*parallel_desc), is_lazy, requires_grad, is_leaf)
.GetPtrOrThrow();
}
};
std::shared_ptr<Tensor> MakeLocalTensor(const std::shared_ptr<const Shape>& shape,
const DType* dtype, const Symbol<Device>& device,
bool is_lazy, bool requires_grad, bool is_leaf) {
return MirroredTensor::MakeTensor(shape, dtype->data_type(), device, is_lazy, requires_grad,
is_leaf)
.GetPtrOrThrow();
}
namespace {
std::shared_ptr<Tensor> MakeConsistentTensor(
const std::shared_ptr<const Shape>& shape, const DType* dtype,
Symbol<cfg::ParallelDistribution>& parallel_distribution, Symbol<ParallelDesc> parallel_desc,
bool is_lazy, bool requires_grad, bool is_leaf) {
return ConsistentTensor::MakeTensor(shape, dtype->data_type(), parallel_distribution,
parallel_desc, is_lazy, requires_grad, is_leaf)
.GetPtrOrThrow();
}
Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tensor) {
Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<Tensor>& t) {
const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t);
CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only";
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback(
tensor,
......@@ -84,19 +73,21 @@ Maybe<void> EagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tens
"mut"));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
void ApiEagerMirroredTensorZeros(const std::shared_ptr<MirroredTensor>& tensor) {
void ApiEagerMirroredTensorZeros(const std::shared_ptr<Tensor>& tensor) {
return EagerMirroredTensorZeros(tensor).GetOrThrow();
}
template<typename T>
Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<MirroredTensor>& tensor,
Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<Tensor>& t,
py::array_t<T> array,
void (*Copy)(uint64_t, py::array_t<T>),
const std::string& modifier) {
const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t);
CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only";
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
std::atomic<bool> synced(false);
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
......@@ -118,15 +109,13 @@ Maybe<void> CopyBetweenMirroredTensorAndNumpy(const std::shared_ptr<MirroredTens
}
template<typename T>
void ApiCopyMirroredTensorToNumpy(const std::shared_ptr<MirroredTensor>& tensor,
py::array_t<T> array) {
void ApiCopyMirroredTensorToNumpy(const std::shared_ptr<Tensor>& tensor, py::array_t<T> array) {
return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyToBuffer, "const")
.GetOrThrow();
}
template<typename T>
void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr<MirroredTensor>& tensor,
py::array_t<T> array) {
void ApiCopyMirroredTensorFromNumpy(const std::shared_ptr<Tensor>& tensor, py::array_t<T> array) {
return CopyBetweenMirroredTensorAndNumpy(tensor, array, OfBlob_CopyFromBuffer, "mut")
.GetOrThrow();
}
......@@ -161,21 +150,22 @@ const std::string& ApiGetCopyMirroredTensorFromNumpyFuncName(const Tensor& tenso
return *GetCopyMirroredTensorFromNumpyFuncName(tensor.dtype()).GetPtrOrThrow();
}
Symbol<Device> TensorGetDevice(const MirroredTensor& tensor) {
return tensor.device().GetOrThrow();
}
Symbol<Device> TensorGetDevice(const Tensor& tensor) { return tensor.device().GetOrThrow(); }
std::shared_ptr<const ParallelDesc> TensorGetParallelDesc(const ConsistentTensor& tensor) {
return tensor.parallel_desc().GetOrThrow().shared_from_symbol();
Symbol<ParallelDesc> TensorGetParallelDesc(const Tensor& tensor) {
return tensor.parallel_desc().GetOrThrow();
}
std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesAndDTypes(
const std::shared_ptr<MirroredTensor>& tensor) {
Maybe<std::tuple<std::vector<Shape>, std::vector<const DType*>>>
MaybeGetTensorBufferShapesAndDTypes(const std::shared_ptr<Tensor>& t) {
const auto& tensor = std::dynamic_pointer_cast<MirroredTensor>(t);
CHECK_NOTNULL_OR_RETURN(tensor) << "local tensors supported only";
CHECK_OR_RETURN(tensor->is_eager()) << "eager tensors supported only";
std::vector<Shape> shapes;
std::vector<const DType*> dtypes;
std::atomic<bool> synced(false);
CHECK_JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback(
tensor, [&synced](uint64_t of_blob_ptr) { synced = true; }, "const"));
return Maybe<void>::Ok();
......@@ -185,7 +175,7 @@ std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesA
while (!synced) {}
});
const Blob& blob = CHECK_JUST(tensor->eager_blob_object())->blob();
const Blob& blob = JUST(tensor->eager_blob_object())->blob();
const Shape& blob_shape = blob.static_shape();
const auto* tensor_buffer_ptr = blob.dptr<TensorBuffer>();
for (int64_t i = 0; i < blob_shape.elem_cnt(); ++i) {
......@@ -193,63 +183,56 @@ std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesA
shapes.push_back(tensor_buffer->shape());
dtypes.push_back(DType::Get(tensor_buffer->data_type()).GetOrThrow().get());
}
return std::make_tuple(shapes, dtypes);
}
} // namespace
void SpecializedDef(py::class_<MirroredTensor, Tensor, std::shared_ptr<MirroredTensor>>* api) {
using T = MirroredTensor;
api->def_property_readonly("device", &TensorGetDevice);
api->def_property_readonly("data", &T::data);
api->def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes);
#define DEFINE_TENSOR_METHOD(T, type_proto) \
api->def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>); \
api->def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>);
OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ);
#undef DEFINE_TENSOR_METHOD
api->def("_get_copy_mirrored_tensor_to_numpy_func_name",
&ApiGetCopyMirroredTensorToNumpyFuncName);
api->def("_get_copy_mirrored_tensor_from_numpy_func_name",
&ApiGetCopyMirroredTensorFromNumpyFuncName);
api->def("zeros_", &ApiEagerMirroredTensorZeros);
api->def("_register_hook",
[](const std::shared_ptr<MirroredTensor>& self, const AutogradMeta::Hook& hook) -> void {
if (!self->grad_fn_node()) { CHECK_JUST(AddAccumulateFunctionNode(self)); }
self->mut_autograd_meta()->add_hook(hook);
});
std::tuple<std::vector<Shape>, std::vector<const DType*>> GetTensorBufferShapesAndDTypes(
const std::shared_ptr<Tensor>& tensor) {
return MaybeGetTensorBufferShapesAndDTypes(tensor).GetOrThrow();
}
void SpecializedDef(py::class_<ConsistentTensor, Tensor, std::shared_ptr<ConsistentTensor>>* api) {
api->def_property_readonly("placement", &TensorGetParallelDesc);
Maybe<void> RegisterTensorHook(const std::shared_ptr<Tensor>& self,
const AutogradMeta::Hook& hook) {
if (!self->grad_fn_node()) { JUST(AddAccumulateFunctionNode(self)); }
self->mut_autograd_meta()->add_hook(hook);
return Maybe<void>::Ok();
}
void ApiRegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMeta::Hook& hook) {
return RegisterTensorHook(self, hook).GetOrThrow();
}
template<typename T>
void ExportTensor(py::module& m, const char* name) {
py::class_<T, Tensor, std::shared_ptr<T>> tensor_api(m, name);
tensor_api
.def(py::init(&TensorExportUtil<T>::MakeTensor))
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor")
.def(py::init(&MakeLocalTensor))
.def(py::init(&MakeConsistentTensor))
// Properties of pytorch
.def_property_readonly("shape", &T::shape)
.def_property_readonly("dtype", &GetTensorDType<T>)
.def_property_readonly("is_cuda", &T::is_cuda)
.def_property_readonly("grad", [](const T& t) { return t.api_acc_grad().GetPtrOrThrow(); })
.def_property_readonly("shape", &Tensor::shape)
.def_property_readonly("dtype", &GetTensorDType)
.def_property_readonly("is_cuda", &Tensor::is_cuda)
.def_property_readonly("grad",
[](const Tensor& t) -> std::shared_ptr<Tensor> {
if (t.has_autograd_meta()) {
return t.acc_grad().GetPtrOrThrow();
} else {
return std::shared_ptr<Tensor>();
}
})
// setter of grad
.def("set_grad",
[](T& t, const std::shared_ptr<T>& grad) {
[](Tensor& t, const std::shared_ptr<Tensor>& grad) {
if (t.is_leaf()) {
t.set_acc_grad(grad);
t.set_acc_grad(grad).GetOrThrow();
} else {
throw std::runtime_error("You can only change gradient of leaf tensors.");
}
})
.def_property_readonly("grad_fn", &T::grad_fn_node)
.def_property_readonly("is_leaf", &T::is_leaf)
.def_property_readonly("grad_fn", &Tensor::grad_fn_node)
.def_property_readonly("is_leaf", &Tensor::is_leaf)
.def_property(
"requires_grad", &T::requires_grad,
[](T& t, bool requires_grad) {
"requires_grad", &Tensor::requires_grad,
[](Tensor& t, bool requires_grad) {
if (t.is_leaf()) {
t.set_requires_grad(requires_grad);
} else {
......@@ -258,23 +241,32 @@ void ExportTensor(py::module& m, const char* name) {
})
// Methods of pytorch
.def("retain_grad",
[](T& t) {
[](Tensor& t) {
if (!t.is_leaf()) { t.set_retain_grad(true).GetOrThrow(); }
})
.def("detach", [](const T& t) { return t.api_detach().GetPtrOrThrow(); })
.def("clone", [](const T& t) { return t.api_clone().GetPtrOrThrow(); })
.def("detach", [](const Tensor& t) { return t.detach().GetPtrOrThrow(); })
.def("clone", [](const Tensor& t) { return t.clone().GetPtrOrThrow(); })
// OneFlow tensor properties other than pytorch tensor
.def_property_readonly("is_lazy", &T::is_lazy)
.def_property_readonly("is_consistent", &T::is_consistent);
SpecializedDef(&tensor_api);
}
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Tensor, std::shared_ptr<Tensor>>(m, "Tensor");
ExportTensor<MirroredTensor>(m, "LocalTensor");
ExportTensor<ConsistentTensor>(m, "ConsistentTensor");
.def_property_readonly("is_lazy", &Tensor::is_lazy)
.def_property_readonly("is_eager", &Tensor::is_eager)
.def_property_readonly("is_consistent", &Tensor::is_consistent)
.def_property_readonly("is_local", &Tensor::is_local)
.def("zeros_", &ApiEagerMirroredTensorZeros)
.def("_register_hook", &ApiRegisterTensorHook)
// local tensor only
.def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes)
.def_property_readonly("device", &TensorGetDevice)
.def_property_readonly("data", &Tensor::data)
#define DEFINE_TENSOR_METHOD(T, type_proto) \
.def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>) \
.def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>)
OF_PP_FOR_EACH_TUPLE(DEFINE_TENSOR_METHOD, POD_DATA_TYPE_SEQ)
#undef DEFINE_TENSOR_METHOD
.def("_get_copy_mirrored_tensor_to_numpy_func_name", &ApiGetCopyMirroredTensorToNumpyFuncName)
.def("_get_copy_mirrored_tensor_from_numpy_func_name",
&ApiGetCopyMirroredTensorFromNumpyFuncName)
// consistent tensor only
.def_property_readonly("placement", &TensorGetParallelDesc);
}
} // namespace one
......
......@@ -67,13 +67,14 @@ int64_t MirroredTensor::dim(int64_t index) const { return shape()->At(index); }
int64_t MirroredTensor::nelement() const { return shape()->elem_cnt(); }
std::shared_ptr<MirroredTensor> MirroredTensor::data() const {
std::shared_ptr<Tensor> MirroredTensor::data() const {
std::shared_ptr<MirroredTensor> t = std::make_shared<MirroredTensor>(impl_);
return t;
}
Maybe<MirroredTensor> MirroredTensor::api_detach() const {
return std::make_shared<MirroredTensor>(JUST(impl_->detach()));
Maybe<Tensor> MirroredTensor::detach() const {
std::shared_ptr<Tensor> tensor = std::make_shared<MirroredTensor>(JUST(impl_->detach()));
return tensor;
}
Maybe<Tensor> MirroredTensor::clone() const {
......@@ -117,13 +118,13 @@ int64_t ConsistentTensor::nelement() const { return shape()->elem_cnt(); }
int64_t ConsistentTensor::ndim() const { return shape()->NumAxes(); }
std::shared_ptr<ConsistentTensor> ConsistentTensor::data() const {
std::shared_ptr<Tensor> ConsistentTensor::data() const {
std::shared_ptr<ConsistentTensor> t = std::make_shared<ConsistentTensor>(impl_);
return t;
}
Maybe<ConsistentTensor> ConsistentTensor::api_detach() const {
std::shared_ptr<ConsistentTensor> t = std::make_shared<ConsistentTensor>(impl_);
Maybe<Tensor> ConsistentTensor::detach() const {
std::shared_ptr<Tensor> t = std::make_shared<ConsistentTensor>(impl_);
return t;
}
......
......@@ -48,9 +48,13 @@ class Tensor {
virtual Maybe<Symbol<cfg::ParallelDistribution>> parallel_distribution() const = 0;
virtual Maybe<Symbol<ParallelDesc>> parallel_desc() const = 0;
virtual Maybe<Symbol<Device>> device() const = 0;
virtual Maybe<Symbol<Device>*> mut_device() { OF_UNIMPLEMENTED(); }
virtual Maybe<Symbol<Device>*> mut_device() = 0;
virtual int64_t ndim() const = 0;
virtual bool is_cuda() const = 0;
virtual bool is_consistent() const = 0;
virtual bool is_local() const { return !is_consistent(); }
virtual bool is_lazy() const = 0;
virtual bool is_eager() const { return !is_lazy(); }
virtual const TensorMeta& tensor_meta() const = 0;
virtual Maybe<Symbol<ConsistentTensorMeta>> consistent_tensor_meta() const { OF_UNIMPLEMENTED(); }
......@@ -81,6 +85,7 @@ class Tensor {
virtual Maybe<TensorArg> now_grad_arg() const = 0;
virtual Maybe<Tensor> detach() const = 0;
virtual Maybe<Tensor> clone() const = 0;
virtual std::shared_ptr<Tensor> data() const = 0;
// Setters for autograd
virtual void set_requires_grad(bool requires_grad) = 0;
......@@ -106,8 +111,6 @@ class TensorIf : public Tensor {
virtual ~TensorIf() = default;
// Getters
virtual int64_t ndim() const = 0;
virtual bool is_cuda() const = 0;
virtual int64_t nelement() const = 0;
virtual int64_t dim(int64_t index) const = 0;
......@@ -115,15 +118,6 @@ class TensorIf : public Tensor {
// acc_grad is tensor's accumulated grad in more than once backward operation,
// and now_grad_arg is temporary grad to shared data with different FunctionNode
std::shared_ptr<const FunctionNode> grad_fn_node() const override { return grad_fn_node_; }
// used by pybind11 only
Maybe<DerivedT> api_acc_grad() const {
if (has_autograd_meta()) {
const std::shared_ptr<Tensor>& tensor = JUST(acc_grad());
return cast_for_api(tensor);
} else {
return std::shared_ptr<DerivedT>();
}
}
// Setters for autograd
void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) override {
......@@ -131,29 +125,9 @@ class TensorIf : public Tensor {
}
const std::shared_ptr<FunctionNode>& mut_grad_fn_node() override { return grad_fn_node_; }
Maybe<Tensor> detach() const override {
return std::static_pointer_cast<Tensor>(JUST(api_detach()));
}
// Operators for tensor
// used by pybind11 only
virtual Maybe<DerivedT> api_detach() const = 0;
Maybe<DerivedT> api_clone() const {
const std::shared_ptr<Tensor>& tensor = JUST(clone());
return cast_for_api(tensor);
}
protected:
TensorIf() = default;
std::shared_ptr<FunctionNode> grad_fn_node_;
private:
Maybe<DerivedT> cast_for_api(const std::shared_ptr<Tensor>& tensor) const {
if (!tensor) { return std::shared_ptr<DerivedT>(); }
const auto& ptr = std::dynamic_pointer_cast<DerivedT>(tensor);
CHECK_OR_RETURN(ptr) << Error::ValueError("Tensor Cast Error");
return ptr;
}
};
class MirroredTensor final : public TensorIf<MirroredTensor>,
......@@ -179,7 +153,7 @@ class MirroredTensor final : public TensorIf<MirroredTensor>,
bool is_cuda() const override;
int64_t dim(int64_t index) const override;
int64_t nelement() const override;
std::shared_ptr<MirroredTensor> data() const;
std::shared_ptr<Tensor> data() const override;
const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); }
// Getters valid only for EagerMirroredTensor
......@@ -216,7 +190,7 @@ class MirroredTensor final : public TensorIf<MirroredTensor>,
}
// Operators for tensor
Maybe<MirroredTensor> api_detach() const override;
Maybe<Tensor> detach() const override;
Maybe<Tensor> clone() const override;
static Maybe<MirroredTensor> MakeTensor(const std::shared_ptr<const Shape>& shape, DataType dtype,
......@@ -251,6 +225,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
}
Maybe<Symbol<ParallelDesc>> parallel_desc() const override { return impl_->parallel_desc(); }
Maybe<Symbol<Device>> device() const override { OF_UNIMPLEMENTED(); }
Maybe<Symbol<Device>*> mut_device() override { OF_UNIMPLEMENTED(); }
bool is_lazy() const override { return impl_->is_lazy(); }
bool is_consistent() const override { return true; }
Maybe<Symbol<cfg::ParallelDistribution>> consumer_parallel_distribution_constraint()
......@@ -264,7 +239,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
bool is_cuda() const override;
int64_t dim(int64_t index) const override;
int64_t nelement() const override;
std::shared_ptr<ConsistentTensor> data() const;
std::shared_ptr<Tensor> data() const override;
// Getters valid only for EagerMirroredTensor
Maybe<vm::EagerBlobObject> eager_blob_object() const override {
......@@ -308,7 +283,7 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor> {
}
// Operators for tensor
virtual Maybe<ConsistentTensor> api_detach() const override;
Maybe<Tensor> detach() const override;
Maybe<Tensor> clone() const override { return Error::Unimplemented(); }
static Maybe<ConsistentTensor> MakeTensor(const std::shared_ptr<const Shape>& shape,
......
......@@ -39,7 +39,7 @@ def register_local_tensor_method(name=None):
op_name = method.__name__
else:
op_name = name
setattr(oneflow._oneflow_internal.LocalTensor, op_name, method)
setattr(oneflow._oneflow_internal.Tensor, op_name, method)
return method
return decorator
......@@ -868,7 +868,7 @@ def _default_initializer_for_determining(tensor):
else:
shape = undetermined_tensor.shape
dtype = undetermined_tensor.dtype
determined_tensor = oneflow._oneflow_internal.LocalTensor(
determined_tensor = oneflow._oneflow_internal.Tensor(
shape,
dtype,
undetermined_tensor.device,
......@@ -891,7 +891,7 @@ def _numpy_initializer_for_determining(tensor):
if undetermined_tensor.is_consistent:
raise NotImplementedError()
else:
determined_tensor = oneflow._oneflow_internal.LocalTensor(
determined_tensor = oneflow._oneflow_internal.Tensor(
undetermined_tensor.shape,
undetermined_tensor.dtype,
undetermined_tensor.device,
......@@ -913,13 +913,7 @@ def _input_args_is_numpy(*args):
def _input_args_is_consistent_or_local(*args):
return len(args) == 1 and isinstance(
args[0],
(
oneflow._oneflow_internal.ConsistentTensor,
oneflow._oneflow_internal.LocalTensor,
),
)
return len(args) == 1 and isinstance(args[0], oneflow._oneflow_internal.Tensor)
def _input_args_is_tensor(*args):
......@@ -937,7 +931,7 @@ def _input_args_is_shape(*args):
def register_tensor_op(op_name):
def set_tensor_op(method):
setattr(Tensor, op_name, method)
setattr(oneflow._oneflow_internal.LocalTensor, op_name, method)
setattr(oneflow._oneflow_internal.Tensor, op_name, method)
return method
return set_tensor_op
......
......@@ -17,17 +17,17 @@ limitations under the License.
import collections
from typing import Union, Sequence, Tuple, Optional
from oneflow.python.framework.tensor import Tensor
from oneflow._oneflow_internal import TensorTuple, LocalTensor
from oneflow.python.framework.tensor import Tensor as PyTensor
from oneflow._oneflow_internal import TensorTuple, Tensor
def convert_to_tensor_tuple(
args: Optional[Union[Tensor, Sequence[Tensor], LocalTensor, Sequence[LocalTensor]]]
args: Optional[Union[PyTensor, Sequence[PyTensor], Tensor, Sequence[Tensor]]]
):
if args is None:
return TensorTuple()
elif isinstance(args, collections.abc.Sequence):
if isinstance(args[0], Tensor):
if isinstance(args[0], PyTensor):
for tensor in args:
if not tensor.is_determined:
tensor.determine()
......@@ -35,7 +35,7 @@ def convert_to_tensor_tuple(
return TensorTuple(args)
else:
tensor_tuple = TensorTuple()
if isinstance(args, Tensor):
if isinstance(args, PyTensor):
if not args.is_determined:
args.determine()
tensor_tuple.append(args._local_or_consistent_tensor)
......
......@@ -25,7 +25,7 @@ class Eq(Module):
def forward(self, input, other):
if isinstance(other, flow.Tensor) or isinstance(
other, flow._oneflow_internal.LocalTensor
other, flow._oneflow_internal.Tensor
):
for i in range(len(input.size())):
assert (
......
......@@ -25,7 +25,7 @@ class Ne(Module):
def forward(self, input, other):
if isinstance(other, flow.Tensor) or isinstance(
other, flow._oneflow_internal.LocalTensor
other, flow._oneflow_internal.Tensor
):
for i in range(len(input.size())):
assert (
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment