Skip to content
Snippets Groups Projects
Unverified Commit 517e24d8 authored by YongtaoShi's avatar YongtaoShi Committed by GitHub
Browse files

add tile module (#5234)


* add tile module

* add backward testcase

* add docstring

* parameters consistent with pytorch

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 2a49e6a1
No related branches found
No related tags found
No related merge requests found
......@@ -162,6 +162,8 @@ Experimental features
.. autofunction:: oneflow.experimental.nn.MaxPool3d
.. autofunction:: oneflow.experimental.repeat
.. autofunction:: oneflow.experimental.Tensor.repeat
.. autofunction:: oneflow.experimental.tile
.. autofunction:: oneflow.experimental.Tensor.tile
.. autofunction:: oneflow.experimental.reshape
.. autofunction:: oneflow.experimental.Tensor.reshape
.. autofunction:: oneflow.experimental.squeeze
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Union
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.framework.tensor import Tensor, register_tensor_op
class Tile(Module):
def __init__(self, reps: tuple) -> None:
super().__init__()
self.reps = reps
def forward(self, input: Tensor) -> Tensor:
reps = self.reps
for s in self.reps:
assert s > 0
input_shape = input.shape
diff = len(input_shape) - len(reps)
if diff > 0:
shape = [1 for _ in range(diff)]
shape.extend([i for i in reps])
reps = tuple(shape)
return input.repeat(reps)
@oneflow_export("tile")
@register_tensor_op("tile")
@experimental_api
def tile_op(x, reps):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.tile.html
Constructs a tensor by repeating the elements of ``input``. The ``reps`` argument specifies the number
of repetitions in each dimension.
If ``reps`` specifies fewer dimensions than ``input`` has, then ones are prepended to ``reps`` until
all dimensions are specified. For example, if ``input`` has shape (8, 6, 4, 2) and ``reps`` is (2, 2),
then ``reps`` is treated as (1, 1, 2, 2).
Analogously, if ``input`` has fewer dimensions than ``reps`` specifies, then ``input`` is treated as
if it were unsqueezed at dimension zero until it has as many dimensions as ``reps`` specifies.
For example, if ``input`` has shape (4, 2) and ``reps`` is (3, 3, 2, 2), then ``input`` is treated as
if it had the shape (1, 1, 4, 2).
.. note::
This function is similar to NumPy’s tile function.
Args:
input (oneflow.Tensor): the tensor whose elements to repeat.
reps (tuple): the number of repetitions per dimension.
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> x = np.array([1, 2]).astype(np.int32)
>>> input = flow.Tensor(x, dtype=flow.int32)
>>> out = input.tile(reps=(2,))
>>> out
tensor([1, 2, 1, 2], dtype=oneflow.int32)
>>> x = np.random.randn(5, 2, 1)
>>> input = flow.Tensor(x)
>>> out = input.tile(reps=(3, 4))
>>> out.size()
flow.Size([5, 6, 4])
"""
return Tile(reps=reps)(x)
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList
def np_tile(x, sizes):
return np.tile(x, sizes)
def np_tile_grad(x, sizes):
times = np.array(sizes).prod()
return np.ones(shape=x.shape) * times
def _test_tile_less_dim_a(test_case, device):
input = flow.Tensor(
np.random.randn(2, 4, 1, 3), dtype=flow.float32, device=flow.device(device)
)
sizes = (2,)
np_out = np_tile(input.numpy(), sizes)
of_out = input.tile(reps=sizes)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def _test_tile_less_dim_b(test_case, device):
input = flow.Tensor(
np.random.randn(3, 2, 5), dtype=flow.float32, device=flow.device(device)
)
sizes = (3, 4)
np_out = np_tile(input.numpy(), sizes)
of_out = input.tile(reps=sizes)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def _test_tile_less_dim_c(test_case, device):
input = flow.Tensor(
np.random.randn(4, 3, 2, 5, 3), dtype=flow.float32, device=flow.device(device)
)
sizes = (2, 3, 4, 4)
np_out = np_tile(input.numpy(), sizes)
of_out = input.tile(reps=sizes)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def _test_tile_same_dim(test_case, device):
input = flow.Tensor(
np.random.randn(1, 2, 5, 3), dtype=flow.float32, device=flow.device(device)
)
sizes = (4, 2, 3, 19)
of_out = input.tile(reps=sizes)
np_out = np_tile(input.numpy(), sizes)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def _test_tile_same_dim_int(test_case, device):
input = flow.Tensor(
np.random.randn(1, 2, 5, 3), dtype=flow.int32, device=flow.device(device)
)
size_tensor = flow.Tensor(np.random.randn(4, 2, 3, 19))
sizes = size_tensor.size()
of_out = input.tile(reps=sizes)
np_out = np_tile(input.numpy(), sizes)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out.astype(np.int32)))
def _test_tile_same_dim_int8(test_case, device):
input = flow.Tensor(
np.random.randn(1, 2, 5, 3), dtype=flow.int8, device=flow.device(device)
)
size_tensor = flow.Tensor(np.random.randn(4, 2, 3, 19))
sizes = size_tensor.size()
of_out = input.tile(reps=sizes)
np_out = np_tile(input.numpy(), sizes)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out.astype(np.int32)))
def _test_tile_less_dim_a_backward(test_case, device):
input = flow.Tensor(
np.random.randn(2, 4, 1, 3),
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
sizes = (2,)
of_out = input.tile(reps=sizes)
of_out = of_out.sum()
of_out.backward()
np_grad = np_tile_grad(input.numpy(), sizes)
test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad))
def _test_tile_less_dim_b_backward(test_case, device):
input = flow.Tensor(
np.random.randn(3, 2, 5),
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
sizes = (3, 4)
of_out = input.tile(reps=sizes)
of_out = of_out.sum()
of_out.backward()
np_grad = np_tile_grad(input.numpy(), sizes)
test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad))
def _test_tile_less_dim_c_backward(test_case, device):
input = flow.Tensor(
np.random.randn(4, 3, 2, 5, 3),
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
sizes = (2, 3, 4, 4)
of_out = input.tile(reps=sizes)
of_out = of_out.sum()
of_out.backward()
np_grad = np_tile_grad(input.numpy(), sizes)
test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad))
def _test_tile_same_dim_backward(test_case, device):
input = flow.Tensor(
np.random.randn(1, 2, 5, 3),
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
sizes = (1, 2, 3, 1)
of_out = input.tile(reps=sizes)
of_out = of_out.sum()
of_out.backward()
np_grad = np_tile_grad(input.numpy(), sizes)
test_case.assertTrue(np.array_equal(input.grad.numpy(), np_grad))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestTile(flow.unittest.TestCase):
def test_tile(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_tile_less_dim_a,
_test_tile_less_dim_b,
_test_tile_less_dim_c,
_test_tile_same_dim,
_test_tile_same_dim_int,
_test_tile_same_dim_int8,
_test_tile_less_dim_a_backward,
_test_tile_less_dim_b_backward,
_test_tile_less_dim_c_backward,
_test_tile_same_dim_backward,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment