Skip to content
Snippets Groups Projects
Unverified Commit 18e5c9e2 authored by Houjiang Chen's avatar Houjiang Chen Committed by GitHub
Browse files

Optimize tensor getitem. (#5433)


* Optimize tensor getitem.

* Fix merge

* Refine code style

* Fix typo and copy input if it's getitem is identity.

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent a1b6cfc5
No related branches found
No related tags found
No related merge requests found
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/python/functional/common.h"
namespace oneflow {
namespace one {
namespace functional {
namespace detail {
Maybe<void> PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop, Py_ssize_t* step) {
PySliceObject* obj = (PySliceObject*)object;
if (obj->step == Py_None) {
*step = 1;
} else {
CHECK_OR_RETURN(_PyEval_SliceIndex(obj->step, step))
<< "Invalid slice " << PyStringAsString(PyObject_Repr(object));
CHECK_NE_OR_RETURN(*step, 0) << "slice step cannot be zero.";
if (*step < -PY_SSIZE_T_MAX) *step = -PY_SSIZE_T_MAX;
}
if (obj->start == Py_None) {
*start = *step < 0 ? PY_SSIZE_T_MAX : 0;
} else {
CHECK_OR_RETURN(_PyEval_SliceIndex(obj->start, start))
<< "Invalid slice " << PyStringAsString(PyObject_Repr(object));
}
if (obj->stop == Py_None) {
*stop = *step < 0 ? PY_SSIZE_T_MIN : PY_SSIZE_T_MAX;
} else {
CHECK_OR_RETURN(_PyEval_SliceIndex(obj->stop, stop))
<< "Invalid slice " << PyStringAsString(PyObject_Repr(object));
}
return Maybe<void>::Ok();
}
const char* PyStringAsString(PyObject* object) {
return PyBytes_AsString(PyUnicode_AsEncodedString(object, "utf-8", "~E~"));
}
Maybe<detail::IndexItem> UnpackIndexItem(PyObject* object) {
if (object == Py_Ellipsis) {
return std::make_shared<detail::IndexItem>(detail::EllipsisIndex{});
} else if (PySlice_Check(object)) {
Py_ssize_t start, end, step;
JUST(PySliceUnpack(object, &start, &end, &step));
return std::make_shared<detail::IndexItem>(start, end, step);
} else if (PyLong_Check(object) && object != Py_False && object != Py_True) {
return std::make_shared<detail::IndexItem>(static_cast<int64_t>(PyLong_AsLongLong(object)));
} else if (object == Py_False || object == Py_True) {
return std::make_shared<detail::IndexItem>(object == Py_True);
} else if (object == Py_None) {
return std::make_shared<detail::IndexItem>(detail::NoneIndex{});
}
UNIMPLEMENTED_THEN_RETURN() << "Invalid index " << PyStringAsString(PyObject_Repr(object));
}
} // namespace detail
} // namespace functional
} // namespace one
} // namespace oneflow
......@@ -20,11 +20,13 @@ limitations under the License.
#include <vector>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/framework/throw.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/functional/tensor_index.h"
namespace py = pybind11;
......@@ -130,6 +132,11 @@ template<typename T>
return values;
}
Maybe<void> PySliceUnpack(PyObject* object, Py_ssize_t* start, Py_ssize_t* stop, Py_ssize_t* step);
const char* PyStringAsString(PyObject* object);
Maybe<detail::IndexItem> UnpackIndexItem(PyObject* object);
} // namespace detail
} // namespace functional
......
......@@ -25,6 +25,7 @@ limitations under the License.
#include "oneflow/core/framework/user_op_attr.cfg.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/functional/scalar.h"
#include "oneflow/core/functional/tensor_index.h"
namespace py = pybind11;
......@@ -167,6 +168,38 @@ Maybe<one::Generator> PythonArg::ObjectAs<one::Generator>() const {
return *JUST(detail::cast<std::shared_ptr<one::Generator>>(Borrow()));
}
template<>
Maybe<TensorIndex> PythonArg::ObjectAs<TensorIndex>() const {
auto tensor_index = std::make_shared<TensorIndex>();
if (object_ == Py_Ellipsis) {
detail::IndexItem index(detail::EllipsisIndex{});
tensor_index->emplace_back(index);
} else if (PySlice_Check(object_)) {
Py_ssize_t start, end, step;
JUST(detail::PySliceUnpack(object_, &start, &end, &step));
detail::IndexItem index(start, end, step);
tensor_index->emplace_back(index);
} else if (PyLong_Check(object_) && object_ != Py_False && object_ != Py_True) {
detail::IndexItem index(static_cast<int64_t>(PyLong_AsLongLong(object_)));
tensor_index->emplace_back(index);
} else if (object_ == Py_False || object_ == Py_True) {
detail::IndexItem index(object_ == Py_True);
tensor_index->emplace_back(index);
} else if (object_ == Py_None) {
detail::IndexItem index(detail::NoneIndex{});
tensor_index->emplace_back(index);
} else {
PyObject* tuple = PySequence_Tuple(object_);
size_t size = PyTuple_GET_SIZE(tuple);
tensor_index->resize(size);
for (size_t i = 0; i < size; ++i) {
PyObject* obj = PyTuple_GET_ITEM(tuple, i);
tensor_index->at(i) = *JUST(detail::UnpackIndexItem(obj));
}
}
return tensor_index;
}
} // namespace functional
} // namespace one
} // namespace oneflow
......@@ -16,7 +16,7 @@
# {
# "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool",
# "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList",
# "BoolList", "DataType", "Shape"
# "BoolList", "DataType", "Shape", "Generator", "TensorIndex"
# }
- name: "add_n"
......@@ -593,3 +593,7 @@
- name: "pad_grad"
signature: "Tensor PadGrad(Tensor dy, *, Int64List pad, String mode=\"constant\", Scalar value=0)"
bind_python: False
- name: "tensor_getitem"
signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)"
bind_python: True
......@@ -20,6 +20,7 @@ limitations under the License.
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/functional/impl/unary_functor.h"
......@@ -457,6 +458,66 @@ class TriuFunctor {
std::shared_ptr<OpExpr> op_;
};
class TensorGetItemFunctor {
public:
TensorGetItemFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index) const {
const auto& regular_index = JUST(RegularTensorIndex(index, *(x->shape())));
int64_t ndims = x->shape()->NumAxes();
CHECK_GE_OR_RETURN(regular_index->size(), ndims) << "Tensor index failed to be regularlized.";
std::vector<int64_t> start(ndims), end(ndims), step(ndims);
int dim = 0;
DimVector result_dims;
for (int i = 0; i < regular_index->size(); ++i) {
const auto& index_item = regular_index->at(i);
CHECK_OR_RETURN(!index_item.IsEllipsis())
<< "Tensor index should not have ellipsis once regularlized.";
if (index_item.IsSlice()) {
CHECK_LT_OR_RETURN(dim, ndims);
start[dim] = index_item.slice().start();
end[dim] = index_item.slice().end();
step[dim] = index_item.slice().step();
int64_t length = (end[dim] - start[dim] + step[dim] - 1) / step[dim];
result_dims.emplace_back(length);
dim++;
} else if (index_item.IsInteger()) {
CHECK_LT_OR_RETURN(dim, ndims);
start[dim] = index_item.integer();
end[dim] = start[dim] + 1;
step[dim] = 1;
dim++;
} else if (index_item.IsNone()) {
result_dims.emplace_back(1);
} else if (index_item.IsBoolean()) {
CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported.";
result_dims.emplace_back(1);
}
}
CHECK_EQ_OR_RETURN(dim, ndims)
<< "Specified dims count for regularlized tensor index should equal to tensor dimension "
<< ndims;
bool is_identity = [&]() {
for (int i = 0; i < ndims; ++i) {
if (start[i] != 0 || end[i] != x->shape()->At(i) || step[i] != 1) { return false; }
}
return true;
}();
std::shared_ptr<one::Tensor> result;
if (is_identity) {
result = JUST(functional::Copy(x, JUST(x->device())->type(), JUST(x->device())->device_id()));
} else {
result = JUST(functional::Slice(x, start, end, step));
}
Shape shape(result_dims);
if (shape.NumAxes() != 0 && shape != *(result->shape())) {
return functional::Reshape(result, shape);
}
return result;
}
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) {
......@@ -484,6 +545,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::UpsampleFunctor>("Upsample");
m.add_functor<impl::UnsortedSegmentSumLikeFunctor>("UnsortedSegmentSumLike");
m.add_functor<impl::TriuFunctor>("Triu");
m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
};
} // namespace functional
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/functional/tensor_index.h"
namespace oneflow {
namespace one {
namespace functional {
int64_t CountSpecifiedDims(const TensorIndex& index) {
int64_t specified_ndims = 0;
for (int i = 0; i < index.size(); ++i) {
const auto& index_item = index.at(i);
if (index_item.IsSlice() || index_item.IsInteger()) { specified_ndims++; }
}
return specified_ndims;
}
Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& shape) {
int64_t specified_ndims = CountSpecifiedDims(index);
int64_t ndims = shape.NumAxes();
CHECK_LE_OR_RETURN(specified_ndims, ndims)
<< "Too many indices for tensor of dimension " << ndims;
auto regular_index = std::make_shared<TensorIndex>();
int64_t dim = 0;
for (int i = 0; i < index.size(); ++i) {
const auto& index_item = index.at(i);
if (index_item.IsSlice()) {
CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
CHECK_GT_OR_RETURN(shape.At(dim), 0) << "Slice cannot be applied to a 0-dim tensor.";
const auto& slice = index_item.slice();
int64_t step = std::min(slice.step(), shape.At(dim));
CHECK_GT_OR_RETURN(step, 0) << "Step must be greater than zero.";
int64_t end = std::min(slice.end(), shape.At(dim));
int64_t start = std::min(slice.start(), shape.At(dim));
if (start < 0) { start += shape.At(dim); }
if (start < 0) { start = 0; }
if (end < 0) { end += shape.At(dim); }
if (end < start) { end = start; }
regular_index->emplace_back(detail::IndexItem(start, end, step));
dim++;
} else if (index_item.IsInteger()) {
CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
int64_t integer = index_item.integer();
if (integer < 0) { integer += shape.At(dim); }
CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim))
<< "Index " << index_item.integer() << " is out of bounds for dimension " << dim
<< " with size " << shape.At(dim);
regular_index->emplace_back(detail::IndexItem(integer));
dim++;
} else if (index_item.IsEllipsis()) {
int64_t unspecified_ndims = ndims - specified_ndims;
unspecified_ndims = std::min(ndims - dim, unspecified_ndims);
for (int j = 0; j < unspecified_ndims; ++j) {
regular_index->emplace_back(detail::IndexItem(0, shape.At(dim + j), 1));
}
dim += unspecified_ndims;
} else {
// None or Boolean.
if (index_item.IsBoolean()) {
CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported.";
}
regular_index->emplace_back(index_item);
}
}
for (int i = dim; i < ndims; ++i) {
regular_index->emplace_back(detail::IndexItem(0, shape.At(i), 1));
}
return regular_index;
}
} // namespace functional
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_FUNCTIONAL_TENSOR_INDEX_H_
#define ONEFLOW_CORE_FUNCTIONAL_TENSOR_INDEX_H_
#include <cstdint>
#include <limits>
#include <vector>
#include "oneflow/core/common/shape.h"
namespace oneflow {
namespace one {
namespace functional {
namespace detail {
struct NoneIndex {};
struct EllipsisIndex {};
class Slice {
public:
Slice() : Slice(0, std::numeric_limits<int64_t>::max(), 1) {}
explicit Slice(int64_t start) : Slice(start, std::numeric_limits<int64_t>::max(), 1) {}
explicit Slice(int64_t start, int64_t end) : Slice(start, end, 1) {}
explicit Slice(int64_t start, int64_t end, int64_t step)
: start_(start), end_(end), step_(step) {}
int64_t start() const { return start_; }
int64_t end() const { return end_; }
int64_t step() const { return step_; }
private:
int64_t start_;
int64_t end_;
int64_t step_;
};
class IndexItem {
public:
IndexItem() : IndexItem(NoneIndex()) {}
explicit IndexItem(NoneIndex none) : item_{.dummy = 0}, tag_(HAS_NONE) {}
explicit IndexItem(int64_t start, int64_t end, int64_t step)
: item_{.slice = Slice{start, end, step}}, tag_(HAS_SLICE) {}
explicit IndexItem(const Slice& slice) : item_{.slice = slice}, tag_(HAS_SLICE) {}
explicit IndexItem(int64_t index) : item_{.i = index}, tag_(HAS_INT) {}
explicit IndexItem(bool boolean) : item_{.b = boolean}, tag_(HAS_BOOLEAN) {}
explicit IndexItem(EllipsisIndex ellipsis) : item_{.dummy = 0}, tag_(HAS_ELLIPSIS) {}
bool IsSlice() const { return tag_ == HAS_SLICE; }
const Slice& slice() const { return item_.slice; }
bool IsInteger() const { return tag_ == HAS_INT; }
int64_t integer() const { return item_.i; }
bool IsBoolean() const { return tag_ == HAS_BOOLEAN; }
bool boolean() const { return item_.b; }
bool IsEllipsis() const { return tag_ == HAS_ELLIPSIS; }
bool IsNone() const { return tag_ == HAS_NONE; }
private:
union {
Slice slice;
bool b;
int64_t i;
char dummy;
} item_;
enum { HAS_SLICE, HAS_BOOLEAN, HAS_INT, HAS_ELLIPSIS, HAS_NONE } tag_;
};
} // namespace detail
class TensorIndex : public std::vector<detail::IndexItem> {
public:
using std::vector<detail::IndexItem>::vector;
};
int64_t CountSpecifiedDims(const TensorIndex& index);
Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& shape);
} // namespace functional
} // namespace one
} // namespace oneflow
#endif // ONEFLOW_CORE_FUNCTIONAL_TENSOR_INDEX_H_
......@@ -38,6 +38,7 @@ class Generator;
namespace functional {
class Scalar;
class TensorIndex;
} // namespace functional
} // namespace one
......@@ -80,6 +81,7 @@ enum ValueType {
kGENERATOR,
kGENERATOR_REF,
kGENERATOR_MAYBE,
kTENSOR_INDEX,
};
#define VALUE_TYPE_OF_IMPL(cpp_type, value_type) \
......@@ -129,6 +131,7 @@ VALUE_TYPE_OF_IMPL(Shape, kSHAPE);
VALUE_TYPE_OF_IMPL(one::Generator, kGENERATOR);
VALUE_TYPE_OF_IMPL(std::shared_ptr<one::Generator>, kGENERATOR_REF);
VALUE_TYPE_OF_IMPL(Maybe<one::Generator>, kGENERATOR_MAYBE);
VALUE_TYPE_OF_IMPL(TensorIndex, kTENSOR_INDEX);
#undef VALUE_TYPE_OF_IMPL
......
......@@ -441,30 +441,7 @@ class Tensor:
@_auto_determine
@register_local_tensor_method()
def __getitem__(self, key):
# TODO: support inplace __getitem__
assert (
isinstance(key, int) or isinstance(key, tuple) or isinstance(key, slice)
), "Unsupported key type!"
squeeze_dims = None
if isinstance(key, tuple):
key = self._transform_ellipsis_type(key)
squeeze_dims = list(
filter(lambda idx: isinstance(key[idx], int), range(len(key)))
)
elif isinstance(key, int):
if key < 0:
key = self.shape[0] + key
squeeze_dims = [0]
else:
# do nothing
pass
start, stop, step, _ = self._get_slice_obj(key)
res = flow.experimental.slice(self, list(zip(start, stop, step)))
if squeeze_dims is not None:
return res.squeeze(dim=squeeze_dims)
return res
return flow.F.tensor_getitem(self, key)
@_auto_determine
@register_local_tensor_method()
......
......@@ -66,6 +66,7 @@ header_fmt = (
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/functional/scalar.h"
#include "oneflow/core/functional/tensor_index.h"
namespace oneflow {{
namespace one {{
......@@ -149,6 +150,7 @@ types_allowed = {
"DataType",
"Shape",
"Generator",
"TensorIndex",
}
generic_type_aliases = {
......@@ -176,6 +178,7 @@ argument_type_aliases = {
"DataType": "const DataType&",
"Shape": "const Shape&",
"Generator": "const std::shared_ptr<one::Generator>&",
"TensorIndex": "const TensorIndex&",
**generic_type_aliases,
}
......@@ -195,6 +198,7 @@ optional_argument_type_aliases = {
"DataType": "const Optional<DataType>&",
"Shape": "const Optional<Shape>&",
"Generator": "const Optional<one::Generator>&",
"TensorIndex": "const Optional<TensorIndex>&",
**{k: "const Optional<{0}>".format(v) for k, v in generic_type_aliases.items()},
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment