From 2b208ec02d5ee77cdf7e960b7287a24386da41a2 Mon Sep 17 00:00:00 2001
From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Date: Tue, 20 Jul 2021 15:29:00 +0800
Subject: [PATCH] Add flip module (#5541)

* fix upsample nearest bug

* fix upsample nearest bug (#5347)

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* fix upsample bilinear bug

* init flip op

* add flip op register

* add flip cpu kernel forward

* add flip kernel impl

* add flip op functor and gradient_funcs

* add test, still need fix bug

* fix segmentfault bug

* add docs

* fix comments

* fix comments

* fix comments

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 docs/source/experimental.rst                  |   2 +
 oneflow/core/autograd/gradient_funcs/flip.cpp |  69 +++++++++
 oneflow/core/functional/functional_api.yaml   |   8 ++
 .../core/functional/impl/array_functor.cpp    |  32 +++++
 oneflow/python/nn/modules/flip.py             | 100 +++++++++++++
 oneflow/python/test/modules/test_flip.py      |  53 +++++++
 oneflow/user/kernels/flip_kernel.cpp          | 132 ++++++++++++++++++
 oneflow/user/kernels/flip_kernel.cu           | 132 ++++++++++++++++++
 oneflow/user/ops/flip_op.cpp                  |  63 +++++++++
 9 files changed, 591 insertions(+)
 create mode 100644 oneflow/core/autograd/gradient_funcs/flip.cpp
 create mode 100644 oneflow/python/nn/modules/flip.py
 create mode 100644 oneflow/python/test/modules/test_flip.py
 create mode 100644 oneflow/user/kernels/flip_kernel.cpp
 create mode 100644 oneflow/user/kernels/flip_kernel.cu
 create mode 100644 oneflow/user/ops/flip_op.cpp

diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst
index e58b0d9c1..6aacefedd 100644
--- a/docs/source/experimental.rst
+++ b/docs/source/experimental.rst
@@ -194,6 +194,8 @@ Experimental features
 .. autofunction:: oneflow.experimental.Tensor.reshape
 .. autofunction:: oneflow.experimental.squeeze
 .. autofunction:: oneflow.experimental.Tensor.squeeze
+.. autofunction:: oneflow.experimental.flip
+.. autofunction:: oneflow.experimental.Tensor.flip
 .. autofunction:: oneflow.experimental.transpose
 .. autofunction:: oneflow.experimental.Tensor.transpose
 .. autofunction:: oneflow.experimental.unsqueeze
diff --git a/oneflow/core/autograd/gradient_funcs/flip.cpp b/oneflow/core/autograd/gradient_funcs/flip.cpp
new file mode 100644
index 000000000..1022e4900
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/flip.cpp
@@ -0,0 +1,69 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/attr_map.h"
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/functional/functional.h"
+
+namespace oneflow {
+namespace one {
+
+struct FlipInterpState : public OpExprInterpState {
+  bool requires_grad;
+  std::vector<int32_t> dims;
+};
+
+class Flip : public OpExprGradFunction<FlipInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override;
+  Maybe<void> Capture(FlipInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
+                      const AttrMap& attrs) const override;
+  Maybe<void> Apply(const FlipInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override;
+
+ private:
+  AttrMap base_attrs_;
+};
+
+Maybe<void> Flip::Init(const OpExpr& op) {
+  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
+  CHECK_NOTNULL_OR_RETURN(fw_op_expr);
+  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> Flip::Capture(FlipInterpState* ctx, const TensorTuple& inputs,
+                          const TensorTuple& outputs, const AttrMap& attrs) const {
+  ctx->requires_grad = inputs.at(0)->requires_grad();
+  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+  ComposedAttrMap composed_attrs(attrs, base_attrs_);
+  ctx->dims = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dims"));
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> Flip::Apply(const FlipInterpState* ctx, const TensorTuple& out_grads,
+                        TensorTuple* in_grads) const {
+  CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+  in_grads->resize(1);
+  if (ctx->requires_grad) {
+    in_grads->at(0) = JUST(functional::FlipGrad(out_grads.at(0), ctx->dims));
+  }
+  return Maybe<void>::Ok();
+}
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("flip", Flip);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index a936c88fd..5c1a6f5ab 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -470,6 +470,14 @@
   signature: "Tensor Copy(Tensor x, *, String device_type, Int64 device_id)"
   bind_python: True
 
+- name: "flip"
+  signature: "Tensor Flip(Tensor x, *, Int32List dims)"
+  bind_python: True
+
+- name: "flip_grad"
+  signature: "Tensor FlipGrad(Tensor dy, *, Int32List dims)"
+  bind_python: False
+
 - name: "upsample"
   signature:
     "Tensor Upsample(Tensor x, *, Float height_scale, Float width_scale, Bool align_corners,
diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp
index e7d913b2c..b9cd3f395 100644
--- a/oneflow/core/functional/impl/array_functor.cpp
+++ b/oneflow/core/functional/impl/array_functor.cpp
@@ -436,6 +436,36 @@ class CopyFunctor {
   std::shared_ptr<OpExpr> op_;
 };
 
+class FlipFunctor {
+ public:
+  FlipFunctor() { op_ = CHECK_JUST(one::OpBuilder("flip").Input("x").Output("y").Build()); }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
+                           const std::vector<int32_t>& dims) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<std::vector<int32_t>>("dims", dims));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class FlipGradFunctor {
+ public:
+  FlipGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("flip_grad").Input("dy").Output("dx").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::vector<int32_t>& dims) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<std::vector<int32_t>>("dims", dims));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
 class UpsampleFunctor {
  public:
   UpsampleFunctor() { op_ = CHECK_JUST(one::OpBuilder("upsample").Input("x").Output("y").Build()); }
@@ -923,6 +953,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::SliceUpdateFunctor>("SliceUpdate");
   m.add_functor<impl::SqueezeFunctor>("Squeeze");
   m.add_functor<impl::CopyFunctor>("Copy");
+  m.add_functor<impl::FlipFunctor>("Flip");
+  m.add_functor<impl::FlipGradFunctor>("FlipGrad");
   m.add_functor<impl::UpsampleFunctor>("Upsample");
   m.add_functor<impl::UpsampleNearest2DFunctor>("UpsampleNearest2D");
   m.add_functor<impl::UpsampleNearest2DGradFunctor>("UpsampleNearest2DGrad");
diff --git a/oneflow/python/nn/modules/flip.py b/oneflow/python/nn/modules/flip.py
new file mode 100644
index 000000000..6c61a6b83
--- /dev/null
+++ b/oneflow/python/nn/modules/flip.py
@@ -0,0 +1,100 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import collections
+from typing import Optional, Sequence, Union
+
+import oneflow as flow
+from oneflow.python.oneflow_export import oneflow_export, experimental_api
+from oneflow.python.nn.module import Module
+from oneflow.python.framework.tensor import register_tensor_op
+from oneflow.python.nn.modules.utils import _check_axis
+
+
+class Flip(Module):
+    def __init__(self, dims) -> None:
+        super().__init__()
+        assert isinstance(dims, (list, tuple)), f"dims must be list or tuple"
+        self.dims = dims
+
+    def forward(self, x):
+        input_len = len(x.shape)
+        assert (
+            len(self.dims) <= input_len
+        ), f"len of dims must less than len of input tensor"
+        new_dims = []
+        for i in self.dims:
+            if i < 0:
+                i += input_len
+            assert (
+                i < input_len
+            ), f"IndexError: Dimension out of range (expected to be in range of {input_len}, but got {i})"
+            new_dims.append(i)
+        return flow.F.flip(x, new_dims)
+
+
+@oneflow_export("flip")
+@experimental_api
+def flip_op(input, dims):
+
+    r"""
+    
+    Reverse the order of a n-D tensor along given axis in dims.
+
+    .. note::
+        `flow.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`,
+        which returns a view in constant time. Since copying a tensor's data is more work than viewing that data,
+        `flow.flip` is expected to be slower than `np.flip`.
+
+    Args:
+        input (Tensor): the input tensor
+        dims (a list or tuple): axis to flip on
+        
+    For example:
+
+    .. code-block:: python
+
+        >>> import oneflow.experimental as flow
+        >>> import numpy as np
+        
+        >>> np_arr = np.arange(0, 8).reshape((2, 2, 2)).astype(np.float32)
+        >>> input = flow.Tensor(np_arr)
+        >>> out = flow.flip(input, [0, 1])
+        >>> out
+        tensor([[[6., 7.],
+                 [4., 5.]],
+        <BLANKLINE>
+                [[2., 3.],
+                 [0., 1.]]], dtype=oneflow.float32)
+
+    """
+
+    return Flip(dims)(input)
+
+
+@register_tensor_op("flip")
+@experimental_api
+def flip_op_tensor(input, dims):
+    r"""
+    See :func:`oneflow.experimental.flip`
+    """
+    return Flip(dims)(input)
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod(raise_on_error=True)
diff --git a/oneflow/python/test/modules/test_flip.py b/oneflow/python/test/modules/test_flip.py
new file mode 100644
index 000000000..a5ad28f21
--- /dev/null
+++ b/oneflow/python/test/modules/test_flip.py
@@ -0,0 +1,53 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+
+import oneflow.experimental as flow
+from test_util import GenArgList
+from automated_test_util import *
+
+
+def _test_flip(test_case, device):
+    np_arr = np.arange(0, 16).reshape((2, 2, 2, 2)).astype(np.float32)
+    input = flow.Tensor(np_arr, device=flow.device(device), requires_grad=True)
+    out = flow.flip(input, [0, 1, 2])
+    np_out = [
+        [[[14.0, 15.0], [12.0, 13.0]], [[10.0, 11.0], [8.0, 9.0]]],
+        [[[6.0, 7.0], [4.0, 5.0]], [[2.0, 3.0], [0.0, 1.0]]],
+    ]
+    test_case.assertTrue(np.allclose(out.numpy(), np_out, 1e-5, 1e-5))
+    out = out.sum()
+    out = out.backward()
+    np_grad = np.ones_like(np_arr)
+    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+
+class TestFlip(flow.unittest.TestCase):
+    def test_flip(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_flip,
+        ]
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/user/kernels/flip_kernel.cpp b/oneflow/user/kernels/flip_kernel.cpp
new file mode 100644
index 000000000..89ad832cb
--- /dev/null
+++ b/oneflow/user/kernels/flip_kernel.cpp
@@ -0,0 +1,132 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+
+namespace oneflow {
+
+namespace {
+
+const int32_t NDIMS = 16;
+
+struct SIZE_V {
+  int32_t val[NDIMS];
+};
+
+struct VIS {
+  bool val[NDIMS] = {false};
+};
+
+template<typename T>
+void FlipCpuForward(const int32_t element, const int64_t total_dims,
+                    const SIZE_V stride_contiguous_v, const SIZE_V sizes_v, const VIS vis,
+                    SIZE_V strides_v, const T* in_dptr, T* out_dptr) {
+  for (int i = 0; i < element; i++) {
+    int32_t cur_indices = i;
+    int32_t rem = 0;
+    int32_t dst_offset = 0;
+    for (int32_t d = 0; d < total_dims; d++) {
+      int32_t temp = cur_indices;
+      cur_indices = cur_indices / stride_contiguous_v.val[d];
+      rem = temp - cur_indices * stride_contiguous_v.val[d];
+      dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d]
+                               : cur_indices * strides_v.val[d];
+      cur_indices = rem;
+    }
+    out_dptr[i] = in_dptr[dst_offset];
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class FlipCpuKernel final : public user_op::OpKernel {
+ public:
+  FlipCpuKernel() = default;
+  ~FlipCpuKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const int32_t elem_cnt = y_tensor->shape().elem_cnt();
+
+    const int32_t total_dims = y_tensor->shape().NumAxes();
+
+    std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>("dims");
+    VIS vis;
+    for (auto x : dims) { vis.val[x] = true; }
+
+    SIZE_V sizes_v;
+    for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape().At(i); }
+
+    // TODO(bbuf) delete strides caluculate, after tensor strides supported
+    SIZE_V strides_v;
+    strides_v.val[total_dims - 1] = 1;
+    for (int32_t i = total_dims - 2; i >= 0; i--) {
+      strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape().At(i);
+    }
+
+    SIZE_V stride_contiguous_v;
+
+    for (int32_t i = total_dims - 1; i >= 0; i--) {
+      if (i == total_dims - 1) {
+        stride_contiguous_v.val[i] = 1;
+      } else {
+        stride_contiguous_v.val[i] =
+            std::max<int32_t>(x_tensor->shape().At(i + 1), 1) * stride_contiguous_v.val[i + 1];
+      }
+    }
+
+    FlipCpuForward(elem_cnt, total_dims, stride_contiguous_v, sizes_v, vis, strides_v,
+                   x_tensor->dptr<T>(), y_tensor->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class FlipGrad1DCpuKernel final : public user_op::OpKernel {
+ public:
+  FlipGrad1DCpuKernel() = default;
+  ~FlipGrad1DCpuKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    memcpy((void*)dx_tensor->mut_dptr<T>(), (void*)dy_tensor->dptr<T>(),
+           dy_tensor->shape().elem_cnt() * sizeof(T));
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_FLIP_CPU_KERNEL(dtype)                                             \
+  REGISTER_USER_KERNEL("flip").SetCreateFn<FlipCpuKernel<dtype>>().SetIsMatchedHob( \
+      (user_op::HobDeviceTag() == "cpu")                                            \
+      & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value));               \
+  REGISTER_USER_KERNEL("flip_grad")                                                 \
+      .SetCreateFn<FlipGrad1DCpuKernel<dtype>>()                                    \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                           \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_FLIP_CPU_KERNEL(float)
+REGISTER_FLIP_CPU_KERNEL(double)
+REGISTER_FLIP_CPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu
new file mode 100644
index 000000000..00e5ed24d
--- /dev/null
+++ b/oneflow/user/kernels/flip_kernel.cu
@@ -0,0 +1,132 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+
+namespace oneflow {
+
+namespace {
+
+const int32_t NDIMS = 16;
+struct SIZE_V {
+  int32_t val[NDIMS];
+};
+
+struct VIS {
+  bool val[NDIMS] = {false};
+};
+
+template<typename T>
+__global__ void FlipGpuForward(const int32_t element, const int64_t total_dims,
+                               const SIZE_V stride_contiguous_v, const SIZE_V sizes_v,
+                               const VIS vis, SIZE_V strides_v, const T* in_dptr, T* out_dptr) {
+  CUDA_1D_KERNEL_LOOP(i, element) {
+    int32_t cur_indices = i;
+    int32_t rem = 0;
+    int32_t dst_offset = 0;
+    for (int32_t d = 0; d < total_dims; d++) {
+      int32_t temp = cur_indices;
+      cur_indices = cur_indices / stride_contiguous_v.val[d];
+      rem = temp - cur_indices * stride_contiguous_v.val[d];
+      dst_offset += vis.val[d] ? (sizes_v.val[d] - 1 - cur_indices) * strides_v.val[d]
+                               : cur_indices * strides_v.val[d];
+      cur_indices = rem;
+    }
+    out_dptr[i] = in_dptr[dst_offset];
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class FlipGpuKernel final : public user_op::OpKernel {
+ public:
+  FlipGpuKernel() = default;
+  ~FlipGpuKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const int32_t elem_cnt = y_tensor->shape().elem_cnt();
+
+    const int32_t total_dims = y_tensor->shape().NumAxes();
+
+    std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>("dims");
+    VIS vis;
+    for (auto x : dims) { vis.val[x] = true; }
+
+    SIZE_V sizes_v;
+    for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape().At(i); }
+
+    // TODO(bbuf) delete strides caluculate, after tensor strides supported
+    SIZE_V strides_v;
+    strides_v.val[total_dims - 1] = 1;
+    for (int32_t i = total_dims - 2; i >= 0; i--) {
+      strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape().At(i);
+    }
+
+    SIZE_V stride_contiguous_v;
+
+    for (int32_t i = total_dims - 1; i >= 0; i--) {
+      if (i == total_dims - 1) {
+        stride_contiguous_v.val[i] = 1;
+      } else {
+        stride_contiguous_v.val[i] =
+            std::max<int32_t>(x_tensor->shape().At(i + 1), 1) * stride_contiguous_v.val[i + 1];
+      }
+    }
+    RUN_CUDA_KERNEL((FlipGpuForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt, total_dims,
+                    stride_contiguous_v, sizes_v, vis, strides_v, x_tensor->dptr<T>(),
+                    y_tensor->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class FlipGrad1DGpuKernel final : public user_op::OpKernel {
+ public:
+  FlipGrad1DGpuKernel() = default;
+  ~FlipGrad1DGpuKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    Memcpy<DeviceType::kGPU>(
+        ctx->device_ctx(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+        dy_tensor->shape().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_FLIP_GPU_KERNEL(dtype)                                             \
+  REGISTER_USER_KERNEL("flip").SetCreateFn<FlipGpuKernel<dtype>>().SetIsMatchedHob( \
+      (user_op::HobDeviceTag() == "gpu")                                            \
+      & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value));               \
+  REGISTER_USER_KERNEL("flip_grad")                                                 \
+      .SetCreateFn<FlipGrad1DGpuKernel<dtype>>()                                    \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                           \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_FLIP_GPU_KERNEL(float)
+REGISTER_FLIP_GPU_KERNEL(double)
+REGISTER_FLIP_GPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/ops/flip_op.cpp b/oneflow/user/ops/flip_op.cpp
new file mode 100644
index 000000000..417608205
--- /dev/null
+++ b/oneflow/user/ops/flip_op.cpp
@@ -0,0 +1,63 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+
+namespace oneflow {
+
+REGISTER_USER_OP("flip")
+    .Input("x")
+    .Output("y")
+    .Attr<std::vector<int32_t>>("dims")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
+      const int input_dims = x_desc->shape().NumAxes();
+      const std::vector<int32_t> dims = ctx->Attr<std::vector<int32_t>>("dims");
+      CHECK_OR_RETURN(dims.size() <= input_dims)
+          << "len of dims must less than len of input tensor";
+      for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << "dims parameter is illegal."; }
+      user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
+      *y_desc->mut_shape() = x_desc->shape();
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("flip_grad")
+    .Input("dy")
+    .Output("dx")
+    .Attr<std::vector<int32_t>>("dims")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      *dx_shape = dy_shape;
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    });
+
+}  // namespace oneflow
-- 
GitLab