diff --git a/oneflow/core/autograd/gradient_funcs/pooling.cpp b/oneflow/core/autograd/gradient_funcs/pooling.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a5d6c8c3737e221741859de4e486b2463bb8a57a
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/pooling.cpp
@@ -0,0 +1,120 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/attr_map.h"
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/framework/op_builder.h"
+#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
+#include "oneflow/core/framework/op_expr.h"
+#include "oneflow/core/framework/op_expr_helper.h"
+#include "oneflow/core/functional/functional.h"
+
+namespace oneflow {
+namespace one {
+
+namespace {
+
+struct PoolingInterpState : public OpExprInterpState {
+  bool requires_grad;
+  size_t input_index;
+  size_t output_index;
+  size_t indice_index;
+
+  std::string data_format;
+  std::string padding;
+  std::vector<int32_t> padding_before;
+  std::vector<int32_t> padding_after;
+  std::vector<int32_t> kernel_size;
+  std::vector<int32_t> stride;
+  std::vector<int32_t> dilation;
+  bool return_indices;
+  bool ceil_mode;
+};
+
+class PoolingNdGrad : public OpExprGradFunction<PoolingInterpState> {
+ public:
+  virtual ~PoolingNdGrad() = default;
+  Maybe<void> Init(const OpExpr& op, const std::string& mode);
+  Maybe<void> Capture(PoolingInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override;
+  Maybe<void> Apply(const PoolingInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override;
+
+ private:
+  std::string mode_;
+  AttrMap base_attrs_;
+};
+
+Maybe<void> PoolingNdGrad::Init(const OpExpr& op, const std::string& mode) {
+  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
+  CHECK_NOTNULL_OR_RETURN(fw_op_expr);
+  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
+  mode_ = mode;
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> PoolingNdGrad::Capture(PoolingInterpState* ctx, const TensorTuple& inputs,
+                                   const TensorTuple& outputs, const AttrMap& attrs) const {
+  ctx->requires_grad = inputs.at(0)->requires_grad();
+  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+
+  ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));
+  ctx->output_index = ctx->SaveTensorForBackward(outputs.at(0));
+  ctx->indice_index = ctx->SaveTensorForBackward(outputs.at(1));
+
+  ComposedAttrMap composed_attrs(attrs, base_attrs_);
+  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
+  ctx->padding = JUST(composed_attrs.GetAttr<std::string>("padding"));
+  ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
+  ctx->padding_after = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_after"));
+  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
+  ctx->stride = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("stride"));
+  ctx->dilation = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation"));
+  ctx->return_indices = JUST(composed_attrs.GetAttr<bool>("return_indices"));
+  ctx->ceil_mode = JUST(composed_attrs.GetAttr<bool>("ceil_mode"));
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> PoolingNdGrad::Apply(const PoolingInterpState* ctx, const TensorTuple& out_grads,
+                                 TensorTuple* in_grads) const {
+  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+  CHECK_LE_OR_RETURN(out_grads.size(), 2);
+
+  int32_t ndims = ctx->kernel_size.size();
+  const auto& input = ctx->SavedTensors().at(ctx->input_index);
+  const auto& output = ctx->SavedTensors().at(ctx->output_index);
+  const auto& indice = ctx->SavedTensors().at(ctx->indice_index);
+
+  in_grads->resize(1);
+  in_grads->at(0) = JUST(functional::PoolingNdGrad(
+      input, output, indice, out_grads.at(0), mode_, ndims, ctx->data_format, ctx->padding,
+      ctx->padding_before, ctx->padding_after, ctx->kernel_size, ctx->stride, ctx->dilation,
+      ctx->return_indices, ctx->ceil_mode));
+
+  return Maybe<void>::Ok();
+}
+
+}  // namespace
+
+class MaxpoolNdGrad final : public PoolingNdGrad {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return PoolingNdGrad::Init(op, "max"); }
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_2d", MaxpoolNdGrad);
+REGISTER_OP_EXPR_GRAD_FUNCTION("maxpool_3d", MaxpoolNdGrad);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index 979f9556243cba135083f607987050ade3c0cc73..c422c2adfc90e3adf9f9744039c4e96a0170e0b8 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -403,6 +403,29 @@
                        Int32List strides, Bool ceil_mode)"
   bind_python: False
 
+- name: "maxpool_2d"
+  signature:
+    "TensorTuple Maxpool2D(Tensor x, *, String data_format=\"channels_first\", String padding,
+                      Int32List padding_before, Int32List padding_after,
+                      Int32List kernel_size, Int32List stride, Int32List dilation, 
+                      Bool return_indices=True, Bool ceil_mode=False)"
+  bind_python: True
+
+- name: "maxpool_3d"
+  signature:
+    "TensorTuple Maxpool3D(Tensor x, *, String data_format=\"channels_first\", String padding,
+                      Int32List padding_before, Int32List padding_after,
+                      Int32List kernel_size, Int32List stride, Int32List dilation, 
+                      Bool return_indices=True, Bool ceil_mode=False)"
+  bind_python: True
+
+- name: "pooling_grad"
+  signature:
+    "Tensor PoolingNdGrad(Tensor x, Tensor y, Tensor indice, Tensor dy, *, String mode, Int32 ndims, String data_format,
+                       String padding, Int32List padding_before, Int32List padding_after, Int32List kernel_size,
+                       Int32List stride, Int32List dilation, Bool return_indices, Bool ceil_mode)"
+  bind_python: False
+
 - name: "prelu"
   signature: "Tensor PRelu(Tensor x, Tensor alpha)"
   bind_python: True
diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp
index da5619afdafd4ea6ae219d8f53c7453ceeb10747..d181e21af21e2002a87702c56c45db5978d2afb7 100644
--- a/oneflow/core/functional/impl/nn_functor.cpp
+++ b/oneflow/core/functional/impl/nn_functor.cpp
@@ -227,6 +227,35 @@ class PoolNDFunctor {
   std::shared_ptr<OpExpr> op_;
 };
 
+class PoolingNDFunctor {
+ public:
+  PoolingNDFunctor() = default;
+  virtual ~PoolingNDFunctor() = default;
+  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& x,
+                                const std::string& data_format, const std::string& padding,
+                                const std::vector<int32_t>& padding_before,
+                                const std::vector<int32_t>& padding_after,
+                                const std::vector<int32_t>& kernel_size,
+                                const std::vector<int32_t>& stride,
+                                const std::vector<int32_t>& dilation, const bool& return_indices,
+                                const bool& ceil_mode) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<std::string>("padding", padding));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation));
+    JUST(attrs.SetAttr<bool>("return_indices", return_indices));
+    JUST(attrs.SetAttr<bool>("ceil_mode", ceil_mode));
+    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {x}, attrs);
+  }
+
+ protected:
+  std::shared_ptr<OpExpr> op_;
+};
+
 class AvgPool2DFunctor : public PoolNDFunctor {
  public:
   AvgPool2DFunctor() {
@@ -241,6 +270,20 @@ class MaxPool2DFunctor : public PoolNDFunctor {
   }
 };
 
+class Maxpool2DFunctor : public PoolingNDFunctor {
+ public:
+  Maxpool2DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("maxpool_2d").Input("x").Output("y").Output("indice").Build());
+  }
+};
+
+class Maxpool3DFunctor : public PoolingNDFunctor {
+ public:
+  Maxpool3DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("maxpool_3d").Input("x").Output("y").Output("indice").Build());
+  }
+};
+
 class SparseSoftmaxCrossEntropyFunctor {
  public:
   SparseSoftmaxCrossEntropyFunctor() {
@@ -420,6 +463,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::LayerNormFunctor>("LayerNorm");
   m.add_functor<impl::LayerNormAffineFunctor>("LayerNormAffine");
   m.add_functor<impl::AvgPool2DFunctor>("AvgPool2D");
+  m.add_functor<impl::Maxpool2DFunctor>("Maxpool2D");
+  m.add_functor<impl::Maxpool3DFunctor>("Maxpool3D");
   m.add_functor<impl::MaxPool2DFunctor>("MaxPool2D");
   m.add_functor<impl::SparseSoftmaxCrossEntropyFunctor>("SparseSoftmaxCrossEntropy");
   m.add_functor<impl::SmoothL1LossFunctor>("SmoothL1Loss");
diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp
index 11c750d5e4e104bac054f3c3807d98b2c64cc280..b47f345d70bda1f4e46761f42409e9086a736143 100644
--- a/oneflow/core/functional/impl/nn_grad_functor.cpp
+++ b/oneflow/core/functional/impl/nn_grad_functor.cpp
@@ -109,6 +109,57 @@ class ConvDataGradFunctor {
   std::shared_ptr<OpExpr> op_;
 };
 
+class PoolingNdGradFunctor {
+ public:
+  PoolingNdGradFunctor() {
+    for (const auto& mode : {"max"}) {
+      for (int ndims = 2; ndims <= 3; ++ndims) {
+        const auto& op_type_name = GetOpTypeName(mode, ndims);
+        op_expr_map_[op_type_name] = CHECK_JUST(one::OpBuilder(op_type_name)
+                                                    .Input("x")
+                                                    .Input("y")
+                                                    .Input("indice")
+                                                    .Input("dy")
+                                                    .Output("dx")
+                                                    .Build());
+      }
+    }
+  }
+  static std::string GetOpTypeName(const std::string& mode, const int32_t& ndims) {
+    return mode + "pool_" + std::to_string(ndims) + "d_grad";
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
+                           const std::shared_ptr<one::Tensor>& y,
+                           const std::shared_ptr<one::Tensor>& indice,
+                           const std::shared_ptr<one::Tensor>& dy, const std::string& mode,
+                           const int32_t& ndims, const std::string& data_format,
+                           const std::string& padding, const std::vector<int32_t>& padding_before,
+                           const std::vector<int32_t>& padding_after,
+                           const std::vector<int32_t>& kernel_size,
+                           const std::vector<int32_t>& stride, const std::vector<int32_t>& dilation,
+                           const bool& return_indices, const bool& ceil_mode) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<std::string>("padding", padding));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_after", padding_after));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("stride", stride));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("dilation", dilation));
+    JUST(attrs.SetAttr<bool>("return_indices", return_indices));
+    JUST(attrs.SetAttr<bool>("ceil_mode", ceil_mode));
+    const auto& op_type_name = GetOpTypeName(mode, ndims);
+    const auto& it = op_expr_map_.find(op_type_name);
+    CHECK_OR_RETURN(it != op_expr_map_.end())
+        << "Encounter unsupported op " << op_type_name << " in PoolingNdGradFunctor.";
+    CHECK_NOTNULL_OR_RETURN(it->second);
+    return OpInterpUtil::Dispatch<Tensor>(*it->second, {x, y, indice, dy}, attrs);
+  }
+
+ protected:
+  std::unordered_map<std::string, std::shared_ptr<OpExpr>> op_expr_map_;
+};
+
 class PoolNdGradFunctor {
  public:
   PoolNdGradFunctor() {
@@ -225,6 +276,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::ConvDataGradFunctor>("ConvDataGrad");
   m.add_functor<impl::PoolNdGradFunctor>("PoolNdGrad");
   m.add_functor<impl::SmoothL1LossGradFunctor>("SmoothL1LossGrad");
+  m.add_functor<impl::PoolingNdGradFunctor>("PoolingNdGrad");
   m.add_functor<impl::PadGradFunctor>("PadGrad");
 };
 
diff --git a/oneflow/core/kernel/util/numeric_limits.cuh b/oneflow/core/kernel/util/numeric_limits.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..7ef9e9e3ee0e1f5fbcd4cccaf59898a5ad3609ad
--- /dev/null
+++ b/oneflow/core/kernel/util/numeric_limits.cuh
@@ -0,0 +1,128 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+// reference: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/NumericLimits.cuh
+#pragma once
+#include <limits.h>
+#include <math.h>
+#include <float.h>
+
+#include "oneflow/core/device/cuda_util.h"
+#include "oneflow/core/framework/framework.h"
+
+// numeric_limits.cuh is a holder for numeric limits definitions of commonly used
+// types. This header is very specific to ROCm HIP and may be removed in the future.
+
+// The lower_bound and upper_bound constants are same as lowest and max for
+// integral types, but are -inf and +inf for floating point types. They are
+// useful in implementing min, max, etc.
+
+namespace oneflow {
+namespace detail {
+
+#if defined(__CUDACC__)
+#define OF_NUMERICS_FUNC static inline __host__ __device__
+#else
+#define OF_NUMERICS_FUNC static inline
+#endif
+
+template<typename T>
+struct numeric_limits {};
+
+// WARNING: the following oneflow::numeric_limits definitions are there only to support
+//          HIP compilation for the moment. Use std::numeric_limits if you are not
+//          compiling for ROCm.
+//          from @colesbury: "The functions on numeric_limits aren't marked with
+//          __device__ which is why they don't work with ROCm. CUDA allows them
+//          because they're constexpr."
+
+namespace {
+// ROCm doesn't like INFINITY too.
+constexpr double inf = INFINITY;
+}  // namespace
+
+template<>
+struct numeric_limits<bool> {
+  OF_NUMERICS_FUNC bool lowest() { return false; }
+  OF_NUMERICS_FUNC bool max() { return true; }
+  OF_NUMERICS_FUNC bool lower_bound() { return false; }
+  OF_NUMERICS_FUNC bool upper_bound() { return true; }
+};
+
+template<>
+struct numeric_limits<uint8_t> {
+  OF_NUMERICS_FUNC uint8_t lowest() { return 0; }
+  OF_NUMERICS_FUNC uint8_t max() { return UINT8_MAX; }
+  OF_NUMERICS_FUNC uint8_t lower_bound() { return 0; }
+  OF_NUMERICS_FUNC uint8_t upper_bound() { return UINT8_MAX; }
+};
+
+template<>
+struct numeric_limits<int8_t> {
+  OF_NUMERICS_FUNC int8_t lowest() { return INT8_MIN; }
+  OF_NUMERICS_FUNC int8_t max() { return INT8_MAX; }
+  OF_NUMERICS_FUNC int8_t lower_bound() { return INT8_MIN; }
+  OF_NUMERICS_FUNC int8_t upper_bound() { return INT8_MAX; }
+};
+
+template<>
+struct numeric_limits<int16_t> {
+  OF_NUMERICS_FUNC int16_t lowest() { return INT16_MIN; }
+  OF_NUMERICS_FUNC int16_t max() { return INT16_MAX; }
+  OF_NUMERICS_FUNC int16_t lower_bound() { return INT16_MIN; }
+  OF_NUMERICS_FUNC int16_t upper_bound() { return INT16_MAX; }
+};
+
+template<>
+struct numeric_limits<int32_t> {
+  OF_NUMERICS_FUNC int32_t lowest() { return INT32_MIN; }
+  OF_NUMERICS_FUNC int32_t max() { return INT32_MAX; }
+  OF_NUMERICS_FUNC int32_t lower_bound() { return INT32_MIN; }
+  OF_NUMERICS_FUNC int32_t upper_bound() { return INT32_MAX; }
+};
+
+template<>
+struct numeric_limits<int64_t> {
+#ifdef _MSC_VER
+  OF_NUMERICS_FUNC int64_t lowest() { return _I64_MIN; }
+  OF_NUMERICS_FUNC int64_t max() { return _I64_MAX; }
+  OF_NUMERICS_FUNC int64_t lower_bound() { return _I64_MIN; }
+  OF_NUMERICS_FUNC int64_t upper_bound() { return _I64_MAX; }
+#else
+  OF_NUMERICS_FUNC int64_t lowest() { return INT64_MIN; }
+  OF_NUMERICS_FUNC int64_t max() { return INT64_MAX; }
+  OF_NUMERICS_FUNC int64_t lower_bound() { return INT64_MIN; }
+  OF_NUMERICS_FUNC int64_t upper_bound() { return INT64_MAX; }
+#endif
+};
+
+template<>
+struct numeric_limits<float> {
+  OF_NUMERICS_FUNC float lowest() { return -FLT_MAX; }
+  OF_NUMERICS_FUNC float max() { return FLT_MAX; }
+  OF_NUMERICS_FUNC float lower_bound() { return -static_cast<float>(inf); }
+  OF_NUMERICS_FUNC float upper_bound() { return static_cast<float>(inf); }
+};
+
+template<>
+struct numeric_limits<double> {
+  OF_NUMERICS_FUNC double lowest() { return -DBL_MAX; }
+  OF_NUMERICS_FUNC double max() { return DBL_MAX; }
+  OF_NUMERICS_FUNC double lower_bound() { return -inf; }
+  OF_NUMERICS_FUNC double upper_bound() { return inf; }
+};
+
+}  // namespace detail
+}  // namespace oneflow
diff --git a/oneflow/core/kernel/util/numerics.cuh b/oneflow/core/kernel/util/numerics.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..d0e46faa070feebded6603961dbe58bf87254199
--- /dev/null
+++ b/oneflow/core/kernel/util/numerics.cuh
@@ -0,0 +1,249 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+// reference: https://github.com/pytorch/pytorch/blob/master/aten/src/THC/THCNumerics.cuh
+#ifndef ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H
+#define ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H
+#pragma once
+
+#include <limits.h>
+#include <math.h>
+#include <float.h>
+#include <cstdlib>
+#include <assert.h>
+
+#include "oneflow/core/device/cuda_util.h"
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/util/numeric_limits.cuh"
+
+namespace oneflow {
+namespace detail {
+
+template<typename T>
+struct numerics {};
+
+template<typename T>
+OF_NUMERICS_FUNC T powi(T a, T b) {
+  assert(numerics<T>::ge(b, 0));
+  T result = 1;
+  while (b) {
+    if (b & 1) { result *= a; }
+    b /= 2;
+    a *= a;
+  }
+  return result;
+}
+
+template<>
+struct numerics<uint8_t> {
+  OF_NUMERICS_FUNC uint8_t min() { return detail::numeric_limits<uint8_t>::lowest(); }
+  OF_NUMERICS_FUNC uint8_t max() { return detail::numeric_limits<uint8_t>::max(); }
+  OF_NUMERICS_FUNC uint8_t lower_bound() { return detail::numeric_limits<uint8_t>::lower_bound(); }
+  OF_NUMERICS_FUNC uint8_t upper_bound() { return detail::numeric_limits<uint8_t>::upper_bound(); }
+
+  OF_NUMERICS_FUNC bool lt(uint8_t a, uint8_t b) { return a < b; }
+  OF_NUMERICS_FUNC bool le(uint8_t a, uint8_t b) { return a <= b; }
+  OF_NUMERICS_FUNC bool gt(uint8_t a, uint8_t b) { return a > b; }
+  OF_NUMERICS_FUNC bool ge(uint8_t a, uint8_t b) { return a >= b; }
+  OF_NUMERICS_FUNC bool eq(uint8_t a, uint8_t b) { return a == b; }
+  OF_NUMERICS_FUNC bool ne(uint8_t a, uint8_t b) { return a != b; }
+
+  OF_NUMERICS_FUNC uint8_t add(uint8_t a, uint8_t b) { return a + b; }
+  OF_NUMERICS_FUNC uint8_t mul(uint8_t a, uint8_t b) { return a * b; }
+  OF_NUMERICS_FUNC uint8_t sub(uint8_t a, uint8_t b) { return a - b; }
+  OF_NUMERICS_FUNC uint8_t div(uint8_t a, uint8_t b) { return a / b; }
+  OF_NUMERICS_FUNC uint8_t pow(uint8_t a, uint8_t b) { return powi<uint8_t>(a, b); }
+  OF_NUMERICS_FUNC bool isnan(uint8_t a) { return false; }
+  OF_NUMERICS_FUNC bool isinf(uint8_t a) { return false; }
+};
+
+#ifdef _MSC_VER
+// Suppress warning C4804: '/': unsafe use of type 'bool' in operation
+#pragma warning(push)
+#pragma warning(disable : 4804)
+#endif
+
+template<>
+struct numerics<bool> {
+  OF_NUMERICS_FUNC bool min() { return detail::numeric_limits<bool>::lowest(); }
+  OF_NUMERICS_FUNC bool max() { return detail::numeric_limits<bool>::max(); }
+  OF_NUMERICS_FUNC bool lower_bound() { return detail::numeric_limits<bool>::lower_bound(); }
+  OF_NUMERICS_FUNC bool upper_bound() { return detail::numeric_limits<bool>::upper_bound(); }
+
+  OF_NUMERICS_FUNC bool lt(bool a, bool b) { return a < b; }
+  OF_NUMERICS_FUNC bool le(bool a, bool b) { return a <= b; }
+  OF_NUMERICS_FUNC bool gt(bool a, bool b) { return a > b; }
+  OF_NUMERICS_FUNC bool ge(bool a, bool b) { return a >= b; }
+  OF_NUMERICS_FUNC bool eq(bool a, bool b) { return a == b; }
+  OF_NUMERICS_FUNC bool ne(bool a, bool b) { return a != b; }
+  OF_NUMERICS_FUNC bool add(bool a, bool b) { return a + b; }
+  OF_NUMERICS_FUNC bool mul(bool a, bool b) { return a && b; }
+  OF_NUMERICS_FUNC bool sub(bool a, bool b) { return a - b; }
+  OF_NUMERICS_FUNC bool div(bool a, bool b) { return a / b; }
+  OF_NUMERICS_FUNC bool isnan(bool a) { return false; }
+  OF_NUMERICS_FUNC bool isinf(bool a) { return false; }
+};
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+template<>
+struct numerics<int8_t> {
+  OF_NUMERICS_FUNC int8_t min() { return detail::numeric_limits<int8_t>::lowest(); }
+  OF_NUMERICS_FUNC int8_t max() { return detail::numeric_limits<int8_t>::max(); }
+  OF_NUMERICS_FUNC int8_t lower_bound() { return detail::numeric_limits<int8_t>::lower_bound(); }
+  OF_NUMERICS_FUNC int8_t upper_bound() { return detail::numeric_limits<int8_t>::upper_bound(); }
+
+  OF_NUMERICS_FUNC bool lt(int8_t a, int8_t b) { return a < b; }
+  OF_NUMERICS_FUNC bool le(int8_t a, int8_t b) { return a <= b; }
+  OF_NUMERICS_FUNC bool gt(int8_t a, int8_t b) { return a > b; }
+  OF_NUMERICS_FUNC bool ge(int8_t a, int8_t b) { return a >= b; }
+  OF_NUMERICS_FUNC bool eq(int8_t a, int8_t b) { return a == b; }
+  OF_NUMERICS_FUNC bool ne(int8_t a, int8_t b) { return a != b; }
+
+  OF_NUMERICS_FUNC int8_t add(int8_t a, int8_t b) { return a + b; }
+  OF_NUMERICS_FUNC int8_t mul(int8_t a, int8_t b) { return a * b; }
+  OF_NUMERICS_FUNC int8_t sub(int8_t a, int8_t b) { return a - b; }
+  OF_NUMERICS_FUNC int8_t div(int8_t a, int8_t b) { return a / b; }
+  OF_NUMERICS_FUNC int8_t pow(int8_t a, int8_t b) { return powi<int8_t>(a, b); }
+  OF_NUMERICS_FUNC bool isnan(int8_t a) { return false; }
+  OF_NUMERICS_FUNC bool isinf(int8_t a) { return false; }
+};
+
+template<>
+struct numerics<int16_t> {
+  OF_NUMERICS_FUNC int16_t min() { return detail::numeric_limits<int16_t>::lowest(); }
+  OF_NUMERICS_FUNC int16_t max() { return detail::numeric_limits<int16_t>::max(); }
+  OF_NUMERICS_FUNC int16_t lower_bound() { return detail::numeric_limits<int16_t>::lower_bound(); }
+  OF_NUMERICS_FUNC int16_t upper_bound() { return detail::numeric_limits<int16_t>::upper_bound(); }
+
+  OF_NUMERICS_FUNC bool lt(int16_t a, int16_t b) { return a < b; }
+  OF_NUMERICS_FUNC bool le(int16_t a, int16_t b) { return a <= b; }
+  OF_NUMERICS_FUNC bool gt(int16_t a, int16_t b) { return a > b; }
+  OF_NUMERICS_FUNC bool ge(int16_t a, int16_t b) { return a >= b; }
+  OF_NUMERICS_FUNC bool eq(int16_t a, int16_t b) { return a == b; }
+  OF_NUMERICS_FUNC bool ne(int16_t a, int16_t b) { return a != b; }
+
+  OF_NUMERICS_FUNC int16_t add(int16_t a, int16_t b) { return a + b; }
+  OF_NUMERICS_FUNC int16_t mul(int16_t a, int16_t b) { return a * b; }
+  OF_NUMERICS_FUNC int16_t sub(int16_t a, int16_t b) { return a - b; }
+  OF_NUMERICS_FUNC int16_t div(int16_t a, int16_t b) { return a / b; }
+  OF_NUMERICS_FUNC int16_t pow(int16_t a, int16_t b) { return powi<int16_t>(a, b); }
+  OF_NUMERICS_FUNC bool isnan(int16_t a) { return false; }
+  OF_NUMERICS_FUNC bool isinf(int16_t a) { return false; }
+};
+
+template<>
+struct numerics<int32_t> {
+  OF_NUMERICS_FUNC int32_t min() { return detail::numeric_limits<int32_t>::lowest(); }
+  OF_NUMERICS_FUNC int32_t max() { return detail::numeric_limits<int32_t>::max(); }
+  OF_NUMERICS_FUNC int32_t lower_bound() { return detail::numeric_limits<int32_t>::lower_bound(); }
+  OF_NUMERICS_FUNC int32_t upper_bound() { return detail::numeric_limits<int32_t>::upper_bound(); }
+
+  OF_NUMERICS_FUNC bool lt(int32_t a, int32_t b) { return a < b; }
+  OF_NUMERICS_FUNC bool le(int32_t a, int32_t b) { return a <= b; }
+  OF_NUMERICS_FUNC bool gt(int32_t a, int32_t b) { return a > b; }
+  OF_NUMERICS_FUNC bool ge(int32_t a, int32_t b) { return a >= b; }
+  OF_NUMERICS_FUNC bool eq(int32_t a, int32_t b) { return a == b; }
+  OF_NUMERICS_FUNC bool ne(int32_t a, int32_t b) { return a != b; }
+
+  OF_NUMERICS_FUNC int32_t add(int32_t a, int32_t b) { return a + b; }
+  OF_NUMERICS_FUNC int32_t mul(int32_t a, int32_t b) { return a * b; }
+  OF_NUMERICS_FUNC int32_t sub(int32_t a, int32_t b) { return a - b; }
+  OF_NUMERICS_FUNC int32_t div(int32_t a, int32_t b) { return a / b; }
+  OF_NUMERICS_FUNC int32_t pow(int32_t a, int32_t b) { return powi<int32_t>(a, b); }
+  OF_NUMERICS_FUNC bool isnan(int32_t a) { return false; }
+  OF_NUMERICS_FUNC bool isinf(int32_t a) { return false; }
+};
+
+template<>
+struct numerics<int64_t> {
+  OF_NUMERICS_FUNC int64_t min() { return detail::numeric_limits<int64_t>::lowest(); }
+  OF_NUMERICS_FUNC int64_t max() { return detail::numeric_limits<int64_t>::max(); }
+  OF_NUMERICS_FUNC int64_t lower_bound() { return detail::numeric_limits<int64_t>::lower_bound(); }
+  OF_NUMERICS_FUNC int64_t upper_bound() { return detail::numeric_limits<int64_t>::upper_bound(); }
+
+  OF_NUMERICS_FUNC bool lt(int64_t a, int64_t b) { return a < b; }
+  OF_NUMERICS_FUNC bool le(int64_t a, int64_t b) { return a <= b; }
+  OF_NUMERICS_FUNC bool gt(int64_t a, int64_t b) { return a > b; }
+  OF_NUMERICS_FUNC bool ge(int64_t a, int64_t b) { return a >= b; }
+  OF_NUMERICS_FUNC bool eq(int64_t a, int64_t b) { return a == b; }
+  OF_NUMERICS_FUNC bool ne(int64_t a, int64_t b) { return a != b; }
+
+  OF_NUMERICS_FUNC int64_t add(int64_t a, int64_t b) { return a + b; }
+  OF_NUMERICS_FUNC int64_t mul(int64_t a, int64_t b) { return a * b; }
+  OF_NUMERICS_FUNC int64_t sub(int64_t a, int64_t b) { return a - b; }
+  OF_NUMERICS_FUNC int64_t div(int64_t a, int64_t b) { return a / b; };
+  OF_NUMERICS_FUNC int64_t pow(int64_t a, int64_t b) { return powi<int64_t>(a, b); }
+  OF_NUMERICS_FUNC bool isnan(int64_t a) { return false; }
+  OF_NUMERICS_FUNC bool isinf(int64_t a) { return false; }
+};
+
+// DEPRECATED: use math functions from std and cuda math API (if needed)
+template<>
+struct numerics<float> {
+  OF_NUMERICS_FUNC float min() { return detail::numeric_limits<float>::lowest(); }
+  OF_NUMERICS_FUNC float max() { return detail::numeric_limits<float>::max(); }
+  OF_NUMERICS_FUNC float lower_bound() { return detail::numeric_limits<float>::lower_bound(); }
+  OF_NUMERICS_FUNC float upper_bound() { return detail::numeric_limits<float>::upper_bound(); }
+
+  OF_NUMERICS_FUNC bool lt(float a, float b) { return a < b; }
+  OF_NUMERICS_FUNC bool le(float a, float b) { return a <= b; }
+  OF_NUMERICS_FUNC bool gt(float a, float b) { return a > b; }
+  OF_NUMERICS_FUNC bool ge(float a, float b) { return a >= b; }
+  OF_NUMERICS_FUNC bool eq(float a, float b) { return a == b; }
+  OF_NUMERICS_FUNC bool ne(float a, float b) { return a != b; }
+
+  OF_NUMERICS_FUNC float sqrt(float a) { return sqrtf(a); }
+  OF_NUMERICS_FUNC float atan(float a) { return atanf(a); }
+  OF_NUMERICS_FUNC float add(float a, float b) { return a + b; }
+  OF_NUMERICS_FUNC float div(float a, float b) { return a / b; }
+  OF_NUMERICS_FUNC float mul(float a, float b) { return a * b; }
+  OF_NUMERICS_FUNC float sub(float a, float b) { return a - b; }
+  OF_NUMERICS_FUNC float pow(float a, float b) { return powf(a, b); }
+  OF_NUMERICS_FUNC bool isnan(float a) { return ::isnan(a); }
+  OF_NUMERICS_FUNC bool isinf(float a) { return ::isinf(a); }
+};
+
+template<>
+struct numerics<double> {
+  OF_NUMERICS_FUNC double min() { return detail::numeric_limits<double>::lowest(); }
+  OF_NUMERICS_FUNC double max() { return detail::numeric_limits<double>::max(); }
+  OF_NUMERICS_FUNC double lower_bound() { return detail::numeric_limits<double>::lower_bound(); }
+  OF_NUMERICS_FUNC double upper_bound() { return detail::numeric_limits<double>::upper_bound(); }
+
+  OF_NUMERICS_FUNC bool lt(double a, double b) { return a < b; }
+  OF_NUMERICS_FUNC bool le(double a, double b) { return a <= b; }
+  OF_NUMERICS_FUNC bool gt(double a, double b) { return a > b; }
+  OF_NUMERICS_FUNC bool ge(double a, double b) { return a >= b; }
+  OF_NUMERICS_FUNC bool eq(double a, double b) { return a == b; }
+  OF_NUMERICS_FUNC bool ne(double a, double b) { return a != b; }
+
+  OF_NUMERICS_FUNC double sqrt(double a) { return ::sqrt(a); }
+  OF_NUMERICS_FUNC double atan(double a) { return ::atan(a); }
+  OF_NUMERICS_FUNC double add(double a, double b) { return a + b; }
+  OF_NUMERICS_FUNC double div(double a, double b) { return a / b; }
+  OF_NUMERICS_FUNC double mul(double a, double b) { return a * b; }
+  OF_NUMERICS_FUNC double sub(double a, double b) { return a - b; }
+  OF_NUMERICS_FUNC double pow(double a, double b) { return ::pow(a, b); }
+  OF_NUMERICS_FUNC bool isnan(double a) { return ::isnan(a); }
+  OF_NUMERICS_FUNC bool isinf(double a) { return ::isinf(a); }
+};
+
+}  // namespace detail
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_KERNEL_UTIL_NUMERICS_H
diff --git a/oneflow/python/nn/modules/pooling.py b/oneflow/python/nn/modules/pooling.py
index 8f2e1b899b5e7ad88cf8e3d5a091d3772e77ef3f..47cfca1d451e529c3a6f2dfc7bd75060da0c3d2d 100644
--- a/oneflow/python/nn/modules/pooling.py
+++ b/oneflow/python/nn/modules/pooling.py
@@ -20,7 +20,7 @@ from oneflow.python.oneflow_export import oneflow_export, experimental_api
 from oneflow.python.nn.module import Module
 from oneflow.python.nn.modules.utils import _single, _pair, _triple
 from oneflow.python.nn.common_types import _size_1_t, _size_2_t, _size_3_t
-from oneflow.python.ops.nn_ops import calc_pool_padding, get_dhw_offset
+from oneflow.python.ops.nn_ops import calc_pool_padding, get_dhw_offset, _GetSequence
 
 
 @oneflow_export("nn.AvgPool1d")
@@ -200,12 +200,13 @@ class AvgPool3d(Module):
 
         >>> import oneflow.experimental as flow
         >>> import numpy as np
-
         >>> flow.enable_eager_execution()
-        >>> inputarr = np.random.randn(9, 7, 11, 32, 20)
-        >>> of_avgpool3d = flow.nn.AvgPool3d(kernel_size=(2,2,2),padding=(0,0,0),stride=(1,1,1),)
-        >>> x = flow.Tensor(inputarr)
-        >>> y = of_avgpool3d(x)
+
+        >>> m = flow.nn.AvgPool3d(kernel_size=(2,2,2),padding=(0,0,0),stride=(1,1,1))
+        >>> x = flow.Tensor(np.random.randn(9, 7, 11, 32, 20))
+        >>> y = m(x)
+        >>> y.shape
+        flow.Size([9, 7, 10, 31, 19])
 
     """
 
@@ -311,8 +312,51 @@ class MaxPool1d(Module):
         return_indices: bool = False,
         ceil_mode: bool = False,
     ):
-        # TODO: fix cuDNN bugs in pooling_1d
-        raise NotImplementedError
+        super().__init__()
+        self.kernel_size = _pair(tuple(kernel_size)[0])
+        self.stride = _pair(tuple(stride)[0]) if (stride is not None) else kernel_size
+        data_format = "NCL"  # Only suport "NCL" for now!
+        self.channel_pos = "channels_first" if data_format == "NCL" else "channels_last"
+        self.dilation = _GetSequence(dilation, 2, "dilation")
+        padding = _pair(tuple(padding)[0])
+        self.return_indices = return_indices
+        self.ceil_mode = ceil_mode
+
+        if len(padding) == 2:
+            if self.channel_pos == "channels_first":
+                padding = (0, 0, padding[0], padding[1])
+            else:
+                raise ValueError("error padding param!")
+        else:
+            raise ValueError("error padding param!")
+
+        self.padding_type, pads_list = calc_pool_padding(
+            padding, get_dhw_offset(self.channel_pos), 2
+        )
+        self.padding_before = [pad[0] for pad in pads_list]
+        self.padding_after = [pad[1] for pad in pads_list]
+
+    def forward(self, x):
+        expand_x = x.unsqueeze(dim=-1)
+        expand_y, expand_indice = flow.F.maxpool_2d(
+            expand_x,
+            data_format=self.channel_pos,
+            padding=self.padding_type,
+            padding_before=self.padding_before,
+            padding_after=self.padding_after,
+            kernel_size=self.kernel_size,
+            stride=self.stride,
+            dilation=self.dilation,
+            return_indices=True,
+            ceil_mode=self.ceil_mode,
+        )
+
+        y = expand_y.squeeze(dim=-1)
+        indice = expand_indice.squeeze(dim=-1)
+        if self.return_indices:
+            return y, indice
+        else:
+            return y
 
 
 @oneflow_export("nn.MaxPool2d")
@@ -374,21 +418,21 @@ class MaxPool2d(Module):
         >>> import numpy as np
         >>> flow.enable_eager_execution()
 
-        >>> kernel_size, stride, padding = (3, 3), (1, 1), (1, 2)
+        >>> kernel_size, stride, padding = (3, 4), (1, 1), (1, 2)
         >>> m = flow.nn.MaxPool2d(kernel_size, stride, padding)
         >>> np.random.seed(0)
         >>> x = flow.Tensor(np.random.rand(1, 1, 5, 3))
         >>> y = m(x)
         >>> y #doctest: +ELLIPSIS
-        tensor([[[[0.5488, 0.7152, 0.7152, 0.7152, 0.6459],
+        tensor([[[[0.7152, 0.7152, 0.7152, 0.7152],
                   ...
-                  [0.568 , 0.9256, 0.9256, 0.9256, 0.5289]]]], dtype=oneflow.float32)
+                  [0.9256, 0.9256, 0.9256, 0.9256]]]], dtype=oneflow.float32)
 
-        >>> kernel_size, stride, padding = (2, 3), (4, 5), (1, 2)
+        >>> kernel_size, stride, padding = (2, 4), (4, 5), (1, 2)
         >>> m = flow.nn.MaxPool2d(kernel_size, stride, padding)
         >>> x = flow.Tensor(np.random.randn(9, 7, 32, 20))
         >>> y = m(x)
-        >>> y.size()
+        >>> y.shape
         flow.Size([9, 7, 9, 5])
 
     """
@@ -404,14 +448,14 @@ class MaxPool2d(Module):
     ):
         super().__init__()
         self.kernel_size = _pair(kernel_size)
-        self.strides = _pair(stride) if (stride is not None) else kernel_size
-        data_format = "NCHW"
+        self.stride = _pair(stride) if (stride is not None) else kernel_size
+        data_format = "NCHW"  # Only suport "NCHW" for now!
         self.channel_pos = (
-            "channels_last" if data_format == "NHWC" else "channels_first"
+            "channels_first" if data_format == "NCHW" else "channels_last"
         )
-
-        assert return_indices is False, "Only support return_indices==False for now!"
-        assert dilation == 1 or dilation == (1, 1), "Only support dilation==1 for now!"
+        self.dilation = _GetSequence(dilation, 2, "dilation")
+        self.return_indices = return_indices
+        self.ceil_mode = ceil_mode
 
         padding = _pair(padding)
         if len(padding) == 2:
@@ -421,25 +465,32 @@ class MaxPool2d(Module):
                 raise ValueError("error padding param!")
         else:
             raise ValueError("error padding param!")
+
         self.padding_type, pads_list = calc_pool_padding(
             padding, get_dhw_offset(self.channel_pos), 2
         )
         self.padding_before = [pad[0] for pad in pads_list]
         self.padding_after = [pad[1] for pad in pads_list]
-        self.ceil_mode = ceil_mode
 
     def forward(self, x):
-        return flow.F.max_pool_2d(
+        y, indice = flow.F.maxpool_2d(
             x,
-            kernel_size=self.kernel_size,
-            stride=self.strides,
+            data_format=self.channel_pos,
             padding=self.padding_type,
             padding_before=self.padding_before,
             padding_after=self.padding_after,
+            kernel_size=self.kernel_size,
+            stride=self.stride,
+            dilation=self.dilation,
+            return_indices=True,
             ceil_mode=self.ceil_mode,
-            data_format=self.channel_pos,
         )
 
+        if self.return_indices:
+            return y, indice
+        else:
+            return y
+
 
 @oneflow_export("nn.MaxPool3d")
 @experimental_api
@@ -507,21 +558,21 @@ class MaxPool3d(Module):
         >>> import numpy as np
         >>> flow.enable_eager_execution()
 
-        >>> kernel_size, stride, padding = (3, 3, 3), (1, 1, 1), (1, 1, 2)
+        >>> kernel_size, stride, padding = (3, 3, 4), (1, 1, 1), (1, 1, 2)
         >>> m = flow.nn.MaxPool3d(kernel_size, stride, padding)
         >>> np.random.seed(0)
         >>> x = flow.Tensor(np.random.rand(1, 1, 3, 5, 3))
         >>> y = m(x)
         >>> y #doctest: +ELLIPSIS
-        tensor([[[[[0.7782, 0.87  , 0.9786, 0.9786, 0.9786],
+        tensor([[[[[0.87  , 0.9786, 0.9786, 0.9786],
                    ...
-                   [0.9447, 0.9447, 0.9447, 0.6668, 0.6668]]]]], dtype=oneflow.float32)
-        >>> kernel_size, stride, padding = (2, 2, 3), (3, 4, 5), (2, 1, 2)
+                   [0.9447, 0.9447, 0.9447, 0.6668]]]]], dtype=oneflow.float32)
+        >>> kernel_size, stride, padding = (4, 2, 4), (3, 4, 5), (2, 1, 2)
         >>> m = flow.nn.MaxPool3d(kernel_size, stride, padding)
         >>> x = flow.Tensor(np.random.randn(9, 7, 11, 32, 20))
         >>> y = m(x)
-        >>> y.size()
-        flow.Size([9, 7, 5, 9, 5])
+        >>> y.shape
+        flow.Size([9, 7, 4, 9, 5])
 
     """
 
@@ -535,19 +586,17 @@ class MaxPool3d(Module):
         ceil_mode: bool = False,
     ):
         super().__init__()
-        kernel_size = _triple(kernel_size)
-        strides = _triple(stride) if (stride is not None) else kernel_size
+        self.kernel_size = _triple(kernel_size)
+        self.stride = _triple(stride) if (stride is not None) else kernel_size
         data_format = "NCDHW"
-        channel_pos = "channels_last" if data_format == "NDHWC" else "channels_first"
-
-        assert return_indices is False, "Only support return_indices==False for now!"
-        assert dilation == 1 or dilation == (
-            1,
-            1,
-            1,
-        ), "Only support dilation==1 for now!"
-
+        self.channel_pos = (
+            "channels_last" if data_format == "NDHWC" else "channels_first"
+        )
+        self.dilation = _GetSequence(dilation, 3, "dilation")
         padding = _triple(padding)
+        self.return_indices = return_indices
+        self.ceil_mode = ceil_mode
+
         if len(padding) == 3:
             if data_format == "NCDHW":
                 padding = (0, 0, padding[0], padding[1], padding[2])
@@ -556,28 +605,30 @@ class MaxPool3d(Module):
         else:
             raise ValueError("error padding param!")
 
-        padding_type, pads_list = calc_pool_padding(
-            padding, get_dhw_offset(channel_pos), 3
+        self.padding_type, pads_list = calc_pool_padding(
+            padding, get_dhw_offset(self.channel_pos), 3
         )
-        padding_before = [pad[0] for pad in pads_list]
-        padding_after = [pad[1] for pad in pads_list]
+        self.padding_before = [pad[0] for pad in pads_list]
+        self.padding_after = [pad[1] for pad in pads_list]
 
-        self._op = (
-            flow.builtin_op("max_pool_3d")
-            .Attr("data_format", channel_pos)
-            .Attr("pool_size", kernel_size)
-            .Attr("strides", strides)
-            .Attr("ceil_mode", ceil_mode)
-            .Attr("padding", padding_type)
-            .Attr("padding_before", padding_before)
-            .Attr("padding_after", padding_after)
-            .Input("x")
-            .Output("y")
-            .Build()
+    def forward(self, x):
+        y, indice = flow.F.maxpool_3d(
+            x,
+            data_format=self.channel_pos,
+            padding=self.padding_type,
+            padding_before=self.padding_before,
+            padding_after=self.padding_after,
+            kernel_size=self.kernel_size,
+            stride=self.stride,
+            dilation=self.dilation,
+            return_indices=True,
+            ceil_mode=self.ceil_mode,
         )
 
-    def forward(self, x):
-        return self._op(x)[0]
+        if self.return_indices:
+            return y, indice
+        else:
+            return y
 
 
 if __name__ == "__main__":
diff --git a/oneflow/python/ops/nn_ops.py b/oneflow/python/ops/nn_ops.py
index c76d3cfef784e2c35e2fcd37777ef7fa687c8277..bcc290bcba79d427149a9f5b058f20376d968119 100644
--- a/oneflow/python/ops/nn_ops.py
+++ b/oneflow/python/ops/nn_ops.py
@@ -1789,6 +1789,426 @@ def calc_pool_padding(padding, dhw_offset, ndims):
     return padding_type, ndim_pads_list
 
 
+@oneflow_export("nn.MaxPool1d")
+@stable_api
+def MaxPool1d(
+    input: oneflow._oneflow_internal.BlobDesc,
+    kernel_size: Union[int, IntPair],
+    stride: Union[int, IntPair],
+    padding: Union[str, IntPair],
+    dilation: Union[int, IntPair] = 1,
+    return_indices: bool = False,
+    ceil_mode: bool = False,
+    data_format: str = "NCHW",
+    name: Optional[str] = None,
+) -> oneflow._oneflow_internal.BlobDesc:
+    r""" Performs the 1d-max pooling on the input `Blob`.
+         Different from nn.max_pool1d, nn.MaxPool2d supports more params e.g. dilation,return_indices.
+
+    Args:
+        input (remote_blob_util.BlobDesc): A 3-D `Blob` of the format specified by data_format.
+        kernel_size (Union[int, IntPair]): An int or list of ints that has length 1, 2. The size of the window for each dimension of the input `Blob`.
+        stride (Union[int, IntPair]): An int or list of ints that has length 1, 2. The stride of the sliding window for each dimension of the input `Blob`.
+        padding (str): '`VALID'` or '`SAME'` or '`SAME_LOWER'` or '`SAME_UPPER'` or Tuple[IntPair, IntPair, IntPair, IntPair]`. The padding algorithm.
+        dilation (Union[int, IntPair]): a parameter that controls the stride of elements in the window.
+        return_indices (bool): if True, will return the max indices along with the outputs.
+        ceil_mode (bool): when True, will use ceil instead of floor to compute the output shape.
+        data_format (str, optional): '`NHWC'`, '`NCHW'` or '`NCHW_VECT_C'`. Defaults to "NCHW", for now only supporr 'NCHW'.
+        name (Optional[str], optional): This operator's name(optional). Defaults to None.
+
+    Returns:
+        remote_blob_util.BlobDesc:  A `Blob` of format specified by data_format. The max pooled output `Blob`.
+
+    For example: 
+
+    .. code-block:: python 
+
+        import oneflow as flow
+        import oneflow.typing as tp
+        from typing import Tuple
+        import numpy as np
+
+        input_shape = (2, 2, 4)
+        @flow.global_function(type="train", function_config=func_config)
+        def maxpool1d_job_with_grad(
+            input_x: tp.Numpy.Placeholder(input_shape),
+        ) -> Tuple[tp.Numpy, tp.Numpy]:
+            x_var = flow.get_variable(
+                name="input_x",
+                shape=input_shape,
+                dtype=flow.float32,
+                initializer=flow.constant_initializer(0),
+                trainable=True,
+            )
+            x_var = flow.cast_to_current_logical_view(x_var)
+            flow.watch_diff(x_var, Setter("x_diff"))
+            x = x_var + input_x
+            # x = flow.cast(x, dtype=flow.int32)
+            with flow.scope.placement("cpu", "0:0"):
+                (y, indice) = flow.nn.MaxPool1d(
+                    x,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    dilation=1,
+                    return_indices=True,
+                    ceil_mode=False,
+                    data_format="NCHW",
+                )
+                flow.optimizer.SGD(
+                    flow.optimizer.PiecewiseConstantScheduler([], [1e-4]), momentum=0
+                ).minimize(y)
+            return (y, indice)
+
+        x = np.arange(16).reshape(input_shape).astype(np.float32)
+        y, indice = maxpool1d_job(x)
+        print("in:\n", x, "\ny:\n", y, "\nindice:\n", indice)
+
+        # x:
+        # [[[ 0.  1.  2.  3.]
+        # [ 4.  5.  6.  7.]]
+
+        # [[ 8.  9. 10. 11.]
+        # [12. 13. 14. 15.]]] 
+        # y:
+        # [[[ 1.  3.]
+        # [ 5.  7.]]
+
+        # [[ 9. 11.]
+        # [13. 15.]]] 
+        # indice:
+        # [[[1 3]
+        # [1 3]]
+
+        # [[1 3]
+        # [1 3]]]
+
+    """
+    assert data_format in ["NCHW"]
+    channel_pos = "channels_last" if data_format == "NHWC" else "channels_first"
+    kernel_size = _GetSequence(kernel_size, 2, "kernel_size")
+    dilation = _GetSequence(dilation, 2, "dilation")
+    stride = _GetSequence(stride, 2, "stride")
+    assert padding >= 0 or padding in ["SAME", "VALID"]
+    if padding >= 0:
+        if data_format == "NCHW":
+            padding = (0, 0, padding, padding)
+        elif data_format == "NHWC":
+            padding = (0, padding, padding, 0)
+        else:
+            raise ValueError('data_format must be "NHWC" or "NCHW".')
+    padding_type, pads_list = calc_pool_padding(padding, get_dhw_offset(channel_pos), 2)
+    padding_before = [pad[0] for pad in pads_list]
+    padding_after = [pad[1] for pad in pads_list]
+
+    expand_input = flow.expand_dims(input=input, axis=2)
+    assert len(pads_list) == len(expand_input.shape) - 2
+    y, indice = (
+        flow.user_op_builder(
+            name if name is not None else id_util.UniqueStr("MaxPool1d_")
+        )
+        .Op("maxpool_2d")
+        .Input("x", [expand_input])
+        .Output("y")
+        .Output("indice")
+        .Attr("data_format", channel_pos)
+        .Attr("stride", stride)
+        .Attr("kernel_size", kernel_size)
+        .Attr("padding", padding_type)
+        .Attr("padding_before", padding_before)
+        .Attr("padding_after", padding_after)
+        .Attr("dilation", dilation)
+        .Attr("return_indices", return_indices)
+        .Attr("ceil_mode", ceil_mode)
+        .Build()
+        .InferAndTryRun()
+        .RemoteBlobList()
+    )
+    y = flow.squeeze(y, axis=(2,))
+    indice = flow.squeeze(indice, axis=(2,))
+    if return_indices == True:
+        return y, indice
+    else:
+        return y
+
+
+@oneflow_export("nn.MaxPool2d")
+@stable_api
+def MaxPool2d(
+    input: oneflow._oneflow_internal.BlobDesc,
+    kernel_size: Union[int, IntPair],
+    stride: Union[int, IntPair],
+    padding: Union[str, int, Tuple[int, int]],
+    dilation: Union[int, IntPair] = 1,
+    return_indices: bool = False,
+    ceil_mode: bool = False,
+    data_format: str = "NCHW",
+    name: Optional[str] = None,
+) -> oneflow._oneflow_internal.BlobDesc:
+    r""" Performs the 2d-max pooling on the input `Blob`.
+         Different from nn.max_pool2d, nn.MaxPool2d supports more params e.g. dilation,return_indices.
+
+    Args:
+        input (remote_blob_util.BlobDesc): A 4-D `Blob` of the format specified by data_format.
+        kernel_size (Union[int, IntPair]): An int or list of ints that has length 1, 2. The size of the window for each dimension of the input `Blob`.
+        stride (Union[int, IntPair]): An int or list of ints that has length 1, 2. The stride of the sliding window for each dimension of the input `Blob`.
+        padding (str): '`VALID'` or '`SAME'` or '`SAME_LOWER'` or '`SAME_UPPER'` or Tuple[IntPair, IntPair, IntPair, IntPair]`. The padding algorithm.
+        dilation (Union[int, IntPair]): a parameter that controls the stride of elements in the window.
+        return_indices (bool): if True, will return the max indices along with the outputs.
+        ceil_mode (bool): when True, will use ceil instead of floor to compute the output shape.
+        data_format (str, optional): '`NHWC'`, '`NCHW'` or '`NCHW_VECT_C'`. Defaults to "NCHW", for now only supporr 'NCHW'.
+        name (Optional[str], optional): This operator's name(optional). Defaults to None.
+
+    Returns:
+        remote_blob_util.BlobDesc:  A `Blob` of format specified by data_format. The max pooled output `Blob`.
+    
+    For example: 
+
+    .. code-block:: python 
+
+        import oneflow as flow
+        import oneflow.typing as tp
+        from typing import Tuple
+        import numpy as np
+
+        input_shape = (1, 2, 4, 4)
+        @flow.global_function(type="predict")
+        def maxpool_job(
+            x: tp.Numpy.Placeholder(input_shape),
+        ) -> Tuple[tp.Numpy, tp.Numpy]:
+            with flow.scope.placement("gpu", "0:0"):
+                (y, indice) = flow.nn.MaxPool2d(
+                    x,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    dilation=1,
+                    return_indices=True,
+                    ceil_mode=False,
+                    data_format="NCHW",
+                )
+            return (y, indice)
+
+        x = np.arange(32).reshape(input_shape).astype(np.float32)
+        y, indice = maxpool_job(x)
+        print("in:\n", x, "\ny:\n", y, "\nindice:\n", indice)
+
+        #in:
+        #[[[[ 0.  1.  2.  3.]
+        #[ 4.  5.  6.  7.]
+        #[ 8.  9. 10. 11.]
+        #[12. 13. 14. 15.]]
+
+        #[[16. 17. 18. 19.]
+        #[20. 21. 22. 23.]
+        #[24. 25. 26. 27.]
+        #[28. 29. 30. 31.]]]] 
+        #y:
+        #[[[[ 5.  7.]
+        #[13. 15.]]
+
+        #[[21. 23.]
+        #[29. 31.]]]] 
+
+        #indice:
+        #[[[[5  7]
+        #[13 15]]
+
+        #[[5  7]
+        #[13 15]]]]
+
+    """
+    assert data_format in ["NCHW"]
+    channel_pos = "channels_last" if data_format == "NHWC" else "channels_first"
+    kernel_size = _GetSequence(kernel_size, 2, "kernel_size")
+    dilation = _GetSequence(dilation, 2, "dilation")
+    stride = _GetSequence(stride, 2, "stride")
+    assert isinstance(padding, int) or len(padding) == 2 or padding in ["SAME", "VALID"]
+
+    if isinstance(padding, int):
+        padding = [padding, padding]
+    if len(padding) == 2:
+        if data_format == "NCHW":
+            padding = (0, 0, padding[0], padding[1])
+        elif data_format == "NHWC":
+            padding = (0, padding[0], padding[1], 0)
+        else:
+            raise ValueError('data_format must be "NHWC" or "NCHW".')
+
+    padding_type, pads_list = calc_pool_padding(padding, get_dhw_offset(channel_pos), 2)
+    padding_before = [pad[0] for pad in pads_list]
+    padding_after = [pad[1] for pad in pads_list]
+    assert len(pads_list) == len(input.shape) - 2
+    y, indice = (
+        flow.user_op_builder(
+            name if name is not None else id_util.UniqueStr("MaxPool2d_")
+        )
+        .Op("maxpool_2d")
+        .Input("x", [input])
+        .Output("y")
+        .Output("indice")
+        .Attr("data_format", channel_pos)
+        .Attr("stride", stride)
+        .Attr("kernel_size", kernel_size)
+        .Attr("padding", padding_type)
+        .Attr("padding_before", padding_before)
+        .Attr("padding_after", padding_after)
+        .Attr("dilation", dilation)
+        .Attr("return_indices", return_indices)
+        .Attr("ceil_mode", ceil_mode)
+        .Build()
+        .InferAndTryRun()
+        .RemoteBlobList()
+    )
+    if return_indices == True:
+        return y, indice
+    else:
+        return y
+
+
+@oneflow_export("nn.MaxPool3d")
+@stable_api
+def MaxPool3d(
+    input: oneflow._oneflow_internal.BlobDesc,
+    kernel_size: Union[int, IntPair],
+    stride: Union[int, IntPair],
+    padding: Union[str, int, Tuple[int, int, int]],
+    dilation: Union[int, IntPair] = 1,
+    return_indices: bool = False,
+    ceil_mode: bool = False,
+    data_format: str = "NCDHW",
+    name: Optional[str] = None,
+) -> oneflow._oneflow_internal.BlobDesc:
+    r""" Performs the 3d-max pooling on the input `Blob`.
+         Different from nn.max_pool3d, nn.MaxPool3d supports more params e.g. dilation,return_indices.
+
+    Args:
+        input (remote_blob_util.BlobDesc): A 5-D `Blob` of the format specified by data_format.
+        kernel_size (Union[int, IntPair]): An int or list of ints that has length 1, 2. The size of the window for each dimension of the input `Blob`.
+        stride (Union[int, IntPair]): An int or list of ints that has length 1, 2. The stride of the sliding window for each dimension of the input `Blob`.
+        padding (str): '`VALID'` or '`SAME'` or '`SAME_LOWER'` or '`SAME_UPPER'` or int value or Tuple[int, int, int]`. The padding algorithm.
+        dilation (Union[int, IntPair]): a parameter that controls the stride of elements in the window.
+        return_indices (bool): if True, will return the max indices along with the outputs.
+        ceil_mode (bool): when True, will use ceil instead of floor to compute the output shape.
+        data_format (str, optional): '`NCDHW'`, '`NCHWD'`. Defaults to "NCDHW", for now only supporr 'NCDHW'.
+        name (Optional[str], optional): This operator's name(optional). Defaults to None.
+
+    Returns:
+        remote_blob_util.BlobDesc:  A `Blob` of format specified by data_format. The max pooled output `Blob`.
+    
+    For example: 
+
+    .. code-block:: python 
+
+        import oneflow as flow
+        import oneflow.typing as tp
+        from typing import Tuple
+        import numpy as np
+
+        input_shape = (1, 1, 2, 4, 4)
+        @flow.global_function(type="predict")
+        def maxpool3d_job(
+            x: tp.Numpy.Placeholder(input_shape),
+        ) -> tp.Numpy:
+            with flow.scope.placement("gpu", "0:0"):
+                (y, indice) = flow.nn.MaxPool3d(
+                    input=x,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    dilation=1,
+                    return_indices=True,
+                    ceil_mode=False,
+                    data_format="NCDHW",
+                )
+            return (y, indice)
+
+        x = np.arange(32).reshape(input_shape).astype(np.float32)
+        y, indice = maxpool3d_job(x)
+        print("in:\n", x, "\ny:\n", y, "\nindice:\n", indice)
+
+        # in:
+        # [[[[[ 0.  1.  2.  3.]
+        #     [ 4.  5.  6.  7.]
+        #     [ 8.  9. 10. 11.]
+        #     [12. 13. 14. 15.]]
+
+        # [[16. 17. 18. 19.]
+        #     [20. 21. 22. 23.]
+        #     [24. 25. 26. 27.]
+        #     [28. 29. 30. 31.]]]]] 
+        # y:
+        # [[[[[21. 22. 23. 23.]
+        #     [25. 26. 27. 27.]
+        #     [29. 30. 31. 31.]
+        #     [29. 30. 31. 31.]]
+
+        # [[21. 22. 23. 23.]
+        #     [25. 26. 27. 27.]
+        #     [29. 30. 31. 31.]
+        #     [29. 30. 31. 31.]]]]] 
+        # indice:
+        # [[[[[21 22 23 23]
+        #     [25 26 27 27]
+        #     [29 30 31 31]
+        #     [29 30 31 31]]
+
+        # [[21 22 23 23]
+        #     [25 26 27 27]
+        #     [29 30 31 31]
+        #     [29 30 31 31]]]]] 
+
+    """
+    assert data_format in ["NCDHW"]
+    channel_pos = "channels_first" if data_format == "NCDHW" else "channels_last"
+    kernel_size = _GetSequence(kernel_size, 3, "kernel_size")
+    dilation = _GetSequence(dilation, 3, "dilation")
+    stride = _GetSequence(stride, 3, "stride")
+    assert (
+        isinstance(padding, int)
+        or isinstance(padding, Tuple)
+        or padding in ["SAME", "VALID"]
+    )
+    if isinstance(padding, int):
+        padding = (padding, padding, padding)
+    if len(padding) == 3:
+        if data_format == "NCDHW":
+            padding = (0, 0, padding[0], padding[1], padding[2])
+        elif data_format == "NDHWC":
+            padding = (0, padding[0], padding[1], padding[2], 0)
+        else:
+            raise ValueError('data_format must be "NHWDC" or "NCDHW".')
+    padding_type, pads_list = calc_pool_padding(padding, get_dhw_offset(channel_pos), 3)
+    padding_before = [pad[0] for pad in pads_list]
+    padding_after = [pad[1] for pad in pads_list]
+    assert len(pads_list) == len(input.shape) - 2
+    y, indice = (
+        flow.user_op_builder(
+            name if name is not None else id_util.UniqueStr("MaxPool3d_")
+        )
+        .Op("maxpool_3d")
+        .Input("x", [input])
+        .Output("y")
+        .Output("indice")
+        .Attr("data_format", channel_pos)
+        .Attr("stride", stride)
+        .Attr("kernel_size", kernel_size)
+        .Attr("padding", padding_type)
+        .Attr("padding_before", padding_before)
+        .Attr("padding_after", padding_after)
+        .Attr("dilation", dilation)
+        .Attr("return_indices", return_indices)
+        .Attr("ceil_mode", ceil_mode)
+        .Build()
+        .InferAndTryRun()
+        .RemoteBlobList()
+    )
+    if return_indices == True:
+        return y, indice
+    else:
+        return y
+
+
 @oneflow_export("nn.max_pool2d")
 def max_pool2d(
     input: oneflow._oneflow_internal.BlobDesc,
diff --git a/oneflow/python/test/modules/test_maxpool.py b/oneflow/python/test/modules/test_pooling.py
similarity index 58%
rename from oneflow/python/test/modules/test_maxpool.py
rename to oneflow/python/test/modules/test_pooling.py
index 08f23c6a7ae2c69adaf2c91098f60cf6f66f1228..ad1951c8a4f7e45040d711aaf9055389c4017f08 100644
--- a/oneflow/python/test/modules/test_maxpool.py
+++ b/oneflow/python/test/modules/test_pooling.py
@@ -158,21 +158,224 @@ class MaxPoolNumpy:
         return dx
 
 
+def _test_maxpool1d_impl(test_case, device):
+    input_arr = np.array(
+        [
+            [
+                [-0.89042996, 2.33971243, -0.86660827, 0.80398747],
+                [-1.46769364, -0.78125064, 1.50086563, -0.76278226],
+                [1.31984534, 0.20741192, -0.86507054, -0.40776015],
+                [-0.89910823, 0.44932938, 1.49148118, -0.22036761],
+            ],
+            [
+                [-0.5452334, -0.10255169, -1.42035108, 0.73922913],
+                [-0.03192764, 0.69341935, 0.96263152, -1.52070843],
+                [0.02058239, 1.504032, 1.84423001, -0.0130596],
+                [2.20517719, 0.38449598, 0.85677771, 0.60425179],
+            ],
+            [
+                [-1.64366213, 0.51370298, -0.21754866, -0.05085382],
+                [1.17065374, 1.13857674, -1.13070507, 0.44353707],
+                [-1.30783846, -0.48031445, 0.41807536, -2.13778887],
+                [0.08259005, 0.5798125, 0.03024696, 1.96100924],
+            ],
+        ]
+    )
+    kernel_size, stride, padding = (3,), (1,), (1,)
+
+    output = np.array(
+        [
+            [
+                [2.33971243, 2.33971243, 2.33971243, 0.80398747],
+                [-0.78125064, 1.50086563, 1.50086563, 1.50086563],
+                [1.31984534, 1.31984534, 0.20741192, -0.40776015],
+                [0.44932938, 1.49148118, 1.49148118, 1.49148118],
+            ],
+            [
+                [-0.10255169, -0.10255169, 0.73922913, 0.73922913],
+                [0.69341935, 0.96263152, 0.96263152, 0.96263152],
+                [1.504032, 1.84423001, 1.84423001, 1.84423001],
+                [2.20517719, 2.20517719, 0.85677771, 0.85677771],
+            ],
+            [
+                [0.51370298, 0.51370298, 0.51370298, -0.05085382],
+                [1.17065374, 1.17065374, 1.13857674, 0.44353707],
+                [-0.48031445, 0.41807536, 0.41807536, 0.41807536],
+                [0.5798125, 0.5798125, 1.96100924, 1.96100924],
+            ],
+        ]
+    )
+
+    output_indice = np.array(
+        [
+            [[1, 1, 1, 3], [1, 2, 2, 2], [0, 0, 1, 3], [1, 2, 2, 2]],
+            [[1, 1, 3, 3], [1, 2, 2, 2], [1, 2, 2, 2], [0, 0, 2, 2]],
+            [[1, 1, 1, 3], [0, 0, 1, 3], [1, 2, 2, 2], [1, 1, 3, 3]],
+        ]
+    )
+
+    grad = np.array(
+        [
+            [
+                [0.0, 3.0, 0.0, 1.0],
+                [0.0, 1.0, 3.0, 0.0],
+                [2.0, 1.0, 0.0, 1.0],
+                [0.0, 1.0, 3.0, 0.0],
+            ],
+            [
+                [0.0, 2.0, 0.0, 2.0],
+                [0.0, 1.0, 3.0, 0.0],
+                [0.0, 1.0, 3.0, 0.0],
+                [2.0, 0.0, 2.0, 0.0],
+            ],
+            [
+                [0.0, 3.0, 0.0, 1.0],
+                [2.0, 1.0, 0.0, 1.0],
+                [0.0, 1.0, 3.0, 0.0],
+                [0.0, 2.0, 0.0, 2.0],
+            ],
+        ]
+    )
+
+    m = flow.nn.MaxPool1d(
+        kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True
+    )
+    m.to(flow.device(device))
+    x = flow.Tensor(input_arr, device=flow.device(device), requires_grad=True)
+    of_output, of_indice = m(x)
+
+    y = of_output.sum()
+    y.backward()
+
+    test_case.assertTrue(np.allclose(x.grad.numpy(), grad, 1e-4, 1e-4))
+    test_case.assertTrue(np.allclose(of_indice.numpy(), output_indice, 1e-4, 1e-4))
+    test_case.assertTrue(np.allclose(of_output.numpy(), output, 1e-4, 1e-4))
+
+
 def _test_maxpool2d(test_case, device):
     dim = 2
-    input_arr = np.random.randn(6, 4, 7, 9)
-    kernel_size, stride, padding = (4, 4), (1, 1), (1, 2)
+
+    input_arr = np.random.randn(2, 3, 4, 5)
+    kernel_size, stride, padding = (3, 3), (1, 1), (1, 1)
 
     m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding)
     numpy_output = m_numpy(input_arr)
 
-    m = flow.nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding)
+    m = flow.nn.MaxPool2d(
+        kernel_size=kernel_size, stride=stride, padding=padding, return_indices=True
+    )
     m.to(flow.device(device))
     x = flow.Tensor(input_arr, device=flow.device(device))
-    output = m(x)
+    output, indice = m(x)
+    test_case.assertTrue(indice.shape == x.shape)
     test_case.assertTrue(np.allclose(numpy_output, output.numpy(), 1e-4, 1e-4))
 
 
+def _test_maxpool2d_ceil_mode(test_case, device):
+    dim = 2
+    input_arr = np.array(
+        [
+            [
+                [
+                    [-0.89042996, 2.33971243, -0.86660827, 0.80398747],
+                    [-1.46769364, -0.78125064, 1.50086563, -0.76278226],
+                    [1.31984534, 0.20741192, -0.86507054, -0.40776015],
+                    [-0.89910823, 0.44932938, 1.49148118, -0.22036761],
+                ],
+                [
+                    [-0.5452334, -0.10255169, -1.42035108, 0.73922913],
+                    [-0.03192764, 0.69341935, 0.96263152, -1.52070843],
+                    [0.02058239, 1.504032, 1.84423001, -0.0130596],
+                    [2.20517719, 0.38449598, 0.85677771, 0.60425179],
+                ],
+                [
+                    [-1.64366213, 0.51370298, -0.21754866, -0.05085382],
+                    [1.17065374, 1.13857674, -1.13070507, 0.44353707],
+                    [-1.30783846, -0.48031445, 0.41807536, -2.13778887],
+                    [0.08259005, 0.5798125, 0.03024696, 1.96100924],
+                ],
+            ],
+            [
+                [
+                    [0.45173843, -0.34680027, -0.99754943, 0.18539502],
+                    [-0.68451047, -0.03217399, 0.44705642, -0.39016231],
+                    [-0.18062337, 1.82099303, -0.19113869, 0.85298683],
+                    [0.14080452, 0.15306701, -1.02466827, -0.34480665],
+                ],
+                [
+                    [-0.21048489, 0.20933038, -0.09206508, -1.80402519],
+                    [-0.52028985, 0.01140166, -1.13452858, 0.96648332],
+                    [0.26454393, 0.48343972, -1.84055509, -0.01256443],
+                    [0.31024029, 0.11983007, 0.98806488, 0.93557438],
+                ],
+                [
+                    [0.39152445, 0.672159, 0.71289289, -0.68072016],
+                    [0.33711062, -1.78106242, 0.34545201, -1.62029359],
+                    [0.47343899, -2.3433269, -0.44517497, 0.09004267],
+                    [0.26310742, -1.53121271, 0.65028836, 1.3669488],
+                ],
+            ],
+        ]
+    )
+
+    ceil_mode_out = np.array(
+        [
+            [
+                [
+                    [2.33971243, 2.33971243, 0.80398747],
+                    [1.31984534, 1.50086563, -0.22036761],
+                    [0.44932938, 1.49148118, -0.22036761],
+                ],
+                [
+                    [0.69341935, 0.96263152, 0.73922913],
+                    [2.20517719, 1.84423001, 0.60425179],
+                    [2.20517719, 0.85677771, 0.60425179],
+                ],
+                [
+                    [1.17065374, 1.13857674, 0.44353707],
+                    [1.17065374, 1.96100924, 1.96100924],
+                    [0.5798125, 1.96100924, 1.96100924],
+                ],
+            ],
+            [
+                [
+                    [0.45173843, 0.44705642, 0.18539502],
+                    [1.82099303, 1.82099303, 0.85298683],
+                    [0.15306701, 0.15306701, -0.34480665],
+                ],
+                [
+                    [0.20933038, 0.96648332, 0.96648332],
+                    [0.48343972, 0.98806488, 0.96648332],
+                    [0.31024029, 0.98806488, 0.93557438],
+                ],
+                [
+                    [0.672159, 0.71289289, -0.68072016],
+                    [0.47343899, 1.3669488, 1.3669488],
+                    [0.26310742, 1.3669488, 1.3669488],
+                ],
+            ],
+        ]
+    )
+    kernel_size, stride, padding = (3, 3), (2, 2), (1, 1)
+
+    m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding)
+    numpy_output = m_numpy(input_arr)
+
+    m1 = flow.nn.MaxPool2d(
+        kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=False
+    )
+    m2 = flow.nn.MaxPool2d(
+        kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=True
+    )
+    m1.to(flow.device(device))
+    m2.to(flow.device(device))
+    x = flow.Tensor(input_arr, device=flow.device(device))
+    output1 = m1(x)
+    output2 = m2(x)
+    test_case.assertTrue(np.allclose(numpy_output, output1.numpy(), 1e-4, 1e-4))
+    test_case.assertTrue(np.allclose(ceil_mode_out, output2.numpy(), 1e-4, 1e-4))
+
+
 def _test_maxpool2d_special_kernel_size(test_case, device):
     dim = 2
     input_arr = np.random.randn(1, 1, 6, 6)
@@ -191,7 +394,7 @@ def _test_maxpool2d_special_kernel_size(test_case, device):
 def _test_maxpool2d_diff_kernel_stride(test_case, device):
     dim = 2
     input_arr = np.random.randn(9, 7, 32, 20)
-    kernel_size, stride, padding = (2, 3), (4, 5), (1, 2)
+    kernel_size, stride, padding = (2, 4), (4, 5), (1, 2)
 
     m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding)
     numpy_output = m_numpy(input_arr)
@@ -205,7 +408,7 @@ def _test_maxpool2d_diff_kernel_stride(test_case, device):
 
 def _test_maxpool2d_negative_input(test_case, device):
     dim = 2
-    input_arr = -1.23456 * np.ones((1, 1, 1, 1), dtype=np.float)
+    input_arr = -1.23456 * np.ones((1, 1, 1, 1), dtype=np.float32)
     kernel_size, stride, padding = (5, 5), (5, 5), (2, 2)
 
     m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding)
@@ -235,7 +438,7 @@ def _test_maxpool2d_backward(test_case, device):
     output.backward()
     doutput = np.ones_like(numpy_output, dtype=np.float64)
     numpy_grad = m_numpy.backward(doutput)
-    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5))
+    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4))
 
 
 def _test_maxpool2d_special_kernel_size_backward(test_case, device):
@@ -255,13 +458,13 @@ def _test_maxpool2d_special_kernel_size_backward(test_case, device):
     output.backward()
     doutput = np.ones_like(numpy_output, dtype=np.float64)
     numpy_grad = m_numpy.backward(doutput)
-    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5))
+    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4))
 
 
 def _test_maxpool2d_diff_kernel_stride_backward(test_case, device):
     dim = 2
     input_arr = np.random.randn(9, 7, 32, 20)
-    kernel_size, stride, padding = (2, 3), (4, 5), (1, 2)
+    kernel_size, stride, padding = (2, 4), (4, 5), (1, 2)
 
     m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding)
     numpy_output = m_numpy(input_arr)
@@ -275,12 +478,12 @@ def _test_maxpool2d_diff_kernel_stride_backward(test_case, device):
     output.backward()
     doutput = np.ones_like(numpy_output, dtype=np.float64)
     numpy_grad = m_numpy.backward(doutput)
-    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5))
+    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4))
 
 
 def _test_maxpool2d_negative_input_backward(test_case, device):
     dim = 2
-    input_arr = -1.23456 * np.ones((1, 1, 1, 1), dtype=np.float)
+    input_arr = -1.23456 * np.ones((1, 1, 1, 1), dtype=np.float32)
     kernel_size, stride, padding = (5, 5), (5, 5), (2, 2)
 
     m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding)
@@ -295,7 +498,22 @@ def _test_maxpool2d_negative_input_backward(test_case, device):
     output.backward()
     doutput = np.ones_like(numpy_output, dtype=np.float64)
     numpy_grad = m_numpy.backward(doutput)
-    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5))
+    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4))
+
+
+def _test_maxpool3d(test_case, device):
+    dim = 3
+    input_arr = np.random.randn(2, 3, 7, 9, 13)
+    kernel_size, stride, padding = (2, 3, 4), (2, 3, 4), (1, 1, 2)
+
+    m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding)
+    numpy_output = m_numpy(input_arr)
+
+    m = flow.nn.MaxPool3d(kernel_size=kernel_size, stride=stride, padding=padding)
+    m.to(flow.device(device))
+    x = flow.Tensor(input_arr, device=flow.device(device))
+    output = m(x)
+    test_case.assertTrue(np.allclose(numpy_output, output.numpy(), 1e-4, 1e-4))
 
 
 def _test_maxpool3d_backward(test_case, device):
@@ -316,7 +534,7 @@ def _test_maxpool3d_backward(test_case, device):
     output.backward()
     doutput = np.ones_like(numpy_output, dtype=np.float64)
     numpy_grad = m_numpy.backward(doutput)
-    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5))
+    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4))
 
 
 def _test_maxpool3d_special_kernel_size_backward(test_case, device):
@@ -337,12 +555,33 @@ def _test_maxpool3d_special_kernel_size_backward(test_case, device):
     output.backward()
     doutput = np.ones_like(numpy_output, dtype=np.float64)
     numpy_grad = m_numpy.backward(doutput)
-    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5))
+    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4))
+
+
+def _test_maxpool3d_diff_kernel_stride_backward(test_case, device):
+    dim = 3
+    input_arr = np.random.randn(9, 7, 48, 32, 20)
+    kernel_size, stride, padding = (6, 2, 4), (5, 4, 5), (3, 1, 2)
+
+    m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding)
+    numpy_output = m_numpy(input_arr)
+
+    m = flow.nn.MaxPool3d(kernel_size=kernel_size, stride=stride, padding=padding)
+    m.to(flow.device(device))
+    x = flow.Tensor(input_arr, requires_grad=True, device=flow.device(device))
+    output = m(x)
+    test_case.assertTrue(np.allclose(numpy_output, output.numpy(), 1e-4, 1e-4))
+
+    output = output.sum()
+    output.backward()
+    doutput = np.ones_like(numpy_output, dtype=np.float64)
+    numpy_grad = m_numpy.backward(doutput)
+    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4))
 
 
 def _test_maxpool3d_negative_input_backward(test_case, device):
     dim = 3
-    input_arr = -1.23456 * np.ones((1, 1, 1, 1, 1), dtype=np.float)
+    input_arr = -1.23456 * np.ones((1, 1, 1, 1, 1), dtype=np.float32)
     kernel_size, stride, padding = (5, 5, 5), (5, 5, 5), (2, 2, 2)
 
     m_numpy = MaxPoolNumpy(dim, kernel_size, stride, padding)
@@ -358,18 +597,28 @@ def _test_maxpool3d_negative_input_backward(test_case, device):
     output.backward()
     doutput = np.ones_like(numpy_output, dtype=np.float64)
     numpy_grad = m_numpy.backward(doutput)
-    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-5, 1e-5))
+    test_case.assertTrue(np.allclose(x.grad.numpy(), numpy_grad, 1e-4, 1e-4))
 
 
 @unittest.skipIf(
     not flow.unittest.env.eager_execution_enabled(),
     ".numpy() doesn't work in lazy mode",
 )
-class TestPoolingModule(flow.unittest.TestCase):
+class TestPooling(flow.unittest.TestCase):
+    def test_maxpool1d(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_maxpool1d_impl,
+        ]
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
     def test_maxpool2d(test_case):
         arg_dict = OrderedDict()
         arg_dict["test_fun"] = [
             _test_maxpool2d,
+            _test_maxpool2d_ceil_mode,
             _test_maxpool2d_special_kernel_size,
             _test_maxpool2d_diff_kernel_stride,
             _test_maxpool2d_negative_input,
@@ -378,6 +627,7 @@ class TestPoolingModule(flow.unittest.TestCase):
             _test_maxpool2d_diff_kernel_stride_backward,
             _test_maxpool2d_negative_input_backward,
         ]
+
         arg_dict["device"] = ["cpu", "cuda"]
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
@@ -385,9 +635,11 @@ class TestPoolingModule(flow.unittest.TestCase):
     def test_maxpool3d(test_case):
         arg_dict = OrderedDict()
         arg_dict["test_fun"] = [
+            _test_maxpool3d,
             _test_maxpool3d_backward,
             _test_maxpool3d_special_kernel_size_backward,
             _test_maxpool3d_negative_input_backward,
+            _test_maxpool3d_diff_kernel_stride_backward,
         ]
         arg_dict["device"] = ["cpu", "cuda"]
         for arg in GenArgList(arg_dict):
diff --git a/oneflow/user/kernels/pooling_kernel.cpp b/oneflow/user/kernels/pooling_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..39213ea3b42db8451331349f978f58933fff8eff
--- /dev/null
+++ b/oneflow/user/kernels/pooling_kernel.cpp
@@ -0,0 +1,252 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/user/kernels/pooling_kernel_util.h"
+
+namespace oneflow {
+
+struct PoolingOpKernelState final : public user_op::OpKernelState {
+  PoolingParams3D params_3d;
+  PoolingOpKernelState(PoolingParams3D params_3d) : params_3d(params_3d) {}
+  const PoolingParams3D& GetParams3D() { return params_3d; }
+};
+
+std::shared_ptr<PoolingOpKernelState> DoCreateOpKernelState(user_op::KernelComputeContext* ctx,
+                                                            const int32_t& dim) {
+  const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape();
+  const std::string& padding = ctx->Attr<std::string>("padding");
+  const std::string& data_format = ctx->Attr<std::string>("data_format");
+  const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
+  const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
+  const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
+  const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride");
+  const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation");
+  const bool return_indices = ctx->Attr<bool>("return_indices");
+  const bool ceil_mode = ctx->Attr<bool>("ceil_mode");
+
+  PoolingParams3D params_3d =
+      PoolingParams3D(dim, x_shape, data_format, padding, padding_before, padding_after,
+                      kernel_size, stride, dilation, return_indices, ceil_mode);
+  std::shared_ptr<PoolingOpKernelState> state(new PoolingOpKernelState(params_3d));
+  return std::move(state);
+}
+
+template<typename T>
+struct PoolingKernelUtil<DeviceType::kCPU, T> {
+  static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
+                               const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                               const PoolingParams3D& params_3d) {
+    Maxpool2dFarwardCompute<T>(
+        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1],
+        params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(),
+        params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(3),
+        params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[1],
+        params_3d.pooling_size_3d()[2], params_3d.stride_3d()[1], params_3d.stride_3d()[2],
+        params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]);
+  }
+
+  static void Maxpool2dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
+                                const int64_t elem_num, const T* src, T* dest,
+                                const int64_t* indice_ptr, const PoolingParams3D& params_3d) {
+    Maxpool2dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr,
+                                params_3d.num_batch(), params_3d.num_channel(),
+                                params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
+                                params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));
+  }
+
+  static void Maxpool3dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper,
+                               const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                               const PoolingParams3D& params_3d) {
+    Maxpool3dFarwardCompute<T>(
+        index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0],
+        params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2], params_3d.num_batch(),
+        params_3d.num_channel(), params_3d.GetXShape5D().At(2), params_3d.GetXShape5D().At(3),
+        params_3d.GetXShape5D().At(4), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),
+        params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0],
+        params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2], params_3d.stride_3d()[0],
+        params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[0],
+        params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]);
+  }
+
+  static void Maxpool3dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5> index_helper,
+                                const int64_t elem_num, const T* src, T* dest,
+                                const int64_t* indice_ptr, const PoolingParams3D& params_3d) {
+    Maxpool3dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr,
+                                params_3d.num_batch(), params_3d.num_channel(),
+                                params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),
+                                params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(2),
+                                params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));
+  }
+};
+
+template<DeviceType device_type, typename T>
+class MaxPool2dKernel final : public user_op::OpKernel {
+ public:
+  MaxPool2dKernel() = default;
+  ~MaxPool2dKernel() = default;
+
+ private:
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
+    user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0);
+
+    const auto& pooling_state = DoCreateOpKernelState(ctx, 2);
+    const PoolingParams3D& params_3d = pooling_state->GetParams3D();
+
+    const int64_t elem_num = y->shape().elem_cnt();
+    const T* src = x->dptr<T>();
+    T* dest = y->mut_dptr<T>();
+    int64_t* indice_ptr = indice->mut_dptr<int64_t>();
+
+    DimVector y_vector;
+    y->shape().ToDimVector(&y_vector);
+    NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());
+
+    PoolingKernelUtil<device_type, T>::Maxpool2dForward(ctx->device_ctx(), index_helper, elem_num,
+                                                        src, dest, indice_ptr, params_3d);
+  };
+};
+
+template<DeviceType device_type, typename T>
+class MaxPool2dGradKernel final : public user_op::OpKernel {
+ public:
+  MaxPool2dGradKernel() = default;
+  ~MaxPool2dGradKernel() = default;
+
+ private:
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0);
+    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    const auto& pooling_state = DoCreateOpKernelState(ctx, 2);
+    const PoolingParams3D& params_3d = pooling_state->GetParams3D();
+
+    const int64_t elem_num = dy->shape().elem_cnt();
+    const T* src = dy->dptr<T>();
+    const int64_t* indice_ptr = indice->dptr<int64_t>();
+    T* dest = dx->mut_dptr<T>();
+    DimVector dy_vector;
+    dy->shape().ToDimVector(&dy_vector);
+    NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data());
+
+    size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type());
+    Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
+
+    PoolingKernelUtil<device_type, T>::Maxpool2dBackward(ctx->device_ctx(), index_helper, elem_num,
+                                                         src, dest, indice_ptr, params_3d);
+  };
+};
+
+template<DeviceType device_type, typename T>
+class MaxPool3dKernel final : public user_op::OpKernel {
+ public:
+  MaxPool3dKernel() = default;
+  ~MaxPool3dKernel() = default;
+
+ private:
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
+    user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0);
+
+    const auto& pooling_state = DoCreateOpKernelState(ctx, 3);
+    const PoolingParams3D& params_3d = pooling_state->GetParams3D();
+
+    const int64_t elem_num = y->shape().elem_cnt();
+    const T* src = x->dptr<T>();
+    T* dest = y->mut_dptr<T>();
+    int64_t* indice_ptr = indice->mut_dptr<int64_t>();
+
+    DimVector y_vector;
+    y->shape().ToDimVector(&y_vector);
+    NdIndexOffsetHelper<int64_t, 5> index_helper(y_vector.data());
+
+    PoolingKernelUtil<device_type, T>::Maxpool3dForward(ctx->device_ctx(), index_helper, elem_num,
+                                                        src, dest, indice_ptr, params_3d);
+  };
+};
+
+template<DeviceType device_type, typename T>
+class MaxPool3dGradKernel final : public user_op::OpKernel {
+ public:
+  MaxPool3dGradKernel() = default;
+  ~MaxPool3dGradKernel() = default;
+
+ private:
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0);
+    user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    const auto& pooling_state = DoCreateOpKernelState(ctx, 3);
+    const PoolingParams3D& params_3d = pooling_state->GetParams3D();
+
+    const int64_t elem_num = dy->shape().elem_cnt();
+    const T* src = dy->dptr<T>();
+    const int64_t* indice_ptr = indice->dptr<int64_t>();
+    T* dest = dx->mut_dptr<T>();
+
+    DimVector dy_vector;
+    dy->shape().ToDimVector(&dy_vector);
+    NdIndexOffsetHelper<int64_t, 5> index_helper(dy_vector.data());
+
+    size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type());
+    Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
+
+    PoolingKernelUtil<device_type, T>::Maxpool3dBackward(ctx->device_ctx(), index_helper, elem_num,
+                                                         src, dest, indice_ptr, params_3d);
+  };
+};
+
+#define REGISTER_POOLING_KERNELS(device, dtype)                                        \
+  REGISTER_USER_KERNEL("maxpool_2d")                                                   \
+      .SetCreateFn<MaxPool2dKernel<device, dtype>>()                                   \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("maxpool_2d_grad")                                              \
+      .SetCreateFn<MaxPool2dGradKernel<device, dtype>>()                               \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("maxpool_3d")                                                   \
+      .SetCreateFn<MaxPool3dKernel<device, dtype>>()                                   \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("maxpool_3d_grad")                                              \
+      .SetCreateFn<MaxPool3dGradKernel<device, dtype>>()                               \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("x", 0) == GetDataType<dtype>::value));
+
+#define REGISTER_POOLING_WITH_DEVICE(device) \
+  REGISTER_POOLING_KERNELS(device, int32_t)  \
+  REGISTER_POOLING_KERNELS(device, float)    \
+  REGISTER_POOLING_KERNELS(device, double)
+
+REGISTER_POOLING_WITH_DEVICE(DeviceType::kCPU)
+
+#ifdef WITH_CUDA
+REGISTER_POOLING_WITH_DEVICE(DeviceType::kGPU)
+// TODO: REGISTER_POOLING_KERNELS(DeviceType::kGPU, float16)
+#endif
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_POOLING_KERNEL_UTIL, (DeviceType::kCPU),
+                                 POOLING_DATA_TYPE_CPU_SEQ);
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/pooling_kernel.cu b/oneflow/user/kernels/pooling_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..7aef65319fe28441f044e5a1e819ca7779bcc295
--- /dev/null
+++ b/oneflow/user/kernels/pooling_kernel.cu
@@ -0,0 +1,148 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include <cstdint>
+#ifdef WITH_CUDA
+#include "oneflow/core/cuda/elementwise.cuh"
+#include "oneflow/user/kernels/pooling_kernel_util.h"
+
+namespace oneflow {
+
+constexpr int kBlockSize = cuda::elementwise::kBlockSize;
+
+const int GetMinThreadNum(int64_t elem_num) { return std::min<int64_t>(elem_num, kBlockSize); }
+
+int GetNumBlocks(int64_t elem_cnt) {
+  int num_blocks = 0;
+  OF_CUDA_CHECK(cuda::elementwise::GetNumBlocks(elem_cnt, &num_blocks));
+  return num_blocks;
+}
+
+template<typename T>
+__launch_bounds__(kBlockSize) __global__
+    void DoCUDAMaxPool2dForward(const NdIndexOffsetHelper<int64_t, 4> index_helper,
+                                int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                                int32_t padding_h, int32_t padding_w, int64_t n_batch,
+                                int64_t n_channel, int64_t x_height, int64_t x_width,
+                                int64_t y_height, int64_t y_width, int32_t kernel_size_h,
+                                int32_t kernel_size_w, int32_t stride_h, int32_t stride_w,
+                                int32_t dilation_h, int32_t dilation_w) {
+  Maxpool2dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_h, padding_w,
+                             n_batch, n_channel, x_height, x_width, y_height, y_width,
+                             kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h,
+                             dilation_w);
+};
+
+template<typename T>
+__launch_bounds__(kBlockSize) __global__
+    void DoCUDAMaxPool3dForward(const NdIndexOffsetHelper<int64_t, 5> index_helper,
+                                int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                                int32_t padding_t, int32_t padding_h, int32_t padding_w,
+                                int64_t n_batch, int64_t n_channel, int64_t x_time,
+                                int64_t x_height, int64_t x_width, int64_t y_time, int64_t y_height,
+                                int64_t y_width, int32_t kernel_size_t, int32_t kernel_size_h,
+                                int32_t kernel_size_w, int32_t stride_t, int32_t stride_h,
+                                int32_t stride_w, int32_t dilation_t, int32_t dilation_h,
+                                int32_t dilation_w) {
+  Maxpool3dFarwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, padding_t, padding_h,
+                             padding_w, n_batch, n_channel, x_time, x_height, x_width, y_time,
+                             y_height, y_width, kernel_size_t, kernel_size_h, kernel_size_w,
+                             stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w);
+};
+
+template<typename T>
+__launch_bounds__(kBlockSize) __global__
+    void DoCUDAMaxPool2dBackward(const NdIndexOffsetHelper<int64_t, 4> index_helper,
+                                 const int64_t elem_num, const T* src, T* dest,
+                                 const int64_t* indice_ptr, const int64_t n_batch,
+                                 const int64_t n_channel, const int64_t src_height,
+                                 const int64_t src_width, const int64_t dst_height,
+                                 const int64_t dst_width) {
+  Maxpool2dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel,
+                              src_height, src_width, dst_height, dst_width);
+};
+
+template<typename T>
+__launch_bounds__(kBlockSize) __global__
+    void DoCUDAMaxPool3dBackward(const NdIndexOffsetHelper<int64_t, 5> index_helper,
+                                 const int64_t elem_num, const T* src, T* dest,
+                                 const int64_t* indice_ptr, const int64_t n_batch,
+                                 const int64_t n_channel, const int64_t src_time,
+                                 const int64_t src_height, const int64_t src_width,
+                                 const int64_t dst_time, const int64_t dst_height,
+                                 const int64_t dst_width) {
+  Maxpool3dBackwardCompute<T>(index_helper, elem_num, src, dest, indice_ptr, n_batch, n_channel,
+                              src_time, src_height, src_width, dst_time, dst_height, dst_width);
+};
+
+template<typename T>
+struct PoolingKernelUtil<DeviceType::kGPU, T> {
+  static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
+                               const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                               const PoolingParams3D& params_3d) {
+    DoCUDAMaxPool2dForward<T>
+        <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
+            index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[1],
+            params_3d.padding_before_3d()[2], params_3d.num_batch(), params_3d.num_channel(),
+            params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),
+            params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
+            params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2],
+            params_3d.stride_3d()[1], params_3d.stride_3d()[2], params_3d.dilation_3d()[1],
+            params_3d.dilation_3d()[2]);
+  }
+
+  static void Maxpool2dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
+                                const int64_t elem_num, const T* src, T* dest,
+                                const int64_t* indice_ptr, const PoolingParams3D& params_3d) {
+    DoCUDAMaxPool2dBackward<T>
+        <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
+            index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),
+            params_3d.num_channel(), params_3d.GetYShape5D().At(3), params_3d.GetYShape5D().At(4),
+            params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));
+  }
+
+  static void Maxpool3dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper,
+                               const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                               const PoolingParams3D& params_3d) {
+    DoCUDAMaxPool3dForward<T>
+        <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
+            index_helper, elem_num, src, dest, indice_ptr, params_3d.padding_before_3d()[0],
+            params_3d.padding_before_3d()[1], params_3d.padding_before_3d()[2],
+            params_3d.num_batch(), params_3d.num_channel(), params_3d.GetXShape5D().At(2),
+            params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4),
+            params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),
+            params_3d.GetYShape5D().At(4), params_3d.pooling_size_3d()[0],
+            params_3d.pooling_size_3d()[1], params_3d.pooling_size_3d()[2],
+            params_3d.stride_3d()[0], params_3d.stride_3d()[1], params_3d.stride_3d()[2],
+            params_3d.dilation_3d()[0], params_3d.dilation_3d()[1], params_3d.dilation_3d()[2]);
+  }
+
+  static void Maxpool3dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper,
+                                const int64_t elem_num, const T* src, T* dest,
+                                const int64_t* indice_ptr, const PoolingParams3D& params_3d) {
+    DoCUDAMaxPool3dBackward<T>
+        <<<GetNumBlocks(elem_num), GetMinThreadNum(elem_num), 0, ctx->cuda_stream()>>>(
+            index_helper, elem_num, src, dest, indice_ptr, params_3d.num_batch(),
+            params_3d.num_channel(), params_3d.GetYShape5D().At(2), params_3d.GetYShape5D().At(3),
+            params_3d.GetYShape5D().At(4), params_3d.GetXShape5D().At(2),
+            params_3d.GetXShape5D().At(3), params_3d.GetXShape5D().At(4));
+  }
+};
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_POOLING_KERNEL_UTIL, (DeviceType::kGPU),
+                                 POOLING_DATA_TYPE_GPU_SEQ);
+
+}  // namespace oneflow
+#endif  // WITH_CUDA
diff --git a/oneflow/user/kernels/pooling_kernel_util.cpp b/oneflow/user/kernels/pooling_kernel_util.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c2f502ce2928c8fec6c60ccc5c288d1c074fb275
--- /dev/null
+++ b/oneflow/user/kernels/pooling_kernel_util.cpp
@@ -0,0 +1,115 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/user/kernels/pooling_kernel_util.h"
+
+namespace oneflow {
+
+std::vector<int32_t> Get3DVec(const std::vector<int32_t>& original_vec, int32_t NDims) {
+  std::vector<int32_t> vec;
+  FOR_RANGE(uint8_t, dim, 0, 3) {
+    int64_t index = static_cast<int64_t>(dim) - (3 - NDims);
+    if (index < 0) {
+      vec.push_back(1);
+    } else {
+      vec.push_back(original_vec.at(index));
+    }
+  }
+  return vec;
+}
+
+std::vector<int32_t> Get3DPadVec(const std::vector<int32_t>& original_vec, int32_t NDims) {
+  std::vector<int32_t> vec;
+  FOR_RANGE(uint8_t, dim, 0, 3) {
+    int64_t index = static_cast<int64_t>(dim) - (3 - NDims);
+    if (index < 0) {
+      vec.push_back(0);
+    } else {
+      vec.push_back(original_vec.at(index));
+    }
+  }
+  return vec;
+}
+
+PoolingParams3D::PoolingParams3D(const int32_t dim, const ShapeView& x_shape,
+                                 const std::string& data_format, const std::string& padding,
+                                 const std::vector<int32_t>& padding_before,
+                                 const std::vector<int32_t>& padding_after,
+                                 const std::vector<int32_t>& kernel_size,
+                                 const std::vector<int32_t>& stride,
+                                 const std::vector<int32_t>& dilation, const bool return_indices,
+                                 const bool ceil_mode)
+    : dim_(dim),
+      data_format_(data_format),
+      padding_(padding),
+      padding_before_3d_(Get3DPadVec(padding_before, dim)),
+      padding_after_3d_(Get3DPadVec(padding_after, dim)),
+      pooling_size_3d_(Get3DVec(kernel_size, dim)),
+      stride_3d_(Get3DVec(stride, dim)),
+      dilation_3d_(Get3DVec(dilation, dim)),
+      return_indices_(return_indices),
+      ceil_mode_(ceil_mode) {
+  x_3d_ = {GetInDim(x_shape, data_format, 0, dim), GetInDim(x_shape, data_format, 1, dim),
+           GetInDim(x_shape, data_format, 2, dim)};
+  Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_,
+                  &padding_before_3d_, &padding_after_3d_);
+  if (data_format == "channels_first") {
+    channel_num_ = x_shape.At(1);
+  } else {
+    CHECK_EQ(data_format_, "channels_last")
+        << "data_format must be 'channels_first' or 'channels_last'";
+    channel_num_ = x_shape.At(x_shape.NumAxes() - 1);
+  }
+  batch_num_ = x_shape.At(0);
+}
+
+void PoolingParams3D::Reset(const ShapeView& x_shape) {
+  x_3d_ = {GetInDim(x_shape, data_format_, 0, dim_), GetInDim(x_shape, data_format_, 1, dim_),
+           GetInDim(x_shape, data_format_, 2, dim_)};
+  Get3DOutputSize(x_3d_, pooling_size_3d_, stride_3d_, padding_, ceil_mode_, &dilation_3d_, &y_3d_,
+                  &padding_before_3d_, &padding_after_3d_);
+}
+
+Shape PoolingParams3D::GetYShape() const {
+  DimVector y_dim_vec;
+  if (dim_ == 1) {
+    y_dim_vec = {y_3d_.at(2)};
+  } else if (dim_ == 2) {
+    y_dim_vec = {y_3d_.at(1), y_3d_.at(2)};
+  } else if (dim_ == 3) {
+    y_dim_vec = {y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)};
+  } else {
+    UNIMPLEMENTED();
+  }
+  if (data_format_ == "channels_first") {
+    y_dim_vec.insert(y_dim_vec.begin(), channel_num_);
+  } else {
+    CHECK_EQ(data_format_, "channels_last")
+        << "data_format must be 'channels_first' or 'channels_last'";
+    y_dim_vec.insert(y_dim_vec.end(), channel_num_);
+  }
+  y_dim_vec.insert(y_dim_vec.begin(), batch_num_);
+  return Shape(y_dim_vec);
+}
+
+Shape PoolingParams3D::GetXShape5D() const {
+  return Shape({batch_num_, channel_num_, x_3d_.at(0), x_3d_.at(1), x_3d_.at(2)});
+}
+
+Shape PoolingParams3D::GetYShape5D() const {
+  return Shape({batch_num_, channel_num_, y_3d_.at(0), y_3d_.at(1), y_3d_.at(2)});
+}
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/pooling_kernel_util.h b/oneflow/user/kernels/pooling_kernel_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..b25ba34e04a3ce8c5027948965c686c69e1134f9
--- /dev/null
+++ b/oneflow/user/kernels/pooling_kernel_util.h
@@ -0,0 +1,277 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#ifndef ONEFLOW_USER_KERNELS_POOLING_KERNEL_UTIL_H_
+#define ONEFLOW_USER_KERNELS_POOLING_KERNEL_UTIL_H_
+#include "oneflow/core/device/device_context.h"
+#include "oneflow/core/ndarray/xpu_util.h"
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/operator/operator_util.h"
+#include "oneflow/core/kernel/util/numerics.cuh"
+#include "oneflow/core/kernel/util/numeric_limits.cuh"
+#ifdef WITH_CUDA
+#include "oneflow/core/cuda/atomic.cuh"
+#endif  // WITH_CUDA
+
+namespace oneflow {
+
+#define POOLING_DATA_TYPE_SEQ                     \
+  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32) \
+  OF_PP_MAKE_TUPLE_SEQ(float, DataType::kFloat)   \
+  OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble)
+
+#define POOLING_DATA_TYPE_CPU_SEQ POOLING_DATA_TYPE_SEQ
+
+#define POOLING_DATA_TYPE_GPU_SEQ POOLING_DATA_TYPE_SEQ
+
+typedef fixed_vector<int64_t, SHAPE_MAX_AXIS_SIZE> FixedDimVector;
+
+template<typename T>
+struct DeviceAdd {
+  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {
+#if defined(__CUDA_ARCH__)
+    cuda::atomic::Add(y, *x);
+#else
+    *y += *x;
+#endif
+  };
+};
+
+class PoolingParams3D {
+ public:
+  PoolingParams3D(const int32_t dim, const ShapeView& x_shape, const std::string& data_format,
+                  const std::string& padding, const std::vector<int32_t>& padding_before,
+                  const std::vector<int32_t>& padding_after,
+                  const std::vector<int32_t>& kernel_size, const std::vector<int32_t>& stride,
+                  const std::vector<int32_t>& dilation, const bool return_indices,
+                  const bool ceil_mode);
+  ~PoolingParams3D() = default;
+
+  const std::string& data_format() const { return data_format_; }
+  const std::vector<int32_t>& padding_before_3d() const { return padding_before_3d_; }
+  const std::vector<int32_t>& padding_after_3d() const { return padding_after_3d_; }
+  const std::vector<int32_t>& pooling_size_3d() const { return pooling_size_3d_; }
+  const std::vector<int32_t>& stride_3d() const { return stride_3d_; }
+  const std::vector<int32_t>& dilation_3d() const { return dilation_3d_; }
+  const bool& return_indices() const { return return_indices_; }
+  const bool& ceil_mode() const { return ceil_mode_; }
+  const int64_t& num_batch() const { return batch_num_; }
+  const int64_t& num_channel() const { return channel_num_; }
+
+  void Reset(const ShapeView& x_shape);
+  Shape GetYShape() const;
+  Shape GetXShape5D() const;
+  Shape GetYShape5D() const;
+
+ private:
+  int32_t dim_;
+  FixedDimVector x_3d_;
+  FixedDimVector y_3d_;
+  std::string data_format_;
+  std::string padding_;
+  std::vector<int32_t> padding_before_3d_;
+  std::vector<int32_t> padding_after_3d_;
+  std::vector<int32_t> pooling_size_3d_;
+  std::vector<int32_t> stride_3d_;
+  std::vector<int32_t> dilation_3d_;
+  bool return_indices_;
+  bool ceil_mode_;
+  int64_t batch_num_;
+  int64_t channel_num_;
+};
+
+template<DeviceType device_type, typename T>
+struct PoolingKernelUtil {
+  static void Maxpool2dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
+                               const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                               const PoolingParams3D& params_3d);
+
+  static void Maxpool2dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 4>& index_helper,
+                                const int64_t elem_num, const T* src, T* dest,
+                                const int64_t* indice_ptr, const PoolingParams3D& params_3d);
+
+  static void Maxpool3dForward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper,
+                               const int64_t elem_num, const T* src, T* dest, int64_t* indice_ptr,
+                               const PoolingParams3D& params_3d);
+
+  static void Maxpool3dBackward(DeviceCtx* ctx, const NdIndexOffsetHelper<int64_t, 5>& index_helper,
+                                const int64_t elem_num, const T* src, T* dest,
+                                const int64_t* indice_ptr, const PoolingParams3D& params_3d);
+};
+
+template<typename T>
+OF_DEVICE_FUNC void Maxpool2dFarwardCompute(
+    const NdIndexOffsetHelper<int64_t, 4> index_helper, int64_t elem_num, const T* src, T* dest,
+    int64_t* indice_ptr, const int32_t padding_h, const int32_t padding_w, const int64_t n_batch,
+    const int64_t n_channel, const int64_t x_height, const int64_t x_width, const int64_t y_height,
+    const int64_t y_width, const int32_t kernel_size_h, const int32_t kernel_size_w,
+    const int32_t stride_h, const int32_t stride_w, const int32_t dilation_h,
+    const int32_t dilation_w) {
+  XPU_1D_KERNEL_LOOP(num, elem_num) {
+    int64_t n, c, h, w;
+    index_helper.OffsetToNdIndex(num, n, c, h, w);
+
+    const int64_t start_idx = (n * n_channel + c) * x_width * x_height;
+    int64_t hstart = h * stride_h - padding_h;
+    int64_t wstart = w * stride_w - padding_w;
+    const int64_t hend = (hstart + (kernel_size_h - 1) * dilation_h + 1) <= x_height
+                             ? (hstart + (kernel_size_h - 1) * dilation_h + 1)
+                             : x_height;
+    const int64_t wend = (wstart + (kernel_size_w - 1) * dilation_w + 1) <= x_width
+                             ? (wstart + (kernel_size_w - 1) * dilation_w + 1)
+                             : x_width;
+
+    while (hstart < 0) { hstart += dilation_h; }
+    while (wstart < 0) { wstart += dilation_w; }
+
+    /* compute max value(src[src_idx]) in kernel box region, and save the value to dest[num] */
+    int64_t maxindex = hstart * x_width + wstart;
+    int64_t src_idx = 0;
+
+    /* equal to -std::numeric_limits<T>::infinity(); */
+    T max_value = detail::numeric_limits<T>::lower_bound();
+
+    for (int64_t i = hstart; i < hend; i += dilation_h) {
+      for (int64_t j = wstart; j < wend; j += dilation_w) {
+        const int64_t tcntr = i * x_width + j;
+        const int64_t search_idx = start_idx + tcntr;
+        T val = src[search_idx];
+        /* NOTE:
+        std::isnan(val) only supports a few data types, see:
+        https://en.cppreference.com/w/cpp/numeric/math/isnan and when use gcc/g++ 4.x to compile,
+        the following exception will be throw:
+
+        new_kernel_util.cu:24] Check failed: cudaMemcpyAsync(dst, src, sz, cudaMemcpyDefault,
+        ctx->cuda_stream() ) : unspecified launch failure (719)
+
+        but if use gcc/g++ 7.x to compile, everything is ok! the exact reason is still unknown!
+        */
+        if (val > max_value || detail::numerics<T>::isnan(val)) {
+          max_value = val;
+          maxindex = tcntr;
+          src_idx = search_idx;
+        }
+      }
+    }
+    dest[num] = src[src_idx];
+    indice_ptr[num] = maxindex;
+  }
+}
+
+template<typename T>
+OF_DEVICE_FUNC void Maxpool2dBackwardCompute(const NdIndexOffsetHelper<int64_t, 4> index_helper,
+                                             const int64_t elem_num, const T* src, T* dest,
+                                             const int64_t* indice_ptr, const int64_t n_batch,
+                                             const int64_t n_channel, const int64_t src_height,
+                                             const int64_t src_width, const int64_t dst_height,
+                                             const int64_t dst_width) {
+  XPU_1D_KERNEL_LOOP(num, elem_num) {
+    int64_t n, c, h, w;
+    index_helper.OffsetToNdIndex(num, n, c, h, w);
+
+    const int64_t src_start = (n * n_channel + c) * src_height * src_width;
+    const int64_t dst_start = (n * n_channel + c) * dst_height * dst_width;
+    const int64_t index = src_start + h * src_width + w;
+    const int64_t maxindex = dst_start + indice_ptr[index];
+    if (maxindex != -1) {
+      /* update gradient, equals to dest[maxindex] += src[index]; */
+      DeviceAdd<T>::Invoke(src + index, dest + maxindex);
+    }
+  }
+}
+
+template<typename T>
+OF_DEVICE_FUNC void Maxpool3dFarwardCompute(
+    const NdIndexOffsetHelper<int64_t, 5> index_helper, int64_t elem_num, const T* src, T* dest,
+    int64_t* indice_ptr, const int32_t padding_t, const int32_t padding_h, const int32_t padding_w,
+    const int64_t n_batch, const int64_t n_channel, const int64_t x_time, const int64_t x_height,
+    const int64_t x_width, const int64_t y_time, const int64_t y_height, const int64_t y_width,
+    const int32_t kernel_size_t, const int32_t kernel_size_h, const int32_t kernel_size_w,
+    const int32_t stride_t, const int32_t stride_h, const int32_t stride_w,
+    const int32_t dilation_t, const int32_t dilation_h, const int32_t dilation_w) {
+  XPU_1D_KERNEL_LOOP(num, elem_num) {
+    int64_t n, c, t, h, w;
+    index_helper.OffsetToNdIndex(num, n, c, t, h, w);
+
+    int64_t xstart = n * n_channel * x_time * x_width * x_height;
+    int64_t start_idx = xstart + c * x_time * x_width * x_height;
+    int64_t tstart = t * stride_t - padding_t;
+    int64_t hstart = h * stride_h - padding_h;
+    int64_t wstart = w * stride_w - padding_w;
+
+    const int64_t t1 = tstart + (kernel_size_t - 1) * dilation_t + 1;
+    const int64_t t2 = hstart + (kernel_size_h - 1) * dilation_h + 1;
+    const int64_t t3 = wstart + (kernel_size_w - 1) * dilation_w + 1;
+    const int64_t tend = t1 <= x_time ? t1 : x_time;
+    const int64_t hend = t2 <= x_height ? t2 : x_height;
+    const int64_t wend = t3 <= x_width ? t3 : x_width;
+
+    while (tstart < 0) { tstart += dilation_t; }
+    while (hstart < 0) { hstart += dilation_h; }
+    while (wstart < 0) { wstart += dilation_w; }
+
+    int64_t maxindex = tstart * x_height * x_width + hstart * x_width + wstart;
+    int64_t src_idx = 0;
+
+    T max_value = detail::numeric_limits<T>::lower_bound();
+    for (int64_t zi = tstart; zi < tend; zi += dilation_t) {
+      for (int64_t i = hstart; i < hend; i += dilation_h) {
+        for (int64_t j = wstart; j < wend; j += dilation_w) {
+          const int64_t tcntr = zi * x_height * x_width + i * x_width + j;
+          const int64_t search_idx = start_idx + tcntr;
+          T val = src[search_idx];
+          if (val > max_value || detail::numerics<T>::isnan(val)) {
+            max_value = val;
+            maxindex = tcntr;
+            src_idx = search_idx;
+          }
+        }
+      }
+      /* set output to local max */
+      dest[num] = src[src_idx];
+      /* store location of max */
+      indice_ptr[num] = maxindex;
+    }
+  }
+}
+
+template<typename T>
+OF_DEVICE_FUNC void Maxpool3dBackwardCompute(const NdIndexOffsetHelper<int64_t, 5> index_helper,
+                                             const int64_t elem_num, const T* src, T* dest,
+                                             const int64_t* indice_ptr, const int64_t n_batch,
+                                             const int64_t n_channel, const int64_t src_time,
+                                             const int64_t src_height, const int64_t src_width,
+                                             const int64_t dst_time, const int64_t dst_height,
+                                             const int64_t dst_width) {
+  XPU_1D_KERNEL_LOOP(num, elem_num) {
+    int64_t n, c, t, h, w;
+    index_helper.OffsetToNdIndex(num, n, c, t, h, w);
+
+    const int64_t src_start = (n * n_channel + c) * src_time * src_height * src_width;
+    const int64_t dst_start = (n * n_channel + c) * dst_time * dst_height * dst_width;
+    const int64_t index = src_start + t * src_height * src_width + h * src_width + w;
+    const int64_t maxindex = dst_start + indice_ptr[index];
+
+    if (maxindex != -1) { DeviceAdd<T>::Invoke(src + index, dest + maxindex); }
+  }
+}
+
+#define INSTANTIATE_POOLING_KERNEL_UTIL(device_type_v, dtype_pair) \
+  template struct PoolingKernelUtil<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_USER_KERNELS_POOLING_KERNEL_UTIL_H_
diff --git a/oneflow/user/ops/pooling_op.cpp b/oneflow/user/ops/pooling_op.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..09577f81ecd0f5a0f1159eb3dc020bab81c12dc6
--- /dev/null
+++ b/oneflow/user/ops/pooling_op.cpp
@@ -0,0 +1,219 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/user/kernels/pooling_kernel_util.h"
+
+namespace oneflow {
+
+namespace {
+
+typedef std::function<Maybe<void>(user_op::InferContext* ctx)> TensorDescInferFn;
+typedef std::function<void(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp)>
+    GenBackwardOpConfFn;
+
+TensorDescInferFn MakeForwardTensorDescInferFn(const int32_t dim) {
+  return [dim](user_op::InferContext* ctx) -> Maybe<void> {
+    const Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0);
+    const std::string& data_format = ctx->Attr<std::string>("data_format");
+    const std::string& padding = ctx->Attr<std::string>("padding");
+    const std::vector<int32_t>& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
+    const std::vector<int32_t>& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
+    const std::vector<int32_t>& kernel_size = ctx->Attr<std::vector<int32_t>>("kernel_size");
+    const std::vector<int32_t>& stride = ctx->Attr<std::vector<int32_t>>("stride");
+    const std::vector<int32_t>& dilation = ctx->Attr<std::vector<int32_t>>("dilation");
+    const bool return_indices = ctx->Attr<bool>("return_indices");
+    const bool ceil_mode = ctx->Attr<bool>("ceil_mode");
+
+    CHECK_EQ_OR_RETURN(kernel_size.size(), dim);
+    for (int32_t pool_dim : kernel_size) { CHECK_GT_OR_RETURN(pool_dim, 0); }
+    CHECK_EQ_OR_RETURN(stride.size(), dim);
+    for (int32_t stride_dim : stride) { CHECK_GT_OR_RETURN(stride_dim, 0); }
+    for (int32_t i = 0; i < padding_after.size(); i++) {
+      CHECK_GE_OR_RETURN(kernel_size[i], 2 * padding_after[i])
+          << "pad should be smaller than half of kernel size";
+    }
+
+    const PoolingParams3D params_3d(dim, *x_shape, data_format, padding, padding_before,
+                                    padding_after, kernel_size, stride, dilation, return_indices,
+                                    ceil_mode);
+    user_op::TensorDesc* y_desc = ctx->TensorDesc4ArgNameAndIndex("y", 0);
+    *y_desc = *ctx->TensorDesc4ArgNameAndIndex("x", 0);
+    *y_desc->mut_shape() = params_3d.GetYShape();
+
+    user_op::TensorDesc* indice_desc = ctx->TensorDesc4ArgNameAndIndex("indice", 0);
+    *indice_desc = *ctx->TensorDesc4ArgNameAndIndex("y", 0);
+    *indice_desc->mut_shape() = *y_desc->mut_shape();
+    DataType* dtype = indice_desc->mut_data_type();
+    *dtype = kInt64;
+    return Maybe<void>::Ok();
+  };
+}
+
+Maybe<void> ForwardGetSbpFn(user_op::SbpContext* ctx) {
+  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
+  const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
+  const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
+  FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {
+    if (padding_before[i] == 0 && padding_after[i] == 0) {
+      ctx->NewBuilder()
+          .Split(user_op::OpArg("x", 0), i)
+          .Split(user_op::OpArg("y", 0), i)
+          .Split(user_op::OpArg("indice", 0), i)
+          .Build();
+    }
+  }
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> BackwardTensorDescInferFn(user_op::InferContext* ctx) {
+  *ctx->TensorDesc4ArgNameAndIndex("dx", 0) = *ctx->TensorDesc4ArgNameAndIndex("x", 0);
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> BackwardGetSbpFn(user_op::SbpContext* ctx) {
+  const user_op::TensorDesc& tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
+  const auto& padding_before = ctx->Attr<std::vector<int32_t>>("padding_before");
+  const auto& padding_after = ctx->Attr<std::vector<int32_t>>("padding_after");
+  FOR_RANGE(int64_t, i, 0, std::min(2, (int)tensor.shape().NumAxes())) {
+    if (padding_before[i] == 0 && padding_after[i] == 0) {
+      ctx->NewBuilder()
+          .Split(user_op::OpArg("x", 0), i)
+          .Split(user_op::OpArg("y", 0), i)
+          .Split(user_op::OpArg("indice", 0), i)
+          .Split(user_op::OpArg("dy", 0), i)
+          .Split(user_op::OpArg("dx", 0), i)
+          .Build();
+    }
+  }
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> FwInferDataType(user_op::InferContext* ctx) {
+  *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> BwInferDataType(user_op::InferContext* ctx) {
+  *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0);
+  return Maybe<void>::Ok();
+}
+
+GenBackwardOpConfFn MakeBackwardOpConfFn(const std::string& mode, const int32_t dim) {
+  return [mode, dim](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
+    if (op.NeedGenGradTensor4OpInput("x", 0)) {
+      user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
+      user_op::UserOpConfWrapper grad_op =
+          builder.Op(mode + "pool_" + std::to_string(dim) + "d_grad")
+              .Input("x", op.input("x", 0))
+              .Input("y", op.output("y", 0))
+              .Input("indice", op.output("indice", 0))
+              .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
+              .Output("dx")
+              .Attr("data_format", op.attr<std::string>("data_format"))
+              .Attr("padding", op.attr<std::string>("padding"))
+              .Attr("padding_before", op.attr<std::vector<int32_t>>("padding_before"))
+              .Attr("padding_after", op.attr<std::vector<int32_t>>("padding_after"))
+              .Attr("kernel_size", op.attr<std::vector<int32_t>>("kernel_size"))
+              .Attr("stride", op.attr<std::vector<int32_t>>("stride"))
+              .Attr("dilation", op.attr<std::vector<int32_t>>("dilation"))
+              .Attr("return_indices", op.attr<bool>("return_indices"))
+              .Attr("ceil_mode", op.attr<bool>("ceil_mode"))
+              .Build();
+      op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
+      AddOp(grad_op);
+    }
+  };
+}
+
+}  // namespace
+
+REGISTER_USER_OP("maxpool_2d")
+    .Input("x")
+    .Output("y")
+    .Output("indice")
+    .Attr<std::string>("padding")
+    .Attr<std::vector<int32_t>>("padding_before")
+    .Attr<std::vector<int32_t>>("padding_after")
+    .Attr<std::string>("data_format")
+    .Attr<std::vector<int32_t>>("kernel_size")
+    .Attr<std::vector<int32_t>>("stride")
+    .Attr<std::vector<int32_t>>("dilation")
+    .Attr<bool>("return_indices")
+    .Attr<bool>("ceil_mode")
+    .SetTensorDescInferFn(MakeForwardTensorDescInferFn(2))
+    .SetGetSbpFn(ForwardGetSbpFn)
+    .SetDataTypeInferFn(FwInferDataType);
+
+REGISTER_USER_OP("maxpool_2d_grad")
+    .Input("x")
+    .Input("y")
+    .Input("indice")
+    .Input("dy")
+    .Output("dx")
+    .Attr<std::string>("padding")
+    .Attr<std::vector<int32_t>>("padding_before")
+    .Attr<std::vector<int32_t>>("padding_after")
+    .Attr<std::string>("data_format")
+    .Attr<std::vector<int32_t>>("kernel_size")
+    .Attr<std::vector<int32_t>>("stride")
+    .Attr<std::vector<int32_t>>("dilation")
+    .Attr<bool>("return_indices")
+    .Attr<bool>("ceil_mode")
+    .SetTensorDescInferFn(BackwardTensorDescInferFn)
+    .SetGetSbpFn(BackwardGetSbpFn)
+    .SetDataTypeInferFn(BwInferDataType);
+
+REGISTER_USER_OP_GRAD("maxpool_2d").SetGenBackwardOpConfFn(MakeBackwardOpConfFn("max", 2));
+
+REGISTER_USER_OP("maxpool_3d")
+    .Input("x")
+    .Output("y")
+    .Output("indice")
+    .Attr<std::string>("padding")
+    .Attr<std::vector<int32_t>>("padding_before")
+    .Attr<std::vector<int32_t>>("padding_after")
+    .Attr<std::string>("data_format")
+    .Attr<std::vector<int32_t>>("kernel_size")
+    .Attr<std::vector<int32_t>>("stride")
+    .Attr<std::vector<int32_t>>("dilation")
+    .Attr<bool>("return_indices")
+    .Attr<bool>("ceil_mode")
+    .SetTensorDescInferFn(MakeForwardTensorDescInferFn(3))
+    .SetGetSbpFn(ForwardGetSbpFn)
+    .SetDataTypeInferFn(FwInferDataType);
+
+REGISTER_USER_OP("maxpool_3d_grad")
+    .Input("x")
+    .Input("y")
+    .Input("indice")
+    .Input("dy")
+    .Output("dx")
+    .Attr<std::string>("padding")
+    .Attr<std::vector<int32_t>>("padding_before")
+    .Attr<std::vector<int32_t>>("padding_after")
+    .Attr<std::string>("data_format")
+    .Attr<std::vector<int32_t>>("kernel_size")
+    .Attr<std::vector<int32_t>>("stride")
+    .Attr<std::vector<int32_t>>("dilation")
+    .Attr<bool>("return_indices")
+    .Attr<bool>("ceil_mode")
+    .SetTensorDescInferFn(BackwardTensorDescInferFn)
+    .SetGetSbpFn(BackwardGetSbpFn)
+    .SetDataTypeInferFn(BwInferDataType);
+
+REGISTER_USER_OP_GRAD("maxpool_3d").SetGenBackwardOpConfFn(MakeBackwardOpConfFn("max", 3));
+
+}  // namespace oneflow