Skip to content
Snippets Groups Projects
Unverified Commit 0d010ead authored by Ziqi Zhou's avatar Ziqi Zhou Committed by GitHub
Browse files

Fix tensor getitem bug (#5474)


* printf tensor index in cpp

* revert printf debug information

* add isinstance tuple and int condition

* format code

* implement and register __iter__ method

* Change to catch exception from C++ and throw IndexError

* add indexError error message

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 11126849
No related branches found
No related tags found
No related merge requests found
......@@ -56,9 +56,10 @@ Maybe<TensorIndex> RegularTensorIndex(const TensorIndex& index, const Shape& sha
CHECK_LT_OR_RETURN(dim, ndims) << "Invalid index for tensor of dimension " << ndims;
int64_t integer = index_item.integer();
if (integer < 0) { integer += shape.At(dim); }
CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim))
<< "Index " << index_item.integer() << " is out of bounds for dimension " << dim
<< " with size " << shape.At(dim);
CHECK_OR_RETURN(integer >= 0 && integer < shape.At(dim)) << Error::ValueError(
std::string("Index ") + std::to_string(index_item.integer())
+ std::string(" is out of bounds for dimension ") + std::to_string(dim)
+ std::string(" with size ") + std::to_string(shape.At(dim)));
regular_index->emplace_back(detail::IndexItem(integer));
dim++;
} else if (index_item.IsEllipsis()) {
......
......@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow.core.job.initializer_conf_pb2 as initializer_conf_util
from oneflow._oneflow_internal.exception import ValueException
from oneflow.python.oneflow_export import oneflow_export
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow._oneflow_internal
......@@ -450,7 +452,12 @@ class Tensor:
@_auto_determine
@register_local_tensor_method()
def __getitem__(self, key):
return flow.F.tensor_getitem(self, key)
try:
return flow.F.tensor_getitem(self, key)
except ValueException as e:
# The stop condition of for in python is IndexError,
# so we have to catch ValueException from C++ and throw IndexError
raise IndexError(e)
@_auto_determine
@register_local_tensor_method()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment