Skip to content
Snippets Groups Projects
Commit e7222a84 authored by lambda7xx's avatar lambda7xx
Browse files

optimize the IB

when send too many message whose number is bigger then qp_init_attr.cap.max_send_wr,
I use the msg_pendding_list to store the extra message and wait for some time to send these message of the msg_pendding_list
parent 93c4c439
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
#define ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
......@@ -72,15 +84,7 @@ class ActorMsg final {
void Deserialize(StreamT& in_stream) {
in_stream.Read(this, sizeof(ActorMsg));
}
//operate flag
void setFlag(bool flag) {
flag_ = flag;
}
bool getFlag() const {
return flag_;
}
private:
struct RegstWrapper {
Regst* regst;
......@@ -100,7 +104,6 @@ class ActorMsg final {
};
uint8_t user_data_size_;
unsigned char user_data_[kActorMsgUserDataMaxSize];
bool flag_;
};
template<typename StreamT>
......
......@@ -13,10 +13,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/graph/node.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/platform/include/ibv.h"
......@@ -70,18 +81,11 @@ void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
static_assert(sizeof(IBVerbsCommNetRMADesc) <= kActorMsgUserDataMaxSize, "");
new_msg.AddUserData(sizeof(IBVerbsCommNetRMADesc), &rma_desc);
}
new_msg.setFlag(true);
qp_vec_.at(dst_machine_id)->PostSendRequest(new_msg);
ActorMsg unuseful_msg = msg;
unuseful_msg.setFlag(false);
for(int i =0 ; i < 64; i++) {
qp_vec_.at(dst_machine_id)->PostSendRequest(unuseful_msg);
}
}
void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
ActorMsg new_msg = msg;
if(new_msg.getFlag() == true) {
if (msg.IsDataRegstMsgToConsumer()) {
std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_);
auto& desc = remote_regst2rma_desc_[std::make_pair(msg.src_actor_id(),
......@@ -92,7 +96,6 @@ void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
new_msg.set_comm_net_token(desc.get());
}
Global<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg);
}
}
IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT) {
......@@ -185,4 +188,4 @@ COMMAND(IBVForkInit());
} // namespace oneflow
#endif // WITH_RDMA && OF_PLATFORM_POSIX
#endif // WITH_RDMA && OF_PLATFORM_POSIX
\ No newline at end of file
......@@ -15,8 +15,6 @@ limitations under the License.
*/
#include "oneflow/core/comm_network/ibverbs/ibverbs_qp.h"
#include <infiniband/verbs.h>
#include <memory>
#include <mutex>
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/actor/actor_message_bus.h"
#include "oneflow/core/job/resource_desc.h"
......@@ -24,13 +22,16 @@ limitations under the License.
#include "oneflow/core/platform/include/ibv.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#include <memory>
#include <mutex>
#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)
namespace oneflow {
namespace {
constexpr int kMaxSendWr = 32;
constexpr int kMaxSendWr = 4096;
}
......@@ -156,7 +157,7 @@ void IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem,
wr.imm_data = 0;
wr.wr.rdma.remote_addr = remote_mem.mem_ptr + i * block_size;
wr.wr.rdma.rkey = remote_mem.mr_rkey;
PostSendReadInQueue(wr, sge);
PostSendReadInQueue(wr, sge);
}
}
......@@ -178,7 +179,7 @@ void IBVerbsQP::PostSendRequest(const ActorMsg& msg) {
wr.send_flags = 0;
wr.imm_data = 0;
memset(&(wr.wr), 0, sizeof(wr.wr));
PostSendReadInQueue(wr, sge);
PostSendReadInQueue(wr, sge);
}
void IBVerbsQP::PostSendReadInQueue(ibv_send_wr wr, ibv_sge sge) {
......@@ -201,7 +202,7 @@ void IBVerbsQP::ReadDone(WorkRequestId* wr_id) {
Global<CommNet>::Get()->ReadDone(wr_id->read_id);
DeleteWorkRequestId(wr_id);
}
ReadSendDoneSendQueueMessage();
EnqueuePostSend();
}
void IBVerbsQP::SendDone(WorkRequestId* wr_id) {
......@@ -210,7 +211,7 @@ void IBVerbsQP::SendDone(WorkRequestId* wr_id) {
send_msg_buf_.push(wr_id->msg_mr);
}
DeleteWorkRequestId(wr_id);
ReadSendDoneSendQueueMessage();
EnqueuePostSend();
}
void IBVerbsQP::RecvDone(WorkRequestId* wr_id) {
......@@ -221,19 +222,19 @@ void IBVerbsQP::RecvDone(WorkRequestId* wr_id) {
DeleteWorkRequestId(wr_id);
}
void IBVerbsQP::ReadSendDoneSendQueueMessage() {
void IBVerbsQP::EnqueuePostSend() {
std::unique_lock<std::mutex> num_outstanding_send_wr_lck(num_outstanding_send_wr_mutex_);
if (num_outstanding_send_wr_ > 0) { num_outstanding_send_wr_--; }
std::unique_lock<std::mutex> msg_pendding_list_lck(msg_pendding_list_mutex_);
if (msg_pendding_list_.empty() == false) {
std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::move(msg_pendding_list_.front());
ibv_send_wr wr = ibv_send_wr_sge.first;
wr.sg_list = &ibv_send_wr_sge.second;
msg_pendding_list_.pop();
ibv_send_wr* bad_wr = nullptr;
num_outstanding_send_wr_++;
CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0);
}
std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::move(msg_pendding_list_.front());
ibv_send_wr wr = ibv_send_wr_sge.first;
wr.sg_list = &ibv_send_wr_sge.second;
msg_pendding_list_.pop();
ibv_send_wr* bad_wr = nullptr;
num_outstanding_send_wr_++;
CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0);
}
}
void IBVerbsQP::PostRecvRequest(ActorMsgMR* msg_mr) {
......
......@@ -71,7 +71,7 @@ class IBVerbsQP final {
void ReadDone(WorkRequestId*);
void SendDone(WorkRequestId*);
void RecvDone(WorkRequestId*);
void ReadSendDoneSendQueueMessage();
void EnqueuePostSend();
private:
WorkRequestId* NewWorkRequestId();
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment