From e7222a8490e9fd6a5156446865b1cf2cc1f14660 Mon Sep 17 00:00:00 2001
From: lambda7xx <lambda7xx@gmail.com>
Date: Wed, 21 Jul 2021 14:53:48 +0800
Subject: [PATCH] optimize the IB when send too many message whose number is
 bigger then qp_init_attr.cap.max_send_wr, I use the msg_pendding_list to
 store the extra message and wait for some time to send these message of the
 msg_pendding_list

---
 oneflow/core/actor/actor_message.h            | 23 +++++++------
 .../ibverbs/ibverbs_comm_network.cpp          | 23 +++++++------
 .../core/comm_network/ibverbs/ibverbs_qp.cpp  | 33 ++++++++++---------
 .../core/comm_network/ibverbs/ibverbs_qp.h    |  2 +-
 4 files changed, 44 insertions(+), 37 deletions(-)

diff --git a/oneflow/core/actor/actor_message.h b/oneflow/core/actor/actor_message.h
index 733de0c10..fc35651dc 100644
--- a/oneflow/core/actor/actor_message.h
+++ b/oneflow/core/actor/actor_message.h
@@ -13,6 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
 #ifndef ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
 #define ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
 
@@ -72,15 +84,7 @@ class ActorMsg final {
   void Deserialize(StreamT& in_stream) {
     in_stream.Read(this, sizeof(ActorMsg));
   }
-  
-  //operate flag
-  void setFlag(bool flag) {
-    flag_ = flag;
-  }
-  bool getFlag() const {
-    return flag_;
-  }
-  
+
  private:
   struct RegstWrapper {
     Regst* regst;
@@ -100,7 +104,6 @@ class ActorMsg final {
   };
   uint8_t user_data_size_;
   unsigned char user_data_[kActorMsgUserDataMaxSize];
-  bool flag_;
 };
 
 template<typename StreamT>
diff --git a/oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp b/oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
index 9acbe3c00..86a260a2e 100644
--- a/oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
+++ b/oneflow/core/comm_network/ibverbs/ibverbs_comm_network.cpp
@@ -13,10 +13,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
 #include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
 #include "oneflow/core/control/ctrl_client.h"
 #include "oneflow/core/control/global_process_ctx.h"
-#include "oneflow/core/graph/node.h"
 #include "oneflow/core/job/resource_desc.h"
 #include "oneflow/core/job/global_for.h"
 #include "oneflow/core/platform/include/ibv.h"
@@ -70,18 +81,11 @@ void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
     static_assert(sizeof(IBVerbsCommNetRMADesc) <= kActorMsgUserDataMaxSize, "");
     new_msg.AddUserData(sizeof(IBVerbsCommNetRMADesc), &rma_desc);
   }
-  new_msg.setFlag(true);
   qp_vec_.at(dst_machine_id)->PostSendRequest(new_msg);
-  ActorMsg unuseful_msg = msg;
-  unuseful_msg.setFlag(false);
-  for(int i =0 ; i < 64; i++) {
-    qp_vec_.at(dst_machine_id)->PostSendRequest(unuseful_msg);
-  }
 }
 
 void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
   ActorMsg new_msg = msg;
-  if(new_msg.getFlag() == true) {
   if (msg.IsDataRegstMsgToConsumer()) {
     std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_);
     auto& desc = remote_regst2rma_desc_[std::make_pair(msg.src_actor_id(),
@@ -92,7 +96,6 @@ void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
     new_msg.set_comm_net_token(desc.get());
   }
   Global<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg);
-  }
 }
 
 IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT) {
@@ -185,4 +188,4 @@ COMMAND(IBVForkInit());
 
 }  // namespace oneflow
 
-#endif  // WITH_RDMA && OF_PLATFORM_POSIX
+#endif  // WITH_RDMA && OF_PLATFORM_POSIX
\ No newline at end of file
diff --git a/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp b/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp
index 12a4f453f..07905a763 100644
--- a/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp
+++ b/oneflow/core/comm_network/ibverbs/ibverbs_qp.cpp
@@ -15,8 +15,6 @@ limitations under the License.
 */
 #include "oneflow/core/comm_network/ibverbs/ibverbs_qp.h"
 #include <infiniband/verbs.h>
-#include <memory>
-#include <mutex>
 #include "oneflow/core/comm_network/comm_network.h"
 #include "oneflow/core/actor/actor_message_bus.h"
 #include "oneflow/core/job/resource_desc.h"
@@ -24,13 +22,16 @@ limitations under the License.
 #include "oneflow/core/platform/include/ibv.h"
 #include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
 
+#include <memory>
+#include <mutex>
+
 #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)
 
 namespace oneflow {
 
 namespace {
 
-constexpr int kMaxSendWr = 32;
+constexpr int kMaxSendWr = 4096;
 
 }
 
@@ -156,7 +157,7 @@ void IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem,
     wr.imm_data = 0;
     wr.wr.rdma.remote_addr = remote_mem.mem_ptr + i * block_size;
     wr.wr.rdma.rkey = remote_mem.mr_rkey;
-    PostSendReadInQueue(wr,  sge);
+    PostSendReadInQueue(wr, sge);
   }
 }
 
@@ -178,7 +179,7 @@ void IBVerbsQP::PostSendRequest(const ActorMsg& msg) {
   wr.send_flags = 0;
   wr.imm_data = 0;
   memset(&(wr.wr), 0, sizeof(wr.wr));
-  PostSendReadInQueue(wr,  sge);
+  PostSendReadInQueue(wr, sge);
 }
 
 void IBVerbsQP::PostSendReadInQueue(ibv_send_wr wr, ibv_sge sge) {
@@ -201,7 +202,7 @@ void IBVerbsQP::ReadDone(WorkRequestId* wr_id) {
     Global<CommNet>::Get()->ReadDone(wr_id->read_id);
     DeleteWorkRequestId(wr_id);
   }
-  ReadSendDoneSendQueueMessage();
+  EnqueuePostSend();
 }
 
 void IBVerbsQP::SendDone(WorkRequestId* wr_id) {
@@ -210,7 +211,7 @@ void IBVerbsQP::SendDone(WorkRequestId* wr_id) {
     send_msg_buf_.push(wr_id->msg_mr);
   }
   DeleteWorkRequestId(wr_id);
-  ReadSendDoneSendQueueMessage();
+  EnqueuePostSend();
 }
 
 void IBVerbsQP::RecvDone(WorkRequestId* wr_id) {
@@ -221,19 +222,19 @@ void IBVerbsQP::RecvDone(WorkRequestId* wr_id) {
   DeleteWorkRequestId(wr_id);
 }
 
-void IBVerbsQP::ReadSendDoneSendQueueMessage() {
+void IBVerbsQP::EnqueuePostSend() {
   std::unique_lock<std::mutex> num_outstanding_send_wr_lck(num_outstanding_send_wr_mutex_);
   if (num_outstanding_send_wr_ > 0) { num_outstanding_send_wr_--; }
   std::unique_lock<std::mutex> msg_pendding_list_lck(msg_pendding_list_mutex_);
   if (msg_pendding_list_.empty() == false) {
-      std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::move(msg_pendding_list_.front());
-      ibv_send_wr wr = ibv_send_wr_sge.first;
-      wr.sg_list = &ibv_send_wr_sge.second;
-      msg_pendding_list_.pop();
-      ibv_send_wr* bad_wr = nullptr;
-      num_outstanding_send_wr_++;
-      CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0);
-    }
+    std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::move(msg_pendding_list_.front());
+    ibv_send_wr wr = ibv_send_wr_sge.first;
+    wr.sg_list = &ibv_send_wr_sge.second;
+    msg_pendding_list_.pop();
+    ibv_send_wr* bad_wr = nullptr;
+    num_outstanding_send_wr_++;
+    CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0);
+  }
 }
 
 void IBVerbsQP::PostRecvRequest(ActorMsgMR* msg_mr) {
diff --git a/oneflow/core/comm_network/ibverbs/ibverbs_qp.h b/oneflow/core/comm_network/ibverbs/ibverbs_qp.h
index 8adb23d2a..fd62cd5a0 100644
--- a/oneflow/core/comm_network/ibverbs/ibverbs_qp.h
+++ b/oneflow/core/comm_network/ibverbs/ibverbs_qp.h
@@ -71,7 +71,7 @@ class IBVerbsQP final {
   void ReadDone(WorkRequestId*);
   void SendDone(WorkRequestId*);
   void RecvDone(WorkRequestId*);
-  void ReadSendDoneSendQueueMessage();
+  void EnqueuePostSend();
 
  private:
   WorkRequestId* NewWorkRequestId();
-- 
GitLab