From a53275f18c2b560552befed7f23e72b52413f08e Mon Sep 17 00:00:00 2001
From: Luyang <flowingsun007@163.com>
Date: Sat, 17 Jul 2021 11:37:20 +0800
Subject: [PATCH] Fix maxpool1d params (#5493)

* fix param

* fix maxpool1d params error

* refine

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/python/nn/modules/pooling.py        | 26 +++++++++------------
 oneflow/python/nn/modules/utils.py          | 10 ++++++++
 oneflow/python/test/modules/test_pooling.py | 16 ++++++++++---
 3 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/oneflow/python/nn/modules/pooling.py b/oneflow/python/nn/modules/pooling.py
index 8a261e885..d34a71053 100644
--- a/oneflow/python/nn/modules/pooling.py
+++ b/oneflow/python/nn/modules/pooling.py
@@ -18,7 +18,7 @@ from typing import Optional
 import oneflow as flow
 from oneflow.python.oneflow_export import oneflow_export, experimental_api
 from oneflow.python.nn.module import Module
-from oneflow.python.nn.modules.utils import _single, _pair, _triple
+from oneflow.python.nn.modules.utils import _getint, _single, _pair, _triple
 from oneflow.python.nn.common_types import _size_1_t, _size_2_t, _size_3_t
 from oneflow.python.ops.nn_ops import calc_pool_padding, get_dhw_offset, _GetSequence
 
@@ -309,22 +309,17 @@ class MaxPool1d(Module):
         ceil_mode: bool = False,
     ):
         super().__init__()
-        self.kernel_size = _pair(tuple(kernel_size)[0])
-        self.stride = (
-            _pair(tuple(stride)[0]) if (stride is not None) else _pair(kernel_size)
-        )
+        self.kernel_size = _getint(kernel_size)
+        self.stride = _getint(stride) if stride is not None else self.kernel_size
         data_format = "NCL"  # Only suport "NCL" for now!
         self.channel_pos = "channels_first" if data_format == "NCL" else "channels_last"
-        self.dilation = _GetSequence(dilation, 2, "dilation")
-        padding = _pair(tuple(padding)[0])
+        self.dilation = _getint(dilation)
+        self.padding = _getint(padding)
         self.return_indices = return_indices
         self.ceil_mode = ceil_mode
 
-        if len(padding) == 2:
-            if self.channel_pos == "channels_first":
-                padding = (0, 0, padding[0], padding[1])
-            else:
-                raise ValueError("error padding param!")
+        if self.channel_pos == "channels_first":
+            padding = (0, 0, self.padding, 0)
         else:
             raise ValueError("error padding param!")
 
@@ -336,15 +331,16 @@ class MaxPool1d(Module):
 
     def forward(self, x):
         expand_x = x.unsqueeze(dim=-1)
+
         expand_y, expand_indice = flow.F.maxpool_2d(
             expand_x,
             data_format=self.channel_pos,
             padding=self.padding_type,
             padding_before=self.padding_before,
             padding_after=self.padding_after,
-            kernel_size=self.kernel_size,
-            stride=self.stride,
-            dilation=self.dilation,
+            kernel_size=[self.kernel_size, 1],
+            stride=[self.stride, 1],
+            dilation=[self.dilation, 1],
             return_indices=True,
             ceil_mode=self.ceil_mode,
         )
diff --git a/oneflow/python/nn/modules/utils.py b/oneflow/python/nn/modules/utils.py
index e087c4c2b..2ec139bac 100644
--- a/oneflow/python/nn/modules/utils.py
+++ b/oneflow/python/nn/modules/utils.py
@@ -28,6 +28,16 @@ def _ntuple(n):
     return parse
 
 
+def _getint():
+    def parse(x):
+        if isinstance(x, container_abcs.Iterable):
+            return int(x[0])
+        return int(x)
+
+    return parse
+
+
+_getint = _getint()
 _single = _ntuple(1)
 _pair = _ntuple(2)
 _triple = _ntuple(3)
diff --git a/oneflow/python/test/modules/test_pooling.py b/oneflow/python/test/modules/test_pooling.py
index 5c5c9f29b..15d798fb7 100644
--- a/oneflow/python/test/modules/test_pooling.py
+++ b/oneflow/python/test/modules/test_pooling.py
@@ -252,6 +252,18 @@ def _test_maxpool1d_impl(test_case, device):
     test_case.assertTrue(np.allclose(of_output.numpy(), output, 1e-4, 1e-4))
 
 
+def _test_maxpool1d_zero_padding(test_case, device):
+    arr = np.arange(1000).reshape(4, 5, 50).astype(np.float)
+    input = flow.tensor(arr, dtype=flow.float32, device=flow.device(device))
+    m1 = flow.nn.MaxPool1d(kernel_size=3, stride=3, padding=0)
+    of_out = m1(input)
+
+    m2 = MaxPoolNumpy(2, kernel_size=(3, 1), stride=(3, 1), padding=(0, 0))
+    np_out = m2(arr.reshape(4, 5, 50, 1))
+    np_out = np.squeeze(np_out, axis=3)
+    test_case.assertTrue(np.allclose(np_out, of_out.numpy(), 1e-4, 1e-4))
+
+
 def _test_maxpool2d(test_case, device):
     dim = 2
 
@@ -607,9 +619,7 @@ def _test_maxpool3d_negative_input_backward(test_case, device):
 class TestPooling(flow.unittest.TestCase):
     def test_maxpool1d(test_case):
         arg_dict = OrderedDict()
-        arg_dict["test_fun"] = [
-            _test_maxpool1d_impl,
-        ]
+        arg_dict["test_fun"] = [_test_maxpool1d_impl, _test_maxpool1d_zero_padding]
         arg_dict["device"] = ["cpu", "cuda"]
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
-- 
GitLab