diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp
index b539ffd6a057202ae0932732b7cc210feb15a897..e95a1ad21cd5868786b6e9f839416a87f828acaf 100644
--- a/oneflow/user/kernels/nccl_logical_kernels.cpp
+++ b/oneflow/user/kernels/nccl_logical_kernels.cpp
@@ -203,6 +203,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel {
     }
 
     {
+      // NOTE(chengcheng): init nccl comm need before ncclGroupStart.
+      ncclComm_t comm = nccl_comm->comm();
       // NOTE(chengcheng): Do S2S
       OF_NCCL_CHECK(ncclGroupStart());
       const int64_t elem_per_chunk = elem_cnt / num_ranks;
@@ -210,8 +212,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel {
       for (int64_t j = 0; j < num_ranks; ++j) {
         OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>(
                                    reinterpret_cast<const char*>(pack_to_ptr) + j * chunk_size),
-                               elem_per_chunk, GetNcclDataType(in->data_type()), j,
-                               nccl_comm->comm(), ctx->device_ctx()->cuda_stream()));
+                               elem_per_chunk, GetNcclDataType(in->data_type()), j, comm,
+                               ctx->device_ctx()->cuda_stream()));
         OF_NCCL_CHECK(ncclRecv(
             reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr) + j * chunk_size),
             elem_per_chunk, GetNcclDataType(in->data_type()), j, nccl_comm->comm(),