Skip to content
Snippets Groups Projects
Unverified Commit a53275f1 authored by Luyang's avatar Luyang Committed by GitHub
Browse files

Fix maxpool1d params (#5493)


* fix param

* fix maxpool1d params error

* refine

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 6bc0596b
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ from typing import Optional
import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.nn.modules.utils import _single, _pair, _triple
from oneflow.python.nn.modules.utils import _getint, _single, _pair, _triple
from oneflow.python.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from oneflow.python.ops.nn_ops import calc_pool_padding, get_dhw_offset, _GetSequence
......@@ -309,22 +309,17 @@ class MaxPool1d(Module):
ceil_mode: bool = False,
):
super().__init__()
self.kernel_size = _pair(tuple(kernel_size)[0])
self.stride = (
_pair(tuple(stride)[0]) if (stride is not None) else _pair(kernel_size)
)
self.kernel_size = _getint(kernel_size)
self.stride = _getint(stride) if stride is not None else self.kernel_size
data_format = "NCL" # Only suport "NCL" for now!
self.channel_pos = "channels_first" if data_format == "NCL" else "channels_last"
self.dilation = _GetSequence(dilation, 2, "dilation")
padding = _pair(tuple(padding)[0])
self.dilation = _getint(dilation)
self.padding = _getint(padding)
self.return_indices = return_indices
self.ceil_mode = ceil_mode
if len(padding) == 2:
if self.channel_pos == "channels_first":
padding = (0, 0, padding[0], padding[1])
else:
raise ValueError("error padding param!")
if self.channel_pos == "channels_first":
padding = (0, 0, self.padding, 0)
else:
raise ValueError("error padding param!")
......@@ -336,15 +331,16 @@ class MaxPool1d(Module):
def forward(self, x):
expand_x = x.unsqueeze(dim=-1)
expand_y, expand_indice = flow.F.maxpool_2d(
expand_x,
data_format=self.channel_pos,
padding=self.padding_type,
padding_before=self.padding_before,
padding_after=self.padding_after,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
kernel_size=[self.kernel_size, 1],
stride=[self.stride, 1],
dilation=[self.dilation, 1],
return_indices=True,
ceil_mode=self.ceil_mode,
)
......
......@@ -28,6 +28,16 @@ def _ntuple(n):
return parse
def _getint():
def parse(x):
if isinstance(x, container_abcs.Iterable):
return int(x[0])
return int(x)
return parse
_getint = _getint()
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
......
......@@ -252,6 +252,18 @@ def _test_maxpool1d_impl(test_case, device):
test_case.assertTrue(np.allclose(of_output.numpy(), output, 1e-4, 1e-4))
def _test_maxpool1d_zero_padding(test_case, device):
arr = np.arange(1000).reshape(4, 5, 50).astype(np.float)
input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device))
m1 = flow.nn.MaxPool1d(kernel_size=3, stride=3, padding=0)
of_out = m1(input)
m2 = MaxPoolNumpy(2, kernel_size=(3, 1), stride=(3, 1), padding=(0, 0))
np_out = m2(arr.reshape(4, 5, 50, 1))
np_out = np.squeeze(np_out, axis=3)
test_case.assertTrue(np.allclose(np_out, of_out.numpy(), 1e-4, 1e-4))
def _test_maxpool2d(test_case, device):
dim = 2
......@@ -607,9 +619,7 @@ def _test_maxpool3d_negative_input_backward(test_case, device):
class TestPooling(flow.unittest.TestCase):
def test_maxpool1d(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_maxpool1d_impl,
]
arg_dict["test_fun"] = [_test_maxpool1d_impl, _test_maxpool1d_zero_padding]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment