From 0d010eadab8bb49ff867f62c5546f53a57abd186 Mon Sep 17 00:00:00 2001
From: Ziqi Zhou <834914152@qq.com>
Date: Thu, 15 Jul 2021 09:57:10 +0800
Subject: [PATCH] Fix tensor getitem bug (#5474)

* printf tensor index in cpp

* revert printf debug information

* add isinstance tuple and int condition

* format code

* implement and register __iter__ method

* Change to catch exception from C++ and throw IndexError

* add indexError error message

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/core/functional/tensor_index.cpp | 7 ++++---
 oneflow/python/framework/tensor.py       | 9 ++++++++-
 2 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/oneflow/core/functional/tensor_index.cpp b/oneflow/core/functional/tensor_index.cpp
index 45d3e6294..2a24a8319 100644
--- a/oneflow/core/functional/tensor_index.cpp
+++ b/oneflow/core/functional/tensor_index.cpp
@@ -56,9 +56,10 @@ Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& sha
       CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
       int64_t integer = index_item.integer();
       if (integer < 0) { integer += shape.At(dim); }
-      CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim))
-          << "Index " << index_item.integer() << " is out of bounds for dimension " << dim
-          << " with size " << shape.At(dim);
+      CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim)) << Error::ValueError(
+          std::string("Index ") + std::to_string(index_item.integer())
+          + std::string(" is out of bounds for dimension ") + std::to_string(dim)
+          + std::string(" with size ") + std::to_string(shape.At(dim)));
       regular_index->emplace_back(detail::IndexItem(integer));
       dim++;
     } else if (index_item.IsEllipsis()) {
diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py
index 0a9bb6bb1..5d2db0e71 100644
--- a/oneflow/python/framework/tensor.py
+++ b/oneflow/python/framework/tensor.py
@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 """
 import oneflow.core.job.initializer_conf_pb2 as initializer_conf_util
+from oneflow._oneflow_internal.exception import ValueException
+
 from oneflow.python.oneflow_export import oneflow_export
 import oneflow.python.framework.remote_blob as remote_blob_util
 import oneflow._oneflow_internal
@@ -450,7 +452,12 @@ class Tensor:
     @_auto_determine
     @register_local_tensor_method()
     def __getitem__(self, key):
-        return flow.F.tensor_getitem(self, key)
+        try:
+            return flow.F.tensor_getitem(self, key)
+        except ValueException as e:
+            # The stop condition of for in python is IndexError,
+            # so we have to catch ValueException from C++ and throw IndexError
+            raise IndexError(e)
 
     @_auto_determine
     @register_local_tensor_method()
-- 
GitLab