Skip to content
Snippets Groups Projects
Unverified Commit 95c20bf0 authored by Juncheng's avatar Juncheng Committed by GitHub
Browse files

Move Global<CommNet> to env scope (#5670)


* Move Global<CommNet> to env scope

* Revert Global RuntimeCtx New

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: default avatarcheng cheng <472491134@qq.com>
parent e53e4c14
No related branches found
No related tags found
No related merge requests found
......@@ -173,17 +173,14 @@ IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT
qp_vec_.at(peer_id)->Connect(conn_info);
LOG(INFO) << "Connected to peer " << peer_id;
}
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
OF_ENV_BARRIER();
for (int64_t peer_id : peer_machine_id()) {
qp_vec_.at(peer_id)->PostAllRecvRequest();
Global<CtrlClient>::Get()->ClearKV(GenConnInfoKey(this_machine_id, peer_id));
}
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
OF_ENV_BARRIER();
poll_thread_ = std::thread(&IBVerbsCommNet::PollCQ, this);
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
OF_ENV_BARRIER();
}
void IBVerbsCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token,
......
......@@ -28,8 +28,9 @@ namespace oneflow {
namespace {
constexpr uint32_t kDefaultQueueDepth = 1024;
constexpr uint64_t kDefaultMemBlockSize = 8388608; // 8M
}
} // namespace
IBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, uint8_t port_num, ibv_cq* send_cq,
ibv_cq* recv_cq) {
......@@ -145,7 +146,8 @@ void IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem,
const IBVerbsMemDesc& local_mem, void* read_id) {
CHECK_EQ(remote_mem.mem_size, local_mem.mem_size());
WorkRequestId* wr_id = NewWorkRequestId();
const size_t block_size = Global<ResourceDesc, ForSession>::Get()->rdma_mem_block_byte();
const size_t block_size =
ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_MEM_BLOCK_SIZE", kDefaultMemBlockSize);
const size_t block_num = RoundUp(remote_mem.mem_size, block_size) / block_size;
wr_id->outstanding_sge_cnt = static_cast<int32_t>(block_num);
wr_id->read_id = read_id;
......
......@@ -39,6 +39,12 @@ limitations under the License.
#include "oneflow/core/framework/symbol_id_cache.h"
#include "oneflow/core/operator/op_node_signature.cfg.h"
#include "oneflow/core/operator/op_conf.cfg.h"
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/comm_network/epoll/epoll_comm_network.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#ifdef WITH_RDMA
#include "oneflow/core/platform/include/ibv.h"
#endif // WITH_RDMA
namespace oneflow {
......@@ -103,6 +109,19 @@ void ClearAllSymbolAndIdCache() {
Global<symbol::IdCache<cfg::OpNodeSignature>>::Get()->ClearAll();
}
#if defined(__linux__) && defined(WITH_RDMA)
bool CommNetIBEnabled() {
bool user_enabled = ParseBooleanFromEnv("ONEFLOW_COMM_NET_IB_ENABLE", false);
if (user_enabled) {
return ibv::IsAvailable();
} else {
return false;
}
}
#endif
} // namespace
Maybe<void> EnvGlobalObjectsScope::Init(const EnvProto& env_proto) {
......@@ -162,7 +181,19 @@ Maybe<void> EnvGlobalObjectsScope::Init(const EnvProto& env_proto) {
#ifdef __linux__
Global<EpollCommNet>::New();
Global<Transport>::New();
if (Global<ResourceDesc, ForSession>::Get()->process_ranks().size() > 1) {
#ifdef WITH_RDMA
if (CommNetIBEnabled()) {
Global<IBVerbsCommNet>::New();
Global<CommNet>::SetAllocated(Global<IBVerbsCommNet>::Get());
} else {
Global<CommNet>::SetAllocated(Global<EpollCommNet>::Get());
}
#else
Global<CommNet>::SetAllocated(Global<EpollCommNet>::Get());
#endif // WITH_RDMA
#endif // __linux__
}
}
return Maybe<void>::Ok();
}
......@@ -170,6 +201,11 @@ Maybe<void> EnvGlobalObjectsScope::Init(const EnvProto& env_proto) {
EnvGlobalObjectsScope::~EnvGlobalObjectsScope() {
if (!Global<ResourceDesc, ForSession>::Get()->enable_dry_run()) {
#ifdef __linux__
if (Global<ResourceDesc, ForSession>::Get()->process_ranks().size() > 1) {
if (Global<EpollCommNet>::Get() != static_cast<EpollCommNet*>(Global<CommNet>::Get())) {
Global<CommNet>::Delete();
}
}
Global<Transport>::Delete();
Global<EpollCommNet>::Delete();
#endif // __linux__
......
......@@ -56,11 +56,6 @@ Maybe<JobDesc> JobDesc::New(int64_t symbol_id, const JobConfigProto& job_conf) {
Maybe<void> JobDesc::Init() {
cfg_job_conf_.reset(new cfg::JobConfigProto(job_conf_));
#ifndef WITH_RDMA
CHECK_NOTNULL_OR_RETURN((Global<ResourceDesc, ForSession>::Get()));
CHECK_EQ_OR_RETURN((Global<ResourceDesc, ForSession>::Get()->use_rdma()), false)
<< "Please compile ONEFLOW with RDMA";
#endif
#ifndef WITH_CUDA
CHECK_EQ_OR_RETURN((Global<ResourceDesc, ForSession>::Get()->GpuDeviceNum()), 0);
......
......@@ -40,8 +40,6 @@ message Resource {
optional int32 cpu_device_num = 5 [default = 0];
optional int32 comm_net_worker_num = 6 [default = 4];
optional int32 max_mdsave_worker_num = 7 [default = 64];
optional bool use_rdma = 8 [default = false];
optional uint64 rdma_mem_block_mbyte = 9 [default = 8];
optional uint64 reserved_host_mem_mbyte = 12 [default = 500];
optional uint64 reserved_device_mem_mbyte = 13 [default = 500];
optional int32 compute_thread_pool_size = 15;
......
......@@ -37,14 +37,12 @@ class ResourceDesc final {
const std::set<int64_t>& process_ranks() const { return process_ranks_; }
__attribute__((deprecated)) Machine machine(int32_t idx) const;
size_t CommNetWorkerNum() const { return resource_.comm_net_worker_num(); }
size_t rdma_mem_block_byte() const { return resource_.rdma_mem_block_mbyte() * kMB; }
int32_t CpuDeviceNum() const { return resource_.cpu_device_num(); }
int32_t GpuDeviceNum() const { return resource_.gpu_device_num(); }
int32_t MemZoneNum() const { return GpuDeviceNum() + 1; }
int32_t MaxMdSaveWorkerNum() const { return resource_.max_mdsave_worker_num(); }
size_t reserved_host_mem_byte() const { return resource_.reserved_host_mem_mbyte() * kMB; }
size_t reserved_device_mem_byte() const { return resource_.reserved_device_mem_mbyte() * kMB; }
bool use_rdma() const { return resource_.use_rdma(); }
bool thread_enable_local_message_queue() const {
return resource_.thread_enable_local_message_queue();
}
......
......@@ -15,8 +15,6 @@ limitations under the License.
*/
#include "oneflow/core/job/runtime.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/comm_network/epoll/epoll_comm_network.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/resource_desc.h"
......@@ -32,9 +30,6 @@ limitations under the License.
#include "oneflow/user/summary/events_writer.h"
#include "oneflow/core/job/collective_boxing_executor.h"
#include "oneflow/core/job/collective_boxing_device_ctx_poller.h"
#ifdef WITH_RDMA
#include "oneflow/core/platform/include/ibv.h"
#endif // WITH_RDMA
namespace oneflow {
......@@ -110,30 +105,6 @@ Runtime::~Runtime() {
void Runtime::NewAllGlobal(const Plan& plan,
const HashMap<std::string, Blob*>& variable_op_name2eager_blob) {
Global<RuntimeCtx>::New();
if (Global<ResourceDesc, ForSession>::Get()->process_ranks().size() > 1) {
#ifdef __linux__
// NOTE(chengcheng): Global<EpollCommNet> will new in any case, and will new in env start.
// if use RDMA,
// The Global<CommNet> is set allocated by new Global<IBVerbsCommNet>
// else,
// The Global<CommNet> is set allocated by Global<EpollCommNet>
if (Global<ResourceDesc, ForSession>::Get()->use_rdma()) {
#ifdef WITH_RDMA
if (ibv::IsAvailable()) {
Global<IBVerbsCommNet>::New();
Global<CommNet>::SetAllocated(Global<IBVerbsCommNet>::Get());
} else {
LOG(ERROR) << "libibverbs not available, falling back to epoll";
Global<CommNet>::SetAllocated(Global<EpollCommNet>::Get());
}
#else
LOG(FATAL) << "RDMA components not found";
#endif
} else {
Global<CommNet>::SetAllocated(Global<EpollCommNet>::Get());
}
#endif
}
Global<boxing::collective::CollectiveBoxingExecutor>::New(plan);
Global<MemoryAllocator>::New();
Global<RegstMgr>::New();
......@@ -154,32 +125,6 @@ void Runtime::DeleteAllGlobal() {
Global<RegstMgr>::Delete();
Global<MemoryAllocator>::Delete();
Global<boxing::collective::CollectiveBoxingExecutor>::Delete();
// should be called after Global<Transport>::Delete()
if (Global<ResourceDesc, ForSession>::Get()->process_ranks().size() > 1) {
#ifdef __linux__
if (Global<ResourceDesc, ForSession>::Get()->use_rdma()) {
#ifdef WITH_RDMA
if (ibv::IsAvailable()) {
CHECK(Global<EpollCommNet>::Get() != static_cast<EpollCommNet*>(Global<CommNet>::Get()));
// NOTE(chengcheng): it means that
// Global<CommNet>::SetAllocated(Global<IBVerbsCommNet>::Get())
// so the Global<CommNet> and Global<EpollCommNet> are NOT same global object
// then need delete both.
Global<CommNet>::Delete();
}
#else
LOG(FATAL) << "RDMA components not found";
#endif
} else {
CHECK(Global<EpollCommNet>::Get() == static_cast<EpollCommNet*>(Global<CommNet>::Get()));
// NOTE(chengcheng): it means that Global<CommNet>::SetAllocated(Global<EpollCommNet>::Get())
// so the Global<CommNet> and Global<EpollCommNet> are same global object
// then only need delete once.
}
#endif
}
Global<ActEventLogger>::Delete();
Global<RuntimeCtx>::Delete();
Global<summary::EventsWriter>::Delete();
......
......@@ -174,14 +174,9 @@ def api_rdma_mem_block_mbyte(val: int) -> None:
Args:
val (int): size of block, e.g. 1024(mb)
"""
return enable_if.unique([rdma_mem_block_mbyte, do_nothing])(val)
@enable_if.condition(hob.in_normal_mode & ~hob.session_initialized)
def rdma_mem_block_mbyte(val):
sess = session_ctx.GetDefaultSession()
assert type(val) is int
sess.config_proto.resource.rdma_mem_block_mbyte = val
print(
"'rdma_mem_block_mbyte' has been deprecated, has no effect and will be removed in the future. Use environment variable 'ONEFLOW_COMM_NET_IB_MEM_BLOCK_SIZE' instead."
)
def api_rdma_recv_msg_buf_mbyte(val: int) -> None:
......@@ -190,14 +185,9 @@ def api_rdma_recv_msg_buf_mbyte(val: int) -> None:
Args:
val (int): buffer size, e.g. 1024(mb)
"""
return enable_if.unique([rdma_recv_msg_buf_mbyte, do_nothing])(val)
@enable_if.condition(hob.in_normal_mode & ~hob.session_initialized)
def rdma_recv_msg_buf_mbyte(val):
sess = session_ctx.GetDefaultSession()
assert type(val) is int
sess.config_proto.resource.rdma_recv_msg_buf_mbyte = val
print(
"'rdma_recv_msg_buf_mbyte' has been deprecated, has no effect and will be removed in the future."
)
def api_reserved_host_mem_mbyte(val: int) -> None:
......@@ -239,14 +229,9 @@ def api_use_rdma(val: bool = True) -> None:
Args:
val (bool, optional): Defaults to True.
"""
return enable_if.unique([use_rdma, do_nothing])(val=val)
@enable_if.condition(hob.in_normal_mode & ~hob.session_initialized)
def use_rdma(val=True):
sess = session_ctx.GetDefaultSession()
assert type(val) is bool
sess.config_proto.resource.use_rdma = val
print(
"'use_rdma' has been deprecated, has no effect and will be removed in the future. Use environment variable 'ONEFLOW_COMM_NET_IB_ENABLE' instead."
)
def api_thread_enable_local_message_queue(val: bool) -> None:
......
......@@ -168,22 +168,6 @@ def compute_thread_pool_size(val):
sess.config_proto.resource.compute_thread_pool_size = val
def api_rdma_mem_block_mbyte(val: int) -> None:
"""Set up the memory block size in rdma mode.
Args:
val (int): size of block, e.g. 1024(mb)
"""
return enable_if.unique([rdma_mem_block_mbyte, do_nothing])(val)
@enable_if.condition(hob.in_normal_mode & ~hob.session_initialized)
def rdma_mem_block_mbyte(val):
sess = session_ctx.GetDefaultSession()
assert type(val) is int
sess.config_proto.resource.rdma_mem_block_mbyte = val
def api_reserved_host_mem_mbyte(val: int) -> None:
"""Set up the memory size of reserved host
......@@ -216,23 +200,6 @@ def reserved_device_mem_mbyte(val):
sess.config_proto.resource.reserved_device_mem_mbyte = val
def api_use_rdma(val: bool = True) -> None:
"""Whether use RDMA to speed up data transmission in cluster nodes or not.
if not, then use normal epoll mode.
Args:
val (bool, optional): Defaults to True.
"""
return enable_if.unique([use_rdma, do_nothing])(val=val)
@enable_if.condition(hob.in_normal_mode & ~hob.session_initialized)
def use_rdma(val=True):
sess = session_ctx.GetDefaultSession()
assert type(val) is bool
sess.config_proto.resource.use_rdma = val
def api_thread_enable_local_message_queue(val: bool) -> None:
"""Whether or not enable thread using local message queue.
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment