From 32a240cb5e37d2d61ac5704734fdc40ca4f38005 Mon Sep 17 00:00:00 2001 From: cheng cheng <472491134@qq.com> Date: Thu, 29 Apr 2021 01:08:23 +0800 Subject: [PATCH] Fix NcclLogicalS2S kernel comm BUG. (#4774) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/user/kernels/nccl_logical_kernels.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index b539ffd6a..e95a1ad21 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -203,6 +203,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel { } { + // NOTE(chengcheng): init nccl comm need before ncclGroupStart. + ncclComm_t comm = nccl_comm->comm(); // NOTE(chengcheng): Do S2S OF_NCCL_CHECK(ncclGroupStart()); const int64_t elem_per_chunk = elem_cnt / num_ranks; @@ -210,8 +212,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel { for (int64_t j = 0; j < num_ranks; ++j) { OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>( reinterpret_cast<const char*>(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, - nccl_comm->comm(), ctx->device_ctx()->cuda_stream())); + elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, + ctx->device_ctx()->cuda_stream())); OF_NCCL_CHECK(ncclRecv( reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr) + j * chunk_size), elem_per_chunk, GetNcclDataType(in->data_type()), j, nccl_comm->comm(), -- GitLab