From a53275f18c2b560552befed7f23e72b52413f08e Mon Sep 17 00:00:00 2001 From: Luyang <flowingsun007@163.com> Date: Sat, 17 Jul 2021 11:37:20 +0800 Subject: [PATCH] Fix maxpool1d params (#5493) * fix param * fix maxpool1d params error * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/python/nn/modules/pooling.py | 26 +++++++++------------ oneflow/python/nn/modules/utils.py | 10 ++++++++ oneflow/python/test/modules/test_pooling.py | 16 ++++++++++--- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/oneflow/python/nn/modules/pooling.py b/oneflow/python/nn/modules/pooling.py index 8a261e885..d34a71053 100644 --- a/oneflow/python/nn/modules/pooling.py +++ b/oneflow/python/nn/modules/pooling.py @@ -18,7 +18,7 @@ from typing import Optional import oneflow as flow from oneflow.python.oneflow_export import oneflow_export, experimental_api from oneflow.python.nn.module import Module -from oneflow.python.nn.modules.utils import _single, _pair, _triple +from oneflow.python.nn.modules.utils import _getint, _single, _pair, _triple from oneflow.python.nn.common_types import _size_1_t, _size_2_t, _size_3_t from oneflow.python.ops.nn_ops import calc_pool_padding, get_dhw_offset, _GetSequence @@ -309,22 +309,17 @@ class MaxPool1d(Module): ceil_mode: bool = False, ): super().__init__() - self.kernel_size = _pair(tuple(kernel_size)[0]) - self.stride = ( - _pair(tuple(stride)[0]) if (stride is not None) else _pair(kernel_size) - ) + self.kernel_size = _getint(kernel_size) + self.stride = _getint(stride) if stride is not None else self.kernel_size data_format = "NCL" # Only suport "NCL" for now! self.channel_pos = "channels_first" if data_format == "NCL" else "channels_last" - self.dilation = _GetSequence(dilation, 2, "dilation") - padding = _pair(tuple(padding)[0]) + self.dilation = _getint(dilation) + self.padding = _getint(padding) self.return_indices = return_indices self.ceil_mode = ceil_mode - if len(padding) == 2: - if self.channel_pos == "channels_first": - padding = (0, 0, padding[0], padding[1]) - else: - raise ValueError("error padding param!") + if self.channel_pos == "channels_first": + padding = (0, 0, self.padding, 0) else: raise ValueError("error padding param!") @@ -336,15 +331,16 @@ class MaxPool1d(Module): def forward(self, x): expand_x = x.unsqueeze(dim=-1) + expand_y, expand_indice = flow.F.maxpool_2d( expand_x, data_format=self.channel_pos, padding=self.padding_type, padding_before=self.padding_before, padding_after=self.padding_after, - kernel_size=self.kernel_size, - stride=self.stride, - dilation=self.dilation, + kernel_size=[self.kernel_size, 1], + stride=[self.stride, 1], + dilation=[self.dilation, 1], return_indices=True, ceil_mode=self.ceil_mode, ) diff --git a/oneflow/python/nn/modules/utils.py b/oneflow/python/nn/modules/utils.py index e087c4c2b..2ec139bac 100644 --- a/oneflow/python/nn/modules/utils.py +++ b/oneflow/python/nn/modules/utils.py @@ -28,6 +28,16 @@ def _ntuple(n): return parse +def _getint(): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return int(x[0]) + return int(x) + + return parse + + +_getint = _getint() _single = _ntuple(1) _pair = _ntuple(2) _triple = _ntuple(3) diff --git a/oneflow/python/test/modules/test_pooling.py b/oneflow/python/test/modules/test_pooling.py index 5c5c9f29b..15d798fb7 100644 --- a/oneflow/python/test/modules/test_pooling.py +++ b/oneflow/python/test/modules/test_pooling.py @@ -252,6 +252,18 @@ def _test_maxpool1d_impl(test_case, device): test_case.assertTrue(np.allclose(of_output.numpy(), output, 1e-4, 1e-4)) +def _test_maxpool1d_zero_padding(test_case, device): + arr = np.arange(1000).reshape(4, 5, 50).astype(np.float) + input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device)) + m1 = flow.nn.MaxPool1d(kernel_size=3, stride=3, padding=0) + of_out = m1(input) + + m2 = MaxPoolNumpy(2, kernel_size=(3, 1), stride=(3, 1), padding=(0, 0)) + np_out = m2(arr.reshape(4, 5, 50, 1)) + np_out = np.squeeze(np_out, axis=3) + test_case.assertTrue(np.allclose(np_out, of_out.numpy(), 1e-4, 1e-4)) + + def _test_maxpool2d(test_case, device): dim = 2 @@ -607,9 +619,7 @@ def _test_maxpool3d_negative_input_backward(test_case, device): class TestPooling(flow.unittest.TestCase): def test_maxpool1d(test_case): arg_dict = OrderedDict() - arg_dict["test_fun"] = [ - _test_maxpool1d_impl, - ] + arg_dict["test_fun"] = [_test_maxpool1d_impl, _test_maxpool1d_zero_padding] arg_dict["device"] = ["cpu", "cuda"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) -- GitLab