diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst
index 4242190fa01c9c224bdad9c7853546869d3c95d2..e58b0d9c1d66b184a40faad8635144d0d5de7c0e 100644
--- a/docs/source/experimental.rst
+++ b/docs/source/experimental.rst
@@ -49,7 +49,6 @@ Experimental features
 .. autofunction:: oneflow.experimental.Tensor.argmax
 .. autofunction:: oneflow.experimental.nn.BatchNorm1d
 .. autofunction:: oneflow.experimental.nn.BatchNorm2d
-.. autofunction:: oneflow.experimental.nn.ReplicationPad2d
 .. autofunction:: oneflow.experimental.nn.InstanceNorm1d
 .. autofunction:: oneflow.experimental.nn.InstanceNorm2d
 .. autofunction:: oneflow.experimental.nn.InstanceNorm3d
@@ -71,7 +70,11 @@ Experimental features
 .. autofunction:: oneflow.experimental.nn.ModuleDict
 .. autofunction:: oneflow.experimental.nn.Conv1d
 .. autofunction:: oneflow.experimental.nn.Conv2d
+.. autofunction:: oneflow.experimental.nn.ZeroPad2d
+.. autofunction:: oneflow.experimental.nn.ReflectionPad2d
+.. autofunction:: oneflow.experimental.nn.ReplicationPad2d
 .. autofunction:: oneflow.experimental.nn.ConstantPad2d
+.. autofunction:: oneflow.experimental.nn.ConstantPad3d
 .. autofunction:: oneflow.experimental.nn.ConvTranspose2d
 .. autofunction:: oneflow.experimental.nn.Dropout
 .. autofunction:: oneflow.experimental.slice
@@ -232,7 +235,6 @@ Experimental features
 .. autofunction:: oneflow.experimental.Tensor.ceil
 .. autofunction:: oneflow.experimental.expm1
 .. autofunction:: oneflow.experimental.Tensor.expm1
-.. autofunction:: oneflow.experimental.nn.ReflectionPad2d
 .. autofunction:: oneflow.experimental.meshgrid
 .. autofunction:: oneflow.experimental.topk
 .. autofunction:: oneflow.experimental.Tensor.topk
@@ -241,7 +243,6 @@ Experimental features
 .. autofunction:: oneflow.experimental.nn.GroupNorm
 .. autofunction:: oneflow.experimental.gather_nd
 .. autofunction:: oneflow.experimental.scatter_nd
-.. autofunction:: oneflow.experimental.nn.ZeroPad2d
 .. autofunction:: oneflow.experimental.nn.image.flip
 .. autofunction:: oneflow.experimental.tensor_buffer_to_tensor
 .. autofunction:: oneflow.experimental.tensor_to_tensor_buffer
diff --git a/oneflow/core/autograd/gradient_funcs/pad2d.cpp b/oneflow/core/autograd/gradient_funcs/padding.cpp
similarity index 92%
rename from oneflow/core/autograd/gradient_funcs/pad2d.cpp
rename to oneflow/core/autograd/gradient_funcs/padding.cpp
index 7b130315dc1f8c604014b71238d8afb01b97719b..a733cf953c322b89468cb9f0d8124bf516d26109 100644
--- a/oneflow/core/autograd/gradient_funcs/pad2d.cpp
+++ b/oneflow/core/autograd/gradient_funcs/padding.cpp
@@ -78,13 +78,13 @@ class ReplicationPad2d : public Pad2d {
 REGISTER_OP_EXPR_GRAD_FUNCTION("reflection_pad2d", ReflectionPad2d);
 REGISTER_OP_EXPR_GRAD_FUNCTION("replication_pad2d", ReplicationPad2d);
 
-struct ConstantPad2dInterpState : public OpExprInterpState {
+struct ConstantPadNdInterpState : public OpExprInterpState {
   bool requires_grad;
   std::vector<int64_t> paddings;
   functional::Scalar padding_value;
 };
 
-class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> {
+class ConstantPadNd : public OpExprGradFunction<ConstantPadNdInterpState> {
  public:
   Maybe<void> Init(const OpExpr& op) override {
     const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
@@ -93,7 +93,7 @@ class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> {
     return Maybe<void>::Ok();
   }
 
-  Maybe<void> Capture(ConstantPad2dInterpState* ctx, const TensorTuple& inputs,
+  Maybe<void> Capture(ConstantPadNdInterpState* ctx, const TensorTuple& inputs,
                       const TensorTuple& outputs, const AttrMap& attrs) const override {
     CHECK_EQ_OR_RETURN(inputs.size(), 1);
     CHECK_EQ_OR_RETURN(outputs.size(), 1);
@@ -112,7 +112,7 @@ class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> {
     return Maybe<void>::Ok();
   }
 
-  Maybe<void> Apply(const ConstantPad2dInterpState* ctx, const TensorTuple& out_grads,
+  Maybe<void> Apply(const ConstantPadNdInterpState* ctx, const TensorTuple& out_grads,
                     TensorTuple* in_grads) const override {
     CHECK_EQ_OR_RETURN(out_grads.size(), 1);
     in_grads->resize(1);
@@ -127,7 +127,8 @@ class ConstantPad2d : public OpExprGradFunction<ConstantPad2dInterpState> {
   AttrMap base_attrs_;
 };
 
-REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad2d", ConstantPad2d);
+REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad2d", ConstantPadNd);
+REGISTER_OP_EXPR_GRAD_FUNCTION("constant_pad3d", ConstantPadNd);
 
 }  // namespace one
 }  // namespace oneflow
diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp
index d181e21af21e2002a87702c56c45db5978d2afb7..535b301820954039b0ec083863780bd904cb3910 100644
--- a/oneflow/core/functional/impl/nn_functor.cpp
+++ b/oneflow/core/functional/impl/nn_functor.cpp
@@ -375,6 +375,7 @@ class NormalizationFunctor {
 class PadFunctor {
  public:
   PadFunctor() {
+    constant_pad_3d_ = CHECK_JUST(one::OpBuilder("constant_pad3d").Input("x").Output("y").Build());
     constant_pad_ = CHECK_JUST(one::OpBuilder("constant_pad2d").Input("x").Output("y").Build());
     reflect_pad_ = CHECK_JUST(one::OpBuilder("reflection_pad2d").Input("x").Output("y").Build());
     replicate_pad_ = CHECK_JUST(one::OpBuilder("replication_pad2d").Input("x").Output("y").Build());
@@ -396,7 +397,14 @@ class PadFunctor {
       } else {
         UNIMPLEMENTED_THEN_RETURN() << "Data type should be floating or integral type.";
       }
-      return OpInterpUtil::Dispatch<Tensor>(*constant_pad_, {x}, attrs);
+      switch (x->shape()->NumAxes()) {
+        case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_, {x}, attrs);
+        case 5: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_3d_, {x}, attrs);
+        default:
+          UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but " << x->shape()->NumAxes()
+                                      << "d-tensor is not support yet! ";
+      }
+
     } else if (mode == "reflect") {
       return OpInterpUtil::Dispatch<Tensor>(*reflect_pad_, {x}, attrs);
     } else if (mode == "replicate") {
@@ -409,6 +417,7 @@ class PadFunctor {
 
  private:
   std::shared_ptr<OpExpr> constant_pad_;
+  std::shared_ptr<OpExpr> constant_pad_3d_;
   std::shared_ptr<OpExpr> reflect_pad_;
   std::shared_ptr<OpExpr> replicate_pad_;
 };
diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp
index b47f345d70bda1f4e46761f42409e9086a736143..fa6110e84bd73abbee04a3f32af5e460be7a1c36 100644
--- a/oneflow/core/functional/impl/nn_grad_functor.cpp
+++ b/oneflow/core/functional/impl/nn_grad_functor.cpp
@@ -229,6 +229,8 @@ class PadGradFunctor {
   PadGradFunctor() {
     constant_pad_grad_ =
         CHECK_JUST(one::OpBuilder("constant_pad2d_grad").Input("dy").Output("dx").Build());
+    constant_pad_3d_grad_ =
+        CHECK_JUST(one::OpBuilder("constant_pad3d_grad").Input("dy").Output("dx").Build());
     reflect_pad_grad_ =
         CHECK_JUST(one::OpBuilder("reflection_pad2d_grad").Input("dy").Output("dx").Build());
     replicate_pad_grad_ =
@@ -251,7 +253,13 @@ class PadGradFunctor {
       } else {
         UNIMPLEMENTED_THEN_RETURN() << "Data type should be floating or integral type.";
       }
-      return OpInterpUtil::Dispatch<Tensor>(*constant_pad_grad_, {dy}, attrs);
+      switch (dy->shape()->NumAxes()) {
+        case 4: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_grad_, {dy}, attrs);
+        case 5: return OpInterpUtil::Dispatch<Tensor>(*constant_pad_3d_grad_, {dy}, attrs);
+        default:
+          UNIMPLEMENTED_THEN_RETURN() << "Pad mode is " << mode << ", but "
+                                      << dy->shape()->NumAxes() << "d-tensor is not support yet! ";
+      }
     } else if (mode == "reflect") {
       return OpInterpUtil::Dispatch<Tensor>(*reflect_pad_grad_, {dy}, attrs);
     } else if (mode == "replicate") {
@@ -266,6 +274,7 @@ class PadGradFunctor {
   std::shared_ptr<OpExpr> constant_pad_grad_;
   std::shared_ptr<OpExpr> reflect_pad_grad_;
   std::shared_ptr<OpExpr> replicate_pad_grad_;
+  std::shared_ptr<OpExpr> constant_pad_3d_grad_;
 };
 
 }  // namespace impl
diff --git a/oneflow/python/nn/modules/constantpad2d.py b/oneflow/python/nn/modules/constantpad2d.py
deleted file mode 100644
index 4a7b4eb08df9f9d6636550f8c44d676eea896b44..0000000000000000000000000000000000000000
--- a/oneflow/python/nn/modules/constantpad2d.py
+++ /dev/null
@@ -1,127 +0,0 @@
-"""
-Copyright 2020 The OneFlow Authors. All rights reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-"""
-from __future__ import absolute_import
-
-from typing import Union
-
-import oneflow as flow
-from oneflow.python.oneflow_export import oneflow_export, experimental_api
-from oneflow.python.nn.module import Module
-
-
-@oneflow_export("nn.ConstantPad2d")
-@experimental_api
-class ConstantPad2d(Module):
-    r"""The interface is consistent with PyTorch.
-    The documentation is referenced from:
-    https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad2d.html?highlight=constantpad2d#torch.nn.ConstantPad2d
-
-    This operator pads the input with constant value that user specifies. User can set the amount of padding by setting the parameter `paddings`.
-
-    Args:
-        padding (Union[int, tuple, list]):  the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\mathrm{padding_{left}}`, :math:`\mathrm{padding_{right}}`, :math:`\mathrm{padding_{top}}`, :math:`\mathrm{padding_{bottom}}`)
-        
-        value (Union[int, float]): The constant value used for padding. Defaults to 0.
-
-    Shape:
-        - Input: :math:`(N, C, H_{in}, W_{in})`
-        - Output: :math:`(N, C, H_{out}, W_{out})` where
-
-            :math:`H_{out} = H_{in} + \mathrm{padding_{top}} + \mathrm{padding_{bottom}}`
-
-            :math:`W_{out} = W_{in} + \mathrm{padding_{left}} + \mathrm{padding_{right}}`
-
-    For example:
-
-    .. code-block:: python
-
-        >>> import oneflow.experimental as flow
-        >>> import numpy as np
-        >>> flow.enable_eager_execution()
-        >>> constantpad_layer_0 = flow.nn.ConstantPad2d((2, 2, 1, 1), 1)
-        >>> input = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))
-        >>> input_int = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.int32))
-        >>> output = constantpad_layer_0(input)
-        >>> output.shape
-        flow.Size([1, 2, 5, 7])
-        >>> output
-        tensor([[[[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],
-                  [ 1.,  1.,  0.,  1.,  2.,  1.,  1.],
-                  [ 1.,  1.,  3.,  4.,  5.,  1.,  1.],
-                  [ 1.,  1.,  6.,  7.,  8.,  1.,  1.],
-                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]],
-        <BLANKLINE>
-                 [[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],
-                  [ 1.,  1.,  9., 10., 11.,  1.,  1.],
-                  [ 1.,  1., 12., 13., 14.,  1.,  1.],
-                  [ 1.,  1., 15., 16., 17.,  1.,  1.],
-                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]]]], dtype=oneflow.float32)
-        >>> output_int = constantpad_layer_0(input_int)
-        >>> output_int
-        tensor([[[[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],
-                  [ 1.,  1.,  0.,  1.,  2.,  1.,  1.],
-                  [ 1.,  1.,  3.,  4.,  5.,  1.,  1.],
-                  [ 1.,  1.,  6.,  7.,  8.,  1.,  1.],
-                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]],
-        <BLANKLINE>
-                 [[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],
-                  [ 1.,  1.,  9., 10., 11.,  1.,  1.],
-                  [ 1.,  1., 12., 13., 14.,  1.,  1.],
-                  [ 1.,  1., 15., 16., 17.,  1.,  1.],
-                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]]]], dtype=oneflow.float32)
-    """
-
-    def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):
-        super().__init__()
-        if isinstance(padding, (tuple, list)):
-            assert len(padding) == 4, ValueError("Length of padding must be 4")
-            boundary = [padding[0], padding[1], padding[2], padding[3]]
-        elif isinstance(padding, int):
-            boundary = [padding, padding, padding, padding]
-        else:
-            raise ValueError("padding must be int or list or tuple!")
-
-        self.padding = boundary
-        self.value = value
-
-    def forward(self, x):
-        _, _, h, w = x.shape
-
-        if x.dtype in [flow.float32, flow.float16, flow.float64]:
-            floating_value = float(self.value)
-            integral_value = int(0)
-        else:
-            floating_value = float(0)
-            integral_value = int(self.value)
-
-        self._op = (
-            flow.builtin_op("constant_pad2d")
-            .Input("x")
-            .Output("y")
-            .Attr("padding", self.padding)
-            .Attr("floating_value", floating_value)
-            .Attr("integral_value", integral_value)
-            .Build()
-        )
-
-        res = self._op(x)[0]
-        return res
-
-
-if __name__ == "__main__":
-    import doctest
-
-    doctest.testmod(raise_on_error=True)
diff --git a/oneflow/python/nn/modules/padding.py b/oneflow/python/nn/modules/padding.py
index 1d24752e86098db6f278dafeff51addbfbdf8f8c..edbea801ea479f3d9727d5ab180323dfa164d4fd 100644
--- a/oneflow/python/nn/modules/padding.py
+++ b/oneflow/python/nn/modules/padding.py
@@ -190,6 +190,176 @@ class ReflectionPad2d(Module):
         return "{}".format(self.padding)
 
 
+@oneflow_export("nn.ConstantPad2d")
+@experimental_api
+class ConstantPad2d(Module):
+    r"""The interface is consistent with PyTorch.
+    The documentation is referenced from:
+    https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad2d.html?highlight=constantpad2d#torch.nn.ConstantPad2d
+
+    This operator pads the input with constant value that user specifies. User can set the amount of padding by setting the parameter `paddings`.
+
+    Args:
+        padding (Union[int, tuple, list]):  the size of the padding. If is `int`, uses the same padding in all boundaries. If a 4-`tuple`, uses (:math:`\mathrm{padding_{left}}`, :math:`\mathrm{padding_{right}}`, :math:`\mathrm{padding_{top}}`, :math:`\mathrm{padding_{bottom}}`)
+        
+        value (Union[int, float]): The constant value used for padding. Defaults to 0.
+
+    Shape:
+        - Input: :math:`(N, C, H_{in}, W_{in})`
+        - Output: :math:`(N, C, H_{out}, W_{out})` where
+
+            :math:`H_{out} = H_{in} + \mathrm{padding_{top}} + \mathrm{padding_{bottom}}`
+
+            :math:`W_{out} = W_{in} + \mathrm{padding_{left}} + \mathrm{padding_{right}}`
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import oneflow.experimental as flow
+        >>> import numpy as np
+        >>> flow.enable_eager_execution()
+        >>> constantpad_layer_0 = flow.nn.ConstantPad2d((2, 2, 1, 1), 1)
+        >>> input = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32))
+        >>> input_int = flow.Tensor(np.arange(18).reshape((1, 2, 3, 3)).astype(np.int32))
+        >>> output = constantpad_layer_0(input)
+        >>> output.shape
+        flow.Size([1, 2, 5, 7])
+        >>> output
+        tensor([[[[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],
+                  [ 1.,  1.,  0.,  1.,  2.,  1.,  1.],
+                  [ 1.,  1.,  3.,  4.,  5.,  1.,  1.],
+                  [ 1.,  1.,  6.,  7.,  8.,  1.,  1.],
+                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]],
+        <BLANKLINE>
+                 [[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],
+                  [ 1.,  1.,  9., 10., 11.,  1.,  1.],
+                  [ 1.,  1., 12., 13., 14.,  1.,  1.],
+                  [ 1.,  1., 15., 16., 17.,  1.,  1.],
+                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]]]], dtype=oneflow.float32)
+        >>> output_int = constantpad_layer_0(input_int)
+        >>> output_int
+        tensor([[[[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],
+                  [ 1.,  1.,  0.,  1.,  2.,  1.,  1.],
+                  [ 1.,  1.,  3.,  4.,  5.,  1.,  1.],
+                  [ 1.,  1.,  6.,  7.,  8.,  1.,  1.],
+                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]],
+        <BLANKLINE>
+                 [[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],
+                  [ 1.,  1.,  9., 10., 11.,  1.,  1.],
+                  [ 1.,  1., 12., 13., 14.,  1.,  1.],
+                  [ 1.,  1., 15., 16., 17.,  1.,  1.],
+                  [ 1.,  1.,  1.,  1.,  1.,  1.,  1.]]]], dtype=oneflow.float32)
+    """
+
+    def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):
+        super().__init__()
+        if isinstance(padding, (tuple, list)):
+            assert len(padding) == 4, ValueError("Length of padding must be 4")
+            boundary = [padding[0], padding[1], padding[2], padding[3]]
+        elif isinstance(padding, int):
+            boundary = [padding, padding, padding, padding]
+        else:
+            raise ValueError("padding must be int or list or tuple!")
+
+        self.padding = boundary
+        self.value = value
+
+    def forward(self, x):
+        if x.dtype in (flow.float32, flow.float16, flow.float64):
+            self.value = float(self.value)
+        else:
+            self.value = int(self.value)
+
+        return flow.F.pad(x, pad=self.padding, mode="constant", value=self.value)
+
+
+@oneflow_export("nn.ConstantPad3d")
+@experimental_api
+class ConstantPad3d(Module):
+    r"""Pads the input tensor boundaries with a constant value.
+    The interface is consistent with PyTorch, and referenced from:
+    https://pytorch.org/docs/stable/generated/torch.nn.ConstantPad3d.html?highlight=constantpad3d#torch.nn.ConstantPad3d
+
+    For `N`-dimensional padding, use :func:`flow.nn.functional.pad()`.
+
+    Args:
+        padding (int, list, tuple): the size of the padding. If is `int`, uses the same
+            padding in all boundaries. If a 6-`tuple`, uses
+            (:math:`\text{padding_left}`, :math:`\text{padding_right}`,
+            :math:`\text{padding_top}`, :math:`\text{padding_bottom}`,
+            :math:`\text{padding_front}`, :math:`\text{padding_back}`)
+        
+        value (Union[int, float]): The constant value used for padding. Defaults to 0.
+
+    Shape:
+        - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
+        - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
+
+          :math:`D_{out} = D_{in} + \text{padding_front} + \text{padding_back}`
+
+          :math:`H_{out} = H_{in} + \text{padding_top} + \text{padding_bottom}`
+
+          :math:`W_{out} = W_{in} + \text{padding_left} + \text{padding_right}`
+
+    Examples::
+
+        >>> import oneflow.experimental as flow
+        >>> import numpy as np
+
+        >>> input = flow.tensor(np.arange(8).reshape(1,1,2,2,2).astype(np.int32))
+        >>> m = flow.nn.ConstantPad3d(padding=1, value=9)
+        >>> output = m(input)
+        >>> output
+        tensor([[[[[9, 9, 9, 9],
+                   [9, 9, 9, 9],
+                   [9, 9, 9, 9],
+                   [9, 9, 9, 9]],
+        <BLANKLINE>
+                  [[9, 9, 9, 9],
+                   [9, 0, 1, 9],
+                   [9, 2, 3, 9],
+                   [9, 9, 9, 9]],
+        <BLANKLINE>
+                  [[9, 9, 9, 9],
+                   [9, 4, 5, 9],
+                   [9, 6, 7, 9],
+                   [9, 9, 9, 9]],
+        <BLANKLINE>
+                  [[9, 9, 9, 9],
+                   [9, 9, 9, 9],
+                   [9, 9, 9, 9],
+                   [9, 9, 9, 9]]]]], dtype=oneflow.int32)
+    """
+
+    def __init__(self, padding: Union[int, tuple, list], value: Union[int, float] = 0):
+        super().__init__()
+        if isinstance(padding, (tuple, list)):
+            assert len(padding) == 6, ValueError("Length of padding must be 6")
+            boundary = [
+                padding[0],
+                padding[1],
+                padding[2],
+                padding[3],
+                padding[4],
+                padding[5],
+            ]
+        elif isinstance(padding, int):
+            boundary = [padding, padding, padding, padding, padding, padding]
+        else:
+            raise ValueError("padding must be int or list or tuple!")
+
+        self.padding = boundary
+        self.value = value
+
+    def forward(self, x):
+        if x.dtype in (flow.float32, flow.float16, flow.float64):
+            self.value = float(self.value)
+        else:
+            self.value = int(self.value)
+        return flow.F.pad(x, pad=self.padding, mode="constant", value=self.value)
+
+
 if __name__ == "__main__":
     import doctest
 
diff --git a/oneflow/python/test/modules/test_constantpad2d.py b/oneflow/python/test/modules/test_constantpad.py
similarity index 81%
rename from oneflow/python/test/modules/test_constantpad2d.py
rename to oneflow/python/test/modules/test_constantpad.py
index d3a42eb45e956e1a47169cd7425efc6787742de1..22391dbb728f52c38a5ffb481473a0b8858adc70 100644
--- a/oneflow/python/test/modules/test_constantpad2d.py
+++ b/oneflow/python/test/modules/test_constantpad.py
@@ -111,7 +111,7 @@ class TestConstantPad2dModule(flow.unittest.TestCase):
 
     def test_with_random_data(test_case):
         for device in ["cpu", "cuda"]:
-            spatial_size = np.random.randint(10, 20)
+            spatial_size = np.random.randint(1, 6)
             test_module_against_pytorch(
                 test_case,
                 "nn.ConstantPad2d",
@@ -120,8 +120,31 @@ class TestConstantPad2dModule(flow.unittest.TestCase):
                     "input": random_tensor(
                         ndim=4, dim2=spatial_size, dim3=spatial_size
                     ),
-                    "padding": random(0, 6),
-                    "value": random(0, 6),
+                    "padding": random(0, 3),
+                    "value": random(0, 10),
+                },
+                device=device,
+            )
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestConstantPad3dModule(flow.unittest.TestCase):
+    def test_with_random_data(test_case):
+        for device in ["cpu", "cuda"]:
+            spatial_size = np.random.randint(1, 6)
+            test_module_against_pytorch(
+                test_case,
+                "nn.ConstantPad3d",
+                extra_annotations={"padding": int, "value": float},
+                extra_generators={
+                    "input": random_tensor(
+                        ndim=5, dim2=spatial_size, dim3=spatial_size, dim4=spatial_size
+                    ),
+                    "padding": random(0, 3),
+                    "value": random(0, 10),
                 },
                 device=device,
             )
diff --git a/oneflow/user/kernels/constantpad3d_kernel.cpp b/oneflow/user/kernels/constantpad3d_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d146fb8d9a3b8ee2f87e338762734c4c5b51101d
--- /dev/null
+++ b/oneflow/user/kernels/constantpad3d_kernel.cpp
@@ -0,0 +1,147 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/device/memory_copier.h"
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/user/kernels/constantpad3d_kernel_util.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+T GetDtypeMatchedValue(double floating, int64_t integral);
+
+template<>
+float16 GetDtypeMatchedValue(double floating, int64_t integral) {
+  return static_cast<float16>(floating);
+}
+
+template<>
+float GetDtypeMatchedValue(double floating, int64_t integral) {
+  return static_cast<float>(floating);
+}
+
+template<>
+double GetDtypeMatchedValue(double floating, int64_t integral) {
+  return floating;
+}
+
+template<>
+int8_t GetDtypeMatchedValue(double floating, int64_t integral) {
+  return static_cast<int8_t>(integral);
+}
+
+template<>
+int32_t GetDtypeMatchedValue(double floating, int64_t integral) {
+  return static_cast<int32_t>(integral);
+}
+
+template<>
+int64_t GetDtypeMatchedValue(double floating, int64_t integral) {
+  return integral;
+}
+
+}  // namespace
+
+namespace user_op {
+
+template<DeviceType device_type, typename IN_T>
+class ConstantPad3dKernel final : public OpKernel {
+ public:
+  ConstantPad3dKernel() = default;
+  ~ConstantPad3dKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
+    Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const ShapeView& x_shape = x->shape();
+    const ShapeView& y_shape = y->shape();
+    CHECK_EQ(x->shape().NumAxes(), 5);
+    const std::vector<int64_t>& padding = ctx->Attr<std::vector<int64_t>>("padding");
+    CHECK_EQ(padding.size(), 6);
+    const IN_T constant_value = GetDtypeMatchedValue<IN_T>(ctx->Attr<double>("floating_value"),
+                                                           ctx->Attr<int64_t>("integral_value"));
+
+    IN_T* dest = y->mut_dptr<IN_T>();
+    const IN_T* src = x->dptr<IN_T>();
+    DimVector y_vector;
+    y->shape().ToDimVector(&y_vector);
+    NdIndexOffsetHelper<int64_t, 5> index_helper(y_vector.data());
+
+    ConstantPad3dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, x_shape,
+                                              y_shape, padding, constant_value);
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<DeviceType device_type, typename IN_T>
+class ConstantPad3dGradKernel final : public OpKernel {
+ public:
+  ConstantPad3dGradKernel() = default;
+  ~ConstantPad3dGradKernel() = default;
+
+ private:
+  void Compute(KernelComputeContext* ctx) const override {
+    const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    CHECK_EQ(dy->shape().NumAxes(), 5);
+    Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    const ShapeView& dx_shape = dx->shape();
+    const ShapeView& dy_shape = dy->shape();
+
+    const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
+    CHECK_EQ(padding.size(), 6);
+
+    const IN_T* src = dy->dptr<IN_T>();
+    IN_T* dest = dx->mut_dptr<IN_T>();
+    DimVector dy_vector;
+    dy->shape().ToDimVector(&dy_vector);
+    NdIndexOffsetHelper<int64_t, 5> index_helper(dy_vector.data());
+
+    size_t out_bytes_size = dx->shape().Count(0) * GetSizeOfDataType(dx->data_type());
+    Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
+
+    ConstantPad3dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper,
+                                                  dy_shape, dx_shape, padding);
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_CONSTANT_PAD3D_KERNELS(device, dtype)                                 \
+  REGISTER_USER_KERNEL("constant_pad3d")                                               \
+      .SetCreateFn<ConstantPad3dKernel<device, dtype>>()                               \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("constant_pad3d_grad")                                          \
+      .SetCreateFn<ConstantPad3dGradKernel<device, dtype>>()                           \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+#define REGISTER_CONSTANT_PAD3D_WITH_DEVICE(device) \
+  REGISTER_CONSTANT_PAD3D_KERNELS(device, float)    \
+  REGISTER_CONSTANT_PAD3D_KERNELS(device, double)   \
+  REGISTER_CONSTANT_PAD3D_KERNELS(device, int32_t)
+
+REGISTER_CONSTANT_PAD3D_WITH_DEVICE(DeviceType::kCPU)
+#ifdef WITH_CUDA
+REGISTER_CONSTANT_PAD3D_WITH_DEVICE(DeviceType::kGPU)
+REGISTER_CONSTANT_PAD3D_KERNELS(DeviceType::kGPU, float16)
+#endif
+
+}  // namespace user_op
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/constantpad3d_kernel_util.cpp b/oneflow/user/kernels/constantpad3d_kernel_util.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0282cd2d3d0e266fade148844e1de03d323808ad
--- /dev/null
+++ b/oneflow/user/kernels/constantpad3d_kernel_util.cpp
@@ -0,0 +1,64 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/user/kernels/constantpad3d_kernel_util.h"
+#include "oneflow/core/framework/framework.h"
+
+namespace oneflow {
+namespace user_op {
+
+template<typename IN_T>
+struct ConstantPad3dFunctor<DeviceType::kCPU, IN_T> final {
+  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
+                  const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape,
+                  const ShapeView& y_shape, const std::vector<int64_t>& padding,
+                  IN_T constant_value) {
+    // for NCDHW format input tensor, index of n, c, d, h, w is 0, 1, 2, 3, 4
+    const int64_t c_idx = 1;
+    const int64_t d_idx = 2;
+    const int64_t h_idx = 3;
+    const int64_t w_idx = 4;
+    // padding vector: [left, right, top, bottom, front, back]
+    DoConstantPad3d<IN_T>(src, dest, index_helper, y_shape.Count(0), y_shape.At(c_idx),
+                          y_shape.At(d_idx), y_shape.At(h_idx), y_shape.At(w_idx),
+                          x_shape.At(d_idx), x_shape.At(h_idx), x_shape.At(w_idx), padding[4],
+                          padding[0], padding[2], constant_value);
+  }
+};
+
+template<typename IN_T>
+struct ConstantPad3dGradFunctor<DeviceType::kCPU, IN_T> final {
+  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
+                  const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape,
+                  const ShapeView& dx_shape, const std::vector<int64_t>& padding) {
+    const int64_t c_idx = 1;
+    const int64_t d_idx = 2;
+    const int64_t h_idx = 3;
+    const int64_t w_idx = 4;
+    DoConstantPad3dGrad<IN_T>(src, dest, index_helper, dy_shape.Count(0), dy_shape.At(c_idx),
+                              dy_shape.At(d_idx), dy_shape.At(h_idx), dy_shape.At(w_idx),
+                              dx_shape.At(d_idx), dx_shape.At(h_idx), dx_shape.At(w_idx),
+                              padding[4], padding[0], padding[2]);
+  }
+};
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_FUNCTOR, (DeviceType::kCPU),
+                                 PADDING_DATA_TYPE_CPU_SEQ);
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR, (DeviceType::kCPU),
+                                 PADDING_DATA_TYPE_CPU_SEQ);
+
+}  // namespace user_op
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/constantpad3d_kernel_util.cu b/oneflow/user/kernels/constantpad3d_kernel_util.cu
new file mode 100644
index 0000000000000000000000000000000000000000..459177be8259dff5947d3a78d7f3a72bf1bb5f87
--- /dev/null
+++ b/oneflow/user/kernels/constantpad3d_kernel_util.cu
@@ -0,0 +1,125 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include <cstdint>
+#include "oneflow/core/common/data_type.h"
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/user/kernels/constantpad3d_kernel_util.h"
+
+namespace oneflow {
+namespace user_op {
+
+template<typename IN_T>
+__global__ void DoCUDAConstantPad3d(const IN_T* src, IN_T* dest,
+                                    const NdIndexOffsetHelper<int64_t, 5> index_helper,
+                                    int64_t elem_num, int64_t n_channel, int64_t y_depth,
+                                    int64_t y_height, int64_t y_width, int64_t x_depth,
+                                    int64_t x_height, int64_t x_width, int64_t pad_front,
+                                    int64_t pad_left, int64_t pad_top, const IN_T const_value) {
+  DoConstantPad3d<IN_T>(src, dest, index_helper, elem_num, n_channel, y_depth, y_height, y_width,
+                        x_depth, x_height, x_width, pad_front, pad_left, pad_top, const_value);
+};
+
+template<typename IN_T>
+__global__ void DoCUDAConstantPad3dGrad(const IN_T* src, IN_T* dest,
+                                        const NdIndexOffsetHelper<int64_t, 5> index_helper,
+                                        int64_t elem_num, int64_t n_channel, int64_t dy_depth,
+                                        int64_t dy_height, int64_t dy_width, int64_t dx_depth,
+                                        int64_t dx_height, int64_t dx_width, int64_t pad_front,
+                                        int64_t pad_left, int64_t pad_top) {
+  DoConstantPad3dGrad<IN_T>(src, dest, index_helper, elem_num, n_channel, dy_depth, dy_height,
+                            dy_width, dx_height, dx_depth, dx_width, pad_front, pad_left, pad_top);
+};
+
+template<typename IN_T>
+struct ConstantPad3dFunctor<DeviceType::kGPU, IN_T> final {
+  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
+                  const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape,
+                  const ShapeView& y_shape, const std::vector<int64_t>& padding,
+                  IN_T constant_value) {
+    const int64_t c_idx = 1;
+    const int64_t d_idx = 2;
+    const int64_t h_idx = 3;
+    const int64_t w_idx = 4;
+
+    DoCUDAConstantPad3d<IN_T><<<BlocksNum4ThreadsNum(y_shape.Count(0)), kCudaThreadsNumPerBlock, 0,
+                                ctx->cuda_stream()>>>(
+        src, dest, index_helper, y_shape.Count(0), y_shape.At(c_idx), y_shape.At(d_idx),
+        y_shape.At(h_idx), y_shape.At(w_idx), x_shape.At(d_idx), x_shape.At(h_idx),
+        x_shape.At(w_idx), padding[4], padding[0], padding[2], constant_value);
+  }
+};
+
+// float16 implementation
+template<>
+void ConstantPad3dFunctor<DeviceType::kGPU, float16>::operator()(
+    DeviceCtx* ctx, const float16* src, float16* dest,
+    const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape,
+    const ShapeView& y_shape, const std::vector<int64_t>& padding, float16 constant_value) {
+  const int64_t c_idx = 1;
+  const int64_t d_idx = 2;
+  const int64_t h_idx = 3;
+  const int64_t w_idx = 4;
+  DoCUDAConstantPad3d<half>
+      <<<BlocksNum4ThreadsNum(y_shape.Count(0)), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper,
+          y_shape.Count(0), y_shape.At(c_idx), y_shape.At(d_idx), y_shape.At(h_idx),
+          y_shape.At(w_idx), x_shape.At(d_idx), x_shape.At(h_idx), x_shape.At(w_idx), padding[4],
+          padding[0], padding[2], static_cast<const half>(constant_value));
+}
+
+template<typename IN_T>
+struct ConstantPad3dGradFunctor<DeviceType::kGPU, IN_T> final {
+  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
+                  const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape,
+                  const ShapeView& dx_shape, const std::vector<int64_t>& padding) {
+    const int64_t c_idx = 1;
+    const int64_t d_idx = 2;
+    const int64_t h_idx = 3;
+    const int64_t w_idx = 4;
+    DoCUDAConstantPad3dGrad<IN_T><<<BlocksNum4ThreadsNum(dy_shape.Count(0)),
+                                    kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+        src, dest, index_helper, dy_shape.Count(0), dy_shape.At(c_idx), dy_shape.At(d_idx),
+        dy_shape.At(h_idx), dy_shape.At(w_idx), dx_shape.At(d_idx), dx_shape.At(h_idx),
+        dx_shape.At(w_idx), padding[4], padding[0], padding[2]);
+  }
+};
+
+// float16 implementation
+template<>
+void ConstantPad3dGradFunctor<DeviceType::kGPU, float16>::operator()(
+    DeviceCtx* ctx, const float16* src, float16* dest,
+    const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape,
+    const ShapeView& dx_shape, const std::vector<int64_t>& padding) {
+  const int64_t c_idx = 1;
+  const int64_t d_idx = 2;
+  const int64_t h_idx = 3;
+  const int64_t w_idx = 4;
+  DoCUDAConstantPad3dGrad<half>
+      <<<BlocksNum4ThreadsNum(dy_shape.Count(0)), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper,
+          dy_shape.Count(0), dy_shape.At(c_idx), dy_shape.At(d_idx), dy_shape.At(h_idx),
+          dy_shape.At(w_idx), dx_shape.At(d_idx), dx_shape.At(h_idx), dx_shape.At(w_idx),
+          padding[4], padding[0], padding[2]);
+}
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_FUNCTOR,
+                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR,
+                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
+
+}  // namespace user_op
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/constantpad3d_kernel_util.h b/oneflow/user/kernels/constantpad3d_kernel_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..0839dd5b8b00c4a9d72ef61b1eb595f26405daa7
--- /dev/null
+++ b/oneflow/user/kernels/constantpad3d_kernel_util.h
@@ -0,0 +1,122 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#ifndef ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_
+#define ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_
+#ifdef WITH_CUDA
+#include "oneflow/core/cuda/atomic.cuh"
+#endif  // WITH_CUDA
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/ndarray/xpu_util.h"
+
+namespace oneflow {
+
+#define PADDING_DATA_TYPE_CPU_SEQ \
+  FLOATING_DATA_TYPE_SEQ          \
+  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
+
+#define PADDING_DATA_TYPE_GPU_SEQ \
+  FLOAT16_DATA_TYPE_SEQ           \
+  PADDING_DATA_TYPE_CPU_SEQ
+
+namespace user_op {
+
+template<typename T>
+struct DeviceAdd {
+  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {
+#if defined(__CUDA_ARCH__)
+    cuda::atomic::Add(y, *x);
+#else
+    *y += *x;
+#endif
+  };
+};
+
+template<DeviceType device_type, typename IN_T>
+struct ConstantPad3dFunctor final {
+  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
+                  const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& x_shape,
+                  const ShapeView& y_shape, const std::vector<int64_t>& padding,
+                  IN_T constant_value);
+};
+
+template<DeviceType device_type, typename IN_T>
+struct ConstantPad3dGradFunctor final {
+  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
+                  const NdIndexOffsetHelper<int64_t, 5>& index_helper, const ShapeView& dy_shape,
+                  const ShapeView& dx_shape, const std::vector<int64_t>& padding);
+};
+
+template<typename IN_T>
+OF_DEVICE_FUNC void DoConstantPad3d(const IN_T* src, IN_T* dest,
+                                    const NdIndexOffsetHelper<int64_t, 5>& index_helper,
+                                    int64_t elem_num, int64_t n_channel, int64_t y_depth,
+                                    int64_t y_height, int64_t y_width, int64_t x_depth,
+                                    int64_t x_height, int64_t x_width, int64_t pad_front,
+                                    int64_t pad_left, int64_t pad_top, IN_T constant_value) {
+  XPU_1D_KERNEL_LOOP(num, elem_num) {
+    int64_t n, c, d, h, w;
+    index_helper.OffsetToNdIndex(num, n, c, d, h, w);
+
+    const int64_t src_num = n_channel * x_depth * x_height * x_width;
+    if (pad_front <= d && d < pad_front + x_depth && w >= pad_left && w < x_width + pad_left
+        && h >= pad_top && h < x_height + pad_top) {
+      const int64_t len_w = w - pad_left;
+      const int64_t len_h = h - pad_top;
+      const int64_t len_d = d - pad_front;
+      const int64_t src_index = n * src_num + c * x_depth * x_width * x_height
+                                + len_d * x_height * x_width + len_h * x_width + len_w;
+      dest[num] = src[src_index];
+    } else {
+      dest[num] = constant_value;
+    }
+  }
+}
+
+template<typename IN_T>
+OF_DEVICE_FUNC void DoConstantPad3dGrad(const IN_T* src, IN_T* dest,
+                                        const NdIndexOffsetHelper<int64_t, 5>& index_helper,
+                                        int64_t elem_num, int64_t n_channel, int64_t dy_depth,
+                                        int64_t dy_height, int64_t dy_width, int64_t dx_depth,
+                                        int64_t dx_height, int64_t dx_width, int64_t pad_front,
+                                        int64_t pad_left, int64_t pad_top) {
+  XPU_1D_KERNEL_LOOP(num, elem_num) {
+    int64_t n, c, d, h, w;
+    index_helper.OffsetToNdIndex(num, n, c, d, h, w);
+
+    const int64_t dest_num = n_channel * dx_depth * dx_height * dx_width;
+    if (pad_front <= d && d < pad_front + dx_depth && w >= pad_left && w < dx_width + pad_left
+        && h >= pad_top && h < dx_height + pad_top) {
+      const int64_t len_d = d - pad_front;
+      const int64_t len_w = w - pad_left;
+      const int64_t len_h = h - pad_top;
+      const int64_t dest_index = n * dest_num + c * dx_depth * dx_width * dx_height
+                                 + len_d * dx_width * dx_height + len_h * dx_width + len_w;
+
+      DeviceAdd<IN_T>::Invoke(src + num, dest + dest_index);
+    }
+  }
+}
+
+#define INSTANTIATE_CONSTANT_PAD3D_FUNCTOR(device_type_v, dtype_pair) \
+  template struct ConstantPad3dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
+
+#define INSTANTIATE_CONSTANT_PAD3D_GRAD_FUNCTOR(device_type_v, dtype_pair) \
+  template struct ConstantPad3dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
+
+}  // namespace user_op
+}  // namespace oneflow
+
+#endif  // ONEFLOW_USER_KERNELS_PAD3D_KERNELS_UTIL_H_
diff --git a/oneflow/user/ops/pad2d_ops.cpp b/oneflow/user/ops/padding_ops.cpp
similarity index 75%
rename from oneflow/user/ops/pad2d_ops.cpp
rename to oneflow/user/ops/padding_ops.cpp
index a64e0092d81a8027d1d57d5b3187ded5bfb1f734..4e84c86a514ffe816eff1a349934eae40e760e1e 100644
--- a/oneflow/user/ops/pad2d_ops.cpp
+++ b/oneflow/user/ops/padding_ops.cpp
@@ -343,4 +343,113 @@ REGISTER_USER_OP_GRAD("constant_pad2d")
       return Maybe<void>::Ok();
     });
 
+REGISTER_USER_OP("constant_pad3d")
+    .Input("x")
+    .Output("y")
+    .Attr<std::vector<int64_t>>("padding")
+    .Attr<double>("floating_value")
+    .Attr<int64_t>("integral_value")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& x_shape = ctx->InputShape("x", 0);
+      const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
+      CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 5);
+      // only support NCDHW format input tensor for now !
+      // for NCDHW format, index of num,channel,depth,height,width is 0,1,2,3,4
+      const int64_t n_idx = 0;
+      const int64_t c_idx = 1;
+      const int64_t d_idx = 2;
+      const int64_t h_idx = 3;
+      const int64_t w_idx = 4;
+
+      DimVector y_dim_vec(x_shape.NumAxes());
+      const int64_t d_x = x_shape.At(d_idx);
+      const int64_t h_x = x_shape.At(h_idx);
+      const int64_t w_x = x_shape.At(w_idx);
+
+      y_dim_vec[n_idx] = x_shape.At(n_idx);
+      y_dim_vec[c_idx] = x_shape.At(c_idx);
+      y_dim_vec[d_idx] = d_x + padding[4] + padding[5];
+      y_dim_vec[h_idx] = h_x + padding[2] + padding[3];
+      y_dim_vec[w_idx] = w_x + padding[0] + padding[1];
+
+      *ctx->OutputShape("y", 0) = Shape(y_dim_vec);
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn(GetOpSbpSignature)
+    .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn,
+                            const user_op::UserOpConfWrapper&) -> Maybe<void> {
+      user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0);
+      CHECK_NOTNULL_OR_RETURN(x_modifier);
+      x_modifier->set_requires_grad(true);
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("constant_pad3d_grad")
+    .Input("dy")
+    .Output("dx")
+    .Attr<std::vector<int64_t>>("padding")
+    .Attr<double>("floating_value")
+    .Attr<int64_t>("integral_value")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
+      CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 5);
+      const int64_t n_idx = 0;
+      const int64_t c_idx = 1;
+      const int64_t d_idx = 2;
+      const int64_t h_idx = 3;
+      const int64_t w_idx = 4;
+
+      DimVector dx_dim_vec(dy_shape.NumAxes());
+      int64_t d_dy, h_dy, w_dy;
+      d_dy = dy_shape.At(d_idx);
+      h_dy = dy_shape.At(h_idx);
+      w_dy = dy_shape.At(w_idx);
+
+      dx_dim_vec[n_idx] = dy_shape.At(0);
+      dx_dim_vec[c_idx] = dy_shape.At(1);
+      dx_dim_vec[d_idx] = d_dy - padding[4] - padding[5];
+      dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3];
+      dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1];
+
+      *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec);
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn(GetOpGradSbpSignature)
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP_GRAD("constant_pad3d")
+    .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op,
+                               user_op::AddOpFn AddOp) -> Maybe<void> {
+      if (op.NeedGenGradTensor4OpInput("x", 0)) {
+        user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
+        user_op::UserOpConfWrapper grad_op =
+            builder.Op("constant_pad3d_grad")
+                .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
+                .Output("dx")
+                .Attr("padding", op.attr<std::vector<int64_t>>("padding"))
+                .Attr("floating_value", op.attr<double>("floating_value"))
+                .Attr("integral_value", op.attr<int64_t>("integral_value"))
+                .Build();
+        op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
+        AddOp(grad_op);
+      }
+      return Maybe<void>::Ok();
+    });
+
 }  // namespace oneflow