diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index b539ffd6a057202ae0932732b7cc210feb15a897..e95a1ad21cd5868786b6e9f839416a87f828acaf 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -203,6 +203,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel { } { + // NOTE(chengcheng): init nccl comm need before ncclGroupStart. + ncclComm_t comm = nccl_comm->comm(); // NOTE(chengcheng): Do S2S OF_NCCL_CHECK(ncclGroupStart()); const int64_t elem_per_chunk = elem_cnt / num_ranks; @@ -210,8 +212,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel { for (int64_t j = 0; j < num_ranks; ++j) { OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>( reinterpret_cast<const char*>(pack_to_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, - nccl_comm->comm(), ctx->device_ctx()->cuda_stream())); + elem_per_chunk, GetNcclDataType(in->data_type()), j, comm, + ctx->device_ctx()->cuda_stream())); OF_NCCL_CHECK(ncclRecv( reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr) + j * chunk_size), elem_per_chunk, GetNcclDataType(in->data_type()), j, nccl_comm->comm(),