Skip to content
Snippets Groups Projects
Unverified Commit 594a64f9 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

Dev tensor buffer eager (#5317)


* fix conflict

* add tensor_buffer_ops eager

* add testcase

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 75dc8371
No related branches found
No related tags found
No related merge requests found
......@@ -215,3 +215,5 @@ Experimental features
.. autofunction:: oneflow.experimental.Tensor.topk
.. autofunction:: oneflow.experimental.nn.GroupNorm
.. autofunction:: oneflow.experimental.nn.ZeroPad2d
.. autofunction:: oneflow.experimental.tensor_buffer_to_tensor
.. autofunction:: oneflow.experimental.tensor_to_tensor_buffer
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Sequence
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export, experimental_api
class TensorBufferToTensor(Module):
def __init__(self, dtype, instance_shape):
super().__init__()
self._op = (
flow.builtin_op("tensor_buffer_to_tensor")
.Input("in")
.Output("out")
.Attr("dtype", dtype)
.Attr("instance_shape", instance_shape)
.Build()
)
def forward(self, input):
return self._op(input)[0]
@oneflow_export("tensor_buffer_to_tensor")
@experimental_api
def tensor_buffer_to_tensor_op(x, dtype: flow.dtype, instance_shape: Sequence[int]):
"""This operator converts the Tensor's type from TensorBuffer to original type.
Some operator's output data type is `TensorBuffer`, you can use this operator to convert back
to `Tensor`.
Refer to `Concept Explanation <https://docs.oneflow.org/basics_topics/concept_explanation.html#3tensorbuffer-tensorlist>`_
for more about TensorBuffer.
Args:
x (oneflow.Tensor): The input Tensor.
dtype (flow.dtype): The data dtype.
instance_shape (Sequence[int]): The shape of each TensorBuffer instance.
Returns:
oneflow.Tensor: The result Tensor.
For example:
.. code-block:: python
>>> import numpy as np
>>> import oneflow.experimental as flow
>>> flow.enable_eager_execution()
>>> x = np.random.randn(4, 16, 64, 64).astype(np.float32)
>>> x = flow.Tensor(x)
>>> x = flow.tensor_to_tensor_buffer(x, instance_dims=2)
>>> output = flow.tensor_buffer_to_tensor(x, instance_shape=(64, 64), dtype=flow.float)
>>> output.shape
flow.Size([4, 16, 64, 64])
"""
return TensorBufferToTensor(dtype=dtype, instance_shape=instance_shape)(x)
class TensorToTensorBuffer(Module):
def __init__(self, instance_dims):
super().__init__()
self._op = (
flow.builtin_op("tensor_to_tensor_buffer")
.Input("in")
.Output("out")
.Attr("instance_dims", instance_dims)
.Build()
)
def forward(self, input):
return self._op(input)[0]
@oneflow_export("tensor_to_tensor_buffer")
@experimental_api
def tensor_to_tensor_buffer(x, instance_dims: int):
"""This operator converts the Tensor's type to TensorBuffer.
Refer to `Concept Explanation <https://docs.oneflow.org/basics_topics/concept_explanation.html#3tensorbuffer-tensorlist>`_
for more about TensorBuffer.
Args:
x (oneflow.Tensor): The input Tensor.
instance_dims (int): The dimensions of dynamic tensor instance.
Returns:
oneflow.Tensor: The result Tensor.
For example:
.. code-block:: python
>>> import numpy as np
>>> import oneflow.experimental as flow
>>> flow.enable_eager_execution()
>>> x = np.random.randn(4, 16, 64, 64).astype(np.float32)
>>> x = flow.Tensor(x)
>>> x = flow.tensor_to_tensor_buffer(x, instance_dims=2)
>>> output = flow.tensor_buffer_to_tensor(x, instance_shape=(64, 64), dtype=flow.float)
>>> output.shape
flow.Size([4, 16, 64, 64])
"""
return TensorToTensorBuffer(instance_dims=instance_dims)(x)
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
......@@ -26,6 +26,7 @@ from typing import Optional, Sequence, List
@oneflow_export("tensor_buffer_to_tensor")
@stable_api
def tensor_buffer_to_tensor(
x: oneflow._oneflow_internal.BlobDesc,
dtype: flow.dtype,
......@@ -89,6 +90,7 @@ def tensor_buffer_to_tensor(
@oneflow_export("tensor_to_tensor_buffer")
@stable_api
def tensor_to_tensor_buffer(
x: oneflow._oneflow_internal.BlobDesc,
instance_dims: int,
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList, type_name_to_flow_type
def _test_tensor_buffer_convert(test_case, device):
input = flow.Tensor(
np.random.rand(16, 24, 32, 36), dtype=flow.float32, device=flow.device(device)
)
tensor_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=2)
orig_tensor = flow.tensor_buffer_to_tensor(
tensor_buffer, dtype=flow.float32, instance_shape=[32, 36]
)
test_case.assertTrue(np.array_equal(input.numpy(), orig_tensor.numpy()))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestTensorBufferOps(flow.unittest.TestCase):
def test_tensor_buffer_convert(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_tensor_buffer_convert]
arg_dict["device"] = ["cpu"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment