Skip to content
Snippets Groups Projects
Unverified Commit 4789392d authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

add greater_less_argmax module (#4756)


* add greater_less_argmax module

* fix comment

* add test_case

* fix conflict

* fix comment

* fix comment

* fix comment

* fix comment

* fix comment

* fix comment

* fix comment

* fix comment

* fix comment

* fix comment

* fix comment

* format file

Co-authored-by: default avatardaquexian <daquexian566@gmail.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent f7b5bb0f
No related branches found
No related tags found
No related merge requests found
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
from oneflow.python.ops.transpose_util import (
get_perm_when_transpose_axis_to_last_dim,
get_inversed_perm,
)
class Argmax(Module):
"""The op computes the index with the largest value of a Tensor at specified axis.
Args:
input (oneflow.Tensor): Input Tensor
dim (int, optional): dimension to be calculated. Defaults to the last dim (-1)
keepdim (bool optional): whether the output tensor has dim retained or not. Ignored if dim=None.
Returns:
oneflow.Tensor: A Tensor(dtype=int32) contains the index with the largest value of `input`
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
x = np.array([[1, 3, 8, 7, 2],
[1, 9, 4, 3, 2]], dtype=np.float32)
out = flow.argmax(flow.Tensor(x))
# out [2 1]
"""
def __init__(self, dim: int = None, keepdim: bool = False) -> None:
super().__init__()
self._op_softmax_last_dim = (
flow.builtin_op("argmax").Input("in").Output("out").Build()
)
self._expand_op = (
flow.builtin_op("expand_dims")
.Input("in")
.Output("out")
.Attr("axis", -1)
.Build()
)
self._flatten = (
flow.builtin_op("flatten")
.Input("in")
.Output("out")
.Attr("start_dim", 0)
.Attr("end_dim", -1)
.Build()
)
self.dim = dim
self.keepdim = keepdim
def forward(self, input):
if self.dim == None:
input = self._flatten(input)[0]
self.dim = 0
num_axes = len(input.shape)
axis = self.dim if self.dim >= 0 else self.dim + num_axes
assert 0 <= axis < num_axes, "axis out of range"
if axis == num_axes - 1:
x = self._op_softmax_last_dim(input)[0]
if self.keepdim == True:
x = self._expand_op(x)
return x
else:
perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis)
x = flow.tmp.transpose(input, perm=perm)
x = self._op_softmax_last_dim(x)[0]
x = self._expand_op(x)[0]
x = flow.tmp.transpose(x, perm=get_inversed_perm(perm))
if self.keepdim == False:
x = flow.tmp.squeeze(x, axis=[axis])
return x
@oneflow_export("argmax")
@register_tensor_op("argmax")
def argmax_op(tensor, dim: int = None, keepdim: bool = False):
return Argmax(dim=dim, keepdim=keepdim)(tensor)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
class Greater(Module):
r"""Returns the truth value of :math:`x > y` element-wise.
Args:
x (oneflow.Tensor): A Tensor
y (oneflow.Tensor): A Tensor
name (Optional[str], optional): The name for the operation. Defaults to None.
Returns:
oneflow.Tensor: A Tensor with int8 type.
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
input1 = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
input2 = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
out = flow.gt(input1, input2).numpy()
# out shape (2, 6, 5, 3)
"""
def __init__(self) -> None:
super().__init__()
self._op = (
flow.builtin_op("broadcast_greater")
.Input("x")
.Input("y")
.Output("z")
.Build()
)
def forward(self, x, y):
return self._op(x, y)[0]
@oneflow_export("gt")
@register_tensor_op("gt")
def greater_op(tensor1, tensor2):
return Greater()(tensor1, tensor2)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
class Less(Module):
r"""Returns the truth value of :math:`x < y` element-wise.
Args:
x (oneflow.Tensor): A Tensor
y (oneflow.Tensor): A Tensor
name (Optional[str], optional): The name for the operation. Defaults to None.
Returns:
oneflow.Tensor: A Tensor with int8 type.
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
input1 = flow.Tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32)
input2 = flow.Tensor(np.array([1, 2, 4]).astype(np.float32), dtype=flow.float32)
out = flow.gt(input1, input2).numpy
# out [0 0 1]
"""
def __init__(self) -> None:
super().__init__()
self._op = (
flow.builtin_op("broadcast_less").Input("x").Input("y").Output("z").Build()
)
def forward(self, x, y):
return self._op(x, y)[0]
@oneflow_export("lt")
@register_tensor_op("lt")
def less_op(tensor1, tensor2):
return Less()(tensor1, tensor2)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
import oneflow.python.framework.id_util as id_util
from typing import Optional, Sequence
class Squeeze(Module):
"""This operator removes the specified dimention which size is 1 of the input Tensor.
If the `axis` is not specified, this operator will remove all the dimention which size is 1 of the input Tensor.
The amount of element in return value is the same as Tensor `input`.
Args:
input (oneflow.Tensor): The input Tensor.
axis (Optional[Sequence[int]], optional): The axis. Defaults to None.
Returns:
oneflow.Tensor: The result Tensor.
For example:
Example:
.. code-block:: python
import oneflow as flow
import numpy as np
input = flow.Tensor(np.array([[[[1, 1, 1]]]]).astype(np.int32))
out = flow.tmp.squeeze(input, axis=[1, 2]).numpy().shape
# out.shape (1, 3)
"""
def __init__(self, axis: Optional[Sequence[int]] = None) -> None:
super().__init__()
self._op = (
flow.builtin_op("squeeze")
.Input("in")
.Output("out")
.Attr("axes", axis)
.Build()
)
def forward(self, x):
return self._op(x)[0]
@oneflow_export("tmp.squeeze")
@register_tensor_op("squeeze")
def squeeze_op(tensor, axis: Optional[Sequence[int]] = None):
return Squeeze(axis=axis)(tensor)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
from typing import Optional, Sequence
class Transpose(Module):
r"""This operator transposes the specified axis of input Tensor.
Args:
a (oneflow.Tensor): The input tensor.
perm (Sequence[int], optional): The list of dimension permutation. Defaults to None.
conjugate (bool, optional): Still Unavailable. Defaults to False.
batch_axis_non_change (bool, optional): deprecated. Defaults to False.
Raises:
NotImplementedError: The attribute `conjugate` still unavailable.
Returns:
oneflow.Tensor: A transposed tensor.
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
out = flow.tmp.transpose(input, perm=(0, 2, 3, 1))
# out.shape (2, 5, 3, 6)
"""
def __init__(
self,
perm: Sequence[int] = None,
conjugate: bool = False,
batch_axis_non_change: bool = False,
) -> None:
super().__init__()
assert isinstance(perm, (tuple, list))
if conjugate:
raise NotImplementedError
if batch_axis_non_change:
raise NotImplementedError
self._op = (
flow.builtin_op("transpose")
.Input("input")
.Output("output")
.Attr("perm", perm)
.Build()
)
def forward(self, x):
return self._op(x)[0]
@oneflow_export("tmp.transpose")
@register_tensor_op("transpose")
def transpose_op(tensor, perm: Sequence[int] = None):
return Transpose(perm=perm)(tensor)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestArgmax(flow.unittest.TestCase):
def test_argmax_v1(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
axis = -1
of_out = flow.argmax(input, dim=axis)
np_out = np.argmax(input.numpy(), axis=axis)
test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))
def test_tensor_argmax(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
axis = 0
of_out = input.argmax(dim=axis)
np_out = np.argmax(input.numpy(), axis=axis)
test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape))
test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))
def test_argmax_v3(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
axis = 1
of_out = flow.argmax(input, dim=axis)
np_out = np.argmax(input.numpy(), axis=axis)
test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))
def test_argmax_keepdims(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
axis = 0
of_out = input.argmax(axis, True)
np_out = np.argmax(input.numpy(), axis=axis)
np_out = np.expand_dims(np_out, axis=axis)
test_case.assertTrue(np.array_equal(of_out.numpy().shape, np_out.shape))
test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))
def test_argmax_dim_equal_none(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = input.argmax()
np_out = np.argmax(input.numpy().flatten(), axis=0)
test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestGreater(flow.unittest.TestCase):
def test_greater_v1(test_case):
input1 = flow.Tensor(np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32)
input2 = flow.Tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32)
of_out = flow.gt(input1, input2)
np_out = np.greater(input1.numpy(), input2.numpy())
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def test_tensor_greater(test_case):
input1 = flow.Tensor(np.array([1, 1, 4]).astype(np.float32), dtype=flow.float32)
input2 = flow.Tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32)
of_out = input1.gt(input2)
np_out = np.greater(input1.numpy(), input2.numpy())
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestLess(flow.unittest.TestCase):
def test_less_v1(test_case):
input1 = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
input2 = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = flow.lt(input1, input2)
np_out = np.less(input1.numpy(), input2.numpy())
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
def test_tensor_less(test_case):
input1 = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
input2 = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = input1.lt(input2)
np_out = np.less(input1.numpy(), input2.numpy())
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestSqueeze(flow.unittest.TestCase):
def test_squeeze_v1(test_case):
input = flow.Tensor(np.array([[[[1, 1, 1]]]]).astype(np.int32))
of_shape = flow.tmp.squeeze(input, axis=[1, 2]).numpy().shape
np_shape = (1, 3)
test_case.assertTrue(np.array_equal(of_shape, np_shape))
def test_tensor_squeeze(test_case):
input = flow.Tensor(np.array([[[[1, 1, 1]]]]).astype(np.int32))
of_shape = input.squeeze(axis=[1, 2]).numpy().shape
np_shape = (1, 3)
test_case.assertTrue(np.array_equal(of_shape, np_shape))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestTranspose(flow.unittest.TestCase):
def test_transpose_v1(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = flow.tmp.transpose(input, perm=(0, 2, 3, 1))
np_out = input.numpy().transpose((0, 2, 3, 1))
test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))
def test_tensor_transpose(test_case):
input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
of_out = input.transpose(perm=(0, 2, 3, 1))
np_out = input.numpy().transpose((0, 2, 3, 1))
test_case.assertTrue(np.array_equal(of_out.numpy().flatten(), np_out.flatten()))
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment