Skip to content
Snippets Groups Projects
Unverified Commit 32a240cb authored by cheng cheng's avatar cheng cheng Committed by GitHub
Browse files

Fix NcclLogicalS2S kernel comm BUG. (#4774)


Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 4789392d
No related branches found
No related tags found
No related merge requests found
......@@ -203,6 +203,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel {
}
{
// NOTE(chengcheng): init nccl comm need before ncclGroupStart.
ncclComm_t comm = nccl_comm->comm();
// NOTE(chengcheng): Do S2S
OF_NCCL_CHECK(ncclGroupStart());
const int64_t elem_per_chunk = elem_cnt / num_ranks;
......@@ -210,8 +212,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel {
for (int64_t j = 0; j < num_ranks; ++j) {
OF_NCCL_CHECK(ncclSend(reinterpret_cast<const void*>(
reinterpret_cast<const char*>(pack_to_ptr) + j * chunk_size),
elem_per_chunk, GetNcclDataType(in->data_type()), j,
nccl_comm->comm(), ctx->device_ctx()->cuda_stream()));
elem_per_chunk, GetNcclDataType(in->data_type()), j, comm,
ctx->device_ctx()->cuda_stream()));
OF_NCCL_CHECK(ncclRecv(
reinterpret_cast<void*>(reinterpret_cast<char*>(unpack_from_ptr) + j * chunk_size),
elem_per_chunk, GetNcclDataType(in->data_type()), j, nccl_comm->comm(),
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment