From 0d010eadab8bb49ff867f62c5546f53a57abd186 Mon Sep 17 00:00:00 2001 From: Ziqi Zhou <834914152@qq.com> Date: Thu, 15 Jul 2021 09:57:10 +0800 Subject: [PATCH] Fix tensor getitem bug (#5474) * printf tensor index in cpp * revert printf debug information * add isinstance tuple and int condition * format code * implement and register __iter__ method * Change to catch exception from C++ and throw IndexError * add indexError error message Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/core/functional/tensor_index.cpp | 7 ++++--- oneflow/python/framework/tensor.py | 9 ++++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/oneflow/core/functional/tensor_index.cpp b/oneflow/core/functional/tensor_index.cpp index 45d3e6294..2a24a8319 100644 --- a/oneflow/core/functional/tensor_index.cpp +++ b/oneflow/core/functional/tensor_index.cpp @@ -56,9 +56,10 @@ Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& sha CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims; int64_t integer = index_item.integer(); if (integer < 0) { integer += shape.At(dim); } - CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim)) - << "Index " << index_item.integer() << " is out of bounds for dimension " << dim - << " with size " << shape.At(dim); + CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim)) << Error::ValueError( + std::string("Index ") + std::to_string(index_item.integer()) + + std::string(" is out of bounds for dimension ") + std::to_string(dim) + + std::string(" with size ") + std::to_string(shape.At(dim))); regular_index->emplace_back(detail::IndexItem(integer)); dim++; } else if (index_item.IsEllipsis()) { diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index 0a9bb6bb1..5d2db0e71 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. """ import oneflow.core.job.initializer_conf_pb2 as initializer_conf_util +from oneflow._oneflow_internal.exception import ValueException + from oneflow.python.oneflow_export import oneflow_export import oneflow.python.framework.remote_blob as remote_blob_util import oneflow._oneflow_internal @@ -450,7 +452,12 @@ class Tensor: @_auto_determine @register_local_tensor_method() def __getitem__(self, key): - return flow.F.tensor_getitem(self, key) + try: + return flow.F.tensor_getitem(self, key) + except ValueException as e: + # The stop condition of for in python is IndexError, + # so we have to catch ValueException from C++ and throw IndexError + raise IndexError(e) @_auto_determine @register_local_tensor_method() -- GitLab