Skip to content
Snippets Groups Projects
Commit e7222a84 authored by lambda7xx's avatar lambda7xx
Browse files

optimize the IB

when send too many message whose number is bigger then qp_init_attr.cap.max_send_wr,
I use the msg_pendding_list to store the extra message and wait for some time to send these message of the msg_pendding_list
parent 93c4c439
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,6 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_ #ifndef ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
#define ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_ #define ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
...@@ -72,15 +84,7 @@ class ActorMsg final { ...@@ -72,15 +84,7 @@ class ActorMsg final {
void Deserialize(StreamT& in_stream) { void Deserialize(StreamT& in_stream) {
in_stream.Read(this, sizeof(ActorMsg)); in_stream.Read(this, sizeof(ActorMsg));
} }
//operate flag
void setFlag(bool flag) {
flag_ = flag;
}
bool getFlag() const {
return flag_;
}
private: private:
struct RegstWrapper { struct RegstWrapper {
Regst* regst; Regst* regst;
...@@ -100,7 +104,6 @@ class ActorMsg final { ...@@ -100,7 +104,6 @@ class ActorMsg final {
}; };
uint8_t user_data_size_; uint8_t user_data_size_;
unsigned char user_data_[kActorMsgUserDataMaxSize]; unsigned char user_data_[kActorMsgUserDataMaxSize];
bool flag_;
}; };
template<typename StreamT> template<typename StreamT>
......
...@@ -13,10 +13,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,10 +13,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h" #include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/control/global_process_ctx.h" #include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/graph/node.h"
#include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h" #include "oneflow/core/job/global_for.h"
#include "oneflow/core/platform/include/ibv.h" #include "oneflow/core/platform/include/ibv.h"
...@@ -70,18 +81,11 @@ void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) { ...@@ -70,18 +81,11 @@ void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
static_assert(sizeof(IBVerbsCommNetRMADesc) <= kActorMsgUserDataMaxSize, ""); static_assert(sizeof(IBVerbsCommNetRMADesc) <= kActorMsgUserDataMaxSize, "");
new_msg.AddUserData(sizeof(IBVerbsCommNetRMADesc), &rma_desc); new_msg.AddUserData(sizeof(IBVerbsCommNetRMADesc), &rma_desc);
} }
new_msg.setFlag(true);
qp_vec_.at(dst_machine_id)->PostSendRequest(new_msg); qp_vec_.at(dst_machine_id)->PostSendRequest(new_msg);
ActorMsg unuseful_msg = msg;
unuseful_msg.setFlag(false);
for(int i =0 ; i < 64; i++) {
qp_vec_.at(dst_machine_id)->PostSendRequest(unuseful_msg);
}
} }
void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) { void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
ActorMsg new_msg = msg; ActorMsg new_msg = msg;
if(new_msg.getFlag() == true) {
if (msg.IsDataRegstMsgToConsumer()) { if (msg.IsDataRegstMsgToConsumer()) {
std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_); std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_);
auto& desc = remote_regst2rma_desc_[std::make_pair(msg.src_actor_id(), auto& desc = remote_regst2rma_desc_[std::make_pair(msg.src_actor_id(),
...@@ -92,7 +96,6 @@ void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) { ...@@ -92,7 +96,6 @@ void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
new_msg.set_comm_net_token(desc.get()); new_msg.set_comm_net_token(desc.get());
} }
Global<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg); Global<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg);
}
} }
IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT) { IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT) {
...@@ -185,4 +188,4 @@ COMMAND(IBVForkInit()); ...@@ -185,4 +188,4 @@ COMMAND(IBVForkInit());
} // namespace oneflow } // namespace oneflow
#endif // WITH_RDMA && OF_PLATFORM_POSIX #endif // WITH_RDMA && OF_PLATFORM_POSIX
\ No newline at end of file
...@@ -15,8 +15,6 @@ limitations under the License. ...@@ -15,8 +15,6 @@ limitations under the License.
*/ */
#include "oneflow/core/comm_network/ibverbs/ibverbs_qp.h" #include "oneflow/core/comm_network/ibverbs/ibverbs_qp.h"
#include <infiniband/verbs.h> #include <infiniband/verbs.h>
#include <memory>
#include <mutex>
#include "oneflow/core/comm_network/comm_network.h" #include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/actor/actor_message_bus.h" #include "oneflow/core/actor/actor_message_bus.h"
#include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/resource_desc.h"
...@@ -24,13 +22,16 @@ limitations under the License. ...@@ -24,13 +22,16 @@ limitations under the License.
#include "oneflow/core/platform/include/ibv.h" #include "oneflow/core/platform/include/ibv.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h" #include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#include <memory>
#include <mutex>
#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX) #if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)
namespace oneflow { namespace oneflow {
namespace { namespace {
constexpr int kMaxSendWr = 32; constexpr int kMaxSendWr = 4096;
} }
...@@ -156,7 +157,7 @@ void IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem, ...@@ -156,7 +157,7 @@ void IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem,
wr.imm_data = 0; wr.imm_data = 0;
wr.wr.rdma.remote_addr = remote_mem.mem_ptr + i * block_size; wr.wr.rdma.remote_addr = remote_mem.mem_ptr + i * block_size;
wr.wr.rdma.rkey = remote_mem.mr_rkey; wr.wr.rdma.rkey = remote_mem.mr_rkey;
PostSendReadInQueue(wr, sge); PostSendReadInQueue(wr, sge);
} }
} }
...@@ -178,7 +179,7 @@ void IBVerbsQP::PostSendRequest(const ActorMsg& msg) { ...@@ -178,7 +179,7 @@ void IBVerbsQP::PostSendRequest(const ActorMsg& msg) {
wr.send_flags = 0; wr.send_flags = 0;
wr.imm_data = 0; wr.imm_data = 0;
memset(&(wr.wr), 0, sizeof(wr.wr)); memset(&(wr.wr), 0, sizeof(wr.wr));
PostSendReadInQueue(wr, sge); PostSendReadInQueue(wr, sge);
} }
void IBVerbsQP::PostSendReadInQueue(ibv_send_wr wr, ibv_sge sge) { void IBVerbsQP::PostSendReadInQueue(ibv_send_wr wr, ibv_sge sge) {
...@@ -201,7 +202,7 @@ void IBVerbsQP::ReadDone(WorkRequestId* wr_id) { ...@@ -201,7 +202,7 @@ void IBVerbsQP::ReadDone(WorkRequestId* wr_id) {
Global<CommNet>::Get()->ReadDone(wr_id->read_id); Global<CommNet>::Get()->ReadDone(wr_id->read_id);
DeleteWorkRequestId(wr_id); DeleteWorkRequestId(wr_id);
} }
ReadSendDoneSendQueueMessage(); EnqueuePostSend();
} }
void IBVerbsQP::SendDone(WorkRequestId* wr_id) { void IBVerbsQP::SendDone(WorkRequestId* wr_id) {
...@@ -210,7 +211,7 @@ void IBVerbsQP::SendDone(WorkRequestId* wr_id) { ...@@ -210,7 +211,7 @@ void IBVerbsQP::SendDone(WorkRequestId* wr_id) {
send_msg_buf_.push(wr_id->msg_mr); send_msg_buf_.push(wr_id->msg_mr);
} }
DeleteWorkRequestId(wr_id); DeleteWorkRequestId(wr_id);
ReadSendDoneSendQueueMessage(); EnqueuePostSend();
} }
void IBVerbsQP::RecvDone(WorkRequestId* wr_id) { void IBVerbsQP::RecvDone(WorkRequestId* wr_id) {
...@@ -221,19 +222,19 @@ void IBVerbsQP::RecvDone(WorkRequestId* wr_id) { ...@@ -221,19 +222,19 @@ void IBVerbsQP::RecvDone(WorkRequestId* wr_id) {
DeleteWorkRequestId(wr_id); DeleteWorkRequestId(wr_id);
} }
void IBVerbsQP::ReadSendDoneSendQueueMessage() { void IBVerbsQP::EnqueuePostSend() {
std::unique_lock<std::mutex> num_outstanding_send_wr_lck(num_outstanding_send_wr_mutex_); std::unique_lock<std::mutex> num_outstanding_send_wr_lck(num_outstanding_send_wr_mutex_);
if (num_outstanding_send_wr_ > 0) { num_outstanding_send_wr_--; } if (num_outstanding_send_wr_ > 0) { num_outstanding_send_wr_--; }
std::unique_lock<std::mutex> msg_pendding_list_lck(msg_pendding_list_mutex_); std::unique_lock<std::mutex> msg_pendding_list_lck(msg_pendding_list_mutex_);
if (msg_pendding_list_.empty() == false) { if (msg_pendding_list_.empty() == false) {
std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::move(msg_pendding_list_.front()); std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::move(msg_pendding_list_.front());
ibv_send_wr wr = ibv_send_wr_sge.first; ibv_send_wr wr = ibv_send_wr_sge.first;
wr.sg_list = &ibv_send_wr_sge.second; wr.sg_list = &ibv_send_wr_sge.second;
msg_pendding_list_.pop(); msg_pendding_list_.pop();
ibv_send_wr* bad_wr = nullptr; ibv_send_wr* bad_wr = nullptr;
num_outstanding_send_wr_++; num_outstanding_send_wr_++;
CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0); CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0);
} }
} }
void IBVerbsQP::PostRecvRequest(ActorMsgMR* msg_mr) { void IBVerbsQP::PostRecvRequest(ActorMsgMR* msg_mr) {
......
...@@ -71,7 +71,7 @@ class IBVerbsQP final { ...@@ -71,7 +71,7 @@ class IBVerbsQP final {
void ReadDone(WorkRequestId*); void ReadDone(WorkRequestId*);
void SendDone(WorkRequestId*); void SendDone(WorkRequestId*);
void RecvDone(WorkRequestId*); void RecvDone(WorkRequestId*);
void ReadSendDoneSendQueueMessage(); void EnqueuePostSend();
private: private:
WorkRequestId* NewWorkRequestId(); WorkRequestId* NewWorkRequestId();
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment