Skip to content
Snippets Groups Projects
Unverified Commit a57b7649 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

align squeeze module with torch (#4855)


* align squeeze module with torch

* fix comment

* fix argmax bug

* fix bug

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 87c0c3d8
No related branches found
No related tags found
No related merge requests found
......@@ -68,7 +68,7 @@ class Argmax(Module):
x = self._expand_op(x)[0]
x = x.transpose(perm=get_inversed_perm(perm))
if self.keepdim == False:
x = x.squeeze(axis=[axis])
x = x.squeeze(dim=[axis])
return x
......
......@@ -225,7 +225,7 @@ class NLLLoss(Module):
def nllloss_1d(self, input, target):
target = flow.experimental.reshape(target, (target.shape[0], 1))
res = self._dim_gather_op(input, target)[0]
res = flow.experimental.squeeze(res, axis=[1])
res = flow.experimental.squeeze(res, dim=[1])
return res
def forward(self, input, target):
......
......@@ -22,14 +22,14 @@ from typing import Optional, Sequence
class Squeeze(Module):
def __init__(self, axis: Optional[Sequence[int]] = None) -> None:
def __init__(self, dim: Optional[Sequence[int]] = None) -> None:
super().__init__()
self._op = (
flow.builtin_op("squeeze")
.Input("in")
.Output("out")
.Attr("axes", axis)
.Attr("axes", dim)
.Build()
)
......@@ -40,15 +40,15 @@ class Squeeze(Module):
@oneflow_export("squeeze")
@register_tensor_op("squeeze")
@experimental_api
def squeeze_op(input, axis: Optional[Sequence[int]] = None):
def squeeze_op(input, dim: Optional[Sequence[int]] = None):
"""This operator removes the specified dimention which size is 1 of the input Tensor.
If the `axis` is not specified, this operator will remove all the dimention which size is 1 of the input Tensor.
If the `dim` is not specified, this operator will remove all the dimention which size is 1 of the input Tensor.
The amount of element in return value is the same as Tensor `input`.
Args:
input (oneflow.Tensor): The input Tensor.
axis (Optional[Sequence[int]], optional): The axis. Defaults to None.
dim (Optional[Sequence[int]]): The dim. Defaults to None.
Returns:
oneflow.Tensor: The result Tensor.
......@@ -63,9 +63,11 @@ def squeeze_op(input, axis: Optional[Sequence[int]] = None):
import numpy as np
input = flow.Tensor(np.array([[[[1, 1, 1]]]]).astype(np.int32))
out = flow.squeeze(input, axis=[1, 2]).numpy().shape
out = flow.squeeze(input, dim=[1, 2]).numpy().shape
# out.shape (1, 3)
"""
return Squeeze(axis=axis)(input)
if type(dim) == int:
dim = [dim]
return Squeeze(dim=dim)(input)
......@@ -24,18 +24,24 @@ import oneflow.experimental as flow
".numpy() doesn't work in lazy mode",
)
class TestSqueeze(flow.unittest.TestCase):
def test_squeeze_v1(test_case):
def test_squeeze(test_case):
input = flow.Tensor(np.array([[[[1, 1, 1]]]]).astype(np.int32))
of_shape = flow.squeeze(input, axis=[1, 2]).numpy().shape
of_shape = flow.squeeze(input, dim=[1, 2]).numpy().shape
np_shape = (1, 3)
test_case.assertTrue(np.array_equal(of_shape, np_shape))
def test_tensor_squeeze(test_case):
input = flow.Tensor(np.array([[[[1, 1, 1]]]]).astype(np.int32))
of_shape = input.squeeze(axis=[1, 2]).numpy().shape
of_shape = input.squeeze(dim=[1, 2]).numpy().shape
np_shape = (1, 3)
test_case.assertTrue(np.array_equal(of_shape, np_shape))
def test_squeeze_int(test_case):
input = flow.Tensor(np.array([[[[1, 1, 1]]]]).astype(np.int32))
of_shape = flow.squeeze(input, 1).numpy().shape
np_shape = (1, 1, 3)
test_case.assertTrue(np.array_equal(of_shape, np_shape))
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment