Skip to content
Snippets Groups Projects
Unverified Commit 6fc696c6 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add meshgrid module (#5205)


* add argmax test

* fix ci error

* fix docstring warning

* fix tensor greater and less bug

* add meshgrid module

* add meshgrid module

* fix meshgrid module bug

* add docstring for meshgrid module

* add doctest for meshgrid module

* auto format by CI

* fix commnet

Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 31c8b666
No related branches found
No related tags found
No related merge requests found
......@@ -197,5 +197,6 @@ Experimental features
.. autofunction:: oneflow.experimental.Tensor.ceil
.. autofunction:: oneflow.experimental.expm1
.. autofunction:: oneflow.experimental.Tensor.expm1
.. autofunction:: oneflow.experimental.meshgrid
.. autofunction:: oneflow.experimental.topk
.. autofunction:: oneflow.experimental.Tensor.topk
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export, experimental_api
class MeshGrid(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, inputs):
size = len(inputs)
assert size > 0, f"meshgrid expects a non-empty TensorList"
shape = list()
for i in range(size):
assert inputs[i].dim() <= 1, f(
"Expected scalar or 1D tensor in the tensor list but got: ", inputs[i]
)
if inputs[i].dim() == 0:
shape.append(1)
else:
shape.append(inputs[i].shape[0])
for i in range(size - 1):
assert (
inputs[i].dtype == inputs[i + 1].dtype
and inputs[i].device == inputs[i + 1].device
), f"meshgrid expects all tensors to have the same dtype and device"
outputs = []
for i in range(size):
view_shape = [1] * size
view_shape[i] = -1
# TODO(BBuf) change reshape to view
outputs.append(inputs[i].reshape(view_shape).expand(*shape))
return outputs
@oneflow_export("meshgrid")
@experimental_api
def meshgrid_op(*inputs):
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/stable/_modules/torch/functional.html#meshgrid
Take :math:`N` tensors, each of which can be either scalar or 1-dimensional
vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by
expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.
Args:
tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
treated as tensors of size :math:`(1,)` automatically
Returns:
seq (sequence of Tensors): If the input has :math:`k` tensors of size
:math:`(N_1,), (N_2,), \ldots , (N_k,)`, then the output would also have :math:`k` tensors,
where all tensors are of size :math:`(N_1, N_2, \ldots , N_k)`.
For example:
.. code-block:: python
>>> import numpy as np
>>> import oneflow.experimental as flow
>>> flow.enable_eager_execution()
>>> input1 = flow.Tensor(np.array([1, 2, 3]), dtype=flow.float32)
>>> input2 = flow.Tensor(np.array([4, 5, 6]), dtype=flow.float32)
>>> of_x, of_y = flow.meshgrid(input1, input2)
>>> of_x
tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.]], dtype=oneflow.float32)
>>> of_y
tensor([[4., 5., 6.],
[4., 5., 6.],
[4., 5., 6.]], dtype=oneflow.float32)
"""
return MeshGrid()(inputs)
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList
def _test_meshgrid_forawd(test_case, device):
input1 = flow.Tensor(
np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device),
)
input2 = flow.Tensor(
np.array([4, 5, 6]), dtype=flow.float32, device=flow.device(device),
)
np_x, np_y = np.meshgrid(input1.numpy(), input2.numpy(), indexing="ij")
of_x, of_y = flow.meshgrid(input1, input2)
test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 1e-4, 1e-4))
test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 1e-4, 1e-4))
def _test_meshgrid_forawd_scalr(test_case, device):
input1 = flow.Tensor(np.array(1.0), dtype=flow.float32, device=flow.device(device),)
input2 = flow.Tensor(np.array(2.0), dtype=flow.float32, device=flow.device(device),)
np_x, np_y = np.meshgrid(input1.numpy(), input2.numpy(), indexing="ij")
of_x, of_y = flow.meshgrid(input1, input2)
test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 1e-4, 1e-4))
test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 1e-4, 1e-4))
def _test_meshgrid_forawd_3tensor(test_case, device):
input1 = flow.Tensor(
np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device),
)
input2 = flow.Tensor(
np.array([4, 5, 6]), dtype=flow.float32, device=flow.device(device),
)
input3 = flow.Tensor(
np.array([7, 8, 9]), dtype=flow.float32, device=flow.device(device),
)
np_x, np_y, np_z = np.meshgrid(
input1.numpy(), input2.numpy(), input3.numpy(), indexing="ij"
)
of_x, of_y, of_z = flow.meshgrid(input1, input2, input3)
test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 1e-4, 1e-4))
test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 1e-4, 1e-4))
test_case.assertTrue(np.allclose(of_z.numpy(), np_z, 1e-4, 1e-4))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestGreater(flow.unittest.TestCase):
def test_greter(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_meshgrid_forawd,
_test_meshgrid_forawd_scalr,
_test_meshgrid_forawd_3tensor,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment