From b7b58491754ab8c82dcc0b72102b949e12b51e6d Mon Sep 17 00:00:00 2001
From: leaves-zwx <kunta0932@gmail.com>
Date: Thu, 22 Apr 2021 13:55:49 +0800
Subject: [PATCH] Fix transpose op (#4695)

* impl pack in transpose util

* fix transpose sbp

* rm pack in transpose kernel

* fix bug

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 .../kernel/util/cuda_arithemetic_interface.cu | 38 +++++++++++++++++--
 oneflow/user/kernels/transpose_kernel.cpp     | 25 ++----------
 oneflow/user/ops/transpose_ops.cpp            |  2 +-
 3 files changed, 39 insertions(+), 26 deletions(-)

diff --git a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu
index 67c519fd9..8148294c0 100644
--- a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu
+++ b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu
@@ -55,9 +55,9 @@ __global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<N
 }
 
 template<int32_t NDIMS, typename T>
-void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape,
-                   const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x,
-                   T* y) {
+void LaunchTransposeGpu(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape,
+                        const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x,
+                        T* y) {
   CHECK_LE(y_shape.elem_cnt(), GetMaxVal<int32_t>());
   Int32Array<NDIMS> y_shape_struct;
   FOR_RANGE(int32_t, i, 0, NDIMS) { y_shape_struct.val[i] = y_shape.At(i); }
@@ -74,6 +74,38 @@ void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_
           y_shape_struct, x_strides, elem_cnt, x, y);
 }
 
+template<int32_t NDIMS, typename T>
+void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape,
+                   const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x,
+                   T* y) {
+  CHECK_EQ(x_shape.NumAxes(), NDIMS);
+  CHECK_EQ(y_shape.NumAxes(), NDIMS);
+
+  using PackType = int64_t;
+  const size_t pack_size = sizeof(PackType) / sizeof(T);
+  int64_t in_last_dim = x_shape.At(x_shape.NumAxes() - 1);
+  int64_t out_last_dim = y_shape.At(y_shape.NumAxes() - 1);
+  if (pack_size != 1 && permutation.back() == permutation.size() - 1
+      && in_last_dim % pack_size == 0) {
+    CHECK_EQ(in_last_dim, out_last_dim);
+    DimVector packed_in_dim_vec;
+    x_shape.ToDimVector(&packed_in_dim_vec);
+    packed_in_dim_vec.back() /= pack_size;
+    Shape packed_in_shape(packed_in_dim_vec);
+    DimVector packed_out_dim_vec;
+    y_shape.ToDimVector(&packed_out_dim_vec);
+    packed_out_dim_vec.back() /= pack_size;
+    Shape packed_out_shape(packed_out_dim_vec);
+
+    LaunchTransposeGpu<NDIMS, PackType>(
+        ctx, ShapeView(packed_in_shape), ShapeView(packed_out_shape), permutation,
+        packed_in_shape.elem_cnt(), reinterpret_cast<const PackType*>(x),
+        reinterpret_cast<PackType*>(y));
+  } else {
+    LaunchTransposeGpu<NDIMS, T>(ctx, x_shape, y_shape, permutation, elem_cnt, x, y);
+  }
+}
+
 template<typename T>
 struct TransposeUtil final {
 #define MAKE_TRANSPOSE_SWITCH_ENTRY(func_name, NDIMS) func_name<NDIMS, T>
diff --git a/oneflow/user/kernels/transpose_kernel.cpp b/oneflow/user/kernels/transpose_kernel.cpp
index ee1991ff2..9fd137185 100644
--- a/oneflow/user/kernels/transpose_kernel.cpp
+++ b/oneflow/user/kernels/transpose_kernel.cpp
@@ -32,30 +32,11 @@ class TransposeKernel final : public OpKernel {
     const Tensor* tensor_in = ctx->Tensor4ArgNameAndIndex("input", 0);
     Tensor* tensor_out = ctx->Tensor4ArgNameAndIndex("output", 0);
     const auto& perm = ctx->Attr<std::vector<int32_t>>("perm");
-    using PackType = int64_t;
-    const size_t num_elem_per_pack = sizeof(PackType) / sizeof(T);
     const ShapeView& in_shape = tensor_in->shape();
     const ShapeView& out_shape = tensor_out->shape();
-    if (num_elem_per_pack != 1 && perm.back() == perm.size() - 1
-        && in_shape.At(in_shape.NumAxes() - 1) % num_elem_per_pack == 0) {
-      CHECK_EQ(in_shape.At(in_shape.NumAxes() - 1), out_shape.At(out_shape.NumAxes() - 1));
-      DimVector packed_in_dim_vec;
-      in_shape.ToDimVector(&packed_in_dim_vec);
-      packed_in_dim_vec.back() /= num_elem_per_pack;
-      const Shape packed_in_shape(packed_in_dim_vec);
-      DimVector packed_out_dim_vec;
-      out_shape.ToDimVector(&packed_out_dim_vec);
-      packed_out_dim_vec.back() /= num_elem_per_pack;
-      const Shape packed_out_shape(packed_out_dim_vec);
-      NewKernelUtil<device_type>::Transpose(
-          ctx->device_ctx(), packed_in_shape.NumAxes(), packed_in_shape, packed_out_shape, perm,
-          packed_in_shape.elem_cnt(), reinterpret_cast<const PackType*>(tensor_in->dptr<T>()),
-          reinterpret_cast<PackType*>(tensor_out->mut_dptr<T>()));
-    } else {
-      NewKernelUtil<device_type>::Transpose(ctx->device_ctx(), in_shape.NumAxes(), in_shape,
-                                            tensor_out->shape(), perm, in_shape.elem_cnt(),
-                                            tensor_in->dptr<T>(), tensor_out->mut_dptr<T>());
-    }
+    NewKernelUtil<device_type>::Transpose(ctx->device_ctx(), in_shape.NumAxes(), in_shape,
+                                          out_shape, perm, in_shape.elem_cnt(),
+                                          tensor_in->dptr<T>(), tensor_out->mut_dptr<T>());
   }
   bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
 };
diff --git a/oneflow/user/ops/transpose_ops.cpp b/oneflow/user/ops/transpose_ops.cpp
index 4b4d6d0fd..c417403e3 100644
--- a/oneflow/user/ops/transpose_ops.cpp
+++ b/oneflow/user/ops/transpose_ops.cpp
@@ -61,7 +61,7 @@ REGISTER_USER_OP("transpose")
         if (axis < 0) { axis += perm.size(); }
         CHECK_GE(axis, 0);
         CHECK_LT(axis, perm.size());
-        ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), axis).Build();
+        ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), i).Build();
       }
       ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();
       return Maybe<void>::Ok();
-- 
GitLab