diff --git a/oneflow/core/autograd/gradient_funcs/pooling.cpp b/oneflow/core/autograd/gradient_funcs/pooling.cpp
index a5d6c8c3737e221741859de4e486b2463bb8a57a..488aaa9c77339bfe3cf079b6a836da6216bf8f5c 100644
--- a/oneflow/core/autograd/gradient_funcs/pooling.cpp
+++ b/oneflow/core/autograd/gradient_funcs/pooling.cpp
@@ -33,9 +33,7 @@ struct PoolingInterpState : public OpExprInterpState {
   size_t indice_index;
 
   std::string data_format;
-  std::string padding;
-  std::vector<int32_t> padding_before;
-  std::vector<int32_t> padding_after;
+  std::vector<int32_t> padding;
   std::vector<int32_t> kernel_size;
   std::vector<int32_t> stride;
   std::vector<int32_t> dilation;
@@ -76,9 +74,7 @@ Maybe<void> PoolingNdGrad::Capture(PoolingInterpState* ctx, const TensorTuple& i
 
   ComposedAttrMap composed_attrs(attrs, base_attrs_);
   ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
-  ctx->padding = JUST(composed_attrs.GetAttr<std::string>("padding"));
-  ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
-  ctx->padding_after = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_after"));
+  ctx->padding = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding"));
   ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
   ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride"));
   ctx->dilation = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation"));
@@ -100,8 +96,7 @@ Maybe<void> PoolingNdGrad::Apply(const PoolingInterpState* ctx, const TensorTupl
   in_grads->resize(1);
   in_grads->at(0) = JUST(functional::PoolingNdGrad(
       input, output, indice, out_grads.at(0), mode_, ndims, ctx->data_format, ctx->padding,
-      ctx->padding_before, ctx->padding_after, ctx->kernel_size, ctx->stride, ctx->dilation,
-      ctx->return_indices, ctx->ceil_mode));
+      ctx->kernel_size, ctx->stride, ctx->dilation, ctx->return_indices, ctx->ceil_mode));
 
   return Maybe<void>::Ok();
 }
@@ -113,6 +108,7 @@ class MaxpoolNdGrad final : public PoolingNdGrad {
   Maybe<void> Init(const OpExpr& op) override { return PoolingNdGrad::Init(op, "max"); }
 };
 
+REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_1d", MaxpoolNdGrad);
 REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_2d", MaxpoolNdGrad);
 REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_3d", MaxpoolNdGrad);
 
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index 5c1a6f5ab320936f275bb75426503ca7577e8521..f24b6a5dc654af47e9f6a1371318b63c17c75433 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -411,26 +411,31 @@
                        Int32List strides, Bool ceil_mode)"
   bind_python: False
 
+- name: "maxpool_1d"
+  signature:
+    "TensorTuple Maxpool1D(Tensor x, *, String data_format=\"channels_first\", Int32List padding,
+                      Int32List kernel_size, Int32List stride, Int32List dilation, 
+                      Bool return_indices=True, Bool ceil_mode=False)"
+  bind_python: True
+
 - name: "maxpool_2d"
   signature:
-    "TensorTuple Maxpool2D(Tensor x, *, String data_format=\"channels_first\", String padding,
-                      Int32List padding_before, Int32List padding_after,
+    "TensorTuple Maxpool2D(Tensor x, *, String data_format=\"channels_first\", Int32List padding,
                       Int32List kernel_size, Int32List stride, Int32List dilation, 
                       Bool return_indices=True, Bool ceil_mode=False)"
   bind_python: True
 
 - name: "maxpool_3d"
   signature:
-    "TensorTuple Maxpool3D(Tensor x, *, String data_format=\"channels_first\", String padding,
-                      Int32List padding_before, Int32List padding_after,
+    "TensorTuple Maxpool3D(Tensor x, *, String data_format=\"channels_first\", Int32List padding,
                       Int32List kernel_size, Int32List stride, Int32List dilation, 
                       Bool return_indices=True, Bool ceil_mode=False)"
   bind_python: True
 
 - name: "pooling_grad"
   signature:
-    "Tensor PoolingNdGrad(Tensor x, Tensor y, Tensor indice, Tensor dy, *, String mode, Int32 ndims, String data_format,
-                       String padding, Int32List padding_before, Int32List padding_after, Int32List kernel_size,
+    "Tensor PoolingNdGrad(Tensor x, Tensor y, Tensor indice, Tensor dy, *, String mode, Int32 ndims, 
+                       String data_format, Int32List padding, Int32List kernel_size,
                        Int32List stride, Int32List dilation, Bool return_indices, Bool ceil_mode)"
   bind_python: False
 
diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp
index 535b301820954039b0ec083863780bd904cb3910..54267dfdb72d24b7de484bf3e4bdec95b2edb94c 100644
--- a/oneflow/core/functional/impl/nn_functor.cpp
+++ b/oneflow/core/functional/impl/nn_functor.cpp
@@ -232,18 +232,14 @@ class PoolingNDFunctor {
   PoolingNDFunctor() = default;
   virtual ~PoolingNDFunctor() = default;
   Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,
-                                const std::string& data_format, const std::string& padding,
-                                const std::vector<int32_t>& padding_before,
-                                const std::vector<int32_t>& padding_after,
+                                const std::string& data_format, const std::vector<int32_t>& padding,
                                 const std::vector<int32_t>& kernel_size,
                                 const std::vector<int32_t>& stride,
                                 const std::vector<int32_t>& dilation, const bool& return_indices,
                                 const bool& ceil_mode) const {
     MutableAttrMap attrs;
-    JUST(attrs.SetAttr<std::string>("padding", padding));
-    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before));
-    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after));
     JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("padding", padding));
     JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
     JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride));
     JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation));
@@ -270,6 +266,13 @@ class MaxPool2DFunctor : public PoolNDFunctor {
   }
 };
 
+class Maxpool1DFunctor : public PoolingNDFunctor {
+ public:
+  Maxpool1DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("maxpool_1d").Input("x").Output("y").Output("indice").Build());
+  }
+};
+
 class Maxpool2DFunctor : public PoolingNDFunctor {
  public:
   Maxpool2DFunctor() {
@@ -472,6 +475,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::LayerNormFunctor>("LayerNorm");
   m.add_functor<impl::LayerNormAffineFunctor>("LayerNormAffine");
   m.add_functor<impl::AvgPool2DFunctor>("AvgPool2D");
+  m.add_functor<impl::Maxpool1DFunctor>("Maxpool1D");
   m.add_functor<impl::Maxpool2DFunctor>("Maxpool2D");
   m.add_functor<impl::Maxpool3DFunctor>("Maxpool3D");
   m.add_functor<impl::MaxPool2DFunctor>("MaxPool2D");
diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp
index fa6110e84bd73abbee04a3f32af5e460be7a1c36..5ab9a9716d9a10323adfba0d9ab4d4b949d09cb9 100644
--- a/oneflow/core/functional/impl/nn_grad_functor.cpp
+++ b/oneflow/core/functional/impl/nn_grad_functor.cpp
@@ -113,7 +113,7 @@ class PoolingNdGradFunctor {
  public:
   PoolingNdGradFunctor() {
     for (const auto& mode : {"max"}) {
-      for (int ndims = 2; ndims <= 3; ++ndims) {
+      for (int ndims = 1; ndims <= 3; ++ndims) {
         const auto& op_type_name = GetOpTypeName(mode, ndims);
         op_expr_map_[op_type_name] = CHECK_JUST(one::OpBuilder(op_type_name)
                                                     .Input("x")
@@ -133,16 +133,13 @@ class PoolingNdGradFunctor {
                            const std::shared_ptr<one::Tensor>& indice,
                            const std::shared_ptr<one::Tensor>& dy, const std::string& mode,
                            const int32_t& ndims, const std::string& data_format,
-                           const std::string& padding, const std::vector<int32_t>& padding_before,
-                           const std::vector<int32_t>& padding_after,
+                           const std::vector<int32_t>& padding,
                            const std::vector<int32_t>& kernel_size,
                            const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation,
                            const bool& return_indices, const bool& ceil_mode) const {
     MutableAttrMap attrs;
-    JUST(attrs.SetAttr<std::string>("padding", padding));
-    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before));
-    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after));
     JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("padding", padding));
     JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
     JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride));
     JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation));
diff --git a/oneflow/python/nn/modules/pooling.py b/oneflow/python/nn/modules/pooling.py
index 57de40a2cbb4fc10f614eb51e3b8e512a5801695..b183d196283efdc530cd926e3db53650ed231635 100644
--- a/oneflow/python/nn/modules/pooling.py
+++ b/oneflow/python/nn/modules/pooling.py
@@ -316,44 +316,26 @@ class MaxPool1d(Module):
         ceil_mode: bool = False,
     ):
         super().__init__()
-        self.kernel_size = _getint(kernel_size)
-        self.stride = _getint(stride) if stride is not None else self.kernel_size
-        data_format = "NCL"  # Only suport "NCL" for now!
+        self.kernel_size = _single(kernel_size)
+        self.stride = _single(stride) if stride is not None else self.kernel_size
+        data_format = "NCL"  # only support "NCL" for now !
         self.channel_pos = "channels_first" if data_format == "NCL" else "channels_last"
-        self.dilation = _getint(dilation)
-        self.padding = _getint(padding)
+        self.dilation = _single(dilation)
+        self.padding = _single(padding)
         self.return_indices = return_indices
         self.ceil_mode = ceil_mode
 
-        if self.channel_pos == "channels_first":
-            padding = (0, 0, self.padding, 0)
-        else:
-            raise ValueError("error padding param!")
-
-        self.padding_type, pads_list = calc_pool_padding(
-            padding, get_dhw_offset(self.channel_pos), 2
-        )
-        self.padding_before = [pad[0] for pad in pads_list]
-        self.padding_after = [pad[1] for pad in pads_list]
-
     def forward(self, x):
-        expand_x = x.unsqueeze(dim=-1)
-
-        expand_y, expand_indice = flow.F.maxpool_2d(
-            expand_x,
+        y, indice = flow.F.maxpool_1d(
+            x,
             data_format=self.channel_pos,
-            padding=self.padding_type,
-            padding_before=self.padding_before,
-            padding_after=self.padding_after,
-            kernel_size=[self.kernel_size, 1],
-            stride=[self.stride, 1],
-            dilation=[self.dilation, 1],
+            padding=self.padding,
+            kernel_size=self.kernel_size,
+            stride=self.stride,
+            dilation=self.dilation,
             return_indices=True,
             ceil_mode=self.ceil_mode,
         )
-
-        y = expand_y.squeeze(dim=-1)
-        indice = expand_indice.squeeze(dim=-1)
         if self.return_indices:
             return y, indice
         else:
@@ -454,45 +436,27 @@ class MaxPool2d(Module):
     ):
         super().__init__()
         self.kernel_size = _pair(kernel_size)
-        self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size)
-        data_format = "NCHW"  # Only suport "NCHW" for now!
+        data_format = "NCHW"  # only support "NCHW" for now !
         self.channel_pos = (
             "channels_first" if data_format == "NCHW" else "channels_last"
         )
+        self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size)
         self.dilation = _GetSequence(dilation, 2, "dilation")
         self.return_indices = return_indices
         self.ceil_mode = ceil_mode
-
-        padding = _pair(padding)
-        self.padding = padding
-        if len(padding) == 2:
-            if data_format == "NCHW":
-                padding = (0, 0, padding[0], padding[1])
-            else:
-                raise ValueError("error padding param!")
-        else:
-            raise ValueError("error padding param!")
-
-        self.padding_type, pads_list = calc_pool_padding(
-            padding, get_dhw_offset(self.channel_pos), 2
-        )
-        self.padding_before = [pad[0] for pad in pads_list]
-        self.padding_after = [pad[1] for pad in pads_list]
+        self.padding = _pair(padding)
 
     def forward(self, x):
         y, indice = flow.F.maxpool_2d(
             x,
             data_format=self.channel_pos,
-            padding=self.padding_type,
-            padding_before=self.padding_before,
-            padding_after=self.padding_after,
+            padding=self.padding,
             kernel_size=self.kernel_size,
             stride=self.stride,
             dilation=self.dilation,
             return_indices=True,
             ceil_mode=self.ceil_mode,
         )
-
         if self.return_indices:
             return y, indice
         else:
@@ -605,32 +569,15 @@ class MaxPool3d(Module):
             "channels_last" if data_format == "NDHWC" else "channels_first"
         )
         self.dilation = _GetSequence(dilation, 3, "dilation")
-        padding = _triple(padding)
-        self.padding = padding
+        self.padding = _triple(padding)
         self.return_indices = return_indices
         self.ceil_mode = ceil_mode
 
-        if len(padding) == 3:
-            if data_format == "NCDHW":
-                padding = (0, 0, padding[0], padding[1], padding[2])
-            else:
-                raise ValueError("error padding param!")
-        else:
-            raise ValueError("error padding param!")
-
-        self.padding_type, pads_list = calc_pool_padding(
-            padding, get_dhw_offset(self.channel_pos), 3
-        )
-        self.padding_before = [pad[0] for pad in pads_list]
-        self.padding_after = [pad[1] for pad in pads_list]
-
     def forward(self, x):
         y, indice = flow.F.maxpool_3d(
             x,
             data_format=self.channel_pos,
-            padding=self.padding_type,
-            padding_before=self.padding_before,
-            padding_after=self.padding_after,
+            padding=self.padding,
             kernel_size=self.kernel_size,
             stride=self.stride,
             dilation=self.dilation,
diff --git a/oneflow/user/kernels/pooling_kernel.cpp b/oneflow/user/kernels/pooling_kernel.cpp
index 39213ea3b42db8451331349f978f58933fff8eff..2f678241f327c4825861116834c2c17b88312fc4 100644
--- a/oneflow/user/kernels/pooling_kernel.cpp
+++ b/oneflow/user/kernels/pooling_kernel.cpp
@@ -26,31 +26,46 @@ struct PoolingOpKernelState final : public user_op::OpKernelState {
 std::shared_ptr<PoolingOpKernelState> DoCreateOpKernelState(user_op::KernelComputeContext* ctx,
                                                             const int32_t& dim) {
   const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape();
-  const std::string& padding = ctx->Attr<std::string>("padding");
   const std::string& data_format = ctx->Attr<std::string>("data_format");
-  const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
-  const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
+  const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding");
   const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
   const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride");
   const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation");
   const bool return_indices = ctx->Attr<bool>("return_indices");
   const bool ceil_mode = ctx->Attr<bool>("ceil_mode");
 
-  PoolingParams3D params_3d =
-      PoolingParams3D(dim, x_shape, data_format, padding, padding_before, padding_after,
-                      kernel_size, stride, dilation, return_indices, ceil_mode);
+  PoolingParams3D params_3d = PoolingParams3D(dim, x_shape, data_format, padding, kernel_size,
+                                              stride, dilation, return_indices, ceil_mode);
   std::shared_ptr<PoolingOpKernelState> state(new PoolingOpKernelState(params_3d));
   return std::move(state);
 }
 
 template<typename T>
 struct PoolingKernelUtil<DeviceType::kCPU, T> {
+  static void Maxpool1dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
+                               const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                               const PoolingParams3D& params_3d) {
+    Maxpool1dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr,
+                               params_3d.padding()[2], params_3d.num_batch(),
+                               params_3d.num_channel(), params_3d.GetXShape5D().At(4),
+                               params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[2],
+                               params_3d.stride_3d()[2], params_3d.dilation_3d()[2]);
+  }
+
+  static void Maxpool1dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
+                                const int64_t elem_num, const T* src, T* dest,
+                                const int64_t* indice_ptr, const PoolingParams3D& params_3d) {
+    Maxpool1dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr,
+                                params_3d.num_batch(), params_3d.num_channel(),
+                                params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4));
+  }
+
   static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
                                const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
                                const PoolingParams3D& params_3d) {
-    Maxpool2dFarwardCompute<T>(
-        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1],
-        params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(),
+    Maxpool2dForwardCompute<T>(
+        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],
+        params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),
         params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3),
         params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[1],
         params_3d.pooling_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2],
@@ -69,9 +84,9 @@ struct PoolingKernelUtil<DeviceType::kCPU, T> {
   static void Maxpool3dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper,
                                const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
                                const PoolingParams3D& params_3d) {
-    Maxpool3dFarwardCompute<T>(
-        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0],
-        params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2], params_3d.num_batch(),
+    Maxpool3dForwardCompute<T>(
+        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0],
+        params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(),
         params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3),
         params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),
         params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0],
@@ -91,6 +106,68 @@ struct PoolingKernelUtil<DeviceType::kCPU, T> {
   }
 };
 
+template<DeviceType device_type, typename T>
+class MaxPool1dKernel final : public user_op::OpKernel {
+ public:
+  MaxPool1dKernel() = default;
+  ~MaxPool1dKernel() = default;
+
+ private:
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
+    user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0);
+
+    const auto& pooling_state = DoCreateOpKernelState(ctx, 1);
+    const PoolingParams3D& params_3d = pooling_state->GetParams3D();
+
+    const int64_t elem_num = y->shape().elem_cnt();
+    const T* src = x->dptr<T>();
+    T* dest = y->mut_dptr<T>();
+    int64_t* indice_ptr = indice->mut_dptr<int64_t>();
+
+    DimVector y_vector;
+    y->shape().ToDimVector(&y_vector);
+    NdIndexOffsetHelper<int64_t, 3> index_helper(y_vector.data());
+
+    PoolingKernelUtil<device_type, T>::Maxpool1dForward(ctx->device_ctx(), index_helper, elem_num,
+                                                        src, dest, indice_ptr, params_3d);
+  };
+};
+
+template<DeviceType device_type, typename T>
+class MaxPool1dGradKernel final : public user_op::OpKernel {
+ public:
+  MaxPool1dGradKernel() = default;
+  ~MaxPool1dGradKernel() = default;
+
+ private:
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0);
+    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    const auto& pooling_state = DoCreateOpKernelState(ctx, 1);
+    const PoolingParams3D& params_3d = pooling_state->GetParams3D();
+
+    const int64_t elem_num = dy->shape().elem_cnt();
+    const T* src = dy->dptr<T>();
+    const int64_t* indice_ptr = indice->dptr<int64_t>();
+    T* dest = dx->mut_dptr<T>();
+    DimVector dy_vector;
+    dy->shape().ToDimVector(&dy_vector);
+    NdIndexOffsetHelper<int64_t, 3> index_helper(dy_vector.data());
+
+    size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type());
+    Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
+
+    PoolingKernelUtil<device_type, T>::Maxpool1dBackward(ctx->device_ctx(), index_helper, elem_num,
+                                                         src, dest, indice_ptr, params_3d);
+  };
+};
+
 template<DeviceType device_type, typename T>
 class MaxPool2dKernel final : public user_op::OpKernel {
  public:
@@ -217,6 +294,14 @@ class MaxPool3dGradKernel final : public user_op::OpKernel {
 };
 
 #define REGISTER_POOLING_KERNELS(device, dtype)                                        \
+  REGISTER_USER_KERNEL("maxpool_1d")                                                   \
+      .SetCreateFn<MaxPool1dKernel<device, dtype>>()                                   \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("maxpool_1d_grad")                                              \
+      .SetCreateFn<MaxPool1dGradKernel<device, dtype>>()                               \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \
   REGISTER_USER_KERNEL("maxpool_2d")                                                   \
       .SetCreateFn<MaxPool2dKernel<device, dtype>>()                                   \
       .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
diff --git a/oneflow/user/kernels/pooling_kernel.cu b/oneflow/user/kernels/pooling_kernel.cu
index 7aef65319fe28441f044e5a1e819ca7779bcc295..5f6b4d42eea293bb54cccd5ee4272b306594a564 100644
--- a/oneflow/user/kernels/pooling_kernel.cu
+++ b/oneflow/user/kernels/pooling_kernel.cu
@@ -30,6 +30,17 @@ int GetNumBlocks(int64_t elem_cnt) {
   return num_blocks;
 }
 
+template<typename T>
+__launch_bounds__(kBlockSize) __global__
+    void DoCUDAMaxPool1dForward(const NdIndexOffsetHelper<int64_t, 3> index_helper,
+                                int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                                int32_t padding_l, int64_t n_batch, int64_t n_channel,
+                                int64_t x_length, int64_t y_length, int32_t kernel_size_l,
+                                int32_t stride_l, int32_t dilation_l) {
+  Maxpool1dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_l, n_batch,
+                             n_channel, x_length, y_length, kernel_size_l, stride_l, dilation_l);
+};
+
 template<typename T>
 __launch_bounds__(kBlockSize) __global__
     void DoCUDAMaxPool2dForward(const NdIndexOffsetHelper<int64_t, 4> index_helper,
@@ -39,7 +50,7 @@ __launch_bounds__(kBlockSize) __global__
                                 int64_t y_height, int64_t y_width, int32_t kernel_size_h,
                                 int32_t kernel_size_w, int32_t stride_h, int32_t stride_w,
                                 int32_t dilation_h, int32_t dilation_w) {
-  Maxpool2dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w,
+  Maxpool2dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w,
                              n_batch, n_channel, x_height, x_width, y_height, y_width,
                              kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h,
                              dilation_w);
@@ -56,12 +67,23 @@ __launch_bounds__(kBlockSize) __global__
                                 int32_t kernel_size_w, int32_t stride_t, int32_t stride_h,
                                 int32_t stride_w, int32_t dilation_t, int32_t dilation_h,
                                 int32_t dilation_w) {
-  Maxpool3dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_t, padding_h,
+  Maxpool3dForwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_t, padding_h,
                              padding_w, n_batch, n_channel, x_time, x_height, x_width, y_time,
                              y_height, y_width, kernel_size_t, kernel_size_h, kernel_size_w,
                              stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w);
 };
 
+template<typename T>
+__launch_bounds__(kBlockSize) __global__
+    void DoCUDAMaxPool1dBackward(const NdIndexOffsetHelper<int64_t, 3> index_helper,
+                                 const int64_t elem_num, const T* src, T* dest,
+                                 const int64_t* indice_ptr, const int64_t n_batch,
+                                 const int64_t n_channel, const int64_t src_length,
+                                 const int64_t dst_length) {
+  Maxpool1dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel,
+                              src_length, dst_length);
+};
+
 template<typename T>
 __launch_bounds__(kBlockSize) __global__
     void DoCUDAMaxPool2dBackward(const NdIndexOffsetHelper<int64_t, 4> index_helper,
@@ -89,13 +111,33 @@ __launch_bounds__(kBlockSize) __global__
 
 template<typename T>
 struct PoolingKernelUtil<DeviceType::kGPU, T> {
+  static void Maxpool1dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
+                               const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                               const PoolingParams3D& params_3d) {
+    DoCUDAMaxPool1dForward<T>
+        <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
+            index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[2],
+            params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(4),
+            params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[2], params_3d.stride_3d()[2],
+            params_3d.dilation_3d()[2]);
+  }
+
+  static void Maxpool1dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
+                                const int64_t elem_num, const T* src, T* dest,
+                                const int64_t* indice_ptr, const PoolingParams3D& params_3d) {
+    DoCUDAMaxPool1dBackward<T>
+        <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
+            index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),
+            params_3d.num_channel(), params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(4));
+  }
+
   static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
                                const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
                                const PoolingParams3D& params_3d) {
     DoCUDAMaxPool2dForward<T>
         <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
-            index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1],
-            params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(),
+            index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[1],
+            params_3d.padding()[2], params_3d.num_batch(), params_3d.num_channel(),
             params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),
             params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
             params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2],
@@ -118,15 +160,15 @@ struct PoolingKernelUtil<DeviceType::kGPU, T> {
                                const PoolingParams3D& params_3d) {
     DoCUDAMaxPool3dForward<T>
         <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
-            index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0],
-            params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2],
-            params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2),
-            params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),
-            params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),
-            params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0],
-            params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2],
-            params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2],
-            params_3d.dilation_3d()[0], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]);
+            index_helper, elem_num, src, dest, indice_ptr, params_3d.padding()[0],
+            params_3d.padding()[1], params_3d.padding()[2], params_3d.num_batch(),
+            params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3),
+            params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(2),
+            params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
+            params_3d.pooling_size_3d()[0], params_3d.pooling_size_3d()[1],
+            params_3d.pooling_size_3d()[2], params_3d.stride_3d()[0], params_3d.stride_3d()[1],
+            params_3d.stride_3d()[2], params_3d.dilation_3d()[0], params_3d.dilation_3d()[1],
+            params_3d.dilation_3d()[2]);
   }
 
   static void Maxpool3dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper,
diff --git a/oneflow/user/kernels/pooling_kernel_util.cpp b/oneflow/user/kernels/pooling_kernel_util.cpp
index c2f502ce2928c8fec6c60ccc5c288d1c074fb275..8da50e3b40abace6a31bc7350c59a919d34d0abb 100644
--- a/oneflow/user/kernels/pooling_kernel_util.cpp
+++ b/oneflow/user/kernels/pooling_kernel_util.cpp
@@ -43,19 +43,36 @@ std::vector<int32_t> Get3DPadVec(const std::vector<int32_t>& original_vec, int32
   return vec;
 }
 
+void GetWindowedOutputShape(int64_t input_size, int32_t filter_size, int32_t stride,
+                            int32_t padding, bool ceil_mode, int32_t dilation_rate,
+                            int64_t* output_ptr) {
+  *output_ptr = (input_size + 2 * padding - dilation_rate * (filter_size - 1) - 1 + stride
+                 + (ceil_mode ? stride - 1 : 0))
+                / stride;
+}
+
+void Get3DOutputShape(const DimVector& in, const std::vector<int32_t>& pool_size,
+                      const std::vector<int32_t>& strides, const std::vector<int32_t>& padding,
+                      const bool ceil_mode, std::vector<int32_t> dilation_rate, DimVector* out) {
+  out->clear();
+  out->resize(3);
+  FOR_RANGE(size_t, i, 0, 3) {
+    int64_t* out_ptr = &(*out).at(i);
+    GetWindowedOutputShape(in.at(i), pool_size.at(i), strides.at(i), padding.at(i), ceil_mode,
+                           dilation_rate.at(i), out_ptr);
+  }
+}
+
 PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape,
-                                 const std::string& data_format, const std::string& padding,
-                                 const std::vector<int32_t>& padding_before,
-                                 const std::vector<int32_t>& padding_after,
+                                 const std::string& data_format,
+                                 const std::vector<int32_t>& padding,
                                  const std::vector<int32_t>& kernel_size,
                                  const std::vector<int32_t>& stride,
                                  const std::vector<int32_t>& dilation, const bool return_indices,
                                  const bool ceil_mode)
     : dim_(dim),
       data_format_(data_format),
-      padding_(padding),
-      padding_before_3d_(Get3DPadVec(padding_before, dim)),
-      padding_after_3d_(Get3DPadVec(padding_after, dim)),
+      padding_(Get3DPadVec(padding, dim)),
       pooling_size_3d_(Get3DVec(kernel_size, dim)),
       stride_3d_(Get3DVec(stride, dim)),
       dilation_3d_(Get3DVec(dilation, dim)),
@@ -63,8 +80,7 @@ PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape,
       ceil_mode_(ceil_mode) {
   x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim),
            GetInDim(x_shape, data_format, 2, dim)};
-  Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_,
-                  &padding_before_3d_, &padding_after_3d_);
+  Get3DOutputShape(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_);
   if (data_format == "channels_first") {
     channel_num_ = x_shape.At(1);
   } else {
@@ -78,8 +94,7 @@ PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape,
 void PoolingParams3D::Reset(const ShapeView& x_shape) {
   x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_),
            GetInDim(x_shape, data_format_, 2, dim_)};
-  Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_,
-                  &padding_before_3d_, &padding_after_3d_);
+  Get3DOutputShape(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, dilation_3d_, &y_3d_);
 }
 
 Shape PoolingParams3D::GetYShape() const {
diff --git a/oneflow/user/kernels/pooling_kernel_util.h b/oneflow/user/kernels/pooling_kernel_util.h
index b25ba34e04a3ce8c5027948965c686c69e1134f9..0d126ea63fa066469151ed708604ecc29fc4abd1 100644
--- a/oneflow/user/kernels/pooling_kernel_util.h
+++ b/oneflow/user/kernels/pooling_kernel_util.h
@@ -53,16 +53,13 @@ struct DeviceAdd {
 class PoolingParams3D {
  public:
   PoolingParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format,
-                  const std::string& padding, const std::vector<int32_t>& padding_before,
-                  const std::vector<int32_t>& padding_after,
-                  const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& stride,
-                  const std::vector<int32_t>& dilation, const bool return_indices,
-                  const bool ceil_mode);
+                  const std::vector<int32_t>& padding, const std::vector<int32_t>& kernel_size,
+                  const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation,
+                  const bool return_indices, const bool ceil_mode);
   ~PoolingParams3D() = default;
 
   const std::string& data_format() const { return data_format_; }
-  const std::vector<int32_t>& padding_before_3d() const { return padding_before_3d_; }
-  const std::vector<int32_t>& padding_after_3d() const { return padding_after_3d_; }
+  const std::vector<int32_t>& padding() const { return padding_; }
   const std::vector<int32_t>& pooling_size_3d() const { return pooling_size_3d_; }
   const std::vector<int32_t>& stride_3d() const { return stride_3d_; }
   const std::vector<int32_t>& dilation_3d() const { return dilation_3d_; }
@@ -81,9 +78,7 @@ class PoolingParams3D {
   FixedDimVector x_3d_;
   FixedDimVector y_3d_;
   std::string data_format_;
-  std::string padding_;
-  std::vector<int32_t> padding_before_3d_;
-  std::vector<int32_t> padding_after_3d_;
+  std::vector<int32_t> padding_;
   std::vector<int32_t> pooling_size_3d_;
   std::vector<int32_t> stride_3d_;
   std::vector<int32_t> dilation_3d_;
@@ -95,6 +90,14 @@ class PoolingParams3D {
 
 template<DeviceType device_type, typename T>
 struct PoolingKernelUtil {
+  static void Maxpool1dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
+                               const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                               const PoolingParams3D& params_3d);
+
+  static void Maxpool1dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 3>& index_helper,
+                                const int64_t elem_num, const T* src, T* dest,
+                                const int64_t* indice_ptr, const PoolingParams3D& params_3d);
+
   static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
                                const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
                                const PoolingParams3D& params_3d);
@@ -113,7 +116,70 @@ struct PoolingKernelUtil {
 };
 
 template<typename T>
-OF_DEVICE_FUNC void Maxpool2dFarwardCompute(
+OF_DEVICE_FUNC void Maxpool1dForwardCompute(const NdIndexOffsetHelper<int64_t, 3> index_helper,
+                                            int64_t elem_num, const T* src, T* dest,
+                                            int64_t* indice_ptr, const int32_t padding_l,
+                                            const int64_t n_batch, const int64_t n_channel,
+                                            const int64_t x_length, const int64_t y_length,
+                                            const int32_t kernel_size_l, const int32_t stride_l,
+                                            const int32_t dilation_l) {
+  XPU_1D_KERNEL_LOOP(num, elem_num) {
+    int64_t n, c, l;
+    index_helper.OffsetToNdIndex(num, n, c, l);
+
+    // n, c, l->index = n*c*l + c* l
+    const int64_t start_idx = (n * n_channel + c) * x_length;
+    int64_t lstart = l * stride_l - padding_l;
+    const int64_t lend = (lstart + (kernel_size_l - 1) * dilation_l + 1) <= x_length
+                             ? (lstart + (kernel_size_l - 1) * dilation_l + 1)
+                             : x_length;
+
+    while (lstart < 0) { lstart += dilation_l; }
+
+    /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */
+    int64_t maxindex = lstart;
+    int64_t src_idx = 0;
+
+    /* equal to -std::numeric_limits<T>::infinity(); */
+    T max_value = detail::numeric_limits<T>::lower_bound();
+
+    for (int64_t idx = lstart; idx < lend; idx += dilation_l) {
+      const int64_t search_idx = start_idx + idx;
+      T val = src[search_idx];
+      if (val > max_value || detail::numerics<T>::isnan(val)) {
+        max_value = val;
+        maxindex = idx;
+        src_idx = search_idx;
+      }
+    }
+    dest[num] = src[src_idx];
+    indice_ptr[num] = maxindex;
+  }
+}
+
+template<typename T>
+OF_DEVICE_FUNC void Maxpool1dBackwardCompute(const NdIndexOffsetHelper<int64_t, 3> index_helper,
+                                             const int64_t elem_num, const T* src, T* dest,
+                                             const int64_t* indice_ptr, const int64_t n_batch,
+                                             const int64_t n_channel, const int64_t src_length,
+                                             const int64_t dst_length) {
+  XPU_1D_KERNEL_LOOP(num, elem_num) {
+    int64_t n, c, l;
+    index_helper.OffsetToNdIndex(num, n, c, l);
+
+    const int64_t src_start = (n * n_channel + c) * src_length;
+    const int64_t dst_start = (n * n_channel + c) * dst_length;
+    const int64_t index = src_start + l;
+    const int64_t maxindex = dst_start + indice_ptr[index];
+    if (maxindex != -1) {
+      /* update gradient, equals to dest[maxindex] += src[index]; */
+      DeviceAdd<T>::Invoke(src + index, dest + maxindex);
+    }
+  }
+}
+
+template<typename T>
+OF_DEVICE_FUNC void Maxpool2dForwardCompute(
     const NdIndexOffsetHelper<int64_t, 4> index_helper, int64_t elem_num, const T* src, T* dest,
     int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int64_t n_batch,
     const int64_t n_channel, const int64_t x_height, const int64_t x_width, const int64_t y_height,
@@ -194,7 +260,7 @@ OF_DEVICE_FUNC void Maxpool2dBackwardCompute(const NdIndexOffsetHelper<int64_t,
 }
 
 template<typename T>
-OF_DEVICE_FUNC void Maxpool3dFarwardCompute(
+OF_DEVICE_FUNC void Maxpool3dForwardCompute(
     const NdIndexOffsetHelper<int64_t, 5> index_helper, int64_t elem_num, const T* src, T* dest,
     int64_t* indice_ptr, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,
     const int64_t n_batch, const int64_t n_channel, const int64_t x_time, const int64_t x_height,
diff --git a/oneflow/user/ops/pooling_op.cpp b/oneflow/user/ops/pooling_op.cpp
index 03896b002c1537ae929d693ffa24ccc4592dabc4..adfbf70be83d39ccd4c50c9d7b7f5556499734ef 100644
--- a/oneflow/user/ops/pooling_op.cpp
+++ b/oneflow/user/ops/pooling_op.cpp
@@ -28,9 +28,7 @@ TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) {
   return [dim](user_op::InferContext* ctx) -> Maybe<void> {
     const Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0);
     const std::string& data_format = ctx->Attr<std::string>("data_format");
-    const std::string& padding = ctx->Attr<std::string>("padding");
-    const std::vector<int32_t>& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
-    const std::vector<int32_t>& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
+    const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding");
     const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
     const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride");
     const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation");
@@ -41,14 +39,13 @@ TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) {
     for (int32_t pool_dim : kernel_size) { CHECK_GT_OR_RETURN(pool_dim, 0); }
     CHECK_EQ_OR_RETURN(stride.size(), dim);
     for (int32_t stride_dim : stride) { CHECK_GT_OR_RETURN(stride_dim, 0); }
-    for (int32_t i = 0; i < padding_after.size(); i++) {
-      CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding_after[i])
+    for (int32_t i = 0; i < padding.size(); i++) {
+      CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding[i])
           << "pad should be smaller than half of kernel size";
     }
 
-    const PoolingParams3D params_3d(dim, *x_shape, data_format, padding, padding_before,
-                                    padding_after, kernel_size, stride, dilation, return_indices,
-                                    ceil_mode);
+    const PoolingParams3D params_3d(dim, *x_shape, data_format, padding, kernel_size, stride,
+                                    dilation, return_indices, ceil_mode);
     user_op::TensorDesc* y_desc = ctx->TensorDesc4ArgNameAndIndex("y", 0);
     *y_desc = *ctx->TensorDesc4ArgNameAndIndex("x", 0);
     *y_desc->mut_shape() = params_3d.GetYShape();
@@ -64,10 +61,9 @@ TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) {
 
 Maybe<void> ForwardGetSbpFn(user_op::SbpContext* ctx) {
   const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
-  const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
-  const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
+  const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding");
   FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {
-    if (padding_before[i] == 0 && padding_after[i] == 0) {
+    if (padding[i] == 0) {
       ctx->NewBuilder()
           .Split(user_op::OpArg("x", 0), i)
           .Split(user_op::OpArg("y", 0), i)
@@ -85,10 +81,9 @@ Maybe<void> BackwardTensorDescInferFn(user_op::InferContext* ctx) {
 
 Maybe<void> BackwardGetSbpFn(user_op::SbpContext* ctx) {
   const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
-  const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
-  const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
+  const std::vector<int32_t>& padding = ctx->Attr<std::vector<int32_t>>("padding");
   FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {
-    if (padding_before[i] == 0 && padding_after[i] == 0) {
+    if (padding[i] == 0) {
       ctx->NewBuilder()
           .Split(user_op::OpArg("x", 0), i)
           .Split(user_op::OpArg("y", 0), i)
@@ -123,9 +118,7 @@ GenBackwardOpConfFn MakeBackwardOpConfFn(const std::string& mode, const int32_t
               .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
               .Output("dx")
               .Attr("data_format", op.attr<std::string>("data_format"))
-              .Attr("padding", op.attr<std::string>("padding"))
-              .Attr("padding_before", op.attr<std::vector<int32_t>>("padding_before"))
-              .Attr("padding_after", op.attr<std::vector<int32_t>>("padding_after"))
+              .Attr("padding", op.attr<std::vector<int32_t>>("padding"))
               .Attr("kernel_size", op.attr<std::vector<int32_t>>("kernel_size"))
               .Attr("stride", op.attr<std::vector<int32_t>>("stride"))
               .Attr("dilation", op.attr<std::vector<int32_t>>("dilation"))
@@ -141,13 +134,45 @@ GenBackwardOpConfFn MakeBackwardOpConfFn(const std::string& mode, const int32_t
 
 }  // namespace
 
+REGISTER_USER_OP("maxpool_1d")
+    .Input("x")
+    .Output("y")
+    .Output("indice")
+    .Attr<std::vector<int32_t>>("padding")
+    .Attr<std::string>("data_format")
+    .Attr<std::vector<int32_t>>("kernel_size")
+    .Attr<std::vector<int32_t>>("stride")
+    .Attr<std::vector<int32_t>>("dilation")
+    .Attr<bool>("return_indices")
+    .Attr<bool>("ceil_mode")
+    .SetTensorDescInferFn(MakeForwardTensorDescInferFn(1))
+    .SetGetSbpFn(ForwardGetSbpFn)
+    .SetDataTypeInferFn(FwInferDataType);
+
+REGISTER_USER_OP("maxpool_1d_grad")
+    .Input("x")
+    .Input("y")
+    .Input("indice")
+    .Input("dy")
+    .Output("dx")
+    .Attr<std::vector<int32_t>>("padding")
+    .Attr<std::string>("data_format")
+    .Attr<std::vector<int32_t>>("kernel_size")
+    .Attr<std::vector<int32_t>>("stride")
+    .Attr<std::vector<int32_t>>("dilation")
+    .Attr<bool>("return_indices")
+    .Attr<bool>("ceil_mode")
+    .SetTensorDescInferFn(BackwardTensorDescInferFn)
+    .SetGetSbpFn(BackwardGetSbpFn)
+    .SetDataTypeInferFn(BwInferDataType);
+
+REGISTER_USER_OP_GRAD("maxpool_1d").SetGenBackwardOpConfFn(MakeBackwardOpConfFn("max", 1));
+
 REGISTER_USER_OP("maxpool_2d")
     .Input("x")
     .Output("y")
     .Output("indice")
-    .Attr<std::string>("padding")
-    .Attr<std::vector<int32_t>>("padding_before")
-    .Attr<std::vector<int32_t>>("padding_after")
+    .Attr<std::vector<int32_t>>("padding")
     .Attr<std::string>("data_format")
     .Attr<std::vector<int32_t>>("kernel_size")
     .Attr<std::vector<int32_t>>("stride")
@@ -164,9 +189,7 @@ REGISTER_USER_OP("maxpool_2d_grad")
     .Input("indice")
     .Input("dy")
     .Output("dx")
-    .Attr<std::string>("padding")
-    .Attr<std::vector<int32_t>>("padding_before")
-    .Attr<std::vector<int32_t>>("padding_after")
+    .Attr<std::vector<int32_t>>("padding")
     .Attr<std::string>("data_format")
     .Attr<std::vector<int32_t>>("kernel_size")
     .Attr<std::vector<int32_t>>("stride")
@@ -183,9 +206,7 @@ REGISTER_USER_OP("maxpool_3d")
     .Input("x")
     .Output("y")
     .Output("indice")
-    .Attr<std::string>("padding")
-    .Attr<std::vector<int32_t>>("padding_before")
-    .Attr<std::vector<int32_t>>("padding_after")
+    .Attr<std::vector<int32_t>>("padding")
     .Attr<std::string>("data_format")
     .Attr<std::vector<int32_t>>("kernel_size")
     .Attr<std::vector<int32_t>>("stride")
@@ -202,9 +223,7 @@ REGISTER_USER_OP("maxpool_3d_grad")
     .Input("indice")
     .Input("dy")
     .Output("dx")
-    .Attr<std::string>("padding")
-    .Attr<std::vector<int32_t>>("padding_before")
-    .Attr<std::vector<int32_t>>("padding_after")
+    .Attr<std::vector<int32_t>>("padding")
     .Attr<std::string>("data_format")
     .Attr<std::vector<int32_t>>("kernel_size")
     .Attr<std::vector<int32_t>>("stride")