From b7b58491754ab8c82dcc0b72102b949e12b51e6d Mon Sep 17 00:00:00 2001 From: leaves-zwx <kunta0932@gmail.com> Date: Thu, 22 Apr 2021 13:55:49 +0800 Subject: [PATCH] Fix transpose op (#4695) * impl pack in transpose util * fix transpose sbp * rm pack in transpose kernel * fix bug Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../kernel/util/cuda_arithemetic_interface.cu | 38 +++++++++++++++++-- oneflow/user/kernels/transpose_kernel.cpp | 25 ++---------- oneflow/user/ops/transpose_ops.cpp | 2 +- 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu index 67c519fd9..8148294c0 100644 --- a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu +++ b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu @@ -55,9 +55,9 @@ __global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<N } template<int32_t NDIMS, typename T> -void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape, - const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x, - T* y) { +void LaunchTransposeGpu(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape, + const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x, + T* y) { CHECK_LE(y_shape.elem_cnt(), GetMaxVal<int32_t>()); Int32Array<NDIMS> y_shape_struct; FOR_RANGE(int32_t, i, 0, NDIMS) { y_shape_struct.val[i] = y_shape.At(i); } @@ -74,6 +74,38 @@ void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_ y_shape_struct, x_strides, elem_cnt, x, y); } +template<int32_t NDIMS, typename T> +void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape, + const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x, + T* y) { + CHECK_EQ(x_shape.NumAxes(), NDIMS); + CHECK_EQ(y_shape.NumAxes(), NDIMS); + + using PackType = int64_t; + const size_t pack_size = sizeof(PackType) / sizeof(T); + int64_t in_last_dim = x_shape.At(x_shape.NumAxes() - 1); + int64_t out_last_dim = y_shape.At(y_shape.NumAxes() - 1); + if (pack_size != 1 && permutation.back() == permutation.size() - 1 + && in_last_dim % pack_size == 0) { + CHECK_EQ(in_last_dim, out_last_dim); + DimVector packed_in_dim_vec; + x_shape.ToDimVector(&packed_in_dim_vec); + packed_in_dim_vec.back() /= pack_size; + Shape packed_in_shape(packed_in_dim_vec); + DimVector packed_out_dim_vec; + y_shape.ToDimVector(&packed_out_dim_vec); + packed_out_dim_vec.back() /= pack_size; + Shape packed_out_shape(packed_out_dim_vec); + + LaunchTransposeGpu<NDIMS, PackType>( + ctx, ShapeView(packed_in_shape), ShapeView(packed_out_shape), permutation, + packed_in_shape.elem_cnt(), reinterpret_cast<const PackType*>(x), + reinterpret_cast<PackType*>(y)); + } else { + LaunchTransposeGpu<NDIMS, T>(ctx, x_shape, y_shape, permutation, elem_cnt, x, y); + } +} + template<typename T> struct TransposeUtil final { #define MAKE_TRANSPOSE_SWITCH_ENTRY(func_name, NDIMS) func_name<NDIMS, T> diff --git a/oneflow/user/kernels/transpose_kernel.cpp b/oneflow/user/kernels/transpose_kernel.cpp index ee1991ff2..9fd137185 100644 --- a/oneflow/user/kernels/transpose_kernel.cpp +++ b/oneflow/user/kernels/transpose_kernel.cpp @@ -32,30 +32,11 @@ class TransposeKernel final : public OpKernel { const Tensor* tensor_in = ctx->Tensor4ArgNameAndIndex("input", 0); Tensor* tensor_out = ctx->Tensor4ArgNameAndIndex("output", 0); const auto& perm = ctx->Attr<std::vector<int32_t>>("perm"); - using PackType = int64_t; - const size_t num_elem_per_pack = sizeof(PackType) / sizeof(T); const ShapeView& in_shape = tensor_in->shape(); const ShapeView& out_shape = tensor_out->shape(); - if (num_elem_per_pack != 1 && perm.back() == perm.size() - 1 - && in_shape.At(in_shape.NumAxes() - 1) % num_elem_per_pack == 0) { - CHECK_EQ(in_shape.At(in_shape.NumAxes() - 1), out_shape.At(out_shape.NumAxes() - 1)); - DimVector packed_in_dim_vec; - in_shape.ToDimVector(&packed_in_dim_vec); - packed_in_dim_vec.back() /= num_elem_per_pack; - const Shape packed_in_shape(packed_in_dim_vec); - DimVector packed_out_dim_vec; - out_shape.ToDimVector(&packed_out_dim_vec); - packed_out_dim_vec.back() /= num_elem_per_pack; - const Shape packed_out_shape(packed_out_dim_vec); - NewKernelUtil<device_type>::Transpose( - ctx->device_ctx(), packed_in_shape.NumAxes(), packed_in_shape, packed_out_shape, perm, - packed_in_shape.elem_cnt(), reinterpret_cast<const PackType*>(tensor_in->dptr<T>()), - reinterpret_cast<PackType*>(tensor_out->mut_dptr<T>())); - } else { - NewKernelUtil<device_type>::Transpose(ctx->device_ctx(), in_shape.NumAxes(), in_shape, - tensor_out->shape(), perm, in_shape.elem_cnt(), - tensor_in->dptr<T>(), tensor_out->mut_dptr<T>()); - } + NewKernelUtil<device_type>::Transpose(ctx->device_ctx(), in_shape.NumAxes(), in_shape, + out_shape, perm, in_shape.elem_cnt(), + tensor_in->dptr<T>(), tensor_out->mut_dptr<T>()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/oneflow/user/ops/transpose_ops.cpp b/oneflow/user/ops/transpose_ops.cpp index 4b4d6d0fd..c417403e3 100644 --- a/oneflow/user/ops/transpose_ops.cpp +++ b/oneflow/user/ops/transpose_ops.cpp @@ -61,7 +61,7 @@ REGISTER_USER_OP("transpose") if (axis < 0) { axis += perm.size(); } CHECK_GE(axis, 0); CHECK_LT(axis, perm.size()); - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), axis).Build(); + ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), i).Build(); } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe<void>::Ok(); -- GitLab