diff --git a/CMakeLists.txt b/CMakeLists.txt
index f9883c38cf37f188f5ec9a018ef9414385b5562d..48557fc2b4a56438ded4944434197a661bc7cf77 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -40,7 +40,7 @@ if (WIN32)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0")
else()
list(APPEND CUDA_NVCC_FLAGS -std=c++11 -w -Wno-deprecated-gpu-targets)
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare -Wno-unused-function")
if (RELEASE_VERSION)
list(APPEND CUDA_NVCC_FLAGS -O3)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
diff --git a/oneflow/core/common/cplusplus_14.h b/oneflow/core/common/cplusplus_14.h
index 041a334effe1b698a11b20d4783e75cee39e8ebe..b6219376d7cccc5a765bf57dbd7d07f3db1e5ded 100644
--- a/oneflow/core/common/cplusplus_14.h
+++ b/oneflow/core/common/cplusplus_14.h
@@ -4,12 +4,35 @@
#if __cplusplus < 201402L
namespace std {
+template<class T>
+struct unique_if {
+ typedef unique_ptr<T> single_object;
+};
+
+template<class T>
+struct unique_if<T[]> {
+ typedef unique_ptr<T[]> unknown_bound;
+};
+
+template<class T, size_t N>
+struct unique_if<T[N]> {
+ typedef void known_bound;
+};
-template<typename T, typename... Args>
-unique_ptr<T> make_unique(Args&&... args) {
+template<class T, class... Args>
+typename unique_if<T>::single_object make_unique(Args&&... args) {
return unique_ptr<T>(new T(forward<Args>(args)...));
}
+template<class T>
+typename unique_if<T>::unknown_bound make_unique(size_t n) {
+ typedef typename remove_extent<T>::type U;
+ return unique_ptr<T>(new U[n]());
+}
+
+template<class T, class... Args>
+typename unique_if<T>::known_bound make_unique(Args&&...) = delete;
+
template<typename T, T... Ints>
struct integer_sequence {
static_assert(is_integral<T>::value, "");
diff --git a/oneflow/core/common/cplusplus_17.h b/oneflow/core/common/cplusplus_17.h
deleted file mode 100644
index a59d891492cc2e7876ae3b01ff47fb035797a5a8..0000000000000000000000000000000000000000
--- a/oneflow/core/common/cplusplus_17.h
+++ /dev/null
@@ -1,10 +0,0 @@
-#ifndef ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
-#define ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
-
-#if __cplusplus < 201703L
-
-namespace std {}
-
-#endif
-
-#endif // ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
diff --git a/oneflow/core/common/meta_util.hpp b/oneflow/core/common/meta_util.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..49210b830c8bc20968aa7b99a577df837bcf2a58
--- /dev/null
+++ b/oneflow/core/common/meta_util.hpp
@@ -0,0 +1,50 @@
+#ifndef ONEFLOW_META_UTIL_HPP
+#define ONEFLOW_META_UTIL_HPP
+
+#include "oneflow/core/common/cplusplus_14.h"
+
+namespace oneflow{
+
+ template <typename... Args, typename Func, std::size_t... Idx>
+ void for_each(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) {
+ (void)std::initializer_list<int> { (f(std::get<Idx>(t)), void(), 0)...};
+ }
+
+ template <typename... Args, typename Func, std::size_t... Idx>
+ void for_each_i(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) {
+ (void)std::initializer_list<int> { (f(std::get<Idx>(t), std::integral_constant<size_t, Idx>{}), void(), 0)...};
+ }
+
+ template<typename T>
+ struct function_traits;
+
+ template<typename Ret, typename... Args>
+ struct function_traits<Ret(Args...)>
+ {
+ public:
+ enum { arity = sizeof...(Args) };
+ typedef Ret function_type(Args...);
+ typedef Ret return_type;
+ using stl_function_type = std::function<function_type>;
+ typedef Ret(*pointer)(Args...);
+
+ typedef std::tuple<Args...> tuple_type;
+ };
+
+ template<typename Ret, typename... Args>
+ struct function_traits<Ret(*)(Args...)> : function_traits<Ret(Args...)>{};
+
+ template <typename Ret, typename... Args>
+ struct function_traits<std::function<Ret(Args...)>> : function_traits<Ret(Args...)>{};
+
+ template <typename ReturnType, typename ClassType, typename... Args>
+ struct function_traits<ReturnType(ClassType::*)(Args...)> : function_traits<ReturnType(Args...)>{};
+
+ template <typename ReturnType, typename ClassType, typename... Args>
+ struct function_traits<ReturnType(ClassType::*)(Args...) const> : function_traits<ReturnType(Args...)>{};
+
+ template<typename Callable>
+ struct function_traits : function_traits<decltype(&Callable::operator())>{};
+}
+
+#endif //ONEFLOW_META_UTIL_HPP
diff --git a/oneflow/core/common/util.h b/oneflow/core/common/util.h
index 956e927467cb3d4224b738366bf64f128d897b13..9924909aa525aa9d3f0ea8b28071ea032ee3e0d1 100644
--- a/oneflow/core/common/util.h
+++ b/oneflow/core/common/util.h
@@ -24,8 +24,7 @@
#include <utility>
#include "oneflow/core/operator/op_conf.pb.h"
-#include "oneflow/core/common/cplusplus_14.h"
-#include "oneflow/core/common/cplusplus_17.h"
+#include "oneflow/core/common/meta_util.hpp"
DECLARE_string(log_dir);
diff --git a/oneflow/core/control/ctrl_call.h b/oneflow/core/control/ctrl_call.h
index 0404a79756606f5c2e7fcdef4271f7a764563780..2f1524bfa14146b14fcb96aad305680070de1876 100644
--- a/oneflow/core/control/ctrl_call.h
+++ b/oneflow/core/control/ctrl_call.h
@@ -26,6 +26,8 @@ class CtrlCall final : public CtrlCallIf {
CtrlCall() : status_(Status::kBeforeHandleRequest), responder_(&server_ctx_) {}
~CtrlCall() = default;
+ static constexpr const size_t value = (size_t)ctrl_method;
+
const CtrlRequest<ctrl_method>& request() const { return request_; }
CtrlRequest<ctrl_method>* mut_request() { return &request_; }
CtrlResponse<ctrl_method>* mut_response() { return &response_; }
diff --git a/oneflow/core/control/ctrl_client.h b/oneflow/core/control/ctrl_client.h
index 6173ae4b5a0c8f8654799dd87f7eeca2d5b68216..37c7a6bef9a06d6396663a97e3be574064c3546a 100644
--- a/oneflow/core/control/ctrl_client.h
+++ b/oneflow/core/control/ctrl_client.h
@@ -68,19 +68,37 @@ class CtrlClient final {
#define OF_BARRIER() Global<CtrlClient>::Get()->Barrier(FILE_LINE_STR)
-#define OF_CALL_ONCE(name, ...) \
- do { \
- TryLockResult lock_ret = Global<CtrlClient>::Get()->TryLock(name); \
- if (lock_ret == TryLockResult::kLocked) { \
- __VA_ARGS__; \
- Global<CtrlClient>::Get()->NotifyDone(name); \
- } else if (lock_ret == TryLockResult::kDone) { \
- } else if (lock_ret == TryLockResult::kDoing) { \
- Global<CtrlClient>::Get()->WaitUntilDone(name); \
- } else { \
- UNIMPLEMENTED(); \
- } \
- } while (0)
+static void OfCallOnce(const std::string& name, std::function<void()> f) {
+ TryLockResult lock_ret = Global<CtrlClient>::Get()->TryLock(name);
+ if (lock_ret == TryLockResult::kLocked) {
+ f();
+ Global<CtrlClient>::Get()->NotifyDone(name);
+ } else if (lock_ret == TryLockResult::kDone) {
+ } else if (lock_ret == TryLockResult::kDoing) {
+ Global<CtrlClient>::Get()->WaitUntilDone(name);
+ } else {
+ UNIMPLEMENTED();
+ }
+}
+
+template<typename Self, typename F, typename Arg, typename... Args>
+static void OfCallOnce(const std::string& name, Self self, F f, Arg&& arg, Args&&... args) {
+ std::function<void()> fn =
+ std::bind(f, self, std::forward<Arg>(arg), std::forward<Args>(args)...);
+ OfCallOnce(name, std::move(fn));
+}
+
+template<typename Self, typename F>
+static void OfCallOnce(const std::string& name, Self self, F f) {
+ std::function<void()> fn = std::bind(f, self, name);
+ OfCallOnce(name, std::move(fn));
+}
+
+template<typename F, typename Arg, typename... Args>
+static void OfCallOnce(const std::string& name, F f, Arg&& arg, Args&&... args) {
+ std::function<void()> fn = std::bind(f, std::forward<Arg>(arg), std::forward<Args>(args)...);
+ OfCallOnce(name, std::move(fn));
+}
} // namespace oneflow
diff --git a/oneflow/core/control/ctrl_server.cpp b/oneflow/core/control/ctrl_server.cpp
index 43f8239ad91fe25f8fbb1478809fdebb1ecd50b9..0949eadc6232ca9bb728258d7006cf290c1dcda6 100644
--- a/oneflow/core/control/ctrl_server.cpp
+++ b/oneflow/core/control/ctrl_server.cpp
@@ -24,6 +24,8 @@ CtrlServer::~CtrlServer() {
}
CtrlServer::CtrlServer(const std::string& server_addr) {
+ Init();
+
if (FLAGS_grpc_use_no_signal) { grpc_use_signal(-1); }
int port = ExtractPortFromAddr(server_addr);
grpc::ServerBuilder server_builder;
@@ -38,17 +40,8 @@ CtrlServer::CtrlServer(const std::string& server_addr) {
loop_thread_ = std::thread(&CtrlServer::HandleRpcs, this);
}
-#define ENQUEUE_REQUEST(method) \
- do { \
- auto call = new CtrlCall<CtrlMethod::k##method>(); \
- call->set_request_handler(std::bind(&CtrlServer::method##Handler, this, call)); \
- grpc_service_->RequestAsyncUnary(static_cast<int32_t>(CtrlMethod::k##method), \
- call->mut_server_ctx(), call->mut_request(), \
- call->mut_responder(), cq_.get(), cq_.get(), call); \
- } while (0);
-
void CtrlServer::HandleRpcs() {
- OF_PP_FOR_EACH_TUPLE(ENQUEUE_REQUEST, CTRL_METHOD_SEQ);
+ EnqueueRequests();
void* tag = nullptr;
bool ok = false;
@@ -64,141 +57,147 @@ void CtrlServer::HandleRpcs() {
}
}
-void CtrlServer::LoadServerHandler(CtrlCall<CtrlMethod::kLoadServer>* call) {
- call->SendResponse();
- ENQUEUE_REQUEST(LoadServer);
-}
+void CtrlServer::Init() {
+ Add([this](CtrlCall<CtrlMethod::kLoadServer>* call) {
+ call->SendResponse();
+ EnqueueRequest<CtrlMethod::kLoadServer>();
+ });
+
+ Add([this](CtrlCall<CtrlMethod::kBarrier>* call) {
+ const std::string& barrier_name = call->request().name();
+ int32_t barrier_num = call->request().num();
+ auto barrier_call_it = barrier_calls_.find(barrier_name);
+ if (barrier_call_it == barrier_calls_.end()) {
+ barrier_call_it =
+ barrier_calls_
+ .emplace(barrier_name, std::make_pair(std::list<CtrlCallIf*>{}, barrier_num))
+ .first;
+ }
+ CHECK_EQ(barrier_num, barrier_call_it->second.second);
+ barrier_call_it->second.first.push_back(call);
+ if (barrier_call_it->second.first.size() == barrier_call_it->second.second) {
+ for (CtrlCallIf* pending_call : barrier_call_it->second.first) {
+ pending_call->SendResponse();
+ }
+ barrier_calls_.erase(barrier_call_it);
+ }
-void CtrlServer::BarrierHandler(CtrlCall<CtrlMethod::kBarrier>* call) {
- const std::string& barrier_name = call->request().name();
- int32_t barrier_num = call->request().num();
- auto barrier_call_it = barrier_calls_.find(barrier_name);
- if (barrier_call_it == barrier_calls_.end()) {
- barrier_call_it =
- barrier_calls_.emplace(barrier_name, std::make_pair(std::list<CtrlCallIf*>{}, barrier_num))
- .first;
- }
- CHECK_EQ(barrier_num, barrier_call_it->second.second);
- barrier_call_it->second.first.push_back(call);
- if (barrier_call_it->second.first.size() == barrier_call_it->second.second) {
- for (CtrlCallIf* pending_call : barrier_call_it->second.first) { pending_call->SendResponse(); }
- barrier_calls_.erase(barrier_call_it);
- }
- ENQUEUE_REQUEST(Barrier);
-}
+ EnqueueRequest<CtrlMethod::kBarrier>();
+ });
-void CtrlServer::TryLockHandler(CtrlCall<CtrlMethod::kTryLock>* call) {
- const std::string& lock_name = call->request().name();
- auto name2lock_status_it = name2lock_status_.find(lock_name);
- if (name2lock_status_it == name2lock_status_.end()) {
- call->mut_response()->set_result(TryLockResult::kLocked);
- auto waiting_until_done_calls = new std::list<CtrlCallIf*>;
- CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second);
- } else {
- if (name2lock_status_it->second) {
- call->mut_response()->set_result(TryLockResult::kDoing);
+ Add([this](CtrlCall<CtrlMethod::kTryLock>* call) {
+ const std::string& lock_name = call->request().name();
+ auto name2lock_status_it = name2lock_status_.find(lock_name);
+ if (name2lock_status_it == name2lock_status_.end()) {
+ call->mut_response()->set_result(TryLockResult::kLocked);
+ auto waiting_until_done_calls = new std::list<CtrlCallIf*>;
+ CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second);
} else {
- call->mut_response()->set_result(TryLockResult::kDone);
+ if (name2lock_status_it->second) {
+ call->mut_response()->set_result(TryLockResult::kDoing);
+ } else {
+ call->mut_response()->set_result(TryLockResult::kDone);
+ }
}
- }
- call->SendResponse();
- ENQUEUE_REQUEST(TryLock);
-}
-
-void CtrlServer::NotifyDoneHandler(CtrlCall<CtrlMethod::kNotifyDone>* call) {
- const std::string& lock_name = call->request().name();
- auto name2lock_status_it = name2lock_status_.find(lock_name);
- auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(name2lock_status_it->second);
- for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); }
- delete waiting_calls;
- name2lock_status_it->second = nullptr;
- call->SendResponse();
- ENQUEUE_REQUEST(NotifyDone);
-}
-
-void CtrlServer::WaitUntilDoneHandler(CtrlCall<CtrlMethod::kWaitUntilDone>* call) {
- const std::string& lock_name = call->request().name();
- void* lock_status = name2lock_status_.at(lock_name);
- if (lock_status) {
- auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(lock_status);
- waiting_calls->push_back(call);
- } else {
call->SendResponse();
- }
- ENQUEUE_REQUEST(WaitUntilDone);
-}
-
-void CtrlServer::PushKVHandler(CtrlCall<CtrlMethod::kPushKV>* call) {
- const std::string& k = call->request().key();
- const std::string& v = call->request().val();
- CHECK(kv_.emplace(k, v).second);
-
- auto pending_kv_calls_it = pending_kv_calls_.find(k);
- if (pending_kv_calls_it != pending_kv_calls_.end()) {
- for (auto pending_call : pending_kv_calls_it->second) {
- pending_call->mut_response()->set_val(v);
- pending_call->SendResponse();
+ EnqueueRequest<CtrlMethod::kTryLock>();
+ });
+
+ Add([this](CtrlCall<CtrlMethod::kNotifyDone>* call) {
+ const std::string& lock_name = call->request().name();
+ auto name2lock_status_it = name2lock_status_.find(lock_name);
+ auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(name2lock_status_it->second);
+ for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); }
+ delete waiting_calls;
+ name2lock_status_it->second = nullptr;
+ call->SendResponse();
+ EnqueueRequest<CtrlMethod::kNotifyDone>();
+ });
+
+ Add([this](CtrlCall<CtrlMethod::kWaitUntilDone>* call) {
+ const std::string& lock_name = call->request().name();
+ void* lock_status = name2lock_status_.at(lock_name);
+ if (lock_status) {
+ auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(lock_status);
+ waiting_calls->push_back(call);
+ } else {
+ call->SendResponse();
+ }
+ EnqueueRequest<CtrlMethod::kWaitUntilDone>();
+ });
+
+ Add([this](CtrlCall<CtrlMethod::kPushKV>* call) {
+ const std::string& k = call->request().key();
+ const std::string& v = call->request().val();
+ CHECK(kv_.emplace(k, v).second);
+
+ auto pending_kv_calls_it = pending_kv_calls_.find(k);
+ if (pending_kv_calls_it != pending_kv_calls_.end()) {
+ for (auto pending_call : pending_kv_calls_it->second) {
+ pending_call->mut_response()->set_val(v);
+ pending_call->SendResponse();
+ }
+ pending_kv_calls_.erase(pending_kv_calls_it);
}
- pending_kv_calls_.erase(pending_kv_calls_it);
- }
- call->SendResponse();
- ENQUEUE_REQUEST(PushKV);
-}
-
-void CtrlServer::ClearKVHandler(CtrlCall<CtrlMethod::kClearKV>* call) {
- const std::string& k = call->request().key();
- CHECK_EQ(kv_.erase(k), 1);
- CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end());
- call->SendResponse();
- ENQUEUE_REQUEST(ClearKV);
-}
-
-void CtrlServer::PullKVHandler(CtrlCall<CtrlMethod::kPullKV>* call) {
- const std::string& k = call->request().key();
- auto kv_it = kv_.find(k);
- if (kv_it != kv_.end()) {
- call->mut_response()->set_val(kv_it->second);
call->SendResponse();
- } else {
- pending_kv_calls_[k].push_back(call);
- }
- ENQUEUE_REQUEST(PullKV);
-}
+ EnqueueRequest<CtrlMethod::kPushKV>();
+ });
-void CtrlServer::PushActEventHandler(CtrlCall<CtrlMethod::kPushActEvent>* call) {
- ActEvent act_event = call->request().act_event();
- call->SendResponse();
- Global<ActEventLogger>::Get()->PrintActEventToLogDir(act_event);
- ENQUEUE_REQUEST(PushActEvent);
-}
+ Add([this](CtrlCall<CtrlMethod::kClearKV>* call) {
+ const std::string& k = call->request().key();
+ CHECK_EQ(kv_.erase(k), 1);
+ CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end());
+ call->SendResponse();
+ EnqueueRequest<CtrlMethod::kClearKV>();
+ });
+
+ Add([this](CtrlCall<CtrlMethod::kPullKV>* call) {
+ const std::string& k = call->request().key();
+ auto kv_it = kv_.find(k);
+ if (kv_it != kv_.end()) {
+ call->mut_response()->set_val(kv_it->second);
+ call->SendResponse();
+ } else {
+ pending_kv_calls_[k].push_back(call);
+ }
+ EnqueueRequest<CtrlMethod::kPullKV>();
+ });
-void CtrlServer::ClearHandler(CtrlCall<CtrlMethod::kClear>* call) {
- name2lock_status_.clear();
- kv_.clear();
- CHECK(pending_kv_calls_.empty());
- call->SendResponse();
- ENQUEUE_REQUEST(Clear);
-}
+ Add([this](CtrlCall<CtrlMethod::kPushActEvent>* call) {
+ ActEvent act_event = call->request().act_event();
+ call->SendResponse();
+ Global<ActEventLogger>::Get()->PrintActEventToLogDir(act_event);
+ EnqueueRequest<CtrlMethod::kPushActEvent>();
+ });
+
+ Add([this](CtrlCall<CtrlMethod::kClear>* call) {
+ name2lock_status_.clear();
+ kv_.clear();
+ CHECK(pending_kv_calls_.empty());
+ call->SendResponse();
+ EnqueueRequest<CtrlMethod::kClear>();
+ });
-void CtrlServer::IncreaseCountHandler(CtrlCall<CtrlMethod::kIncreaseCount>* call) {
- int32_t& count = count_[call->request().key()];
- count += call->request().val();
- call->mut_response()->set_val(count);
- call->SendResponse();
- ENQUEUE_REQUEST(IncreaseCount);
-}
+ Add([this](CtrlCall<CtrlMethod::kIncreaseCount>* call) {
+ int32_t& count = count_[call->request().key()];
+ count += call->request().val();
+ call->mut_response()->set_val(count);
+ call->SendResponse();
+ EnqueueRequest<CtrlMethod::kIncreaseCount>();
+ });
-void CtrlServer::EraseCountHandler(CtrlCall<CtrlMethod::kEraseCount>* call) {
- CHECK_EQ(count_.erase(call->request().key()), 1);
- call->SendResponse();
- ENQUEUE_REQUEST(EraseCount);
-}
+ Add([this](CtrlCall<CtrlMethod::kEraseCount>* call) {
+ CHECK_EQ(count_.erase(call->request().key()), 1);
+ call->SendResponse();
+ EnqueueRequest<CtrlMethod::kEraseCount>();
+ });
-void CtrlServer::PushAvgActIntervalHandler(CtrlCall<CtrlMethod::kPushAvgActInterval>* call) {
- Global<Profiler>::Get()->PushAvgActInterval(call->request().actor_id(),
- call->request().avg_act_interval());
- call->SendResponse();
- ENQUEUE_REQUEST(PushAvgActInterval);
+ Add([this](CtrlCall<CtrlMethod::kPushAvgActInterval>* call) {
+ Global<Profiler>::Get()->PushAvgActInterval(call->request().actor_id(),
+ call->request().avg_act_interval());
+ call->SendResponse();
+ EnqueueRequest<CtrlMethod::kPushAvgActInterval>();
+ });
}
} // namespace oneflow
diff --git a/oneflow/core/control/ctrl_server.h b/oneflow/core/control/ctrl_server.h
index 7cd008ecfef93fe572c19cfa764d94b479afd193..b2962df7d2b24f640b9f014eb85ec23da030d54b 100644
--- a/oneflow/core/control/ctrl_server.h
+++ b/oneflow/core/control/ctrl_server.h
@@ -7,6 +7,15 @@
namespace oneflow {
+namespace {
+template<size_t... Idx>
+static std::tuple<std::function<void(CtrlCall<(CtrlMethod)Idx>*)>...> GetHandlerTuple(
+ std::index_sequence<Idx...>) {
+ return {};
+}
+
+} // namespace
+
class CtrlServer final {
public:
OF_DISALLOW_COPY_AND_MOVE(CtrlServer);
@@ -17,14 +26,44 @@ class CtrlServer final {
private:
void HandleRpcs();
+ void Init();
+
+ void EnqueueRequests() {
+ for_each_i(handlers_, helper{this}, std::make_index_sequence<kCtrlMethodNum>{});
+ }
+
+ template<CtrlMethod kMethod>
+ void EnqueueRequest() {
+ constexpr const size_t I = (size_t)kMethod;
+ auto handler = std::get<I>(handlers_);
+ auto call = new CtrlCall<(CtrlMethod)I>();
+ call->set_request_handler(std::bind(handler, call));
+ grpc_service_->RequestAsyncUnary(I, call->mut_server_ctx(), call->mut_request(),
+ call->mut_responder(), cq_.get(), cq_.get(), call);
+ }
+
+ template<typename F>
+ void Add(F f) {
+ using tuple_type = typename function_traits<F>::tuple_type;
+ using arg_type =
+ typename std::remove_pointer<typename std::tuple_element<0, tuple_type>::type>::type;
+
+ std::get<arg_type::value>(handlers_) = std::move(f);
+ }
-#define DECLARE_CTRL_METHOD_HANDLE(method) \
- void method##Handler(CtrlCall<CtrlMethod::k##method>* call);
+ struct helper {
+ helper(CtrlServer* s) : s_(s) {}
+ template<typename T, typename V>
+ void operator()(const T& t, V) {
+ s_->EnqueueRequest<(CtrlMethod)V::value>();
+ }
- OF_PP_FOR_EACH_TUPLE(DECLARE_CTRL_METHOD_HANDLE, CTRL_METHOD_SEQ);
+ CtrlServer* s_;
+ };
-#undef DECLARE_CTRL_METHOD_HANDLE
+ using HandlerTuple = decltype(GetHandlerTuple(std::make_index_sequence<kCtrlMethodNum>{}));
+ HandlerTuple handlers_;
std::unique_ptr<CtrlService::AsyncService> grpc_service_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> grpc_server_;
diff --git a/oneflow/core/control/ctrl_service.cpp b/oneflow/core/control/ctrl_service.cpp
index 14843bc83bb96efc73b703182ff51b2288688288..1b89f5b3baddfcaaa79ba7d677dc0e6a2803c140 100644
--- a/oneflow/core/control/ctrl_service.cpp
+++ b/oneflow/core/control/ctrl_service.cpp
@@ -4,12 +4,6 @@ namespace oneflow {
namespace {
-const char* g_method_name[] = {
-#define DEFINE_METHOD_NAME(method) "/oneflow.CtrlService/" OF_PP_STRINGIZE(method),
- OF_PP_FOR_EACH_TUPLE(DEFINE_METHOD_NAME, CTRL_METHOD_SEQ)};
-
-const char* GetMethodName(CtrlMethod method) { return g_method_name[static_cast<int32_t>(method)]; }
-
template<size_t method_index>
const grpc::RpcMethod BuildOneRpcMethod(std::shared_ptr<grpc::ChannelInterface> channel) {
return grpc::RpcMethod(GetMethodName(static_cast<CtrlMethod>(method_index)),
diff --git a/oneflow/core/control/ctrl_service.h b/oneflow/core/control/ctrl_service.h
index 198e3fb476179d8a02b35cb92b8eaaf5697f8673..a89d690e181af59631b65091c74a00a012916a09 100644
--- a/oneflow/core/control/ctrl_service.h
+++ b/oneflow/core/control/ctrl_service.h
@@ -32,31 +32,33 @@ namespace oneflow {
OF_PP_MAKE_TUPLE_SEQ(EraseCount) \
OF_PP_MAKE_TUPLE_SEQ(PushAvgActInterval)
-enum class CtrlMethod {
-#define MAKE_ENTRY(method) k##method,
- OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ)
-};
-#undef MAKE_ENTRY
+#define CatRequest(method) method##Request,
+#define CatReqponse(method) method##Response,
+#define CatEnum(method) k##method,
+#define CatName(method) "/oneflow.CtrlService/" OF_PP_STRINGIZE(method),
-const int32_t kCtrlMethodNum = OF_PP_SEQ_SIZE(CTRL_METHOD_SEQ);
+#define MAKE_META_DATA() \
+ enum class CtrlMethod { OF_PP_FOR_EACH_TUPLE(CatEnum, CTRL_METHOD_SEQ) }; \
+ static const char* g_method_name[] = {OF_PP_FOR_EACH_TUPLE(CatName, CTRL_METHOD_SEQ)}; \
+ using CtrlRequestTuple = std::tuple<OF_PP_FOR_EACH_TUPLE(CatRequest, CTRL_METHOD_SEQ) void>; \
+ using CtrlResponseTuple = std::tuple<OF_PP_FOR_EACH_TUPLE(CatReqponse, CTRL_METHOD_SEQ) void>;
-using CtrlRequestTuple = std::tuple<
-#define MAKE_ENTRY(method) method##Request,
- OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) void>;
-#undef MAKE_ENTRY
+MAKE_META_DATA()
-using CtrlResponseTuple = std::tuple<
-#define MAKE_ENTRY(method) method##Response,
- OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) void>;
-#undef MAKE_ENTRY
+constexpr const size_t kCtrlMethodNum = OF_PP_SEQ_SIZE(CTRL_METHOD_SEQ);
template<CtrlMethod ctrl_method>
using CtrlRequest =
typename std::tuple_element<static_cast<size_t>(ctrl_method), CtrlRequestTuple>::type;
+
template<CtrlMethod ctrl_method>
using CtrlResponse =
typename std::tuple_element<static_cast<size_t>(ctrl_method), CtrlResponseTuple>::type;
+inline const char* GetMethodName(CtrlMethod method) {
+ return g_method_name[static_cast<int32_t>(method)];
+}
+
class CtrlService final {
public:
class Stub final {
diff --git a/oneflow/core/kernel/print_kernel.cpp b/oneflow/core/kernel/print_kernel.cpp
index eb80e72655e78837d55a9ff67e6be730be645afb..8d505a499e42c3e5aa540b619d7a4732dedd488e 100644
--- a/oneflow/core/kernel/print_kernel.cpp
+++ b/oneflow/core/kernel/print_kernel.cpp
@@ -7,7 +7,7 @@ namespace oneflow {
void PrintKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) {
const auto& conf = op_conf().print_conf();
const std::string& root_path = conf.print_dir();
- OF_CALL_ONCE(root_path, GlobalFS()->RecursivelyCreateDir(root_path));
+ OfCallOnce(root_path, GlobalFS(), &fs::FileSystem::RecursivelyCreateDir);
int32_t part_name_suffix_length = conf.part_name_suffix_length();
std::string num = std::to_string(parallel_ctx->parallel_id());
int32_t zero_count = std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0);
diff --git a/oneflow/core/persistence/persistent_out_stream.cpp b/oneflow/core/persistence/persistent_out_stream.cpp
index 6a74bed9cc1545d55c117f2524d608ed257dfc83..0ca8fe97e40b5e262f88c03790f6f4c6d0f833c2 100644
--- a/oneflow/core/persistence/persistent_out_stream.cpp
+++ b/oneflow/core/persistence/persistent_out_stream.cpp
@@ -7,8 +7,8 @@ namespace oneflow {
PersistentOutStream::PersistentOutStream(fs::FileSystem* fs, const std::string& file_path) {
std::string file_dir = Dirname(file_path);
- OF_CALL_ONCE(Global<MachineCtx>::Get()->GetThisCtrlAddr() + "/" + file_dir,
- fs->RecursivelyCreateDirIfNotExist(file_dir));
+ OfCallOnce(Global<MachineCtx>::Get()->GetThisCtrlAddr() + "/" + file_dir, fs,
+ &fs::FileSystem::RecursivelyCreateDirIfNotExist, file_dir);
fs->NewWritableFile(file_path, &file_);
}
diff --git a/oneflow/core/persistence/snapshot.cpp b/oneflow/core/persistence/snapshot.cpp
index b9d2b7a5353d551bf887af291ad96cbfa9d5679a..b8ab84ca3d2c02382f4757ee68ed399591ee03b0 100644
--- a/oneflow/core/persistence/snapshot.cpp
+++ b/oneflow/core/persistence/snapshot.cpp
@@ -18,10 +18,10 @@ std::unique_ptr<PersistentOutStream> Snapshot::GetOutStream(const LogicalBlobId&
int32_t part_id) {
// op_name_dir
std::string op_name_dir = JoinPath(root_path_, lbi.op_name());
- OF_CALL_ONCE(op_name_dir, GlobalFS()->CreateDir(op_name_dir));
+ OfCallOnce(op_name_dir, GlobalFS(), &fs::FileSystem::CreateDir);
// bn_in_op_tmp_dir
std::string bn_in_op_tmp_dir = JoinPath(op_name_dir, lbi.blob_name() + "_tmp4a58");
- OF_CALL_ONCE(bn_in_op_tmp_dir, GlobalFS()->CreateDir(bn_in_op_tmp_dir));
+ OfCallOnce(bn_in_op_tmp_dir, GlobalFS(), &fs::FileSystem::CreateDir);
// part_file
std::string part_file = JoinPath(bn_in_op_tmp_dir, "part_" + std::to_string(part_id));
return std::make_unique<PersistentOutStream>(GlobalFS(), part_file);
diff --git a/oneflow/core/persistence/snapshot_manager.cpp b/oneflow/core/persistence/snapshot_manager.cpp
index 45062391e91a87a75dd861c23d772e7d030744ba..71afa3e94d052ea9d0468a9d42373fa39e35a247 100644
--- a/oneflow/core/persistence/snapshot_manager.cpp
+++ b/oneflow/core/persistence/snapshot_manager.cpp
@@ -7,7 +7,7 @@ namespace oneflow {
SnapshotMgr::SnapshotMgr(const Plan& plan) {
if (Global<JobDesc>::Get()->IsTrain()) {
model_save_snapshots_path_ = Global<JobDesc>::Get()->MdSaveSnapshotsPath();
- OF_CALL_ONCE(model_save_snapshots_path_, GlobalFS()->MakeEmptyDir(model_save_snapshots_path_));
+ OfCallOnce(model_save_snapshots_path_, GlobalFS(), &fs::FileSystem::MakeEmptyDir);
}
const std::string& load_path = Global<JobDesc>::Get()->MdLoadSnapshotPath();
if (load_path != "") { readable_snapshot_.reset(new Snapshot(load_path)); }
@@ -20,7 +20,7 @@ Snapshot* SnapshotMgr::GetWriteableSnapshot(int64_t snapshot_id) {
if (it == snapshot_id2writeable_snapshot_.end()) {
std::string snapshot_root_path =
JoinPath(model_save_snapshots_path_, "snapshot_" + std::to_string(snapshot_id));
- OF_CALL_ONCE(snapshot_root_path, GlobalFS()->CreateDirIfNotExist(snapshot_root_path));
+ OfCallOnce(snapshot_root_path, GlobalFS(), &fs::FileSystem::CreateDirIfNotExist);
std::unique_ptr<Snapshot> ret(new Snapshot(snapshot_root_path));
auto emplace_ret = snapshot_id2writeable_snapshot_.emplace(snapshot_id, std::move(ret));
it = emplace_ret.first;