diff --git a/oneflow/python/ops/pad.py b/oneflow/python/ops/pad.py
index b6d5a5d718dc05e7f49b5bad843b8e1a453750f1..9602ecbc180e1aacac637abc828aaacba2f6cec3 100644
--- a/oneflow/python/ops/pad.py
+++ b/oneflow/python/ops/pad.py
@@ -317,3 +317,84 @@ def reflection_pad2d(
         .InferAndTryRun()
         .RemoteBlobList()[0]
     )
+
+
+@oneflow_export("replication_pad2d")
+def replication_pad2d(
+    x: oneflow_api.BlobDesc,
+    padding: Union[int, tuple, list],
+    name: Optional[str] = None,
+) -> oneflow_api.BlobDesc:
+    """Pads the input tensor using the replication of the input boundary. 
+
+    Args:
+        x (oneflow_api.BlobDesc): input blob, only support "NCHW" format.
+        padding (Union[int, oneflow_api.BlobDesc]): The size or bundary of padding, if is int uses the same padding in all dimension;
+        if 4-dims tuple, uses (\text{padding\_left}padding_left , \text{padding\_right}padding_right , \text{padding\_top}padding_top , \text{padding\_bottom}padding_bottom )
+        name (Optional[str], optional): The name for the operation. Defaults to None.
+
+    Returns:
+        oneflow_api.BlobDesc: [description]
+
+    For example: 
+
+    .. code-block:: python 
+
+        import oneflow as flow
+        import oneflow.typing as tp
+        import numpy as np 
+
+
+        @flow.global_function()
+        def pad_Job(x: tp.Numpy.Placeholder((1, 2, 3, 3))
+        ) -> tp.Numpy:
+            return flow.reflection_pad2d(x, padding=[2, 2, 1, 1])
+
+
+        x = np.arange(18).reshape((1, 2, 3, 3)).astype(np.float32)
+        out = pad_Job(x)
+
+        # out [[[[ 0.  0.  0.  1.  2.  2.  2.]
+        #    [ 0.  0.  0.  1.  2.  2.  2.]
+        #    [ 3.  3.  3.  4.  5.  5.  5.]
+        #    [ 6.  6.  6.  7.  8.  8.  8.]
+        #    [ 6.  6.  6.  7.  8.  8.  8.]]
+
+        #   [[9.  9.  9.  10.  11.  11.  11.]
+        #    [9.  9.  9.  10.  11.  11.  11.]
+        #    [12.  12.  12.  13.  14.  14.  14.]
+        #    [15.  15.  15.  16.  17.  17.  17.]
+        #    [15.  15.  15.  16.  17.  17.  17.]]]]
+
+    """
+    H, W = x.shape[2], x.shape[3]
+    if isinstance(padding, (tuple, list)):
+        assert len(padding) == len(x.shape), ValueError(
+            "padding boundry must be the same size of input dims"
+        )
+        assert (
+            padding[2] < H and padding[3] < H and padding[0] < W and padding[1] < W
+        ), ValueError(
+            "Padding size should be less than the corresponding input dimension!"
+        )
+        boundry = [padding[0], padding[1], padding[2], padding[3]]
+    elif isinstance(padding, int):
+        assert padding < H and padding < W, ValueError(
+            "Padding size should be less than the corresponding input dimension!"
+        )
+        boundry = [padding, padding, padding, padding]
+    else:
+        raise ValueError("padding must be in or list or tuple!")
+
+    return (
+        oneflow.user_op_builder(
+            name if name is not None else id_util.UniqueStr("Replication_Pad2d")
+        )
+        .Op("replication_pad2d")
+        .Input("x", [x])
+        .Output("y")
+        .Attr("padding", list(boundry))
+        .Build()
+        .InferAndTryRun()
+        .RemoteBlobList()[0]
+    )
diff --git a/oneflow/python/test/ops/test_replication_pad2d.py b/oneflow/python/test/ops/test_replication_pad2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..e23faec580d0ee0f2a495c5949c3d8948bf0e0e3
--- /dev/null
+++ b/oneflow/python/test/ops/test_replication_pad2d.py
@@ -0,0 +1,329 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+import os
+from collections import OrderedDict
+
+import numpy as np
+import oneflow as flow
+import oneflow.typing as tp
+from test_util import Args, GenArgDict, GenArgList
+
+
+def flatten_array(input_array):
+    output_array = list()
+    for x in np.nditer(input_array):
+        output_array.append(x.tolist())
+    return output_array
+
+
+def array_to_numpy(input_array, target_shape):
+    return np.array(input_array).reshape(target_shape, order="C")
+
+
+def index2coordinate(idx, tensor_shape):
+    coordinate = []
+    tmp = idx
+    for i in range(len(tensor_shape) - 1, -1, -1):
+        axis_size = tensor_shape[i]
+        coor = tmp % axis_size
+        coordinate.insert(0, int(coor))
+        tmp = (tmp - coor) / axis_size
+    return coordinate
+
+
+def coordinate2index(coordinate, tensor_shape):
+    if len(coordinate) != len(tensor_shape):
+        raise "wrong coordinate or shape"
+    idx = 0
+    for i, coor in enumerate(coordinate):
+        size_at_axis = coor
+        for j in range(i + 1, len(tensor_shape)):
+            size_at_axis *= tensor_shape[j]
+
+        idx += size_at_axis
+    return idx
+
+
+def _make_op_function(
+    test_case,
+    input,
+    padding,
+    grad,
+    device_type,
+    value_type,
+    machine_ids,
+    device_counts,
+):
+    flow.clear_default_session()
+    if device_type == "cpu":
+        flow.config.cpu_device_num(device_counts)
+    else:
+        flow.config.gpu_device_num(device_counts)
+
+    func_config = flow.FunctionConfig()
+
+    # global function needs float32 as type of argument and return value
+    if value_type == flow.float16:
+        func_config.default_data_type(flow.float32)
+    else:
+        func_config.default_data_type(value_type)
+
+    func_config.default_placement_scope(flow.scope.placement(device_type, machine_ids))
+    func_config.default_logical_view(flow.scope.consistent_view())
+
+    def _compare_diff(blob: tp.Numpy):
+        test_case.assertTrue(np.allclose(grad, blob, 1e-3, 1e-3))
+
+    if value_type == flow.float32 or value_type == flow.float64:
+
+        @flow.global_function(type="train", function_config=func_config)
+        def op_function(x: tp.Numpy.Placeholder(input.shape, dtype=value_type)):
+            with flow.scope.placement(device_type, "0:0"):
+                x += flow.get_variable(
+                    name="input",
+                    shape=input.shape,
+                    dtype=value_type,
+                    initializer=flow.zeros_initializer(),
+                )
+                out = flow.replication_pad2d(x, padding)
+                flow.optimizer.SGD(
+                    flow.optimizer.PiecewiseConstantScheduler([], [0]), momentum=0
+                ).minimize(out)
+
+            flow.watch_diff(x, _compare_diff)
+            return out
+
+        return op_function
+
+    elif value_type == flow.int32:
+
+        @flow.global_function(type="train", function_config=func_config)
+        def op_function(x: tp.Numpy.Placeholder(input.shape, dtype=flow.float32)):
+            with flow.scope.placement(device_type, "0:0"):
+                x += flow.get_variable(
+                    name="input",
+                    shape=input.shape,
+                    dtype=flow.float32,
+                    initializer=flow.zeros_initializer(),
+                )
+                y_int32 = flow.replication_pad2d(x, padding)
+                y_fp32 = flow.cast(y_int32, dtype=flow.float32)
+                flow.optimizer.SGD(
+                    flow.optimizer.PiecewiseConstantScheduler([], [0]), momentum=0
+                ).minimize(y_fp32)
+
+            flow.watch_diff(x, _compare_diff)
+            return y_fp32
+
+        return op_function
+
+    elif value_type == flow.float16:
+
+        @flow.global_function(type="train", function_config=func_config)
+        def op_function(x: tp.Numpy.Placeholder(input.shape, dtype=flow.float32)):
+            with flow.scope.placement(device_type, "0:0"):
+                x_var = flow.get_variable(
+                    name="input",
+                    shape=input.shape,
+                    dtype=flow.float32,
+                    initializer=flow.constant_initializer(0),
+                )
+                x_var = flow.cast_to_current_logical_view(x_var)
+                input_x = x_var + x
+                x_fp32 = flow.cast(input_x, flow.float32)
+                x_fp16 = flow.cast(input_x, dtype=flow.float16)
+                y_fp16 = flow.replication_pad2d(x_fp16, padding)
+                y_fp32 = flow.cast(y_fp16, dtype=flow.float32)
+                flow.optimizer.SGD(
+                    flow.optimizer.PiecewiseConstantScheduler([], [0]), momentum=0
+                ).minimize(y_fp32)
+
+            flow.watch_diff(x_fp32, _compare_diff)
+            return y_fp32
+
+        return op_function
+
+
+def gen_numpy_test_sample(input_shape, padding, is_float=True):
+    c_idx, h_idx, w_idx = 1, 2, 3
+    pad_left = padding[0]
+    pad_right = padding[1]
+    pad_top = padding[2]
+    pad_bottom = padding[3]
+    pad_shape = ((0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right))
+
+    def _np_replication_pad2d(input, pad_shape):
+        numpy_replicate = np.pad(input, pad_shape, "edge")
+        return numpy_replicate
+
+    def _np_replication_pad2d_grad(src, dest):
+        dx_height, dx_width = input.shape[h_idx], input.shape[w_idx]
+        dy_height, dy_width = output.shape[h_idx], output.shape[w_idx]
+
+        numpy_src = np.ones(src.shape, np.int32)
+        numpy_dest = np.zeros(dest.shape, np.int32)
+        array_src = flatten_array(numpy_src)
+        array_dest = flatten_array(numpy_dest)
+
+        src_num = src.shape[c_idx] * src.shape[h_idx] * src.shape[w_idx]
+        dest_num = dest.shape[c_idx] * dest.shape[h_idx] * dest.shape[w_idx]
+        elements_num = src.shape[0] * src_num
+        for iter_n in range(elements_num):
+            coords = index2coordinate(iter_n, src.shape)
+            n, c, i, j = coords[0], coords[c_idx], coords[h_idx], coords[w_idx]
+            ip_x = ip_y = 0
+            if j < pad_left:
+                ip_x = pad_left
+            elif j >= pad_left and j < (dx_width + pad_left):
+                ip_x = j
+            else:
+                ip_x = dx_width + pad_left - 1
+
+            if i < pad_top:
+                ip_y = pad_top
+            elif i >= pad_top and i < (dx_height + pad_top):
+                ip_y = i
+            else:
+                ip_y = dx_height + pad_top - 1
+
+            ip_x = ip_x - pad_left
+            ip_y = ip_y - pad_top
+            src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j
+            dest_index = (
+                n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x
+            )
+            array_dest[dest_index] += array_src[src_index]
+
+        numpy_dest = array_to_numpy(array_dest, dest.shape)
+        return numpy_dest
+
+    if is_float:
+        input = np.random.random(input_shape).astype(np.float32)
+    else:
+        input = np.random.randint(0, 100, input_shape)
+
+    output = _np_replication_pad2d(input, pad_shape)
+    grad = _np_replication_pad2d_grad(output, input)
+
+    numpy_results = {
+        "input": input,
+        "padding": padding,
+        "output": output,
+        "grad": grad,
+    }
+
+    return numpy_results
+
+
+def _compare_op_function_with_samples(
+    test_case, device_type, sample, value_type, machine_ids, device_count
+):
+    op_function = _make_op_function(
+        test_case,
+        sample["input"].astype(value_type[0]),
+        sample["padding"],
+        sample["grad"].astype(value_type[0]),
+        device_type,
+        value_type[1],
+        machine_ids,
+        device_count,
+    )
+    y = (
+        op_function(sample["input"].astype(value_type[0]))
+        .get()
+        .numpy()
+        .astype(value_type[0])
+    )
+
+    if value_type == flow.float16:
+        test_case.assertTrue(
+            np.allclose(y, sample["output"].astype(np.float32), 1e-3, 1e-3)
+        )
+    else:
+        test_case.assertTrue(np.allclose(y, sample["output"].astype(value_type[0])))
+
+
+def _gen_arg_dict(
+    device_type="gpu", value_type="float", machine_ids="0:0", device_count=1
+):
+    arg_dict = OrderedDict()
+    arg_dict["device_type"] = [device_type]
+    arg_dict["samples"] = []
+    arg_dict["samples"].append(gen_numpy_test_sample((2, 1, 2, 2), [1, 1, 1, 1]))
+    arg_dict["samples"].append(gen_numpy_test_sample((4, 2, 3, 3), [2, 2, 2, 2]))
+    arg_dict["samples"].append(gen_numpy_test_sample((2, 3, 4, 5), [3, 2, 1, 2]))
+    if value_type == "float":
+        if device_type == "gpu":
+            arg_dict["value_type"] = [
+                (np.float32, flow.float32),
+                # (np.float16, flow.float16),
+            ]
+        else:
+            arg_dict["value_type"] = [(np.float32, flow.float32)]
+
+    elif value_type == "int":
+        arg_dict["value_type"] = [(np.float32, flow.int32)]
+    else:
+        raise "float or int for value type only"
+
+    arg_dict["machine_ids"] = [machine_ids]
+    arg_dict["device_count"] = [device_count]
+    return arg_dict
+
+
+@flow.unittest.skip_unless_1n1d()
+class TestReplicationPad2d1n1d(flow.unittest.TestCase):
+    def test_op_function_int_cpu(test_case):
+        arg_dict = _gen_arg_dict("cpu", "int", "0:0", 1)
+        for arg in GenArgList(arg_dict):
+            _compare_op_function_with_samples(test_case, *arg)
+
+    def test_op_function_float_cpu(test_case):
+        arg_dict = _gen_arg_dict("cpu", "float", "0:0", 1)
+        for arg in GenArgList(arg_dict):
+            _compare_op_function_with_samples(test_case, *arg)
+
+    @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
+    def test_op_function_int_gpu(test_case):
+        arg_dict = _gen_arg_dict("gpu", "int", "0:0", 1)
+        for arg in GenArgList(arg_dict):
+            _compare_op_function_with_samples(test_case, *arg)
+
+    @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
+    def test_op_function_float_gpu(test_case):
+        arg_dict = _gen_arg_dict("gpu", "float", "0:0", 1)
+        for arg in GenArgList(arg_dict):
+            _compare_op_function_with_samples(test_case, *arg)
+
+
+@flow.unittest.skip_unless_1n2d()
+class TestReplicationPad2d1n2d(flow.unittest.TestCase):
+    @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
+    def test_op_function_float(test_case):
+        arg_dict = _gen_arg_dict("gpu", "float", "0:0-1", 2)
+        for arg in GenArgList(arg_dict):
+            _compare_op_function_with_samples(test_case, *arg)
+
+    @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
+    def test_op_function_int(test_case):
+        arg_dict = _gen_arg_dict("gpu", "int", "0:0-1", 2)
+        for arg in GenArgList(arg_dict):
+            _compare_op_function_with_samples(test_case, *arg)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/user/kernels/pad2d_kernels.cpp b/oneflow/user/kernels/pad2d_kernels.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..87949e7c5246fe613a2cf45d9126e3b2a713e820
--- /dev/null
+++ b/oneflow/user/kernels/pad2d_kernels.cpp
@@ -0,0 +1,250 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/device/memory_copier.h"
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/user/kernels/pad2d_kernels_util.h"
+
+namespace oneflow {
+namespace user_op {
+
+template<DeviceType device_type, typename IN_T>
+class ReflectionPad2dKernel final : public OpKernel {
+ public:
+  ReflectionPad2dKernel() = default;
+  ~ReflectionPad2dKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext *ctx) const override {
+    const Tensor *x = ctx->Tensor4ArgNameAndIndex("x", 0);
+    Tensor *y = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const auto &padding = ctx->Attr<std::vector<int64_t>>("padding");
+    const int64_t ndims = x->shape().NumAxes();
+    CHECK_EQ(padding.size(), ndims);
+    const int64_t n_idx = 0;
+    const int64_t c_idx = 1;
+    const int64_t h_idx = 2;
+    const int64_t w_idx = 3;
+
+    const int64_t pad_left = padding[0];
+    const int64_t pad_top = padding[2];
+
+    const int64_t n_batch = y->shape().At(n_idx);
+    const int64_t n_channel = y->shape().At(c_idx);
+    const int64_t y_height = y->shape().At(h_idx);
+    const int64_t y_width = y->shape().At(w_idx);
+    const int64_t x_height = x->shape().At(h_idx);
+    const int64_t x_width = x->shape().At(w_idx);
+
+    IN_T *dest = y->mut_dptr<IN_T>();
+    const IN_T *src = x->dptr<IN_T>();
+    DimVector y_vector;
+    y->shape().ToDimVector(&y_vector);
+    NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());
+
+    ReflectionPad2dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, n_batch,
+                                                n_channel, y_height, y_width, x_height, x_width,
+                                                pad_left, pad_top);
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<DeviceType device_type, typename IN_T>
+class ReflectionPad2dGradKernel final : public OpKernel {
+ public:
+  ReflectionPad2dGradKernel() = default;
+  ~ReflectionPad2dGradKernel() = default;
+
+ private:
+  void Compute(KernelComputeContext *ctx) const override {
+    const Tensor *dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    Tensor *dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    const auto &padding = ctx->Attr<std::vector<int64_t>>("padding");
+    const int64_t ndims = dy->shape().NumAxes();
+    CHECK_EQ(padding.size(), ndims);
+
+    const int64_t n_idx = 0;
+    const int64_t c_idx = 1;
+    const int64_t h_idx = 2;
+    const int64_t w_idx = 3;
+
+    int64_t pad_left = padding[0];
+    int64_t pad_top = padding[2];
+    int64_t n_batch = dy->shape().At(n_idx);
+    int64_t n_channel = dy->shape().At(c_idx);
+    int64_t dy_height = dy->shape().At(h_idx);
+    int64_t dy_width = dy->shape().At(w_idx);
+    int64_t dx_height = dx->shape().At(h_idx);
+    int64_t dx_width = dx->shape().At(w_idx);
+
+    const IN_T *src = dy->dptr<IN_T>();
+    IN_T *dest = dx->mut_dptr<IN_T>();
+    DimVector dy_vector;
+    dy->shape().ToDimVector(&dy_vector);
+    NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data());
+
+    size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type());
+    Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
+
+    ReflectionPad2dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper,
+                                                    n_batch, n_channel, dy_height, dy_width,
+                                                    dx_height, dx_width, pad_left, pad_top);
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_REFLECTION_PAD2D_KERNELS(device, dtype)                               \
+  REGISTER_USER_KERNEL("reflection_pad2d")                                             \
+      .SetCreateFn<ReflectionPad2dKernel<device, dtype>>()                             \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("reflection_pad2d_grad")                                        \
+      .SetCreateFn<ReflectionPad2dGradKernel<device, dtype>>()                         \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+#define REGISTER_REFLECTION_PAD2D_WITH_DEVICE(device) \
+  REGISTER_REFLECTION_PAD2D_KERNELS(device, float)    \
+  REGISTER_REFLECTION_PAD2D_KERNELS(device, double)   \
+  REGISTER_REFLECTION_PAD2D_KERNELS(device, int32_t)
+
+REGISTER_REFLECTION_PAD2D_WITH_DEVICE(DeviceType::kCPU)
+#ifdef WITH_CUDA
+REGISTER_REFLECTION_PAD2D_WITH_DEVICE(DeviceType::kGPU)
+REGISTER_REFLECTION_PAD2D_KERNELS(DeviceType::kGPU, float16)
+#endif
+
+template<DeviceType device_type, typename IN_T>
+class ReplicationPad2dKernel final : public OpKernel {
+ public:
+  ReplicationPad2dKernel() = default;
+  ~ReplicationPad2dKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext *ctx) const override {
+    const Tensor *x = ctx->Tensor4ArgNameAndIndex("x", 0);
+    Tensor *y = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const auto &padding = ctx->Attr<std::vector<int64_t>>("padding");
+    const int64_t ndims = x->shape().NumAxes();
+    CHECK_EQ(padding.size(), ndims);
+    const int64_t n_idx = 0;
+    const int64_t c_idx = 1;
+    const int64_t h_idx = 2;
+    const int64_t w_idx = 3;
+
+    const int64_t pad_left = padding[0];
+    const int64_t pad_top = padding[2];
+
+    const int64_t n_batch = y->shape().At(n_idx);
+    const int64_t n_channel = y->shape().At(c_idx);
+    const int64_t y_height = y->shape().At(h_idx);
+    const int64_t y_width = y->shape().At(w_idx);
+    const int64_t x_height = x->shape().At(h_idx);
+    const int64_t x_width = x->shape().At(w_idx);
+
+    IN_T *dest = y->mut_dptr<IN_T>();
+    const IN_T *src = x->dptr<IN_T>();
+    DimVector y_vector;
+    y->shape().ToDimVector(&y_vector);
+    NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());
+
+    ReplicationPad2dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper,
+                                                 n_batch, n_channel, y_height, y_width, x_height,
+                                                 x_width, pad_left, pad_top);
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<DeviceType device_type, typename IN_T>
+class ReplicationPad2dGradKernel final : public OpKernel {
+ public:
+  ReplicationPad2dGradKernel() = default;
+  ~ReplicationPad2dGradKernel() = default;
+
+ private:
+  void Compute(KernelComputeContext *ctx) const override {
+    const Tensor *dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    Tensor *dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    const auto &padding = ctx->Attr<std::vector<int64_t>>("padding");
+    const int64_t ndims = dy->shape().NumAxes();
+    CHECK_EQ(padding.size(), ndims);
+
+    const int64_t n_idx = 0;
+    const int64_t c_idx = 1;
+    const int64_t h_idx = 2;
+    const int64_t w_idx = 3;
+
+    int64_t pad_left = padding[0];
+    int64_t pad_top = padding[2];
+    int64_t n_batch = dy->shape().At(n_idx);
+    int64_t n_channel = dy->shape().At(c_idx);
+    int64_t dy_height = dy->shape().At(h_idx);
+    int64_t dy_width = dy->shape().At(w_idx);
+    int64_t dx_height = dx->shape().At(h_idx);
+    int64_t dx_width = dx->shape().At(w_idx);
+
+    const IN_T *src = dy->dptr<IN_T>();
+    IN_T *dest = dx->mut_dptr<IN_T>();
+    DimVector dy_vector;
+    dy->shape().ToDimVector(&dy_vector);
+    NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data());
+
+    size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type());
+    Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
+
+    ReplicationPad2dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper,
+                                                     n_batch, n_channel, dy_height, dy_width,
+                                                     dx_height, dx_width, pad_left, pad_top);
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_REPLICATION_PAD2D_KERNELS(device, dtype)                              \
+  REGISTER_USER_KERNEL("replication_pad2d")                                            \
+      .SetCreateFn<ReplicationPad2dKernel<device, dtype>>()                            \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("replication_pad2d_grad")                                       \
+      .SetCreateFn<ReplicationPad2dGradKernel<device, dtype>>()                        \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+#define REGISTER_REPLICATION_PAD2D_WITH_DEVICE(device) \
+  REGISTER_REPLICATION_PAD2D_KERNELS(device, float)    \
+  REGISTER_REPLICATION_PAD2D_KERNELS(device, double)   \
+  REGISTER_REPLICATION_PAD2D_KERNELS(device, int32_t)
+
+REGISTER_REPLICATION_PAD2D_WITH_DEVICE(DeviceType::kCPU)
+#ifdef WITH_CUDA
+REGISTER_REPLICATION_PAD2D_WITH_DEVICE(DeviceType::kGPU)
+REGISTER_REPLICATION_PAD2D_KERNELS(DeviceType::kGPU, float16)
+#endif
+
+}  // namespace user_op
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/pad2d_kernels_util.cpp b/oneflow/user/kernels/pad2d_kernels_util.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..58ab1cd3ac6277ed8c3a2287438165c28e9f2205
--- /dev/null
+++ b/oneflow/user/kernels/pad2d_kernels_util.cpp
@@ -0,0 +1,92 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/user/kernels/pad2d_kernels_util.h"
+#include "oneflow/core/framework/framework.h"
+
+namespace oneflow {
+
+namespace user_op {
+
+template<typename IN_T>
+struct ReflectionPad2dFunctor<DeviceType::kCPU, IN_T> final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height,
+                  int64_t x_width, int64_t pad_left, int64_t pad_top) {
+    int64_t dest_num = n_channel * y_height * y_width;
+    int64_t src_num = n_channel * x_height * x_width;
+    int64_t elem_num = n_batch * dest_num;
+    DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width,
+                            x_height, x_width, pad_left, pad_top);
+  }
+};
+
+template<typename IN_T>
+struct ReflectionPad2dGradFunctor<DeviceType::kCPU, IN_T> final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                  int64_t dx_width, int64_t pad_left, int64_t pad_top) {
+    int64_t dest_num = n_channel * dx_height * dx_width;
+    int64_t src_num = n_channel * dy_height * dy_width;
+    int64_t elem_num = n_batch * src_num;
+    DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,
+                                dy_width, dx_height, dx_width, pad_left, pad_top);
+  }
+};
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_FUNCTOR, (DeviceType::kCPU),
+                                 PADDING_DATA_TYPE_CPU_SEQ);
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR, (DeviceType::kCPU),
+                                 PADDING_DATA_TYPE_CPU_SEQ);
+
+template<typename IN_T>
+struct ReplicationPad2dFunctor<DeviceType::kCPU, IN_T> final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height,
+                  int64_t x_width, int64_t pad_left, int64_t pad_top) {
+    int64_t dest_num = n_channel * y_height * y_width;
+    int64_t src_num = n_channel * x_height * x_width;
+    int64_t elem_num = n_batch * dest_num;
+    DoReplicationPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height,
+                             y_width, x_height, x_width, pad_left, pad_top);
+  }
+};
+
+template<typename IN_T>
+struct ReplicationPad2dGradFunctor<DeviceType::kCPU, IN_T> final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                  int64_t dx_width, int64_t pad_left, int64_t pad_top) {
+    int64_t dest_num = n_channel * dx_height * dx_width;
+    int64_t src_num = n_channel * dy_height * dy_width;
+    int64_t elem_num = n_batch * src_num;
+    DoReplicationPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,
+                                 dy_width, dx_height, dx_width, pad_left, pad_top);
+  }
+};
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD2D_FUNCTOR, (DeviceType::kCPU),
+                                 PADDING_DATA_TYPE_CPU_SEQ);
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD2D_GRAD_FUNCTOR, (DeviceType::kCPU),
+                                 PADDING_DATA_TYPE_CPU_SEQ);
+
+}  // namespace user_op
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/pad2d_kernels_util.cu b/oneflow/user/kernels/pad2d_kernels_util.cu
new file mode 100644
index 0000000000000000000000000000000000000000..0dbbfb7ff6448e918724db09479a08863665590c
--- /dev/null
+++ b/oneflow/user/kernels/pad2d_kernels_util.cu
@@ -0,0 +1,208 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include <cstdint>
+#ifdef WITH_CUDA
+#include "oneflow/core/common/data_type.h"
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/user/kernels/pad2d_kernels_util.h"
+
+namespace oneflow {
+namespace user_op {
+
+template<typename IN_T>
+__global__ void DoCUDAReflectionPad2d(const IN_T *src, IN_T *dest,
+                                      const NdIndexOffsetHelper<int64_t, 4> index_helper,
+                                      int64_t elem_num, int64_t src_num, int64_t dest_num,
+                                      int64_t y_height, int64_t y_width, int64_t x_height,
+                                      int64_t x_width, int64_t pad_left, int64_t pad_top) {
+  DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width,
+                          x_height, x_width, pad_left, pad_top);
+};
+
+template<typename IN_T>
+__global__ void DoCUDAReflectionPad2dGrad(const IN_T *src, IN_T *dest,
+                                          const NdIndexOffsetHelper<int64_t, 4> index_helper,
+                                          int64_t elem_num, int64_t src_num, int64_t dest_num,
+                                          int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                                          int64_t dx_width, int64_t pad_left, int64_t pad_top) {
+  DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,
+                              dy_width, dx_height, dx_width, pad_left, pad_top);
+};
+
+template<typename IN_T>
+struct ReflectionPad2dFunctor<DeviceType::kGPU, IN_T> final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height,
+                  int64_t x_width, int64_t pad_left, int64_t pad_top) {
+    int64_t dest_num = n_channel * y_height * y_width;
+    int64_t src_num = n_channel * x_height * x_width;
+    int64_t elem_num = n_batch * dest_num;
+    DoCUDAReflectionPad2d<IN_T>
+        <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+            src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height,
+            x_width, pad_left, pad_top);
+  }
+};
+
+// float16 implementation
+template<>
+void ReflectionPad2dFunctor<DeviceType::kGPU, float16>::operator()(
+    DeviceCtx *ctx, const float16 *src, float16 *dest,
+    const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, int64_t n_channel,
+    int64_t y_height, int64_t y_width, int64_t x_height, int64_t x_width, int64_t pad_left,
+    int64_t pad_top) {
+  int64_t dest_num = n_channel * y_height * y_width;
+  int64_t src_num = n_channel * x_height * x_width;
+  int64_t elem_num = n_batch * dest_num;
+  DoCUDAReflectionPad2d<half>
+      <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          reinterpret_cast<const half *>(src), reinterpret_cast<half *>(dest), index_helper,
+          elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top);
+}
+
+template<typename IN_T>
+struct ReflectionPad2dGradFunctor<DeviceType::kGPU, IN_T> final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                  int64_t dx_width, int64_t pad_left, int64_t pad_top) {
+    int64_t dest_num = n_channel * dx_height * dx_width;
+    int64_t src_num = n_channel * dy_height * dy_width;
+    int64_t elem_num = n_batch * src_num;
+    DoCUDAReflectionPad2dGrad<IN_T>
+        <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+            src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height,
+            dx_width, pad_left, pad_top);
+  }
+};
+
+// float16 implementation
+template<>
+void ReflectionPad2dGradFunctor<DeviceType::kGPU, float16>::operator()(
+    DeviceCtx *ctx, const float16 *src, float16 *dest,
+    const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, int64_t n_channel,
+    int64_t dy_height, int64_t dy_width, int64_t dx_height, int64_t dx_width, int64_t pad_left,
+    int64_t pad_top) {
+  int64_t dest_num = n_channel * dx_height * dx_width;
+  int64_t src_num = n_channel * dy_height * dy_width;
+  int64_t elem_num = n_batch * src_num;
+  DoCUDAReflectionPad2dGrad<half>
+      <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          reinterpret_cast<const half *>(src), reinterpret_cast<half *>(dest), index_helper,
+          elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top);
+}
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_FUNCTOR,
+                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR,
+                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
+
+template<typename IN_T>
+__global__ void DoCUDAReplicationPad2d(const IN_T *src, IN_T *dest,
+                                       const NdIndexOffsetHelper<int64_t, 4> index_helper,
+                                       int64_t elem_num, int64_t src_num, int64_t dest_num,
+                                       int64_t y_height, int64_t y_width, int64_t x_height,
+                                       int64_t x_width, int64_t pad_left, int64_t pad_top) {
+  DoReplicationPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width,
+                           x_height, x_width, pad_left, pad_top);
+};
+
+template<typename IN_T>
+__global__ void DoCUDAReplicationPad2dGrad(const IN_T *src, IN_T *dest,
+                                           const NdIndexOffsetHelper<int64_t, 4> index_helper,
+                                           int64_t elem_num, int64_t src_num, int64_t dest_num,
+                                           int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                                           int64_t dx_width, int64_t pad_left, int64_t pad_top) {
+  DoReplicationPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,
+                               dy_width, dx_height, dx_width, pad_left, pad_top);
+};
+
+template<typename IN_T>
+struct ReplicationPad2dFunctor<DeviceType::kGPU, IN_T> final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height,
+                  int64_t x_width, int64_t pad_left, int64_t pad_top) {
+    int64_t dest_num = n_channel * y_height * y_width;
+    int64_t src_num = n_channel * x_height * x_width;
+    int64_t elem_num = n_batch * dest_num;
+    DoCUDAReplicationPad2d<IN_T>
+        <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+            src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height,
+            x_width, pad_left, pad_top);
+  }
+};
+
+// float16 implementation
+template<>
+void ReplicationPad2dFunctor<DeviceType::kGPU, float16>::operator()(
+    DeviceCtx *ctx, const float16 *src, float16 *dest,
+    const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, int64_t n_channel,
+    int64_t y_height, int64_t y_width, int64_t x_height, int64_t x_width, int64_t pad_left,
+    int64_t pad_top) {
+  int64_t dest_num = n_channel * y_height * y_width;
+  int64_t src_num = n_channel * x_height * x_width;
+  int64_t elem_num = n_batch * dest_num;
+  DoCUDAReplicationPad2d<half>
+      <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          reinterpret_cast<const half *>(src), reinterpret_cast<half *>(dest), index_helper,
+          elem_num, src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top);
+}
+
+template<typename IN_T>
+struct ReplicationPad2dGradFunctor<DeviceType::kGPU, IN_T> final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                  int64_t dx_width, int64_t pad_left, int64_t pad_top) {
+    int64_t dest_num = n_channel * dx_height * dx_width;
+    int64_t src_num = n_channel * dy_height * dy_width;
+    int64_t elem_num = n_batch * src_num;
+    DoCUDAReplicationPad2dGrad<IN_T>
+        <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+            src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height,
+            dx_width, pad_left, pad_top);
+  }
+};
+
+// float16 implementation
+template<>
+void ReplicationPad2dGradFunctor<DeviceType::kGPU, float16>::operator()(
+    DeviceCtx *ctx, const float16 *src, float16 *dest,
+    const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch, int64_t n_channel,
+    int64_t dy_height, int64_t dy_width, int64_t dx_height, int64_t dx_width, int64_t pad_left,
+    int64_t pad_top) {
+  int64_t dest_num = n_channel * dx_height * dx_width;
+  int64_t src_num = n_channel * dy_height * dy_width;
+  int64_t elem_num = n_batch * src_num;
+  DoCUDAReplicationPad2dGrad<half>
+      <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
+          reinterpret_cast<const half *>(src), reinterpret_cast<half *>(dest), index_helper,
+          elem_num, src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top);
+}
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD2D_FUNCTOR,
+                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
+
+OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REPLICATION_PAD2D_GRAD_FUNCTOR,
+                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU), PADDING_DATA_TYPE_GPU_SEQ);
+
+}  // namespace user_op
+}  // namespace oneflow
+
+#endif  // WITH_CUDA
diff --git a/oneflow/user/kernels/pad2d_kernels_util.h b/oneflow/user/kernels/pad2d_kernels_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..d527cd0626e32f2278dde9ac9e1c33b3a651f7f3
--- /dev/null
+++ b/oneflow/user/kernels/pad2d_kernels_util.h
@@ -0,0 +1,247 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#ifndef ONEFLOW_USER_KERNELS_PAD2D_KERNELS_UTIL_H_
+#define ONEFLOW_USER_KERNELS_PAD2D_KERNELS_UTIL_H_
+#ifdef WITH_CUDA
+#include "oneflow/core/cuda/atomic.cuh"
+#endif  // WITH_CUDA
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/ndarray/xpu_util.h"
+
+namespace oneflow {
+
+#define PADDING_DATA_TYPE_CPU_SEQ \
+  FLOATING_DATA_TYPE_SEQ          \
+  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
+
+#define PADDING_DATA_TYPE_GPU_SEQ \
+  FLOAT16_DATA_TYPE_SEQ           \
+  PADDING_DATA_TYPE_CPU_SEQ
+
+namespace user_op {
+
+template<typename T>
+struct DeviceAdd {
+  OF_DEVICE_FUNC static void Invoke(const T *x, T *y) {
+#if defined(__CUDA_ARCH__)
+    cuda::atomic::Add(y, *x);
+#else
+    *y += *x;
+#endif
+  };
+};
+
+template<DeviceType device_type, typename IN_T>
+struct ReflectionPad2dFunctor final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height,
+                  int64_t x_width, int64_t pad_left, int64_t pad_top);
+};
+
+template<DeviceType device_type, typename IN_T>
+struct ReflectionPad2dGradFunctor final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                  int64_t dx_width, int64_t pad_left, int64_t pad_top);
+};
+
+template<typename IN_T>
+OF_DEVICE_FUNC void DoReflectionPad2d(const IN_T *src, IN_T *dest,
+                                      const NdIndexOffsetHelper<int64_t, 4> &index_helper,
+                                      int64_t elem_num, int64_t src_num, int64_t dest_num,
+                                      int64_t y_height, int64_t y_width, int64_t x_height,
+                                      int64_t x_width, int64_t pad_left, int64_t pad_top) {
+  XPU_1D_KERNEL_LOOP(k, elem_num) {
+    int64_t n, c, i, j, ip_x, ip_y;
+    int64_t coord_y[4];
+    index_helper.OffsetToNdIndex(k, coord_y);
+    n = coord_y[0];
+    c = coord_y[1];
+    i = coord_y[2];
+    j = coord_y[3];
+    if (j < pad_left) {
+      ip_x = pad_left * 2 - j;
+    } else if (j >= pad_left && j < x_width + pad_left) {
+      ip_x = j;
+    } else {
+      ip_x = (x_width + pad_left - 1) * 2 - j;
+    }
+
+    if (i < pad_top) {
+      ip_y = pad_top * 2 - i;
+    } else if (i >= pad_top && i < x_height + pad_top) {
+      ip_y = i;
+    } else {
+      ip_y = (x_height + pad_top - 1) * 2 - i;
+    }
+    ip_x = ip_x - pad_left;
+    ip_y = ip_y - pad_top;
+    int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j;
+    int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x;
+    dest[dest_index] = src[src_index];
+  }
+}
+
+template<typename IN_T>
+OF_DEVICE_FUNC void DoReflectionPad2dGrad(const IN_T *src, IN_T *dest,
+                                          const NdIndexOffsetHelper<int64_t, 4> &index_helper,
+                                          int64_t elem_num, int64_t src_num, int64_t dest_num,
+                                          int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                                          int64_t dx_width, int64_t pad_left, int64_t pad_top) {
+  XPU_1D_KERNEL_LOOP(k, elem_num) {
+    int64_t n, c, i, j, ip_x, ip_y;
+    int64_t coord[4];
+    index_helper.OffsetToNdIndex(k, coord);
+    n = coord[0];
+    c = coord[1];
+    i = coord[2];
+    j = coord[3];
+    if (j < pad_left) {
+      ip_x = pad_left * 2 - j;
+    } else if (j >= pad_left && j < dx_width + pad_left) {
+      ip_x = j;
+    } else {
+      ip_x = (dx_width + pad_left - 1) * 2 - j;
+    }
+
+    if (i < pad_top) {
+      ip_y = pad_top * 2 - i;
+    } else if (i >= pad_top && i < dx_height + pad_top) {
+      ip_y = i;
+    } else {
+      ip_y = (dx_height + pad_top - 1) * 2 - i;
+    }
+    ip_x = ip_x - pad_left;
+    ip_y = ip_y - pad_top;
+
+    int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j;
+    int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x;
+    DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index);
+  }
+}
+
+// macros for functors instantiate(used by pad2d_kernels_util.cu)
+#define INSTANTIATE_REFLECTION_PAD2D_FUNCTOR(device_type_v, dtype_pair) \
+  template struct ReflectionPad2dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
+
+#define INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR(device_type_v, dtype_pair) \
+  template struct ReflectionPad2dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
+
+template<DeviceType device_type, typename IN_T>
+struct ReplicationPad2dFunctor final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height,
+                  int64_t x_width, int64_t pad_left, int64_t pad_top);
+};
+
+template<DeviceType device_type, typename IN_T>
+struct ReplicationPad2dGradFunctor final {
+  void operator()(DeviceCtx *ctx, const IN_T *src, IN_T *dest,
+                  const NdIndexOffsetHelper<int64_t, 4> &index_helper, int64_t n_batch,
+                  int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                  int64_t dx_width, int64_t pad_left, int64_t pad_top);
+};
+
+template<typename IN_T>
+OF_DEVICE_FUNC void DoReplicationPad2d(const IN_T *src, IN_T *dest,
+                                       const NdIndexOffsetHelper<int64_t, 4> &index_helper,
+                                       int64_t elem_num, int64_t src_num, int64_t dest_num,
+                                       int64_t y_height, int64_t y_width, int64_t x_height,
+                                       int64_t x_width, int64_t pad_left, int64_t pad_top) {
+  XPU_1D_KERNEL_LOOP(k, elem_num) {
+    int64_t n, c, i, j, ip_x, ip_y;
+    int64_t coord_y[4];
+    index_helper.OffsetToNdIndex(k, coord_y);
+    n = coord_y[0];
+    c = coord_y[1];
+    i = coord_y[2];
+    j = coord_y[3];
+    if (j < pad_left) {
+      ip_x = pad_left;
+    } else if (j >= pad_left && j < x_width + pad_left) {
+      ip_x = j;
+    } else {
+      ip_x = x_width + pad_left - 1;
+    }
+
+    if (i < pad_top) {
+      ip_y = pad_top;
+    } else if (i >= pad_top && i < x_height + pad_top) {
+      ip_y = i;
+    } else {
+      ip_y = x_height + pad_top - 1;
+    }
+    ip_x = ip_x - pad_left;
+    ip_y = ip_y - pad_top;
+
+    int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j;
+    int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x;
+    dest[dest_index] = src[src_index];
+  }
+}
+
+template<typename IN_T>
+OF_DEVICE_FUNC void DoReplicationPad2dGrad(const IN_T *src, IN_T *dest,
+                                           const NdIndexOffsetHelper<int64_t, 4> &index_helper,
+                                           int64_t elem_num, int64_t src_num, int64_t dest_num,
+                                           int64_t dy_height, int64_t dy_width, int64_t dx_height,
+                                           int64_t dx_width, int64_t pad_left, int64_t pad_top) {
+  XPU_1D_KERNEL_LOOP(k, elem_num) {
+    int64_t n, c, i, j, ip_x, ip_y;
+    int64_t coord[4];
+    index_helper.OffsetToNdIndex(k, coord);
+    n = coord[0];
+    c = coord[1];
+    i = coord[2];
+    j = coord[3];
+    if (j < pad_left) {
+      ip_x = pad_left;
+    } else if (j >= pad_left && j < dx_width + pad_left) {
+      ip_x = j;
+    } else {
+      ip_x = dx_width + pad_left - 1;
+    }
+
+    if (i < pad_top) {
+      ip_y = pad_top;
+    } else if (i >= pad_top && i < dx_height + pad_top) {
+      ip_y = i;
+    } else {
+      ip_y = dx_height + pad_top - 1;
+    }
+    ip_x = ip_x - pad_left;
+    ip_y = ip_y - pad_top;
+
+    int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j;
+    int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x;
+    DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index);
+  }
+}
+
+// macros for functors instantiate(used by pad2d_kernels_util.cu)
+#define INSTANTIATE_REPLICATION_PAD2D_FUNCTOR(device_type_v, dtype_pair) \
+  template struct ReplicationPad2dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
+
+#define INSTANTIATE_REPLICATION_PAD2D_GRAD_FUNCTOR(device_type_v, dtype_pair) \
+  template struct ReplicationPad2dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
+
+}  // namespace user_op
+}  // namespace oneflow
+
+#endif  // ONEFLOW_USER_KERNELS_PAD2D_KERNELS_UTIL_H_
diff --git a/oneflow/user/kernels/reflection_pad2d_kernel.cpp b/oneflow/user/kernels/reflection_pad2d_kernel.cpp
deleted file mode 100644
index e9a876781d1fc36ed39f9f2ad8910a56cee8a090..0000000000000000000000000000000000000000
--- a/oneflow/user/kernels/reflection_pad2d_kernel.cpp
+++ /dev/null
@@ -1,151 +0,0 @@
-/*
-Copyright 2020 The OneFlow Authors. All rights reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-/*
-Copyright 2020 The OneFlow Authors. All rights reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-    http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-#include "oneflow/core/framework/framework.h"
-#include "oneflow/core/device/memory_copier.h"
-#include "oneflow/core/kernel/new_kernel_util.h"
-#include "oneflow/core/common/nd_index_offset_helper.h"
-#include "oneflow/user/kernels/reflection_pad2d_kernel_util.h"
-
-namespace oneflow {
-namespace user_op {
-
-// Fill ShapeView into dim vector
-DimVector ShapeViewToDimVector(const ShapeView& tensor_shape) {
-  int64_t ndims = tensor_shape.NumAxes();
-  DimVector shape_vec(ndims);
-  for (int64_t i = 0; i < ndims; ++i) { shape_vec[i] = tensor_shape.At(i); }
-  shape_vec[ndims - 1] = shape_vec[ndims - 1];
-  return shape_vec;
-}
-
-template<DeviceType device_type, typename IN_T>
-class ReflectionPad2dKernel final : public OpKernel {
- public:
-  ReflectionPad2dKernel() = default;
-  ~ReflectionPad2dKernel() = default;
-
- private:
-  void Compute(user_op::KernelComputeContext* ctx) const override {
-    const Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
-    Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0);
-    const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
-    const int64_t ndims = x->shape().NumAxes();
-    CHECK_EQ(padding.size(), ndims);
-    const int64_t n_idx = 0;
-    const int64_t c_idx = 1;
-    const int64_t h_idx = 2;
-    const int64_t w_idx = 3;
-
-    const int64_t pad_left = padding[0];
-    const int64_t pad_top = padding[2];
-
-    const int64_t n_batch = y->shape().At(n_idx);
-    const int64_t n_channel = y->shape().At(c_idx);
-    const int64_t y_height = y->shape().At(h_idx);
-    const int64_t y_width = y->shape().At(w_idx);
-    const int64_t x_height = x->shape().At(h_idx);
-    const int64_t x_width = x->shape().At(w_idx);
-
-    IN_T* dest = y->mut_dptr<IN_T>();
-    const IN_T* src = x->dptr<IN_T>();
-    DimVector y_vector = ShapeViewToDimVector(y->shape());
-    NdIndexOffsetHelper<int64_t, 4> index_helper(y_vector.data());
-
-    ReflectionPad2dFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper, n_batch,
-                                                n_channel, y_height, y_width, x_height, x_width,
-                                                pad_left, pad_top);
-  }
-  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
-};
-
-template<DeviceType device_type, typename IN_T>
-class ReflectionPad2dGradKernel final : public OpKernel {
- public:
-  ReflectionPad2dGradKernel() = default;
-  ~ReflectionPad2dGradKernel() = default;
-
- private:
-  void Compute(KernelComputeContext* ctx) const override {
-    const Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
-    Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
-    const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
-    const int64_t ndims = dy->shape().NumAxes();
-    CHECK_EQ(padding.size(), ndims);
-
-    const int64_t n_idx = 0;
-    const int64_t c_idx = 1;
-    const int64_t h_idx = 2;
-    const int64_t w_idx = 3;
-
-    int64_t pad_left = padding[0];
-    int64_t pad_top = padding[2];
-    int64_t n_batch = dy->shape().At(n_idx);
-    int64_t n_channel = dy->shape().At(c_idx);
-    int64_t dy_height = dy->shape().At(h_idx);
-    int64_t dy_width = dy->shape().At(w_idx);
-    int64_t dx_height = dx->shape().At(h_idx);
-    int64_t dx_width = dx->shape().At(w_idx);
-
-    const IN_T* src = dy->dptr<IN_T>();
-    IN_T* dest = dx->mut_dptr<IN_T>();
-    DimVector dy_vector = ShapeViewToDimVector(dy->shape());
-    NdIndexOffsetHelper<int64_t, 4> index_helper(dy_vector.data());
-
-    size_t out_bytes_size = dx->shape().elem_cnt() * GetSizeOfDataType(dx->data_type());
-    Memset<device_type>(ctx->device_ctx(), dest, 0, out_bytes_size);
-
-    ReflectionPad2dGradFunctor<device_type, IN_T>()(ctx->device_ctx(), src, dest, index_helper,
-                                                    n_batch, n_channel, dy_height, dy_width,
-                                                    dx_height, dx_width, pad_left, pad_top);
-  }
-  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
-};
-
-#define REGISTER_REFLECTION_PAD2D_KERNELS(device, dtype)                               \
-  REGISTER_USER_KERNEL("reflection_pad2d")                                             \
-      .SetCreateFn<ReflectionPad2dKernel<device, dtype>>()                             \
-      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
-                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
-  REGISTER_USER_KERNEL("reflection_pad2d_grad")                                        \
-      .SetCreateFn<ReflectionPad2dGradKernel<device, dtype>>()                         \
-      .SetIsMatchedHob((user_op::HobDeviceTag() == device)                             \
-                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
-
-#define REGISTER_REFLECTION_PAD2D_WITH_DEVICE(device) \
-  REGISTER_REFLECTION_PAD2D_KERNELS(device, float)    \
-  REGISTER_REFLECTION_PAD2D_KERNELS(device, double)   \
-  REGISTER_REFLECTION_PAD2D_KERNELS(device, int32_t)
-
-REGISTER_REFLECTION_PAD2D_WITH_DEVICE(DeviceType::kCPU)
-#ifdef WITH_CUDA
-REGISTER_REFLECTION_PAD2D_WITH_DEVICE(DeviceType::kGPU)
-REGISTER_REFLECTION_PAD2D_KERNELS(DeviceType::kGPU, float16)
-#endif
-
-}  // namespace user_op
-}  // namespace oneflow
\ No newline at end of file
diff --git a/oneflow/user/kernels/reflection_pad2d_kernel_util.cpp b/oneflow/user/kernels/reflection_pad2d_kernel_util.cpp
deleted file mode 100644
index 73ec583430c10eb9ebc1ea2aca66c3393714385b..0000000000000000000000000000000000000000
--- a/oneflow/user/kernels/reflection_pad2d_kernel_util.cpp
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
-Copyright 2020 The OneFlow Authors. All rights reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-#include "oneflow/core/framework/framework.h"
-#include "oneflow/user/kernels/reflection_pad2d_kernel_util.h"
-
-namespace oneflow {
-
-namespace user_op {
-
-template<typename IN_T>
-struct ReflectionPad2dFunctor<DeviceType::kCPU, IN_T> final {
-  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
-                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch,
-                  int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height,
-                  int64_t x_width, int64_t pad_left, int64_t pad_top) {
-    int64_t dest_num = n_channel * y_height * y_width;
-    int64_t src_num = n_channel * x_height * x_width;
-    int64_t elem_num = n_batch * dest_num;
-    DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width,
-                            x_height, x_width, pad_left, pad_top);
-  }
-};
-
-template<typename IN_T>
-struct ReflectionPad2dGradFunctor<DeviceType::kCPU, IN_T> final {
-  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
-                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch,
-                  int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height,
-                  int64_t dx_width, int64_t pad_left, int64_t pad_top) {
-    int64_t dest_num = n_channel * dx_height * dx_width;
-    int64_t src_num = n_channel * dy_height * dy_width;
-    int64_t elem_num = n_batch * src_num;
-    DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,
-                                dy_width, dx_height, dx_width, pad_left, pad_top);
-  }
-};
-
-OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_FUNCTOR, (DeviceType::kCPU),
-                                 REFLECTION_PAD2D_DATA_TYPE_CPU_SEQ);
-
-OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR, (DeviceType::kCPU),
-                                 REFLECTION_PAD2D_GRAD_DATA_TYPE_CPU_SEQ);
-
-}  // namespace user_op
-}  // namespace oneflow
diff --git a/oneflow/user/kernels/reflection_pad2d_kernel_util.cu b/oneflow/user/kernels/reflection_pad2d_kernel_util.cu
deleted file mode 100644
index 9c36b88061712b134f77b1e551aa5c3b80d7a8eb..0000000000000000000000000000000000000000
--- a/oneflow/user/kernels/reflection_pad2d_kernel_util.cu
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
-Copyright 2020 The OneFlow Authors. All rights reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-#include <cstdint>
-#ifdef WITH_CUDA
-#include "oneflow/core/framework/framework.h"
-#include "oneflow/core/common/data_type.h"
-#include "oneflow/user/kernels/reflection_pad2d_kernel_util.h"
-
-namespace oneflow {
-namespace user_op {
-
-template<typename IN_T>
-__global__ void DoCUDAReflectionPad2d(const IN_T* src, IN_T* dest,
-                                      const NdIndexOffsetHelper<int64_t, 4> index_helper,
-                                      int64_t elem_num, int64_t src_num, int64_t dest_num,
-                                      int64_t y_height, int64_t y_width, int64_t x_height,
-                                      int64_t x_width, int64_t pad_left, int64_t pad_top) {
-  DoReflectionPad2d<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width,
-                          x_height, x_width, pad_left, pad_top);
-};
-
-template<typename IN_T>
-__global__ void DoCUDAReflectionPad2dGrad(const IN_T* src, IN_T* dest,
-                                          const NdIndexOffsetHelper<int64_t, 4> index_helper,
-                                          int64_t elem_num, int64_t src_num, int64_t dest_num,
-                                          int64_t dy_height, int64_t dy_width, int64_t dx_height,
-                                          int64_t dx_width, int64_t pad_left, int64_t pad_top) {
-  DoReflectionPad2dGrad<IN_T>(src, dest, index_helper, elem_num, src_num, dest_num, dy_height,
-                              dy_width, dx_height, dx_width, pad_left, pad_top);
-};
-
-template<typename IN_T>
-struct ReflectionPad2dFunctor<DeviceType::kGPU, IN_T> final {
-  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
-                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch,
-                  int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height,
-                  int64_t x_width, int64_t pad_left, int64_t pad_top) {
-    int64_t dest_num = n_channel * y_height * y_width;
-    int64_t src_num = n_channel * x_height * x_width;
-    int64_t elem_num = n_batch * dest_num;
-    DoCUDAReflectionPad2d<IN_T>
-        <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
-            src, dest, index_helper, elem_num, src_num, dest_num, y_height, y_width, x_height,
-            x_width, pad_left, pad_top);
-  }
-};
-
-// float16 implementation
-template<>
-void ReflectionPad2dFunctor<DeviceType::kGPU, float16>::operator()(
-    DeviceCtx* ctx, const float16* src, float16* dest,
-    const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, int64_t n_channel,
-    int64_t y_height, int64_t y_width, int64_t x_height, int64_t x_width, int64_t pad_left,
-    int64_t pad_top) {
-  int64_t dest_num = n_channel * y_height * y_width;
-  int64_t src_num = n_channel * x_height * x_width;
-  int64_t elem_num = n_batch * dest_num;
-  DoCUDAReflectionPad2d<half>
-      <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
-          reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,
-          src_num, dest_num, y_height, y_width, x_height, x_width, pad_left, pad_top);
-}
-
-template<typename IN_T>
-struct ReflectionPad2dGradFunctor<DeviceType::kGPU, IN_T> final {
-  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
-                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch,
-                  int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height,
-                  int64_t dx_width, int64_t pad_left, int64_t pad_top) {
-    int64_t dest_num = n_channel * dx_height * dx_width;
-    int64_t src_num = n_channel * dy_height * dy_width;
-    int64_t elem_num = n_batch * src_num;
-    DoCUDAReflectionPad2dGrad<IN_T>
-        <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
-            src, dest, index_helper, elem_num, src_num, dest_num, dy_height, dy_width, dx_height,
-            dx_width, pad_left, pad_top);
-  }
-};
-
-// float16 implementation
-template<>
-void ReflectionPad2dGradFunctor<DeviceType::kGPU, float16>::operator()(
-    DeviceCtx* ctx, const float16* src, float16* dest,
-    const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch, int64_t n_channel,
-    int64_t dy_height, int64_t dy_width, int64_t dx_height, int64_t dx_width, int64_t pad_left,
-    int64_t pad_top) {
-  int64_t dest_num = n_channel * dx_height * dx_width;
-  int64_t src_num = n_channel * dy_height * dy_width;
-  int64_t elem_num = n_batch * src_num;
-  DoCUDAReflectionPad2dGrad<half>
-      <<<BlocksNum4ThreadsNum(elem_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
-          reinterpret_cast<const half*>(src), reinterpret_cast<half*>(dest), index_helper, elem_num,
-          src_num, dest_num, dy_height, dy_width, dx_height, dx_width, pad_left, pad_top);
-}
-
-OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_FUNCTOR,
-                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU),
-                                 REFLECTION_PAD2D_DATA_TYPE_GPU_SEQ);
-
-OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR,
-                                 OF_PP_MAKE_TUPLE_SEQ(DeviceType::kGPU),
-                                 REFLECTION_PAD2D_GRAD_DATA_TYPE_GPU_SEQ);
-
-}  // namespace user_op
-}  // namespace oneflow
-
-#endif  // WITH_CUDA
\ No newline at end of file
diff --git a/oneflow/user/kernels/reflection_pad2d_kernel_util.h b/oneflow/user/kernels/reflection_pad2d_kernel_util.h
deleted file mode 100644
index 3fe1ccf52cbe8e3be676720ef9ce7492e1b8c2e0..0000000000000000000000000000000000000000
--- a/oneflow/user/kernels/reflection_pad2d_kernel_util.h
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
-Copyright 2020 The OneFlow Authors. All rights reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-#ifndef ONEFLOW_USER_KERNELS_REFLECTION_PAD2D_KERNEL_UTIL_H_
-#define ONEFLOW_USER_KERNELS_REFLECTION_PAD2D_KERNEL_UTIL_H_
-#ifdef WITH_CUDA
-#include "oneflow/core/cuda/atomic.cuh"
-#endif  // WITH_CUDA
-#include "oneflow/core/ndarray/xpu_util.h"
-#include "oneflow/core/common/nd_index_offset_helper.h"
-
-namespace oneflow {
-
-#define REFLECTION_PAD2D_DATA_TYPE_CPU_SEQ \
-  FLOATING_DATA_TYPE_SEQ                   \
-  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
-
-#define REFLECTION_PAD2D_DATA_TYPE_GPU_SEQ \
-  FLOAT16_DATA_TYPE_SEQ                    \
-  REFLECTION_PAD2D_DATA_TYPE_CPU_SEQ
-
-#define REFLECTION_PAD2D_GRAD_DATA_TYPE_CPU_SEQ \
-  FLOATING_DATA_TYPE_SEQ                        \
-  OF_PP_MAKE_TUPLE_SEQ(int32_t, DataType::kInt32)
-
-#define REFLECTION_PAD2D_GRAD_DATA_TYPE_GPU_SEQ \
-  FLOAT16_DATA_TYPE_SEQ                         \
-  REFLECTION_PAD2D_GRAD_DATA_TYPE_CPU_SEQ
-
-namespace user_op {
-
-template<typename T>
-struct DeviceAdd {
-  OF_DEVICE_FUNC static void Invoke(const T* x, T* y) {
-#if defined(__CUDA_ARCH__)
-    cuda::atomic::Add(y, *x);
-#else
-    *y += *x;
-#endif
-  };
-};
-
-template<DeviceType device_type, typename IN_T>
-struct ReflectionPad2dFunctor final {
-  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
-                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch,
-                  int64_t n_channel, int64_t y_height, int64_t y_width, int64_t x_height,
-                  int64_t x_width, int64_t pad_left, int64_t pad_top);
-};
-
-template<DeviceType device_type, typename IN_T>
-struct ReflectionPad2dGradFunctor final {
-  void operator()(DeviceCtx* ctx, const IN_T* src, IN_T* dest,
-                  const NdIndexOffsetHelper<int64_t, 4>& index_helper, int64_t n_batch,
-                  int64_t n_channel, int64_t dy_height, int64_t dy_width, int64_t dx_height,
-                  int64_t dx_width, int64_t pad_left, int64_t pad_top);
-};
-
-template<typename IN_T>
-OF_DEVICE_FUNC void DoReflectionPad2d(const IN_T* src, IN_T* dest,
-                                      const NdIndexOffsetHelper<int64_t, 4>& index_helper,
-                                      int64_t elem_num, int64_t src_num, int64_t dest_num,
-                                      int64_t y_height, int64_t y_width, int64_t x_height,
-                                      int64_t x_width, int64_t pad_left, int64_t pad_top) {
-  XPU_1D_KERNEL_LOOP(k, elem_num) {
-    int64_t n, c, i, j, ip_x, ip_y;
-    int64_t coord_y[4];
-    index_helper.OffsetToNdIndex(k, coord_y);
-    n = coord_y[0];
-    c = coord_y[1];
-    i = coord_y[2];
-    j = coord_y[3];
-    if (j < pad_left) {
-      ip_x = pad_left * 2 - j;
-    } else if (j >= pad_left && j < x_width + pad_left) {
-      ip_x = j;
-    } else {
-      ip_x = (x_width + pad_left - 1) * 2 - j;
-    }
-
-    if (i < pad_top) {
-      ip_y = pad_top * 2 - i;
-    } else if (i >= pad_top && i < x_height + pad_top) {
-      ip_y = i;
-    } else {
-      ip_y = (x_height + pad_top - 1) * 2 - i;
-    }
-    ip_x = ip_x - pad_left;
-    ip_y = ip_y - pad_top;
-    int64_t dest_index = n * dest_num + c * y_width * y_height + i * y_width + j;
-    int64_t src_index = n * src_num + c * x_width * x_height + ip_y * x_width + ip_x;
-    dest[dest_index] = src[src_index];
-  }
-}
-
-template<typename IN_T>
-OF_DEVICE_FUNC void DoReflectionPad2dGrad(const IN_T* src, IN_T* dest,
-                                          const NdIndexOffsetHelper<int64_t, 4>& index_helper,
-                                          int64_t elem_num, int64_t src_num, int64_t dest_num,
-                                          int64_t dy_height, int64_t dy_width, int64_t dx_height,
-                                          int64_t dx_width, int64_t pad_left, int64_t pad_top) {
-  XPU_1D_KERNEL_LOOP(k, elem_num) {
-    int64_t n, c, i, j, ip_x, ip_y;
-    int64_t coord[4];
-    index_helper.OffsetToNdIndex(k, coord);
-    n = coord[0];
-    c = coord[1];
-    i = coord[2];
-    j = coord[3];
-    if (j < pad_left) {
-      ip_x = pad_left * 2 - j;
-    } else if (j >= pad_left && j < dx_width + pad_left) {
-      ip_x = j;
-    } else {
-      ip_x = (dx_width + pad_left - 1) * 2 - j;
-    }
-
-    if (i < pad_top) {
-      ip_y = pad_top * 2 - i;
-    } else if (i >= pad_top && i < dx_height + pad_top) {
-      ip_y = i;
-    } else {
-      ip_y = (dx_height + pad_top - 1) * 2 - i;
-    }
-    ip_x = ip_x - pad_left;
-    ip_y = ip_y - pad_top;
-
-    int64_t src_index = n * src_num + c * dy_width * dy_height + i * dy_width + j;
-    int64_t dest_index = n * dest_num + c * dx_width * dx_height + ip_y * dx_width + ip_x;
-    DeviceAdd<IN_T>::Invoke(src + src_index, dest + dest_index);
-  }
-}
-
-// macros for functors instantiate(used by reflection_pad2d_kernel_util.cu)
-#define INSTANTIATE_REFLECTION_PAD2D_FUNCTOR(device_type_v, dtype_pair) \
-  template struct ReflectionPad2dFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
-
-#define INSTANTIATE_REFLECTION_PAD2D_GRAD_FUNCTOR(device_type_v, dtype_pair) \
-  template struct ReflectionPad2dGradFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>;
-
-}  // namespace user_op
-}  // namespace oneflow
-
-#endif  // ONEFLOW_USER_KERNELS_REFLECTION_PAD2D_KERNEL_UTIL_H_
\ No newline at end of file
diff --git a/oneflow/user/ops/pad2d_ops.cpp b/oneflow/user/ops/pad2d_ops.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..05d5e12a4bd8de53fbfa8f946d605d0c7e0b88c7
--- /dev/null
+++ b/oneflow/user/ops/pad2d_ops.cpp
@@ -0,0 +1,134 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/common/balanced_splitter.h"
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/user/ops/nn_util.h"
+#include "pad_2d_seq.h"
+
+namespace oneflow {
+
+namespace {
+
+Maybe<void> GetOpSbpSignature(user_op::SbpContext *ctx) {
+  const user_op::TensorDesc &x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
+  const auto &padding = ctx->Attr<std::vector<int64_t>>("padding");
+  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {
+    if (padding[i] == 0) {
+      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();
+    }
+  }
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> GetOpGradSbpSignature(user_op::SbpContext *ctx) {
+  const user_op::TensorDesc &dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0);
+  const auto &padding = ctx->Attr<std::vector<int64_t>>("padding");
+  FOR_RANGE(int64_t, i, 0, dy_tensor.shape().NumAxes()) {
+    if (padding[i] == 0) {
+      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();
+    }
+  }
+  return Maybe<void>::Ok();
+}
+
+}  // namespace
+
+#define REGISTER_PAD_2D_OP_AND_GRAD(pad_2d_type)                                             \
+  REGISTER_USER_OP(pad_2d_type)                                                              \
+      .Input("x")                                                                            \
+      .Output("y")                                                                           \
+      .Attr<std::vector<int64_t>>("padding")                                                 \
+      .SetTensorDescInferFn([](user_op::InferContext *ctx) -> Maybe<void> {                  \
+        Shape *x_shape = ctx->Shape4ArgNameAndIndex("x", 0);                                 \
+        const auto &padding = ctx->Attr<std::vector<int64_t>>("padding");                    \
+        CHECK_EQ_OR_RETURN(padding.size(), x_shape->NumAxes());                              \
+        const int64_t n_idx = 0;                                                             \
+        const int64_t c_idx = 1;                                                             \
+        const int64_t h_idx = 2;                                                             \
+        const int64_t w_idx = 3;                                                             \
+        CHECK_LT_OR_RETURN(padding[0], x_shape->At(w_idx));                                  \
+        CHECK_LT_OR_RETURN(padding[1], x_shape->At(w_idx));                                  \
+        CHECK_LT_OR_RETURN(padding[2], x_shape->At(h_idx));                                  \
+        CHECK_LT_OR_RETURN(padding[3], x_shape->At(h_idx));                                  \
+                                                                                             \
+        DimVector y_dim_vec(x_shape->NumAxes());                                             \
+        const int64_t h_x = x_shape->At(h_idx);                                              \
+        const int64_t w_x = x_shape->At(w_idx);                                              \
+                                                                                             \
+        y_dim_vec[n_idx] = x_shape->At(n_idx);                                               \
+        y_dim_vec[c_idx] = x_shape->At(c_idx);                                               \
+        y_dim_vec[h_idx] = h_x + padding[2] + padding[3];                                    \
+        y_dim_vec[w_idx] = w_x + padding[0] + padding[1];                                    \
+                                                                                             \
+        *ctx->Shape4ArgNameAndIndex("y", 0) = Shape(y_dim_vec);                              \
+        *ctx->Dtype4ArgNameAndIndex("y", 0) = *ctx->Dtype4ArgNameAndIndex("x", 0);           \
+        return Maybe<void>::Ok();                                                            \
+      })                                                                                     \
+      .SetGetSbpFn(GetOpSbpSignature)                                                        \
+      .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn,            \
+                              const user_op::UserOpConfWrapper &) {                          \
+        user_op::InputArgModifier *x_modifier = GetInputArgModifierFn("x", 0);               \
+        CHECK_NOTNULL(x_modifier);                                                           \
+        x_modifier->set_requires_grad(true);                                                 \
+      });                                                                                    \
+                                                                                             \
+  REGISTER_USER_OP((std::string("") + pad_2d_type + "_grad"))                                \
+      .Input("dy")                                                                           \
+      .Output("dx")                                                                          \
+      .Attr<std::vector<int64_t>>("padding")                                                 \
+      .SetTensorDescInferFn([](user_op::InferContext *ctx) -> Maybe<void> {                  \
+        Shape *dy_shape = ctx->Shape4ArgNameAndIndex("dy", 0);                               \
+        const auto &padding = ctx->Attr<std::vector<int64_t>>("padding");                    \
+        CHECK_EQ_OR_RETURN(padding.size(), dy_shape->NumAxes());                             \
+        const int64_t n_idx = 0;                                                             \
+        const int64_t c_idx = 1;                                                             \
+        const int64_t h_idx = 2;                                                             \
+        const int64_t w_idx = 3;                                                             \
+                                                                                             \
+        DimVector dx_dim_vec(dy_shape->NumAxes());                                           \
+        int64_t h_dy, w_dy;                                                                  \
+        h_dy = dy_shape->At(h_idx);                                                          \
+        w_dy = dy_shape->At(w_idx);                                                          \
+                                                                                             \
+        dx_dim_vec[n_idx] = dy_shape->At(0);                                                 \
+        dx_dim_vec[c_idx] = dy_shape->At(1);                                                 \
+        dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3];                                  \
+        dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1];                                  \
+                                                                                             \
+        *ctx->Shape4ArgNameAndIndex("dx", 0) = Shape(dx_dim_vec);                            \
+        *ctx->Dtype4ArgNameAndIndex("dx", 0) = *ctx->Dtype4ArgNameAndIndex("dy", 0);         \
+        return Maybe<void>::Ok();                                                            \
+      })                                                                                     \
+      .SetGetSbpFn(GetOpGradSbpSignature);                                                   \
+                                                                                             \
+  REGISTER_USER_OP_GRAD(pad_2d_type)                                                         \
+      .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper &op, user_op::AddOpFn AddOp) { \
+        if (op.NeedGenGradTensor4OpInput("x", 0)) {                                          \
+          user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");                 \
+          user_op::UserOpConfWrapper grad_op =                                               \
+              builder.Op((std::string("") + pad_2d_type + "_grad"))                          \
+                  .Input("dy", op.GetGradTensorWithOpOutput("y", 0))                         \
+                  .Output("dx")                                                              \
+                  .Attr("padding", op.attr<std::vector<int64_t>>("padding"))                 \
+                  .Build();                                                                  \
+          op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);                     \
+          AddOp(grad_op);                                                                    \
+        }                                                                                    \
+      });
+
+OF_PP_FOR_EACH_TUPLE(REGISTER_PAD_2D_OP_AND_GRAD, PAD_2D_TYPE_SEQ)
+
+}  // namespace oneflow
diff --git a/oneflow/user/ops/pad_2d_seq.h b/oneflow/user/ops/pad_2d_seq.h
new file mode 100644
index 0000000000000000000000000000000000000000..fb750a1e4b9eaac5c80b8697ecc8784eba0fbee0
--- /dev/null
+++ b/oneflow/user/ops/pad_2d_seq.h
@@ -0,0 +1,28 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#ifndef ONEFLOW_USER_OPS_PAD_2D_SEQ_H_
+#define ONEFLOW_USER_OPS_PAD_2D_SEQ_H_
+
+#include "oneflow/core/common/util.h"
+
+namespace oneflow {
+
+#define PAD_2D_TYPE_SEQ                    \
+  OF_PP_MAKE_TUPLE_SEQ("reflection_pad2d") \
+  OF_PP_MAKE_TUPLE_SEQ("replication_pad2d")
+}  // namespace oneflow
+
+#endif  // ONEFLOW_USER_OPS_PAD_2D_SEQ_H_
diff --git a/oneflow/user/ops/reflection_pad2d_op.cpp b/oneflow/user/ops/reflection_pad2d_op.cpp
deleted file mode 100644
index 8380e4830957377cb171f1da8d2a0d1edc30c8ee..0000000000000000000000000000000000000000
--- a/oneflow/user/ops/reflection_pad2d_op.cpp
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
-Copyright 2020 The OneFlow Authors. All rights reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-#include "oneflow/core/framework/framework.h"
-#include "oneflow/core/common/balanced_splitter.h"
-#include "oneflow/user/ops/nn_util.h"
-
-namespace oneflow {
-
-namespace {
-
-Maybe<void> GetOpSbpSignature(user_op::SbpContext* ctx) {
-  const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0);
-  const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
-  FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) {
-    if (padding[i] == 0) {
-      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();
-    }
-  }
-  return Maybe<void>::Ok();
-}
-
-Maybe<void> GetOpGradSbpSignature(user_op::SbpContext* ctx) {
-  const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0);
-  const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
-  FOR_RANGE(int64_t, i, 0, dy_tensor.shape().NumAxes()) {
-    if (padding[i] == 0) {
-      ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();
-    }
-  }
-  return Maybe<void>::Ok();
-}
-
-}  // namespace
-
-REGISTER_USER_OP("reflection_pad2d")
-    .Input("x")
-    .Output("y")
-    .Attr<std::vector<int64_t>>("padding")
-    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
-      Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0);
-      const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
-      CHECK_EQ_OR_RETURN(padding.size(), x_shape->NumAxes());
-      const int64_t n_idx = 0;
-      const int64_t c_idx = 1;
-      const int64_t h_idx = 2;
-      const int64_t w_idx = 3;
-      // (padding_left, padding_right, padding_top, padding_bottom)
-      CHECK_LT_OR_RETURN(padding[0], x_shape->At(w_idx));
-      CHECK_LT_OR_RETURN(padding[1], x_shape->At(w_idx));
-      CHECK_LT_OR_RETURN(padding[2], x_shape->At(h_idx));
-      CHECK_LT_OR_RETURN(padding[3], x_shape->At(h_idx));
-
-      DimVector y_dim_vec(x_shape->NumAxes());
-      const int64_t h_x = x_shape->At(h_idx);
-      const int64_t w_x = x_shape->At(w_idx);
-
-      y_dim_vec[n_idx] = x_shape->At(n_idx);
-      y_dim_vec[c_idx] = x_shape->At(c_idx);
-      y_dim_vec[h_idx] = h_x + padding[2] + padding[3];
-      y_dim_vec[w_idx] = w_x + padding[0] + padding[1];
-
-      *ctx->Shape4ArgNameAndIndex("y", 0) = Shape(y_dim_vec);
-      *ctx->Dtype4ArgNameAndIndex("y", 0) = *ctx->Dtype4ArgNameAndIndex("x", 0);
-      return Maybe<void>::Ok();
-    })
-    .SetGetSbpFn(GetOpSbpSignature)
-    .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn,
-                            const user_op::UserOpConfWrapper&) {
-      user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0);
-      CHECK_NOTNULL(x_modifier);
-      x_modifier->set_requires_grad(true);
-    });
-
-REGISTER_USER_OP("reflection_pad2d_grad")
-    .Input("dy")
-    .Output("dx")
-    .Attr<std::vector<int64_t>>("padding")
-    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
-      Shape* dy_shape = ctx->Shape4ArgNameAndIndex("dy", 0);
-      const auto& padding = ctx->Attr<std::vector<int64_t>>("padding");
-      CHECK_EQ_OR_RETURN(padding.size(), dy_shape->NumAxes());
-      const int64_t n_idx = 0;
-      const int64_t c_idx = 1;
-      const int64_t h_idx = 2;
-      const int64_t w_idx = 3;
-
-      DimVector dx_dim_vec(dy_shape->NumAxes());
-      int64_t h_dy, w_dy;
-      h_dy = dy_shape->At(h_idx);
-      w_dy = dy_shape->At(w_idx);
-
-      dx_dim_vec[n_idx] = dy_shape->At(0);
-      dx_dim_vec[c_idx] = dy_shape->At(1);
-      dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3];
-      dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1];
-
-      *ctx->Shape4ArgNameAndIndex("dx", 0) = Shape(dx_dim_vec);
-      *ctx->Dtype4ArgNameAndIndex("dx", 0) = *ctx->Dtype4ArgNameAndIndex("dy", 0);
-      return Maybe<void>::Ok();
-    })
-    .SetGetSbpFn(GetOpGradSbpSignature);
-
-REGISTER_USER_OP_GRAD("reflection_pad2d")
-    .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
-      if (op.NeedGenGradTensor4OpInput("x", 0)) {
-        user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
-        user_op::UserOpConfWrapper grad_op =
-            builder.Op("reflection_pad2d_grad")
-                .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
-                .Output("dx")
-                .Attr("padding", op.attr<std::vector<int64_t>>("padding"))
-                .Build();
-        op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
-        AddOp(grad_op);
-      }
-    });
-
-}  // namespace oneflow
\ No newline at end of file