diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index 5d4d54cb63a982fe4e656ae19c3f03c502286a08..536d6f107625e2bd6a94e828e610a189404576e8 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -189,6 +189,20 @@ bool TryBuildNcclBy1DHierarchy(OperatorConf* ret, const SbpParallel& src_sbp, .Build() .op_conf(); return true; + } else if ((src_sbp.has_split_parallel() && dst_sbp.has_broadcast_parallel()) + && (src_sbp.split_parallel().axis() > 0) + && (logical_blob_desc.shape().At(src_sbp.split_parallel().axis()) % parallel_num + == 0)) { + // S(1)->B : AllGather Noncontinuous + *ret = user_op::UserOpConfWrapperBuilder(kNcclLogicalOpNamePrefix + "-S2B-" + NewUniqueId()) + .Op("_nccl_logical_all_gather_noncontinuous") + .Input("in", lbn) + .Output("out") + .Attr<int64_t>("in_split_axis", src_sbp.split_parallel().axis()) + .ScopeSymbolId(scope_symbol_id) + .Build() + .op_conf(); + return true; } else if ((src_sbp.has_split_parallel() && dst_sbp.has_split_parallel()) && (src_sbp.split_parallel().axis() != dst_sbp.split_parallel().axis()) && (logical_blob_desc.shape().At(src_sbp.split_parallel().axis()) % parallel_num == 0) diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index 14effa226d19440d36b8246d01e863210023c1ee..4cfd4aeb4dd1abd01c2d1ae149ee7705e2e8bf45 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -144,7 +144,7 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const int64_t dtype_size = GetSizeOfDataType(in->data_type()); - int64_t data_size = GetCudaAlignedSize(in->shape().elem_cnt() * dtype_size); + int64_t data_size = GetCudaAlignedSize(out->shape().elem_cnt() * dtype_size); void* unpack_from_ptr = tmp_buffer->mut_dptr(); CHECK_EQ(tmp_buffer->shape().elem_cnt(), data_size); @@ -188,9 +188,9 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern }; size_t Infer2DSameDim0AllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) { - const user_op::TensorDesc* in_tensor = ctx->TensorDesc4ArgNameAndIndex("in", 0); - return GetCudaAlignedSize(in_tensor->shape().elem_cnt() - * GetSizeOfDataType(in_tensor->data_type())); + const user_op::TensorDesc* out_tensor = ctx->TensorDesc4ArgNameAndIndex("out", 0); + return GetCudaAlignedSize(out_tensor->shape().elem_cnt() + * GetSizeOfDataType(out_tensor->data_type())); } template<typename T> diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index e95a1ad21cd5868786b6e9f839416a87f828acaf..a524564ec2b165e616ba13370ce5152a8a659c78 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -129,6 +129,74 @@ class NcclLogicalAllGatherKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; +template<typename T> +class NcclLogicalAllGatherNoncontinuous final : public user_op::OpKernel { + public: + NcclLogicalAllGatherNoncontinuous() = default; + ~NcclLogicalAllGatherNoncontinuous() override = default; + + std::shared_ptr<user_op::OpKernelState> CreateOpKernelState( + user_op::KernelInitContext* ctx) const override { + return std::make_shared<NcclLogicalKernelCommState>(ctx); + } + + private: + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + auto* nccl_comm = dynamic_cast<NcclLogicalKernelCommState*>(state); + CHECK(nccl_comm != nullptr); + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + const int64_t dtype_size = GetSizeOfDataType(in->data_type()); + int64_t data_size = GetCudaAlignedSize(out->shape().elem_cnt() * dtype_size); + void* unpack_from_ptr = tmp_buffer->mut_dptr(); + CHECK_EQ(tmp_buffer->shape().elem_cnt(), data_size); + + CHECK_EQ(in->data_type(), out->data_type()); + const int64_t num_ranks = ctx->parallel_ctx().parallel_num(); + const int64_t in_split_axis = ctx->Attr<int64_t>("in_split_axis"); + + DimVector logical_shape_dim_vec; + in->shape().ToDimVector(&logical_shape_dim_vec); + logical_shape_dim_vec[in_split_axis] = logical_shape_dim_vec.at(in_split_axis) * num_ranks; + + // NOTE(chengcheng): Do AllGather + CHECK_EQ(in->shape().elem_cnt() * num_ranks, out->shape().elem_cnt()); + OF_NCCL_CHECK(ncclAllGather(in->dptr(), unpack_from_ptr, in->shape().elem_cnt(), + GetNcclDataType(in->data_type()), nccl_comm->comm(), + ctx->device_ctx()->cuda_stream())); + + CHECK_GT(in_split_axis, 0); + // NOTE(chengcheng): Do unpack. + DimVector unpack_from_dim_vec = logical_shape_dim_vec; + CHECK_EQ(unpack_from_dim_vec.at(in_split_axis) % num_ranks, 0); + unpack_from_dim_vec[in_split_axis] = unpack_from_dim_vec.at(in_split_axis) / num_ranks; + unpack_from_dim_vec.insert(unpack_from_dim_vec.begin(), num_ranks); + const Shape unpack_from_shape(unpack_from_dim_vec); + DimVector transpose_out_dim_vec; + std::vector<int32_t> perm; + FOR_RANGE(int64_t, i, 1, unpack_from_shape.NumAxes()) { + perm.push_back(i); + transpose_out_dim_vec.push_back(unpack_from_shape.At(i)); + } + perm.insert(perm.begin() + in_split_axis, 0); + transpose_out_dim_vec.insert(transpose_out_dim_vec.begin() + in_split_axis, + unpack_from_shape.At(0)); + const Shape transpose_out_shape(transpose_out_dim_vec); + NewKernelUtil<DeviceType::kGPU>::Transpose( + ctx->device_ctx(), unpack_from_shape.NumAxes(), unpack_from_shape, transpose_out_shape, + perm, unpack_from_shape.elem_cnt(), reinterpret_cast<const T*>(unpack_from_ptr), + out->mut_dptr<T>()); + }; + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +size_t InferAllGatherNoncontinuousKernelTmpBufferSize(user_op::InferContext* ctx) { + const user_op::TensorDesc* out_tensor = ctx->TensorDesc4ArgNameAndIndex("out", 0); + return GetCudaAlignedSize(out_tensor->shape().elem_cnt() + * GetSizeOfDataType(out_tensor->data_type())); +} + template<typename T> class NcclLogicalS2SKernel final : public user_op::OpKernel { public: @@ -278,6 +346,21 @@ REGISTER_USER_KERNEL("_nccl_logical_all_gather") .SetCreateFn<NcclLogicalAllGatherKernel>() .SetIsMatchedHob(user_op::HobDeviceTag() == "gpu"); +#define REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(dtype) \ + REGISTER_USER_KERNEL("_nccl_logical_all_gather_noncontinuous") \ + .SetCreateFn<NcclLogicalAllGatherNoncontinuous<dtype>>() \ + .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") \ + & (user_op::HobDataType("in", 0) == GetDataType<dtype>::value) \ + & (user_op::HobDataType("out", 0) == GetDataType<dtype>::value)) \ + .SetInferTmpSizeFn(InferAllGatherNoncontinuousKernelTmpBufferSize); + +REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int8_t) +REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int32_t) +REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(int64_t) +REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float) +REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(double) +REGISTER_ALLGATHER_NONCONTINUOUS_KERNEL(float16) + #define REGISTER_S2S_KERNEL(dtype) \ REGISTER_USER_KERNEL("_nccl_logical_s2s") \ .SetCreateFn<NcclLogicalS2SKernel<dtype>>() \ diff --git a/oneflow/user/ops/nccl_logical_ops.cpp b/oneflow/user/ops/nccl_logical_ops.cpp index 6fd884f45fb23f0dc75ffca7464bb1984fdbc1c7..63fdd1cc8292529fe7db70181bf7d3f5f579b27b 100644 --- a/oneflow/user/ops/nccl_logical_ops.cpp +++ b/oneflow/user/ops/nccl_logical_ops.cpp @@ -127,6 +127,46 @@ REGISTER_USER_OP("_nccl_logical_all_gather") return Maybe<void>::Ok(); }); +REGISTER_USER_OP("_nccl_logical_all_gather_noncontinuous") + .Input("in") + .Output("out") + .Attr<int64_t>("in_split_axis", -1) + .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->Shape4ArgNameAndIndex("out", 0) = *ctx->Shape4ArgNameAndIndex("in", 0); + *ctx->IsDynamic4ArgNameAndIndex("out", 0) = *ctx->IsDynamic4ArgNameAndIndex("in", 0); + return Maybe<void>::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + *ctx->Dtype4ArgNameAndIndex("out", 0) = *ctx->Dtype4ArgNameAndIndex("in", 0); + return Maybe<void>::Ok(); + }) + .SetParallelDistributionInferFn([](user_op::InferParallelDistributionFnContext* ctx) + -> Maybe<void> { + const ParallelDistribution& in_dis_hint = + ctx->ParallelDistributionHint4InputArgNameAndIndex("in", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + const int64_t in_split_axis = ctx->user_op_conf().attr<int64_t>("in_split_axis"); + CHECK_GE_OR_RETURN(in_split_axis, 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); + } + + ParallelDistribution* in_distribution = ctx->ParallelDistribution4ArgNameAndIndex("in", 0); + ParallelDistribution* out_distribution = ctx->ParallelDistribution4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // S(1)->(B) + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe<void>::Ok(); + }); + REGISTER_USER_OP("_nccl_logical_s2s") .Input("in") .Output("out")