diff --git a/oneflow/api/python/framework/throw.h b/oneflow/api/python/framework/throw.h index 7d8a82aa3397b8c363a891d33fa8418ad38dbbbc..a0def93dd3fe92357fea6d7a692712eb5d99df0e 100644 --- a/oneflow/api/python/framework/throw.h +++ b/oneflow/api/python/framework/throw.h @@ -36,7 +36,7 @@ class Throw final { #define CHECK_OR_THROW(expr) \ if (!(expr)) \ Throw(oneflow::Error::CheckFailedError().AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__)).error() \ - << " Check failed: " << OF_PP_STRINGIZE(expr) << "\t" + << " Check failed: " << OF_PP_STRINGIZE(expr) << ": " #define CHECK_EQ_OR_THROW(lhs, rhs) \ CHECK_OR_THROW((lhs) == (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " diff --git a/oneflow/core/common/error.cpp b/oneflow/core/common/error.cpp index a15c225961709269d400198b1cd84a780ef36766..b589b33552dce4162f8ff31ae6de0e3b70e1c219 100644 --- a/oneflow/core/common/error.cpp +++ b/oneflow/core/common/error.cpp @@ -70,6 +70,12 @@ Error Error::ValueError(const std::string& error_summary) { return error; } +Error Error::IndexError() { + auto error = std::make_shared<cfg::ErrorProto>(); + error->mutable_index_error(); + return error; +} + Error Error::JobNameExistError() { auto error = std::make_shared<cfg::ErrorProto>(); error->mutable_job_name_exist_error(); diff --git a/oneflow/core/common/error.h b/oneflow/core/common/error.h index 93e4470c4e8450a59442c637e8f9c38d2acb6f0e..5a878e3f9c634507437b90cfb09c780cede64d34 100644 --- a/oneflow/core/common/error.h +++ b/oneflow/core/common/error.h @@ -43,6 +43,7 @@ class Error final { static Error JobSetEmptyError(); static Error DeviceTagNotFoundError(); static Error ValueError(const std::string& error_summary); + static Error IndexError(); static Error JobNameExistError(); static Error JobNameEmptyError(); static Error JobNameNotEqualError(); diff --git a/oneflow/core/common/error.proto b/oneflow/core/common/error.proto index d090915c9daaa50f37747c17a5b6e9ffc815ebce..81dd7b622576a5d19e5e72f0b29fa0790d3b2574 100644 --- a/oneflow/core/common/error.proto +++ b/oneflow/core/common/error.proto @@ -125,6 +125,8 @@ message SymbolIdUninitializedError {} message ValueError {} +message IndexError {} + message ErrorProto { optional string error_summary = 1 [default = ""]; optional string msg = 2 [default = ""]; @@ -145,6 +147,7 @@ message ErrorProto { JobSetEmptyError job_set_empty_error = 25; DeviceTagNotFoundError device_tag_not_found_error = 26; ValueError value_error = 27; + IndexError index_error = 28; JobNameExistError job_name_exist_error = 100; JobNameEmptyError job_name_empty_error = 101; JobNameNotEqualError job_name_not_equal_error = 102; diff --git a/oneflow/core/common/exception.h b/oneflow/core/common/exception.h index 084d02126c52351f64550970f1f8ef8a31bacab1..e1e1dee1fb7863e11bcaf5903c2e9d66bac1acaa 100644 --- a/oneflow/core/common/exception.h +++ b/oneflow/core/common/exception.h @@ -69,6 +69,7 @@ class Exception : public std::exception { OF_PP_MAKE_TUPLE_SEQ(Unknown) \ OF_PP_MAKE_TUPLE_SEQ(CompileOptionWrong) \ OF_PP_MAKE_TUPLE_SEQ(Value) \ + OF_PP_MAKE_TUPLE_SEQ(Index) \ OF_PP_MAKE_TUPLE_SEQ(InputDeviceNotMatch) #define DEFINE_EXCEPTION_CLASS(cls) \ diff --git a/oneflow/core/common/shape.cpp b/oneflow/core/common/shape.cpp index 949935d361acdab2656b00fd6cd8b8e36c05571d..f3f5cec1947f4d85548ee0ee75b03d956793dc1c 100644 --- a/oneflow/core/common/shape.cpp +++ b/oneflow/core/common/shape.cpp @@ -184,4 +184,14 @@ bool Shape::Containing(const Shape& small_shape) const { return true; } +Maybe<Shape> Shape::Slice(int64_t start_dim, int64_t end_dim) const { + CHECK_OR_RETURN(start_dim >= 0 && end_dim >= start_dim); + int64_t ndims = this->NumAxes(); + if (start_dim > ndims) { start_dim = ndims; } + if (end_dim > ndims) { end_dim = ndims; } + DimVector dim_vec; + for (int64_t i = start_dim; i < end_dim && i < ndims; ++i) { dim_vec.push_back(this->At(i)); } + return std::make_shared<Shape>(dim_vec); +} + } // namespace oneflow diff --git a/oneflow/core/common/shape.h b/oneflow/core/common/shape.h index 513ec410f281982a86bc05cbc25f9bdc47c54f3a..9b1a9f46c18d07bd196845d0ca5afb657bddddc8 100644 --- a/oneflow/core/common/shape.h +++ b/oneflow/core/common/shape.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/shape.pb.h" #include "oneflow/core/common/util.h" +#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/shape_vec.h" namespace oneflow { @@ -68,6 +69,8 @@ class Shape final { bool Containing(const Shape& small_shape) const; + Maybe<Shape> Slice(int64_t start_dim, int64_t end_dim) const; + private: void UpdateElemCnt(); diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 9eaa451c9deaac2822aa75aa1bd4749ebd8770a1..a936c88fd3e54870a6754d1cb2e516a01a0be796 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -313,7 +313,7 @@ bind_python: False - name: "expand" - signature: "Tensor Expand(Tensor x, *, Int32List in_shape, Int32List out_shape, Int32List stride)" + signature: "Tensor Expand(Tensor x, *, Shape shape)" bind_python: True - name: "expand_dims" @@ -694,3 +694,6 @@ signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)" bind_python: True +- name: "tensor_setitem" + signature: "Void TensorSetItem(Tensor x, *, TensorIndex index, Tensor value)" + bind_python: True diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 53ddfa662b9b5cf2e167a6a04fdf742266962841..e7d913b2c2cc9912f43b632749c854aeced047a7 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -167,10 +167,42 @@ class ConcatFunctor { class ExpandFunctor { public: ExpandFunctor() { op_ = CHECK_JUST(one::OpBuilder("expand").Input("in").Output("out").Build()); } - Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, - const std::vector<int32_t>& in_shape, - const std::vector<int32_t>& out_shape, - const std::vector<int32_t>& stride) const { + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Shape& shape) const { + CHECK_GE_OR_RETURN(shape.NumAxes(), x->shape()->NumAxes()) + << "The desired expanded dims should not be less than the input dims."; + std::vector<int32_t> in_shape(x->shape()->NumAxes()); + for (int i = 0; i < in_shape.size(); ++i) { in_shape[i] = x->shape()->At(i); } + + // calculate the original stride. + std::vector<int32_t> original_stride(in_shape.size(), 1); + for (int i = x->shape()->NumAxes() - 2; i >= 0; --i) { + original_stride[i] = in_shape.at(i + 1) * original_stride.at(i + 1); + } + std::vector<int32_t> out_shape(shape.NumAxes()); + std::vector<int32_t> stride(shape.NumAxes()); + int shift = out_shape.size() - in_shape.size(); + for (int i = out_shape.size() - 1; i >= 0; --i) { + int index = i - shift; + if (index >= 0) { + if (shape.At(i) == -1 || shape.At(i) == in_shape.at(index)) { + out_shape[i] = in_shape.at(index); + stride[i] = original_stride.at(index); + } else { + CHECK_OR_RETURN(shape.At(i) > 0 && in_shape.at(index) == 1) + << "Invalid expand shape " << shape.ToString(); + out_shape[i] = shape.At(i); + stride[i] = 0; + } + } else { + CHECK_GT_OR_RETURN(shape.At(i), 0) << "Invalid expand shape " << shape.ToString(); + out_shape[i] = shape.At(i); + if (shape.At(i) == 1 && i < out_shape.size() - 1) { + stride[i] = stride.at(i + 1); + } else { + stride[i] = 0; + } + } + } MutableAttrMap attrs; JUST(attrs.SetAttr<std::vector<int32_t>>("in_shape", in_shape)); JUST(attrs.SetAttr<std::vector<int32_t>>("out_shape", out_shape)); @@ -773,41 +805,24 @@ class TensorGetItemFunctor { public: TensorGetItemFunctor() {} Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index) const { - const auto& regular_index = JUST(RegularTensorIndex(index, *(x->shape()))); int64_t ndims = x->shape()->NumAxes(); - CHECK_GE_OR_RETURN(regular_index->size(), ndims) << "Tensor index failed to be regularlized."; + std::vector<detail::Slice> slice_indices; + std::vector<std::shared_ptr<one::Tensor>> tensor_indices; + std::vector<int64_t> target_dims; + + JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &target_dims)); + CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << "Failed to prepare slice indices."; + Shape target_shape(DimVector(target_dims.begin(), target_dims.end())); + CHECK_GT_OR_RETURN(target_shape.Count(0), 0) + << "Target shape is zero shape which was not supported yet."; + std::vector<int64_t> start(ndims), end(ndims), step(ndims); - int dim = 0; - DimVector result_dims; - for (int i = 0; i < regular_index->size(); ++i) { - const auto& index_item = regular_index->at(i); - CHECK_OR_RETURN(!index_item.IsEllipsis()) - << "Tensor index should not have ellipsis once regularlized."; - if (index_item.IsSlice()) { - CHECK_LT_OR_RETURN(dim, ndims); - start[dim] = index_item.slice().start(); - end[dim] = index_item.slice().end(); - step[dim] = index_item.slice().step(); - int64_t length = (end[dim] - start[dim] + step[dim] - 1) / step[dim]; - result_dims.emplace_back(length); - dim++; - } else if (index_item.IsInteger()) { - CHECK_LT_OR_RETURN(dim, ndims); - start[dim] = index_item.integer(); - end[dim] = start[dim] + 1; - step[dim] = 1; - dim++; - } else if (index_item.IsNone()) { - result_dims.emplace_back(1); - } else if (index_item.IsBoolean()) { - CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported."; - result_dims.emplace_back(1); - } + for (int i = 0; i < ndims; ++i) { + const auto& slice = slice_indices.at(i); + start[i] = slice.start(); + end[i] = slice.end(); + step[i] = slice.step(); } - CHECK_EQ_OR_RETURN(dim, ndims) - << "Specified dims count for regularlized tensor index should equal to tensor dimension " - << ndims; - bool is_identity = [&]() { for (int i = 0; i < ndims; ++i) { if (start[i] != 0 || end[i] != x->shape()->At(i) || step[i] != 1) { return false; } @@ -821,14 +836,69 @@ class TensorGetItemFunctor { result = JUST(functional::Slice(x, start, end, step)); } - Shape shape(result_dims); + Shape shape(DimVector(target_dims.begin(), target_dims.end())); if (shape.NumAxes() != 0 && shape != *(result->shape())) { - return functional::Reshape(result, shape); + result = JUST(functional::Reshape(result, shape)); } return result; } }; +class TensorSetItemFunctor { + public: + TensorSetItemFunctor() {} + Maybe<void> operator()(const std::shared_ptr<one::Tensor>& x, const TensorIndex& index, + const std::shared_ptr<one::Tensor>& value) const { + int64_t ndims = x->shape()->NumAxes(); + std::vector<detail::Slice> slice_indices; + std::vector<std::shared_ptr<one::Tensor>> tensor_indices; + std::vector<int64_t> target_dims; + + JUST(PrepareSliceIndices(index, *(x->shape()), &slice_indices, &tensor_indices, &target_dims)); + CHECK_EQ_OR_RETURN(slice_indices.size(), ndims) << "Failed to prepare slice indices."; + Shape target_shape(DimVector(target_dims.begin(), target_dims.end())); + if (target_shape.Count(0) == 0) { return Maybe<void>::Ok(); } + + const auto& value_shape = value->shape(); + bool matched = [&]() { + for (int i = 0; i < value_shape->NumAxes() - target_shape.NumAxes(); ++i) { + if (value_shape->At(i) != 1) { return false; } + } + return true; + }(); + CHECK_OR_RETURN(matched) << "The tensor size mismatch. Target sizes: " + << target_shape.ToString() + << ", value sizes: " << value_shape->ToString(); + std::shared_ptr<one::Tensor> value_tensor(value); + if (target_shape.NumAxes() != 0 && // NOLINT + /*need_expand=*/value_shape->Count(0) != target_shape.Count(0)) { + // Remove the beginning redundant 1-dimensions. + if (value_shape->NumAxes() > target_shape.NumAxes()) { + int64_t start_axis = value_shape->NumAxes() - target_shape.NumAxes(); + const auto& shape = JUST(value_shape->Slice(start_axis, value_shape->NumAxes())); + value_tensor = JUST(functional::Reshape(value, *shape)); + } + value_tensor = JUST(functional::Expand(value_tensor, target_shape)); + } + + std::vector<int64_t> start(ndims), end(ndims), step(ndims); + DimVector slice_dims(ndims); + for (int i = 0; i < ndims; ++i) { + const auto& slice = slice_indices.at(i); + start[i] = slice.start(); + end[i] = slice.end(); + step[i] = slice.step(); + slice_dims[i] = (end[i] - start[i] + step[i] - 1) / step[i]; + } + Shape slice_shape(slice_dims); + if (slice_shape != *(value_tensor->shape())) { + value_tensor = JUST(functional::Reshape(value_tensor, slice_shape)); + } + JUST(LogicalSliceAssign(x, value_tensor, start, end, step)); + return Maybe<void>::Ok(); + } +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -873,6 +943,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::DiagFunctor>("Diag"); m.add_functor<impl::DiagGradFunctor>("DiagGrad"); m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem"); + m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem"); }; } // namespace functional diff --git a/oneflow/core/functional/tensor_index.cpp b/oneflow/core/functional/tensor_index.cpp index 2a24a8319b4d083998cbc1d422cb681ea62a78ef..843991963c6445512833bf122c0e814e29663927 100644 --- a/oneflow/core/functional/tensor_index.cpp +++ b/oneflow/core/functional/tensor_index.cpp @@ -28,14 +28,15 @@ int64_t CountSpecifiedDims(const TensorIndex& index) { return specified_ndims; } -Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& shape) { - int64_t specified_ndims = CountSpecifiedDims(index); +Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape, + std::vector<detail::Slice>* slice_indices, + std::vector<std::shared_ptr<Tensor>>* tensor_indices, + std::vector<int64_t>* target_dims) { int64_t ndims = shape.NumAxes(); + int64_t specified_ndims = CountSpecifiedDims(index); CHECK_LE_OR_RETURN(specified_ndims, ndims) << "Too many indices for tensor of dimension " << ndims; - - auto regular_index = std::make_shared<TensorIndex>(); - int64_t dim = 0; + int dim = 0; for (int i = 0; i < index.size(); ++i) { const auto& index_item = index.at(i); if (index_item.IsSlice()) { @@ -50,37 +51,40 @@ Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& sha if (start < 0) { start = 0; } if (end < 0) { end += shape.At(dim); } if (end < start) { end = start; } - regular_index->emplace_back(detail::IndexItem(start, end, step)); + slice_indices->emplace_back(start, end, step); + int64_t length = (end - start + step - 1) / step; + target_dims->emplace_back(length); dim++; } else if (index_item.IsInteger()) { CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims; int64_t integer = index_item.integer(); if (integer < 0) { integer += shape.At(dim); } - CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim)) << Error::ValueError( - std::string("Index ") + std::to_string(index_item.integer()) - + std::string(" is out of bounds for dimension ") + std::to_string(dim) - + std::string(" with size ") + std::to_string(shape.At(dim))); - regular_index->emplace_back(detail::IndexItem(integer)); + if (integer < 0 || integer >= shape.At(dim)) { + return Error::IndexError() + << "Index " << index_item.integer() << " is out of bounds for dimension " << dim + << " with size " << shape.At(dim); + } + slice_indices->emplace_back(integer, integer + 1, 1); dim++; } else if (index_item.IsEllipsis()) { int64_t unspecified_ndims = ndims - specified_ndims; unspecified_ndims = std::min(ndims - dim, unspecified_ndims); for (int j = 0; j < unspecified_ndims; ++j) { - regular_index->emplace_back(detail::IndexItem(0, shape.At(dim + j), 1)); + slice_indices->emplace_back(0, shape.At(dim + j), 1); + target_dims->emplace_back(shape.At(dim + j)); } dim += unspecified_ndims; - } else { - // None or Boolean. - if (index_item.IsBoolean()) { - CHECK_OR_RETURN(index_item.boolean()) << "Index false is not supported."; - } - regular_index->emplace_back(index_item); + } else if (index_item.IsNone()) { + target_dims->emplace_back(1); + } else if (index_item.IsBoolean()) { + target_dims->emplace_back(index_item.boolean()); } } for (int i = dim; i < ndims; ++i) { - regular_index->emplace_back(detail::IndexItem(0, shape.At(i), 1)); + slice_indices->emplace_back(0, shape.At(i), 1); + target_dims->emplace_back(shape.At(i)); } - return regular_index; + return Maybe<void>::Ok(); } } // namespace functional diff --git a/oneflow/core/functional/tensor_index.h b/oneflow/core/functional/tensor_index.h index f2b7e78c8094a6fccfc75f49c27087b884f39568..eba019f8686fb8a110021a0c288c8370f5cabc57 100644 --- a/oneflow/core/functional/tensor_index.h +++ b/oneflow/core/functional/tensor_index.h @@ -25,6 +25,9 @@ limitations under the License. namespace oneflow { namespace one { + +class Tensor; + namespace functional { namespace detail { @@ -94,7 +97,11 @@ class TensorIndex : public std::vector<detail::IndexItem> { }; int64_t CountSpecifiedDims(const TensorIndex& index); -Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& shape); + +Maybe<void> PrepareSliceIndices(const TensorIndex& index, const Shape& shape, + std::vector<detail::Slice>* slice_indices, + std::vector<std::shared_ptr<Tensor>>* tensor_indices, + std::vector<int64_t>* target_dims); } // namespace functional } // namespace one diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index 5d2db0e7138b576a5ce9d359fe4594301e7bfc76..04dae678da5078f8a1beb26c9bb9ec18390676aa 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ import oneflow.core.job.initializer_conf_pb2 as initializer_conf_util -from oneflow._oneflow_internal.exception import ValueException +from oneflow._oneflow_internal.exception import IndexException from oneflow.python.oneflow_export import oneflow_export import oneflow.python.framework.remote_blob as remote_blob_util @@ -454,44 +454,17 @@ class Tensor: def __getitem__(self, key): try: return flow.F.tensor_getitem(self, key) - except ValueException as e: + except IndexException as e: # The stop condition of for in python is IndexError, - # so we have to catch ValueException from C++ and throw IndexError + # so we have to catch IndexException from C++ and throw IndexError raise IndexError(e) @_auto_determine @register_local_tensor_method() def __setitem__(self, key, value): - if isinstance(key, tuple): - key = self._transform_ellipsis_type(key) - unsqueeze_dims = list( - filter(lambda idx: isinstance(key[idx], int), range(len(key))) - ) - elif isinstance(key, int): - if key < 0: - key = self.shape[0] + key - unsqueeze_dims = [0] - else: - unsqueeze_dims = [] - - start, stop, step, shape = self._get_slice_obj(key) if isinstance(value, (int, float)): - scalar = value - value = flow.Tensor(*shape) - value.fill_(scalar) - else: - prepended_broadcasting_dims = range( - len(self.shape) - len(unsqueeze_dims) - len(value.shape) - ) - for dim in prepended_broadcasting_dims: - value = flow.experimental.unsqueeze(value, dim) - for dim in unsqueeze_dims: - value = flow.experimental.unsqueeze(value, dim) - value = flow.experimental.expand(value, *shape) - - flow.experimental.tmp.logical_slice_assign( - self, value, list(zip(start, stop, step)) - ) + value = flow.F.constant([1], value, self.dtype) + flow.F.tensor_setitem(self, key, value) return self @register_local_tensor_method() diff --git a/oneflow/python/nn/modules/expand.py b/oneflow/python/nn/modules/expand.py index acb03b33ea2f490c44f6b98dfdb7c0d1168e7ec8..11e9695392a2e897ca22e2577ba4a7477e3a45ff 100644 --- a/oneflow/python/nn/modules/expand.py +++ b/oneflow/python/nn/modules/expand.py @@ -28,39 +28,7 @@ class Expand(Module): def forward(self, x): if x.dtype == flow.int8: x = flow.experimental.cast(x, flow.int32) - expand_size = self.expand_size - assert len(expand_size) >= len( - x.shape - ), "The desired expanded dims should not be less than the input dims." - # calculate the original stride - original_stride = [1] - for i in range(len(x.shape) - 2, -1, -1): - original_stride.insert(0, original_stride[0] * x.shape[i + 1]) - - # calculate the output shape and stride - new_size = [] - new_stride = [] - diff = len(expand_size) - len(x.shape) - for i in range(len(expand_size) - 1, -1, -1): - if i >= diff: - if expand_size[i] == -1 or expand_size[i] == x.shape[i - diff]: - new_size.insert(0, x.shape[i - diff]) - new_stride.insert(0, original_stride[i - diff]) - else: - assert expand_size[i] >= 1 and x.shape[i - diff] == 1 - new_size.insert(0, expand_size[i]) - new_stride.insert(0, 0) - else: - assert expand_size[i] >= 1 - new_size.insert(0, expand_size[i]) - if expand_size[i] == 1: - new_stride.insert(0, new_stride[0]) - else: - new_stride.insert(0, 0) - - return flow.F.expand( - x, in_shape=list(x.shape), out_shape=new_size, stride=new_stride - ) + return flow.F.expand(x, self.expand_size) @oneflow_export("expand")