From 32a240cb5e37d2d61ac5704734fdc40ca4f38005 Mon Sep 17 00:00:00 2001
From: cheng cheng <472491134@qq.com>
Date: Thu, 29 Apr 2021 01:08:23 +0800
Subject: [PATCH] Fix NcclLogicalS2S kernel comm BUG. (#4774)

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/user/kernels/nccl_logical_kernels.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp
index b539ffd6a..e95a1ad21 100644
--- a/oneflow/user/kernels/nccl_logical_kernels.cpp
+++ b/oneflow/user/kernels/nccl_logical_kernels.cpp
@@ -203,6 +203,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel {
     }
 
     {
+      // NOTE(chengcheng): init nccl comm need before ncclGroupStart.
+      ncclComm_t comm = nccl_comm->comm();
       // NOTE(chengcheng): Do S2S
       OF_NCCL_CHECK(ncclGroupStart());
       const int64_t elem_per_chunk = elem_cnt / num_ranks;
@@ -210,8 +212,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel {
       for (int64_t j = 0; j < num_ranks; ++j) {
         OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>(
                                    reinterpret_cast<const char*>(pack_to_ptr) + j * chunk_size),
-                               elem_per_chunk, GetNcclDataType(in->data_type()), j,
-                               nccl_comm->comm(), ctx->device_ctx()->cuda_stream()));
+                               elem_per_chunk, GetNcclDataType(in->data_type()), j, comm,
+                               ctx->device_ctx()->cuda_stream()));
         OF_NCCL_CHECK(ncclRecv(
             reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr) + j * chunk_size),
             elem_per_chunk, GetNcclDataType(in->data_type()), j, nccl_comm->comm(),
-- 
GitLab