Skip to content
Snippets Groups Projects
Unverified Commit 61a951e3 authored by Houjiang Chen's avatar Houjiang Chen Committed by GitHub
Browse files

Dev optimize tensor setitem (#5501)


* Refactor expand and tensor setitem functional apis.

* Raise IndexInception in tensor getitem.

* Bugfix

* Bugfix

* Fix conflicit with xrt

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent ec0d02c1
No related branches found
No related tags found
No related merge requests found
......@@ -36,7 +36,7 @@ class Throw final {
#define CHECK_OR_THROW(expr) \
if (!(expr)) \
Throw(oneflow::Error::CheckFailedError().AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__)).error() \
<< " Check failed: " << OF_PP_STRINGIZE(expr) << "\t"
<< " Check failed: " << OF_PP_STRINGIZE(expr) << ": "
#define CHECK_EQ_OR_THROW(lhs, rhs) \
CHECK_OR_THROW((lhs) == (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
......
......@@ -70,6 +70,12 @@ Error Error::ValueError(const std::string& error_summary) {
return error;
}
Error Error::IndexError() {
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_index_error();
return error;
}
Error Error::JobNameExistError() {
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_name_exist_error();
......
......@@ -43,6 +43,7 @@ class Error final {
static Error JobSetEmptyError();
static Error DeviceTagNotFoundError();
static Error ValueError(const std::string& error_summary);
static Error IndexError();
static Error JobNameExistError();
static Error JobNameEmptyError();
static Error JobNameNotEqualError();
......
......@@ -125,6 +125,8 @@ message SymbolIdUninitializedError {}
message ValueError {}
message IndexError {}
message ErrorProto {
optional string error_summary = 1 [default = ""];
optional string msg = 2 [default = ""];
......@@ -145,6 +147,7 @@ message ErrorProto {
JobSetEmptyError job_set_empty_error = 25;
DeviceTagNotFoundError device_tag_not_found_error = 26;
ValueError value_error = 27;
IndexError index_error = 28;
JobNameExistError job_name_exist_error = 100;
JobNameEmptyError job_name_empty_error = 101;
JobNameNotEqualError job_name_not_equal_error = 102;
......
......@@ -69,6 +69,7 @@ class Exception : public std::exception {
OF_PP_MAKE_TUPLE_SEQ(Unknown) \
OF_PP_MAKE_TUPLE_SEQ(CompileOptionWrong) \
OF_PP_MAKE_TUPLE_SEQ(Value) \
OF_PP_MAKE_TUPLE_SEQ(Index) \
OF_PP_MAKE_TUPLE_SEQ(InputDeviceNotMatch)
#define DEFINE_EXCEPTION_CLASS(cls) \
......
......@@ -184,4 +184,14 @@ bool Shape::Containing(const Shape& small_shape) const {
return true;
}
Maybe<Shape> Shape::Slice(int64_t start_dim, int64_t end_dim) const {
CHECK_OR_RETURN(start_dim >= 0 && end_dim >= start_dim);
int64_t ndims = this->NumAxes();
if (start_dim > ndims) { start_dim = ndims; }
if (end_dim > ndims) { end_dim = ndims; }
DimVector dim_vec;
for (int64_t i = start_dim; i < end_dim && i < ndims; ++i) { dim_vec.push_back(this->At(i)); }
return std::make_shared<Shape>(dim_vec);
}
} // namespace oneflow
......@@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/common/shape.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/shape_vec.h"
namespace oneflow {
......@@ -68,6 +69,8 @@ class Shape final {
bool Containing(const Shape& small_shape) const;
Maybe<Shape> Slice(int64_t start_dim, int64_t end_dim) const;
private:
void UpdateElemCnt();
......
......@@ -313,7 +313,7 @@
bind_python: False
- name: "expand"
signature: "Tensor Expand(Tensor x, *, Int32List in_shape, Int32List out_shape, Int32List stride)"
signature: "Tensor Expand(Tensor x, *, Shape shape)"
bind_python: True
- name: "expand_dims"
......@@ -694,3 +694,6 @@
signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)"
bind_python: True
- name: "tensor_setitem"
signature: "Void TensorSetItem(Tensor x, *, TensorIndex index, Tensor value)"
bind_python: True
......@@ -167,10 +167,42 @@ class ConcatFunctor {
class ExpandFunctor {
public:
ExpandFunctor() { op_ = CHECK_JUST(one::OpBuilder("expand").Input("in").Output("out").Build()); }
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::vector<int32_t>& in_shape,
const std::vector<int32_t>& out_shape,
const std::vector<int32_t>& stride) const {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Shape& shape) const {
CHECK_GE_OR_RETURN(shape.NumAxes(), x->shape()->NumAxes())
<< "The desired expanded dims should not be less than the input dims.";
std::vector<int32_t> in_shape(x->shape()->NumAxes());
for (int i = 0; i < in_shape.size(); ++i) { in_shape[i] = x->shape()->At(i); }
// calculate the original stride.
std::vector<int32_t> original_stride(in_shape.size(), 1);
for (int i = x->shape()->NumAxes() - 2; i >= 0; --i) {
original_stride[i] = in_shape.at(i + 1) * original_stride.at(i + 1);
}
std::vector<int32_t> out_shape(shape.NumAxes());
std::vector<int32_t> stride(shape.NumAxes());
int shift = out_shape.size() - in_shape.size();
for (int i = out_shape.size() - 1; i >= 0; --i) {
int index = i - shift;
if (index >= 0) {
if (shape.At(i) == -1 || shape.At(i) == in_shape.at(index)) {
out_shape[i] = in_shape.at(index);
stride[i] = original_stride.at(index);
} else {
CHECK_OR_RETURN(shape.At(i) > 0 && in_shape.at(index) == 1)
<< "Invalid expand shape " << shape.ToString();
out_shape[i] = shape.At(i);
stride[i] = 0;
}
} else {
CHECK_GT_OR_RETURN(shape.At(i), 0) << "Invalid expand shape " << shape.ToString();
out_shape[i] = shape.At(i);
if (shape.At(i) == 1 && i < out_shape.size() - 1) {
stride[i] = stride.at(i + 1);
} else {
stride[i] = 0;
}
}
}
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("in_shape", in_shape));
JUST(attrs.SetAttr<std::vector<int32_t>>("out_shape", out_shape));
......@@ -773,41 +805,24 @@ class TensorGetItemFunctor {
public:
TensorGetItemFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index) const {
const auto& regular_index = JUST(RegularTensorIndex(index, *(x->shape())));
int64_t ndims = x->shape()->NumAxes();
CHECK_GE_OR_RETURN(regular_index->size(), ndims) << "Tensor index failed to be regularlized.";
std::vector<detail::Slice> slice_indices;
std::vector<std::shared_ptr<one::Tensor>> tensor_indices;
std::vector<int64_t> target_dims;
JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &target_dims));
CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << "Failed to prepare slice indices.";
Shape target_shape(DimVector(target_dims.begin(), target_dims.end()));
CHECK_GT_OR_RETURN(target_shape.Count(0), 0)
<< "Target shape is zero shape which was not supported yet.";
std::vector<int64_t> start(ndims), end(ndims), step(ndims);
int dim = 0;
DimVector result_dims;
for (int i = 0; i < regular_index->size(); ++i) {
const auto& index_item = regular_index->at(i);
CHECK_OR_RETURN(!index_item.IsEllipsis())
<< "Tensor index should not have ellipsis once regularlized.";
if (index_item.IsSlice()) {
CHECK_LT_OR_RETURN(dim, ndims);
start[dim] = index_item.slice().start();
end[dim] = index_item.slice().end();
step[dim] = index_item.slice().step();
int64_t length = (end[dim] - start[dim] + step[dim] - 1) / step[dim];
result_dims.emplace_back(length);
dim++;
} else if (index_item.IsInteger()) {
CHECK_LT_OR_RETURN(dim, ndims);
start[dim] = index_item.integer();
end[dim] = start[dim] + 1;
step[dim] = 1;
dim++;
} else if (index_item.IsNone()) {
result_dims.emplace_back(1);
} else if (index_item.IsBoolean()) {
CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported.";
result_dims.emplace_back(1);
}
for (int i = 0; i < ndims; ++i) {
const auto& slice = slice_indices.at(i);
start[i] = slice.start();
end[i] = slice.end();
step[i] = slice.step();
}
CHECK_EQ_OR_RETURN(dim, ndims)
<< "Specified dims count for regularlized tensor index should equal to tensor dimension "
<< ndims;
bool is_identity = [&]() {
for (int i = 0; i < ndims; ++i) {
if (start[i] != 0 || end[i] != x->shape()->At(i) || step[i] != 1) { return false; }
......@@ -821,14 +836,69 @@ class TensorGetItemFunctor {
result = JUST(functional::Slice(x, start, end, step));
}
Shape shape(result_dims);
Shape shape(DimVector(target_dims.begin(), target_dims.end()));
if (shape.NumAxes() != 0 && shape != *(result->shape())) {
return functional::Reshape(result, shape);
result = JUST(functional::Reshape(result, shape));
}
return result;
}
};
class TensorSetItemFunctor {
public:
TensorSetItemFunctor() {}
Maybe<void> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index,
const std::shared_ptr<one::Tensor>& value) const {
int64_t ndims = x->shape()->NumAxes();
std::vector<detail::Slice> slice_indices;
std::vector<std::shared_ptr<one::Tensor>> tensor_indices;
std::vector<int64_t> target_dims;
JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &target_dims));
CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << "Failed to prepare slice indices.";
Shape target_shape(DimVector(target_dims.begin(), target_dims.end()));
if (target_shape.Count(0) == 0) { return Maybe<void>::Ok(); }
const auto& value_shape = value->shape();
bool matched = [&]() {
for (int i = 0; i < value_shape->NumAxes() - target_shape.NumAxes(); ++i) {
if (value_shape->At(i) != 1) { return false; }
}
return true;
}();
CHECK_OR_RETURN(matched) << "The tensor size mismatch. Target sizes: "
<< target_shape.ToString()
<< ", value sizes: " << value_shape->ToString();
std::shared_ptr<one::Tensor> value_tensor(value);
if (target_shape.NumAxes() != 0 && // NOLINT
/*need_expand=*/value_shape->Count(0) != target_shape.Count(0)) {
// Remove the beginning redundant 1-dimensions.
if (value_shape->NumAxes() > target_shape.NumAxes()) {
int64_t start_axis = value_shape->NumAxes() - target_shape.NumAxes();
const auto& shape = JUST(value_shape->Slice(start_axis, value_shape->NumAxes()));
value_tensor = JUST(functional::Reshape(value, *shape));
}
value_tensor = JUST(functional::Expand(value_tensor, target_shape));
}
std::vector<int64_t> start(ndims), end(ndims), step(ndims);
DimVector slice_dims(ndims);
for (int i = 0; i < ndims; ++i) {
const auto& slice = slice_indices.at(i);
start[i] = slice.start();
end[i] = slice.end();
step[i] = slice.step();
slice_dims[i] = (end[i] - start[i] + step[i] - 1) / step[i];
}
Shape slice_shape(slice_dims);
if (slice_shape != *(value_tensor->shape())) {
value_tensor = JUST(functional::Reshape(value_tensor, slice_shape));
}
JUST(LogicalSliceAssign(x, value_tensor, start, end, step));
return Maybe<void>::Ok();
}
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) {
......@@ -873,6 +943,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::DiagFunctor>("Diag");
m.add_functor<impl::DiagGradFunctor>("DiagGrad");
m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem");
};
} // namespace functional
......
......@@ -28,14 +28,15 @@ int64_t CountSpecifiedDims(const TensorIndex& index) {
return specified_ndims;
}
Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& shape) {
int64_t specified_ndims = CountSpecifiedDims(index);
Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
std::vector<detail::Slice>* slice_indices,
std::vector<std::shared_ptr<Tensor>>* tensor_indices,
std::vector<int64_t>* target_dims) {
int64_t ndims = shape.NumAxes();
int64_t specified_ndims = CountSpecifiedDims(index);
CHECK_LE_OR_RETURN(specified_ndims, ndims)
<< "Too many indices for tensor of dimension " << ndims;
auto regular_index = std::make_shared<TensorIndex>();
int64_t dim = 0;
int dim = 0;
for (int i = 0; i < index.size(); ++i) {
const auto& index_item = index.at(i);
if (index_item.IsSlice()) {
......@@ -50,37 +51,40 @@ Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& sha
if (start < 0) { start = 0; }
if (end < 0) { end += shape.At(dim); }
if (end < start) { end = start; }
regular_index->emplace_back(detail::IndexItem(start, end, step));
slice_indices->emplace_back(start, end, step);
int64_t length = (end - start + step - 1) / step;
target_dims->emplace_back(length);
dim++;
} else if (index_item.IsInteger()) {
CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
int64_t integer = index_item.integer();
if (integer < 0) { integer += shape.At(dim); }
CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim)) << Error::ValueError(
std::string("Index ") + std::to_string(index_item.integer())
+ std::string(" is out of bounds for dimension ") + std::to_string(dim)
+ std::string(" with size ") + std::to_string(shape.At(dim)));
regular_index->emplace_back(detail::IndexItem(integer));
if (integer < 0 || integer >= shape.At(dim)) {
return Error::IndexError()
<< "Index " << index_item.integer() << " is out of bounds for dimension " << dim
<< " with size " << shape.At(dim);
}
slice_indices->emplace_back(integer, integer + 1, 1);
dim++;
} else if (index_item.IsEllipsis()) {
int64_t unspecified_ndims = ndims - specified_ndims;
unspecified_ndims = std::min(ndims - dim, unspecified_ndims);
for (int j = 0; j < unspecified_ndims; ++j) {
regular_index->emplace_back(detail::IndexItem(0, shape.At(dim + j), 1));
slice_indices->emplace_back(0, shape.At(dim + j), 1);
target_dims->emplace_back(shape.At(dim + j));
}
dim += unspecified_ndims;
} else {
// None or Boolean.
if (index_item.IsBoolean()) {
CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported.";
}
regular_index->emplace_back(index_item);
} else if (index_item.IsNone()) {
target_dims->emplace_back(1);
} else if (index_item.IsBoolean()) {
target_dims->emplace_back(index_item.boolean());
}
}
for (int i = dim; i < ndims; ++i) {
regular_index->emplace_back(detail::IndexItem(0, shape.At(i), 1));
slice_indices->emplace_back(0, shape.At(i), 1);
target_dims->emplace_back(shape.At(i));
}
return regular_index;
return Maybe<void>::Ok();
}
} // namespace functional
......
......@@ -25,6 +25,9 @@ limitations under the License.
namespace oneflow {
namespace one {
class Tensor;
namespace functional {
namespace detail {
......@@ -94,7 +97,11 @@ class TensorIndex : public std::vector<detail::IndexItem> {
};
int64_t CountSpecifiedDims(const TensorIndex& index);
Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& shape);
Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape,
std::vector<detail::Slice>* slice_indices,
std::vector<std::shared_ptr<Tensor>>* tensor_indices,
std::vector<int64_t>* target_dims);
} // namespace functional
} // namespace one
......
......@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow.core.job.initializer_conf_pb2 as initializer_conf_util
from oneflow._oneflow_internal.exception import ValueException
from oneflow._oneflow_internal.exception import IndexException
from oneflow.python.oneflow_export import oneflow_export
import oneflow.python.framework.remote_blob as remote_blob_util
......@@ -454,44 +454,17 @@ class Tensor:
def __getitem__(self, key):
try:
return flow.F.tensor_getitem(self, key)
except ValueException as e:
except IndexException as e:
# The stop condition of for in python is IndexError,
# so we have to catch ValueException from C++ and throw IndexError
# so we have to catch IndexException from C++ and throw IndexError
raise IndexError(e)
@_auto_determine
@register_local_tensor_method()
def __setitem__(self, key, value):
if isinstance(key, tuple):
key = self._transform_ellipsis_type(key)
unsqueeze_dims = list(
filter(lambda idx: isinstance(key[idx], int), range(len(key)))
)
elif isinstance(key, int):
if key < 0:
key = self.shape[0] + key
unsqueeze_dims = [0]
else:
unsqueeze_dims = []
start, stop, step, shape = self._get_slice_obj(key)
if isinstance(value, (int, float)):
scalar = value
value = flow.Tensor(*shape)
value.fill_(scalar)
else:
prepended_broadcasting_dims = range(
len(self.shape) - len(unsqueeze_dims) - len(value.shape)
)
for dim in prepended_broadcasting_dims:
value = flow.experimental.unsqueeze(value, dim)
for dim in unsqueeze_dims:
value = flow.experimental.unsqueeze(value, dim)
value = flow.experimental.expand(value, *shape)
flow.experimental.tmp.logical_slice_assign(
self, value, list(zip(start, stop, step))
)
value = flow.F.constant([1], value, self.dtype)
flow.F.tensor_setitem(self, key, value)
return self
@register_local_tensor_method()
......
......@@ -28,39 +28,7 @@ class Expand(Module):
def forward(self, x):
if x.dtype == flow.int8:
x = flow.experimental.cast(x, flow.int32)
expand_size = self.expand_size
assert len(expand_size) >= len(
x.shape
), "The desired expanded dims should not be less than the input dims."
# calculate the original stride
original_stride = [1]
for i in range(len(x.shape) - 2, -1, -1):
original_stride.insert(0, original_stride[0] * x.shape[i + 1])
# calculate the output shape and stride
new_size = []
new_stride = []
diff = len(expand_size) - len(x.shape)
for i in range(len(expand_size) - 1, -1, -1):
if i >= diff:
if expand_size[i] == -1 or expand_size[i] == x.shape[i - diff]:
new_size.insert(0, x.shape[i - diff])
new_stride.insert(0, original_stride[i - diff])
else:
assert expand_size[i] >= 1 and x.shape[i - diff] == 1
new_size.insert(0, expand_size[i])
new_stride.insert(0, 0)
else:
assert expand_size[i] >= 1
new_size.insert(0, expand_size[i])
if expand_size[i] == 1:
new_stride.insert(0, new_stride[0])
else:
new_stride.insert(0, 0)
return flow.F.expand(
x, in_shape=list(x.shape), out_shape=new_size, stride=new_stride
)
return flow.F.expand(x, self.expand_size)
@oneflow_export("expand")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment