Skip to content
Snippets Groups Projects
Unverified Commit b7b58491 authored by leaves-zwx's avatar leaves-zwx Committed by GitHub
Browse files

Fix transpose op (#4695)


* impl pack in transpose util

* fix transpose sbp

* rm pack in transpose kernel

* fix bug

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 709fa873
No related branches found
No related tags found
No related merge requests found
...@@ -55,7 +55,7 @@ __global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<N ...@@ -55,7 +55,7 @@ __global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<N
} }
template<int32_t NDIMS, typename T> template<int32_t NDIMS, typename T>
void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape, void LaunchTransposeGpu(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape,
const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x, const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x,
T* y) { T* y) {
CHECK_LE(y_shape.elem_cnt(), GetMaxVal<int32_t>()); CHECK_LE(y_shape.elem_cnt(), GetMaxVal<int32_t>());
...@@ -74,6 +74,38 @@ void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_ ...@@ -74,6 +74,38 @@ void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_
y_shape_struct, x_strides, elem_cnt, x, y); y_shape_struct, x_strides, elem_cnt, x, y);
} }
template<int32_t NDIMS, typename T>
void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape,
const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x,
T* y) {
CHECK_EQ(x_shape.NumAxes(), NDIMS);
CHECK_EQ(y_shape.NumAxes(), NDIMS);
using PackType = int64_t;
const size_t pack_size = sizeof(PackType) / sizeof(T);
int64_t in_last_dim = x_shape.At(x_shape.NumAxes() - 1);
int64_t out_last_dim = y_shape.At(y_shape.NumAxes() - 1);
if (pack_size != 1 && permutation.back() == permutation.size() - 1
&& in_last_dim % pack_size == 0) {
CHECK_EQ(in_last_dim, out_last_dim);
DimVector packed_in_dim_vec;
x_shape.ToDimVector(&packed_in_dim_vec);
packed_in_dim_vec.back() /= pack_size;
Shape packed_in_shape(packed_in_dim_vec);
DimVector packed_out_dim_vec;
y_shape.ToDimVector(&packed_out_dim_vec);
packed_out_dim_vec.back() /= pack_size;
Shape packed_out_shape(packed_out_dim_vec);
LaunchTransposeGpu<NDIMS, PackType>(
ctx, ShapeView(packed_in_shape), ShapeView(packed_out_shape), permutation,
packed_in_shape.elem_cnt(), reinterpret_cast<const PackType*>(x),
reinterpret_cast<PackType*>(y));
} else {
LaunchTransposeGpu<NDIMS, T>(ctx, x_shape, y_shape, permutation, elem_cnt, x, y);
}
}
template<typename T> template<typename T>
struct TransposeUtil final { struct TransposeUtil final {
#define MAKE_TRANSPOSE_SWITCH_ENTRY(func_name, NDIMS) func_name<NDIMS, T> #define MAKE_TRANSPOSE_SWITCH_ENTRY(func_name, NDIMS) func_name<NDIMS, T>
......
...@@ -32,31 +32,12 @@ class TransposeKernel final : public OpKernel { ...@@ -32,31 +32,12 @@ class TransposeKernel final : public OpKernel {
const Tensor* tensor_in = ctx->Tensor4ArgNameAndIndex("input", 0); const Tensor* tensor_in = ctx->Tensor4ArgNameAndIndex("input", 0);
Tensor* tensor_out = ctx->Tensor4ArgNameAndIndex("output", 0); Tensor* tensor_out = ctx->Tensor4ArgNameAndIndex("output", 0);
const auto& perm = ctx->Attr<std::vector<int32_t>>("perm"); const auto& perm = ctx->Attr<std::vector<int32_t>>("perm");
using PackType = int64_t;
const size_t num_elem_per_pack = sizeof(PackType) / sizeof(T);
const ShapeView& in_shape = tensor_in->shape(); const ShapeView& in_shape = tensor_in->shape();
const ShapeView& out_shape = tensor_out->shape(); const ShapeView& out_shape = tensor_out->shape();
if (num_elem_per_pack != 1 && perm.back() == perm.size() - 1
&& in_shape.At(in_shape.NumAxes() - 1) % num_elem_per_pack == 0) {
CHECK_EQ(in_shape.At(in_shape.NumAxes() - 1), out_shape.At(out_shape.NumAxes() - 1));
DimVector packed_in_dim_vec;
in_shape.ToDimVector(&packed_in_dim_vec);
packed_in_dim_vec.back() /= num_elem_per_pack;
const Shape packed_in_shape(packed_in_dim_vec);
DimVector packed_out_dim_vec;
out_shape.ToDimVector(&packed_out_dim_vec);
packed_out_dim_vec.back() /= num_elem_per_pack;
const Shape packed_out_shape(packed_out_dim_vec);
NewKernelUtil<device_type>::Transpose(
ctx->device_ctx(), packed_in_shape.NumAxes(), packed_in_shape, packed_out_shape, perm,
packed_in_shape.elem_cnt(), reinterpret_cast<const PackType*>(tensor_in->dptr<T>()),
reinterpret_cast<PackType*>(tensor_out->mut_dptr<T>()));
} else {
NewKernelUtil<device_type>::Transpose(ctx->device_ctx(), in_shape.NumAxes(), in_shape, NewKernelUtil<device_type>::Transpose(ctx->device_ctx(), in_shape.NumAxes(), in_shape,
tensor_out->shape(), perm, in_shape.elem_cnt(), out_shape, perm, in_shape.elem_cnt(),
tensor_in->dptr<T>(), tensor_out->mut_dptr<T>()); tensor_in->dptr<T>(), tensor_out->mut_dptr<T>());
} }
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
}; };
......
...@@ -61,7 +61,7 @@ REGISTER_USER_OP("transpose") ...@@ -61,7 +61,7 @@ REGISTER_USER_OP("transpose")
if (axis < 0) { axis += perm.size(); } if (axis < 0) { axis += perm.size(); }
CHECK_GE(axis, 0); CHECK_GE(axis, 0);
CHECK_LT(axis, perm.size()); CHECK_LT(axis, perm.size());
ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), axis).Build(); ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), i).Build();
} }
ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();
return Maybe<void>::Ok(); return Maybe<void>::Ok();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment