Skip to content
Snippets Groups Projects
Unverified Commit 1299deae authored by Yinggang Wang's avatar Yinggang Wang Committed by GitHub
Browse files

Feat tensor stride property (#5543)


* feat(Stride): add Stride class

* feat(Tensor): support stride and storage_offset interface

* feat(Tensor): add is_contiguous interface

* remove test declaration

* feat(TensorMeta): add hash and compare for stride

* refine code

* refine IsContiguous function

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 2b208ec0
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,9 @@ limitations under the License.
#include "oneflow/core/common/tensor_buffer.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_method.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stride.h"
#include "oneflow/core/framework/py_distribute.h"
#include "oneflow/core/job/placement.cfg.h"
#include "oneflow/core/job/global_for.h"
......@@ -201,6 +203,10 @@ void ApiRegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMe
return RegisterTensorHook(self, hook).GetOrThrow();
}
bool ApiIsContiguous(const std::shared_ptr<Tensor>& tensor) {
return IsContiguous(tensor).GetOrThrow();
}
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
......@@ -219,6 +225,13 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
return std::shared_ptr<Tensor>();
}
})
.def("storage_offset", [](const Tensor& t) { return t.storage_offset().GetOrThrow(); })
.def("stride",
[](const Tensor& t) {
const auto& stride = t.stride().GetPtrOrThrow()->StrideVec();
return py::tuple(py::make_iterator(stride.begin(), stride.end()));
})
.def("is_contiguous", &ApiIsContiguous)
// setter of grad
.def("set_grad",
[](Tensor& t, const std::shared_ptr<Tensor>& grad) {
......
......@@ -27,11 +27,13 @@ namespace oneflow {
typedef std::vector<int64_t> DimVector;
typedef std::vector<int64_t> AxisVector;
typedef std::vector<int64_t> StrideVector;
#else
typedef fixed_vector<int64_t, SHAPE_MAX_AXIS_SIZE> DimVector;
typedef fixed_vector<int64_t, SHAPE_MAX_AXIS_SIZE> AxisVector;
typedef fixed_vector<int64_t, SHAPE_MAX_AXIS_SIZE> StrideVector;
#endif
} // namespace oneflow
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/stride.h"
namespace oneflow {
Stride::Stride(const Shape& shape) {
if (shape.NumAxes() > 0) {
stride_vec_.resize(shape.NumAxes());
int64_t stride = 1;
for (size_t i = shape.NumAxes(); i > 0; --i) {
stride_vec_.at(i - 1) = stride;
stride *= shape.At(i - 1);
}
}
}
Stride& Stride::operator=(const Stride& stride) {
stride_vec_ = stride.stride_vec_;
return *this;
}
bool Stride::operator==(const Stride& rhs) const { return stride_vec_ == rhs.stride_vec_; }
std::string Stride::ToString() const {
std::stringstream ss;
int32_t idx = 0;
ss << "(";
for (int64_t dim : stride_vec_) {
ss << dim;
if (++idx != stride_vec_.size() || stride_vec_.size() == 1) { ss << ","; }
}
ss << ")";
return ss.str();
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_FRAMEWORK_STRIDE_H_
#define ONEFLOW_CORE_FRAMEWORK_STRIDE_H_
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
class Stride final {
public:
Stride() = default;
explicit Stride(const Shape& shape);
explicit Stride(const StrideVector& stride_vec) : stride_vec_(stride_vec) {}
explicit Stride(StrideVector&& stride_vec) : stride_vec_(stride_vec) {}
Stride(const std::initializer_list<int64_t>& stride_vec) : stride_vec_(stride_vec) {}
Stride& operator=(const Stride& stride);
~Stride() = default;
bool operator==(const Stride& rhs) const;
bool operator!=(const Stride& rhs) const { return !(*this == rhs); }
std::string ToString() const;
// Getters and Setters
const StrideVector& StrideVec() const { return stride_vec_; }
int64_t NumAxes() const { return stride_vec_.size(); }
int64_t At(int64_t index) const { return stride_vec_.at(index); }
void Set(int64_t index, int64_t val) { stride_vec_.at(index) = val; }
private:
StrideVector stride_vec_;
};
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::Stride> {
size_t operator()(const oneflow::Stride& stride) const {
size_t ret = 0;
FOR_RANGE(int, i, 0, stride.NumAxes()) { ret ^= std::hash<int64_t>()(stride.At(i)); }
return ret;
}
};
} // namespace std
#endif // ONEFLOW_CORE_FRAMEWORK_STRIDE_H_
......@@ -64,6 +64,8 @@ class Tensor {
virtual Maybe<VmLocalDepObject> compute_local_dep_object() const = 0;
virtual Maybe<bool> has_eager_blob_object() const = 0;
virtual Maybe<TensorStorage> tensor_storage() const { OF_UNIMPLEMENTED(); }
virtual Maybe<const Stride> stride() const { OF_UNIMPLEMENTED(); }
virtual Maybe<int64_t> storage_offset() const { OF_UNIMPLEMENTED(); }
// Getters/Setters valid only for EagerConsistentTensor
virtual Maybe<Symbol<cfg::ParallelDistribution>> consumer_parallel_distribution_constraint()
......@@ -165,6 +167,8 @@ class MirroredTensor final : public TensorIf<MirroredTensor>,
}
Maybe<TensorStorage> tensor_storage() const override { return impl_->tensor_storage(); }
Maybe<bool> has_eager_blob_object() const override { return impl_->has_eager_blob_object(); }
Maybe<const Stride> stride() const override { return impl_->stride(); }
Maybe<int64_t> storage_offset() const override { return impl_->storage_offset(); }
// Getters for autograd
Maybe<Tensor> acc_grad() const override { return impl_->acc_grad(); }
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor_impl.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/stride.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/sbp_parallel.cfg.h"
#include "oneflow/core/framework/device.h"
......@@ -165,16 +166,23 @@ Maybe<MirroredTensorImpl> EagerMirroredTensorImpl::detach() const {
return std::shared_ptr<MirroredTensorImpl>(detached_impl);
}
MirroredTensorMeta::MirroredTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype,
Symbol<Device> device)
: TensorMeta(shape, dtype),
device_(device),
stride_(std::make_shared<const Stride>(*shape)),
storage_offset_(0) {}
bool MirroredTensorMeta::operator==(const MirroredTensorMeta& other) const {
// It's correct to ignore is_dynamic_ field.
return *this->shape_ptr() == *other.shape_ptr() && this->dtype() == other.dtype()
&& *this->device() == *other.device();
&& *this->device() == *other.device() && this->stride() == other.stride();
}
size_t MirroredTensorMeta::CalcHashValue() const {
// It's correct to ignore is_dynamic_ field.
return std::hash<Shape>()(*shape_ptr()) ^ std::hash<DataType>()(dtype())
^ std::hash<Device>()(*device());
^ std::hash<Device>()(*device()) ^ std::hash<Stride>()(stride());
}
bool ConsistentTensorMeta::operator==(const ConsistentTensorMeta& other) const {
......
......@@ -63,6 +63,8 @@ class TensorImpl {
virtual Maybe<VmLocalDepObject> compute_local_dep_object() const = 0;
virtual Maybe<TensorStorage> tensor_storage() const { OF_UNIMPLEMENTED(); }
virtual Maybe<bool> has_eager_blob_object() const = 0;
virtual Maybe<const Stride> stride() const { OF_UNIMPLEMENTED(); }
virtual Maybe<int64_t> storage_offset() const { OF_UNIMPLEMENTED(); }
// Getters for autograd
Maybe<Tensor> acc_grad() const;
......@@ -210,6 +212,8 @@ class EagerMirroredTensorImpl final : public MirroredTensorImpl {
return tensor_storage_;
}
Maybe<bool> has_eager_blob_object() const override { return eager_blob_object_.get(); }
Maybe<const Stride> stride() const override { return tensor_meta_->stride_ptr(); }
Maybe<int64_t> storage_offset() const override { return tensor_meta_->storage_offset(); }
// Setters
TensorStorage* mut_tensor_storage() { return tensor_storage_.get(); }
......
......@@ -27,6 +27,7 @@ class ParallelDistribution;
class Shape;
class Device;
class Stride;
class ParallelDesc;
namespace one {
......@@ -63,19 +64,25 @@ class TensorMeta : public user_op::TensorDesc {
class MirroredTensorMeta : public TensorMeta {
public:
MirroredTensorMeta(const std::shared_ptr<const Shape>& shape, DataType dtype,
Symbol<Device> device)
: TensorMeta(shape, dtype), device_(device) {}
Symbol<Device> device);
virtual ~MirroredTensorMeta() = default;
const Symbol<Device>& device() const { return device_; }
const Stride& stride() const { return *stride_; }
const std::shared_ptr<const Stride>& stride_ptr() const { return stride_; }
int64_t storage_offset() const { return storage_offset_; }
Symbol<Device>* mut_device() { return &device_; }
void set_stride(const std::shared_ptr<const Stride>& stride) { stride_ = stride; }
void set_storage_offset(int64_t offset) { storage_offset_ = offset; }
bool operator==(const MirroredTensorMeta& other) const;
size_t CalcHashValue() const;
private:
Symbol<Device> device_;
std::shared_ptr<const Stride> stride_;
int64_t storage_offset_;
};
class ConsistentTensorMeta : public TensorMeta {
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/tensor_method.h"
#include "oneflow/core/framework/stride.h"
#include "oneflow/core/common/shape.h"
namespace oneflow {
namespace one {
Maybe<bool> IsContiguous(const std::shared_ptr<Tensor>& tensor) {
const Shape& shape = *tensor->shape();
const Stride& stride = *JUST(tensor->stride());
int64_t dim = shape.NumAxes();
int64_t expected_stride = 1;
bool contig_if_nonempty = true;
for (int64_t i = dim - 1; i >= 0; --i) {
// Contiguous by default when any dim is equal to zero
// https://stackoverflow.com/questions/31681324/identify-contiguous-segments-of-a-non-contiguous-numpy-array
if (shape.At(i) == 0) { return true; }
if (contig_if_nonempty && shape.At(i) != 1) {
if (stride.At(i) != expected_stride) { contig_if_nonempty = false; }
expected_stride *= shape.At(i);
}
}
return contig_if_nonempty;
}
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_METHOD_H_
#define ONEFLOW_CORE_FRAMEWORK_TENSOR_METHOD_H_
#include "oneflow/core/framework/tensor.h"
namespace oneflow {
namespace one {
class Tensor;
Maybe<bool> IsContiguous(const std::shared_ptr<Tensor>& tensor);
} // namespace one
} // namespace oneflow
#endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_METHOD_H_
......@@ -205,6 +205,18 @@ class Tensor:
else:
return self._undetermined_tensor.shape
def stride(self):
assert self.is_determined
return self._local_or_consistent_tensor.stride()
def storage_offset(self):
assert self.is_determined
return self._local_or_consistent_tensor.storage_offset()
def is_contiguous(self):
assert self.is_determined
return self._local_or_consistent_tensor.is_contiguous()
@property
def device(self):
if self._local_or_consistent_tensor is not None:
......
......@@ -34,6 +34,15 @@ class TestTensor(flow.unittest.TestCase):
np.array_equal(tensor.numpy(), np.ones(shape, dtype=np.float32))
)
def test_tensor_property(test_case):
shape = (2, 3, 4, 5)
tensor = flow.Tensor(*shape)
tensor.determine()
test_case.assertEqual(tensor.storage_offset(), 0)
test_case.assertEqual(tensor.stride(), (60, 20, 5, 1))
test_case.assertEqual(tensor.is_cuda, False)
test_case.assertTrue(tensor.is_contiguous())
def test_copy_to_and_from_numpy(test_case):
np_arr = np.array([4, 6], dtype=np.float32)
tensor = flow.Tensor(np_arr, dtype=flow.float32)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment