From 5fa0b3faede64e304aa1afceb162cc42b8d25790 Mon Sep 17 00:00:00 2001 From: qicosmos <383121719@qq.com> Date: Tue, 12 Jun 2018 15:30:51 +0800 Subject: [PATCH] Oneflow qiyu (#928) * remove cplusplus_17.h, it's no need now. * improve make_unique * make macros together * remove OF_CALL_ONCE macro * remove macros in ctrl_server * -Wno-unused-function * remove useless header --- CMakeLists.txt | 2 +- oneflow/core/common/cplusplus_14.h | 27 +- oneflow/core/common/cplusplus_17.h | 10 - oneflow/core/common/meta_util.hpp | 50 ++++ oneflow/core/common/util.h | 3 +- oneflow/core/control/ctrl_call.h | 2 + oneflow/core/control/ctrl_client.h | 44 ++- oneflow/core/control/ctrl_server.cpp | 265 +++++++++--------- oneflow/core/control/ctrl_server.h | 47 +++- oneflow/core/control/ctrl_service.cpp | 6 - oneflow/core/control/ctrl_service.h | 30 +- oneflow/core/kernel/print_kernel.cpp | 2 +- .../persistence/persistent_out_stream.cpp | 4 +- oneflow/core/persistence/snapshot.cpp | 4 +- oneflow/core/persistence/snapshot_manager.cpp | 4 +- 15 files changed, 308 insertions(+), 192 deletions(-) delete mode 100644 oneflow/core/common/cplusplus_17.h create mode 100644 oneflow/core/common/meta_util.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f9883c38c..48557fc2b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,7 @@ if (WIN32) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0") else() list(APPEND CUDA_NVCC_FLAGS -std=c++11 -w -Wno-deprecated-gpu-targets) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare -Wno-unused-function") if (RELEASE_VERSION) list(APPEND CUDA_NVCC_FLAGS -O3) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") diff --git a/oneflow/core/common/cplusplus_14.h b/oneflow/core/common/cplusplus_14.h index 041a334ef..b6219376d 100644 --- a/oneflow/core/common/cplusplus_14.h +++ b/oneflow/core/common/cplusplus_14.h @@ -4,12 +4,35 @@ #if __cplusplus < 201402L namespace std { +template<class T> +struct unique_if { + typedef unique_ptr<T> single_object; +}; + +template<class T> +struct unique_if<T[]> { + typedef unique_ptr<T[]> unknown_bound; +}; + +template<class T, size_t N> +struct unique_if<T[N]> { + typedef void known_bound; +}; -template<typename T, typename... Args> -unique_ptr<T> make_unique(Args&&... args) { +template<class T, class... Args> +typename unique_if<T>::single_object make_unique(Args&&... args) { return unique_ptr<T>(new T(forward<Args>(args)...)); } +template<class T> +typename unique_if<T>::unknown_bound make_unique(size_t n) { + typedef typename remove_extent<T>::type U; + return unique_ptr<T>(new U[n]()); +} + +template<class T, class... Args> +typename unique_if<T>::known_bound make_unique(Args&&...) = delete; + template<typename T, T... Ints> struct integer_sequence { static_assert(is_integral<T>::value, ""); diff --git a/oneflow/core/common/cplusplus_17.h b/oneflow/core/common/cplusplus_17.h deleted file mode 100644 index a59d89149..000000000 --- a/oneflow/core/common/cplusplus_17.h +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_ -#define ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_ - -#if __cplusplus < 201703L - -namespace std {} - -#endif - -#endif // ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_ diff --git a/oneflow/core/common/meta_util.hpp b/oneflow/core/common/meta_util.hpp new file mode 100644 index 000000000..49210b830 --- /dev/null +++ b/oneflow/core/common/meta_util.hpp @@ -0,0 +1,50 @@ +#ifndef ONEFLOW_META_UTIL_HPP +#define ONEFLOW_META_UTIL_HPP + +#include "oneflow/core/common/cplusplus_14.h" + +namespace oneflow{ + + template <typename... Args, typename Func, std::size_t... Idx> + void for_each(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) { + (void)std::initializer_list<int> { (f(std::get<Idx>(t)), void(), 0)...}; + } + + template <typename... Args, typename Func, std::size_t... Idx> + void for_each_i(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) { + (void)std::initializer_list<int> { (f(std::get<Idx>(t), std::integral_constant<size_t, Idx>{}), void(), 0)...}; + } + + template<typename T> + struct function_traits; + + template<typename Ret, typename... Args> + struct function_traits<Ret(Args...)> + { + public: + enum { arity = sizeof...(Args) }; + typedef Ret function_type(Args...); + typedef Ret return_type; + using stl_function_type = std::function<function_type>; + typedef Ret(*pointer)(Args...); + + typedef std::tuple<Args...> tuple_type; + }; + + template<typename Ret, typename... Args> + struct function_traits<Ret(*)(Args...)> : function_traits<Ret(Args...)>{}; + + template <typename Ret, typename... Args> + struct function_traits<std::function<Ret(Args...)>> : function_traits<Ret(Args...)>{}; + + template <typename ReturnType, typename ClassType, typename... Args> + struct function_traits<ReturnType(ClassType::*)(Args...)> : function_traits<ReturnType(Args...)>{}; + + template <typename ReturnType, typename ClassType, typename... Args> + struct function_traits<ReturnType(ClassType::*)(Args...) const> : function_traits<ReturnType(Args...)>{}; + + template<typename Callable> + struct function_traits : function_traits<decltype(&Callable::operator())>{}; +} + +#endif //ONEFLOW_META_UTIL_HPP diff --git a/oneflow/core/common/util.h b/oneflow/core/common/util.h index 956e92746..9924909aa 100644 --- a/oneflow/core/common/util.h +++ b/oneflow/core/common/util.h @@ -24,8 +24,7 @@ #include <utility> #include "oneflow/core/operator/op_conf.pb.h" -#include "oneflow/core/common/cplusplus_14.h" -#include "oneflow/core/common/cplusplus_17.h" +#include "oneflow/core/common/meta_util.hpp" DECLARE_string(log_dir); diff --git a/oneflow/core/control/ctrl_call.h b/oneflow/core/control/ctrl_call.h index 0404a7975..2f1524bfa 100644 --- a/oneflow/core/control/ctrl_call.h +++ b/oneflow/core/control/ctrl_call.h @@ -26,6 +26,8 @@ class CtrlCall final : public CtrlCallIf { CtrlCall() : status_(Status::kBeforeHandleRequest), responder_(&server_ctx_) {} ~CtrlCall() = default; + static constexpr const size_t value = (size_t)ctrl_method; + const CtrlRequest<ctrl_method>& request() const { return request_; } CtrlRequest<ctrl_method>* mut_request() { return &request_; } CtrlResponse<ctrl_method>* mut_response() { return &response_; } diff --git a/oneflow/core/control/ctrl_client.h b/oneflow/core/control/ctrl_client.h index 6173ae4b5..37c7a6bef 100644 --- a/oneflow/core/control/ctrl_client.h +++ b/oneflow/core/control/ctrl_client.h @@ -68,19 +68,37 @@ class CtrlClient final { #define OF_BARRIER() Global<CtrlClient>::Get()->Barrier(FILE_LINE_STR) -#define OF_CALL_ONCE(name, ...) \ - do { \ - TryLockResult lock_ret = Global<CtrlClient>::Get()->TryLock(name); \ - if (lock_ret == TryLockResult::kLocked) { \ - __VA_ARGS__; \ - Global<CtrlClient>::Get()->NotifyDone(name); \ - } else if (lock_ret == TryLockResult::kDone) { \ - } else if (lock_ret == TryLockResult::kDoing) { \ - Global<CtrlClient>::Get()->WaitUntilDone(name); \ - } else { \ - UNIMPLEMENTED(); \ - } \ - } while (0) +static void OfCallOnce(const std::string& name, std::function<void()> f) { + TryLockResult lock_ret = Global<CtrlClient>::Get()->TryLock(name); + if (lock_ret == TryLockResult::kLocked) { + f(); + Global<CtrlClient>::Get()->NotifyDone(name); + } else if (lock_ret == TryLockResult::kDone) { + } else if (lock_ret == TryLockResult::kDoing) { + Global<CtrlClient>::Get()->WaitUntilDone(name); + } else { + UNIMPLEMENTED(); + } +} + +template<typename Self, typename F, typename Arg, typename... Args> +static void OfCallOnce(const std::string& name, Self self, F f, Arg&& arg, Args&&... args) { + std::function<void()> fn = + std::bind(f, self, std::forward<Arg>(arg), std::forward<Args>(args)...); + OfCallOnce(name, std::move(fn)); +} + +template<typename Self, typename F> +static void OfCallOnce(const std::string& name, Self self, F f) { + std::function<void()> fn = std::bind(f, self, name); + OfCallOnce(name, std::move(fn)); +} + +template<typename F, typename Arg, typename... Args> +static void OfCallOnce(const std::string& name, F f, Arg&& arg, Args&&... args) { + std::function<void()> fn = std::bind(f, std::forward<Arg>(arg), std::forward<Args>(args)...); + OfCallOnce(name, std::move(fn)); +} } // namespace oneflow diff --git a/oneflow/core/control/ctrl_server.cpp b/oneflow/core/control/ctrl_server.cpp index 43f8239ad..0949eadc6 100644 --- a/oneflow/core/control/ctrl_server.cpp +++ b/oneflow/core/control/ctrl_server.cpp @@ -24,6 +24,8 @@ CtrlServer::~CtrlServer() { } CtrlServer::CtrlServer(const std::string& server_addr) { + Init(); + if (FLAGS_grpc_use_no_signal) { grpc_use_signal(-1); } int port = ExtractPortFromAddr(server_addr); grpc::ServerBuilder server_builder; @@ -38,17 +40,8 @@ CtrlServer::CtrlServer(const std::string& server_addr) { loop_thread_ = std::thread(&CtrlServer::HandleRpcs, this); } -#define ENQUEUE_REQUEST(method) \ - do { \ - auto call = new CtrlCall<CtrlMethod::k##method>(); \ - call->set_request_handler(std::bind(&CtrlServer::method##Handler, this, call)); \ - grpc_service_->RequestAsyncUnary(static_cast<int32_t>(CtrlMethod::k##method), \ - call->mut_server_ctx(), call->mut_request(), \ - call->mut_responder(), cq_.get(), cq_.get(), call); \ - } while (0); - void CtrlServer::HandleRpcs() { - OF_PP_FOR_EACH_TUPLE(ENQUEUE_REQUEST, CTRL_METHOD_SEQ); + EnqueueRequests(); void* tag = nullptr; bool ok = false; @@ -64,141 +57,147 @@ void CtrlServer::HandleRpcs() { } } -void CtrlServer::LoadServerHandler(CtrlCall<CtrlMethod::kLoadServer>* call) { - call->SendResponse(); - ENQUEUE_REQUEST(LoadServer); -} +void CtrlServer::Init() { + Add([this](CtrlCall<CtrlMethod::kLoadServer>* call) { + call->SendResponse(); + EnqueueRequest<CtrlMethod::kLoadServer>(); + }); + + Add([this](CtrlCall<CtrlMethod::kBarrier>* call) { + const std::string& barrier_name = call->request().name(); + int32_t barrier_num = call->request().num(); + auto barrier_call_it = barrier_calls_.find(barrier_name); + if (barrier_call_it == barrier_calls_.end()) { + barrier_call_it = + barrier_calls_ + .emplace(barrier_name, std::make_pair(std::list<CtrlCallIf*>{}, barrier_num)) + .first; + } + CHECK_EQ(barrier_num, barrier_call_it->second.second); + barrier_call_it->second.first.push_back(call); + if (barrier_call_it->second.first.size() == barrier_call_it->second.second) { + for (CtrlCallIf* pending_call : barrier_call_it->second.first) { + pending_call->SendResponse(); + } + barrier_calls_.erase(barrier_call_it); + } -void CtrlServer::BarrierHandler(CtrlCall<CtrlMethod::kBarrier>* call) { - const std::string& barrier_name = call->request().name(); - int32_t barrier_num = call->request().num(); - auto barrier_call_it = barrier_calls_.find(barrier_name); - if (barrier_call_it == barrier_calls_.end()) { - barrier_call_it = - barrier_calls_.emplace(barrier_name, std::make_pair(std::list<CtrlCallIf*>{}, barrier_num)) - .first; - } - CHECK_EQ(barrier_num, barrier_call_it->second.second); - barrier_call_it->second.first.push_back(call); - if (barrier_call_it->second.first.size() == barrier_call_it->second.second) { - for (CtrlCallIf* pending_call : barrier_call_it->second.first) { pending_call->SendResponse(); } - barrier_calls_.erase(barrier_call_it); - } - ENQUEUE_REQUEST(Barrier); -} + EnqueueRequest<CtrlMethod::kBarrier>(); + }); -void CtrlServer::TryLockHandler(CtrlCall<CtrlMethod::kTryLock>* call) { - const std::string& lock_name = call->request().name(); - auto name2lock_status_it = name2lock_status_.find(lock_name); - if (name2lock_status_it == name2lock_status_.end()) { - call->mut_response()->set_result(TryLockResult::kLocked); - auto waiting_until_done_calls = new std::list<CtrlCallIf*>; - CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second); - } else { - if (name2lock_status_it->second) { - call->mut_response()->set_result(TryLockResult::kDoing); + Add([this](CtrlCall<CtrlMethod::kTryLock>* call) { + const std::string& lock_name = call->request().name(); + auto name2lock_status_it = name2lock_status_.find(lock_name); + if (name2lock_status_it == name2lock_status_.end()) { + call->mut_response()->set_result(TryLockResult::kLocked); + auto waiting_until_done_calls = new std::list<CtrlCallIf*>; + CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second); } else { - call->mut_response()->set_result(TryLockResult::kDone); + if (name2lock_status_it->second) { + call->mut_response()->set_result(TryLockResult::kDoing); + } else { + call->mut_response()->set_result(TryLockResult::kDone); + } } - } - call->SendResponse(); - ENQUEUE_REQUEST(TryLock); -} - -void CtrlServer::NotifyDoneHandler(CtrlCall<CtrlMethod::kNotifyDone>* call) { - const std::string& lock_name = call->request().name(); - auto name2lock_status_it = name2lock_status_.find(lock_name); - auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(name2lock_status_it->second); - for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); } - delete waiting_calls; - name2lock_status_it->second = nullptr; - call->SendResponse(); - ENQUEUE_REQUEST(NotifyDone); -} - -void CtrlServer::WaitUntilDoneHandler(CtrlCall<CtrlMethod::kWaitUntilDone>* call) { - const std::string& lock_name = call->request().name(); - void* lock_status = name2lock_status_.at(lock_name); - if (lock_status) { - auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(lock_status); - waiting_calls->push_back(call); - } else { call->SendResponse(); - } - ENQUEUE_REQUEST(WaitUntilDone); -} - -void CtrlServer::PushKVHandler(CtrlCall<CtrlMethod::kPushKV>* call) { - const std::string& k = call->request().key(); - const std::string& v = call->request().val(); - CHECK(kv_.emplace(k, v).second); - - auto pending_kv_calls_it = pending_kv_calls_.find(k); - if (pending_kv_calls_it != pending_kv_calls_.end()) { - for (auto pending_call : pending_kv_calls_it->second) { - pending_call->mut_response()->set_val(v); - pending_call->SendResponse(); + EnqueueRequest<CtrlMethod::kTryLock>(); + }); + + Add([this](CtrlCall<CtrlMethod::kNotifyDone>* call) { + const std::string& lock_name = call->request().name(); + auto name2lock_status_it = name2lock_status_.find(lock_name); + auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(name2lock_status_it->second); + for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); } + delete waiting_calls; + name2lock_status_it->second = nullptr; + call->SendResponse(); + EnqueueRequest<CtrlMethod::kNotifyDone>(); + }); + + Add([this](CtrlCall<CtrlMethod::kWaitUntilDone>* call) { + const std::string& lock_name = call->request().name(); + void* lock_status = name2lock_status_.at(lock_name); + if (lock_status) { + auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(lock_status); + waiting_calls->push_back(call); + } else { + call->SendResponse(); + } + EnqueueRequest<CtrlMethod::kWaitUntilDone>(); + }); + + Add([this](CtrlCall<CtrlMethod::kPushKV>* call) { + const std::string& k = call->request().key(); + const std::string& v = call->request().val(); + CHECK(kv_.emplace(k, v).second); + + auto pending_kv_calls_it = pending_kv_calls_.find(k); + if (pending_kv_calls_it != pending_kv_calls_.end()) { + for (auto pending_call : pending_kv_calls_it->second) { + pending_call->mut_response()->set_val(v); + pending_call->SendResponse(); + } + pending_kv_calls_.erase(pending_kv_calls_it); } - pending_kv_calls_.erase(pending_kv_calls_it); - } - call->SendResponse(); - ENQUEUE_REQUEST(PushKV); -} - -void CtrlServer::ClearKVHandler(CtrlCall<CtrlMethod::kClearKV>* call) { - const std::string& k = call->request().key(); - CHECK_EQ(kv_.erase(k), 1); - CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end()); - call->SendResponse(); - ENQUEUE_REQUEST(ClearKV); -} - -void CtrlServer::PullKVHandler(CtrlCall<CtrlMethod::kPullKV>* call) { - const std::string& k = call->request().key(); - auto kv_it = kv_.find(k); - if (kv_it != kv_.end()) { - call->mut_response()->set_val(kv_it->second); call->SendResponse(); - } else { - pending_kv_calls_[k].push_back(call); - } - ENQUEUE_REQUEST(PullKV); -} + EnqueueRequest<CtrlMethod::kPushKV>(); + }); -void CtrlServer::PushActEventHandler(CtrlCall<CtrlMethod::kPushActEvent>* call) { - ActEvent act_event = call->request().act_event(); - call->SendResponse(); - Global<ActEventLogger>::Get()->PrintActEventToLogDir(act_event); - ENQUEUE_REQUEST(PushActEvent); -} + Add([this](CtrlCall<CtrlMethod::kClearKV>* call) { + const std::string& k = call->request().key(); + CHECK_EQ(kv_.erase(k), 1); + CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end()); + call->SendResponse(); + EnqueueRequest<CtrlMethod::kClearKV>(); + }); + + Add([this](CtrlCall<CtrlMethod::kPullKV>* call) { + const std::string& k = call->request().key(); + auto kv_it = kv_.find(k); + if (kv_it != kv_.end()) { + call->mut_response()->set_val(kv_it->second); + call->SendResponse(); + } else { + pending_kv_calls_[k].push_back(call); + } + EnqueueRequest<CtrlMethod::kPullKV>(); + }); -void CtrlServer::ClearHandler(CtrlCall<CtrlMethod::kClear>* call) { - name2lock_status_.clear(); - kv_.clear(); - CHECK(pending_kv_calls_.empty()); - call->SendResponse(); - ENQUEUE_REQUEST(Clear); -} + Add([this](CtrlCall<CtrlMethod::kPushActEvent>* call) { + ActEvent act_event = call->request().act_event(); + call->SendResponse(); + Global<ActEventLogger>::Get()->PrintActEventToLogDir(act_event); + EnqueueRequest<CtrlMethod::kPushActEvent>(); + }); + + Add([this](CtrlCall<CtrlMethod::kClear>* call) { + name2lock_status_.clear(); + kv_.clear(); + CHECK(pending_kv_calls_.empty()); + call->SendResponse(); + EnqueueRequest<CtrlMethod::kClear>(); + }); -void CtrlServer::IncreaseCountHandler(CtrlCall<CtrlMethod::kIncreaseCount>* call) { - int32_t& count = count_[call->request().key()]; - count += call->request().val(); - call->mut_response()->set_val(count); - call->SendResponse(); - ENQUEUE_REQUEST(IncreaseCount); -} + Add([this](CtrlCall<CtrlMethod::kIncreaseCount>* call) { + int32_t& count = count_[call->request().key()]; + count += call->request().val(); + call->mut_response()->set_val(count); + call->SendResponse(); + EnqueueRequest<CtrlMethod::kIncreaseCount>(); + }); -void CtrlServer::EraseCountHandler(CtrlCall<CtrlMethod::kEraseCount>* call) { - CHECK_EQ(count_.erase(call->request().key()), 1); - call->SendResponse(); - ENQUEUE_REQUEST(EraseCount); -} + Add([this](CtrlCall<CtrlMethod::kEraseCount>* call) { + CHECK_EQ(count_.erase(call->request().key()), 1); + call->SendResponse(); + EnqueueRequest<CtrlMethod::kEraseCount>(); + }); -void CtrlServer::PushAvgActIntervalHandler(CtrlCall<CtrlMethod::kPushAvgActInterval>* call) { - Global<Profiler>::Get()->PushAvgActInterval(call->request().actor_id(), - call->request().avg_act_interval()); - call->SendResponse(); - ENQUEUE_REQUEST(PushAvgActInterval); + Add([this](CtrlCall<CtrlMethod::kPushAvgActInterval>* call) { + Global<Profiler>::Get()->PushAvgActInterval(call->request().actor_id(), + call->request().avg_act_interval()); + call->SendResponse(); + EnqueueRequest<CtrlMethod::kPushAvgActInterval>(); + }); } } // namespace oneflow diff --git a/oneflow/core/control/ctrl_server.h b/oneflow/core/control/ctrl_server.h index 7cd008ecf..b2962df7d 100644 --- a/oneflow/core/control/ctrl_server.h +++ b/oneflow/core/control/ctrl_server.h @@ -7,6 +7,15 @@ namespace oneflow { +namespace { +template<size_t... Idx> +static std::tuple<std::function<void(CtrlCall<(CtrlMethod)Idx>*)>...> GetHandlerTuple( + std::index_sequence<Idx...>) { + return {}; +} + +} // namespace + class CtrlServer final { public: OF_DISALLOW_COPY_AND_MOVE(CtrlServer); @@ -17,14 +26,44 @@ class CtrlServer final { private: void HandleRpcs(); + void Init(); + + void EnqueueRequests() { + for_each_i(handlers_, helper{this}, std::make_index_sequence<kCtrlMethodNum>{}); + } + + template<CtrlMethod kMethod> + void EnqueueRequest() { + constexpr const size_t I = (size_t)kMethod; + auto handler = std::get<I>(handlers_); + auto call = new CtrlCall<(CtrlMethod)I>(); + call->set_request_handler(std::bind(handler, call)); + grpc_service_->RequestAsyncUnary(I, call->mut_server_ctx(), call->mut_request(), + call->mut_responder(), cq_.get(), cq_.get(), call); + } + + template<typename F> + void Add(F f) { + using tuple_type = typename function_traits<F>::tuple_type; + using arg_type = + typename std::remove_pointer<typename std::tuple_element<0, tuple_type>::type>::type; + + std::get<arg_type::value>(handlers_) = std::move(f); + } -#define DECLARE_CTRL_METHOD_HANDLE(method) \ - void method##Handler(CtrlCall<CtrlMethod::k##method>* call); + struct helper { + helper(CtrlServer* s) : s_(s) {} + template<typename T, typename V> + void operator()(const T& t, V) { + s_->EnqueueRequest<(CtrlMethod)V::value>(); + } - OF_PP_FOR_EACH_TUPLE(DECLARE_CTRL_METHOD_HANDLE, CTRL_METHOD_SEQ); + CtrlServer* s_; + }; -#undef DECLARE_CTRL_METHOD_HANDLE + using HandlerTuple = decltype(GetHandlerTuple(std::make_index_sequence<kCtrlMethodNum>{})); + HandlerTuple handlers_; std::unique_ptr<CtrlService::AsyncService> grpc_service_; std::unique_ptr<grpc::ServerCompletionQueue> cq_; std::unique_ptr<grpc::Server> grpc_server_; diff --git a/oneflow/core/control/ctrl_service.cpp b/oneflow/core/control/ctrl_service.cpp index 14843bc83..1b89f5b3b 100644 --- a/oneflow/core/control/ctrl_service.cpp +++ b/oneflow/core/control/ctrl_service.cpp @@ -4,12 +4,6 @@ namespace oneflow { namespace { -const char* g_method_name[] = { -#define DEFINE_METHOD_NAME(method) "/oneflow.CtrlService/" OF_PP_STRINGIZE(method), - OF_PP_FOR_EACH_TUPLE(DEFINE_METHOD_NAME, CTRL_METHOD_SEQ)}; - -const char* GetMethodName(CtrlMethod method) { return g_method_name[static_cast<int32_t>(method)]; } - template<size_t method_index> const grpc::RpcMethod BuildOneRpcMethod(std::shared_ptr<grpc::ChannelInterface> channel) { return grpc::RpcMethod(GetMethodName(static_cast<CtrlMethod>(method_index)), diff --git a/oneflow/core/control/ctrl_service.h b/oneflow/core/control/ctrl_service.h index 198e3fb47..a89d690e1 100644 --- a/oneflow/core/control/ctrl_service.h +++ b/oneflow/core/control/ctrl_service.h @@ -32,31 +32,33 @@ namespace oneflow { OF_PP_MAKE_TUPLE_SEQ(EraseCount) \ OF_PP_MAKE_TUPLE_SEQ(PushAvgActInterval) -enum class CtrlMethod { -#define MAKE_ENTRY(method) k##method, - OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) -}; -#undef MAKE_ENTRY +#define CatRequest(method) method##Request, +#define CatReqponse(method) method##Response, +#define CatEnum(method) k##method, +#define CatName(method) "/oneflow.CtrlService/" OF_PP_STRINGIZE(method), -const int32_t kCtrlMethodNum = OF_PP_SEQ_SIZE(CTRL_METHOD_SEQ); +#define MAKE_META_DATA() \ + enum class CtrlMethod { OF_PP_FOR_EACH_TUPLE(CatEnum, CTRL_METHOD_SEQ) }; \ + static const char* g_method_name[] = {OF_PP_FOR_EACH_TUPLE(CatName, CTRL_METHOD_SEQ)}; \ + using CtrlRequestTuple = std::tuple<OF_PP_FOR_EACH_TUPLE(CatRequest, CTRL_METHOD_SEQ) void>; \ + using CtrlResponseTuple = std::tuple<OF_PP_FOR_EACH_TUPLE(CatReqponse, CTRL_METHOD_SEQ) void>; -using CtrlRequestTuple = std::tuple< -#define MAKE_ENTRY(method) method##Request, - OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) void>; -#undef MAKE_ENTRY +MAKE_META_DATA() -using CtrlResponseTuple = std::tuple< -#define MAKE_ENTRY(method) method##Response, - OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) void>; -#undef MAKE_ENTRY +constexpr const size_t kCtrlMethodNum = OF_PP_SEQ_SIZE(CTRL_METHOD_SEQ); template<CtrlMethod ctrl_method> using CtrlRequest = typename std::tuple_element<static_cast<size_t>(ctrl_method), CtrlRequestTuple>::type; + template<CtrlMethod ctrl_method> using CtrlResponse = typename std::tuple_element<static_cast<size_t>(ctrl_method), CtrlResponseTuple>::type; +inline const char* GetMethodName(CtrlMethod method) { + return g_method_name[static_cast<int32_t>(method)]; +} + class CtrlService final { public: class Stub final { diff --git a/oneflow/core/kernel/print_kernel.cpp b/oneflow/core/kernel/print_kernel.cpp index eb80e7265..8d505a499 100644 --- a/oneflow/core/kernel/print_kernel.cpp +++ b/oneflow/core/kernel/print_kernel.cpp @@ -7,7 +7,7 @@ namespace oneflow { void PrintKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) { const auto& conf = op_conf().print_conf(); const std::string& root_path = conf.print_dir(); - OF_CALL_ONCE(root_path, GlobalFS()->RecursivelyCreateDir(root_path)); + OfCallOnce(root_path, GlobalFS(), &fs::FileSystem::RecursivelyCreateDir); int32_t part_name_suffix_length = conf.part_name_suffix_length(); std::string num = std::to_string(parallel_ctx->parallel_id()); int32_t zero_count = std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0); diff --git a/oneflow/core/persistence/persistent_out_stream.cpp b/oneflow/core/persistence/persistent_out_stream.cpp index 6a74bed9c..0ca8fe97e 100644 --- a/oneflow/core/persistence/persistent_out_stream.cpp +++ b/oneflow/core/persistence/persistent_out_stream.cpp @@ -7,8 +7,8 @@ namespace oneflow { PersistentOutStream::PersistentOutStream(fs::FileSystem* fs, const std::string& file_path) { std::string file_dir = Dirname(file_path); - OF_CALL_ONCE(Global<MachineCtx>::Get()->GetThisCtrlAddr() + "/" + file_dir, - fs->RecursivelyCreateDirIfNotExist(file_dir)); + OfCallOnce(Global<MachineCtx>::Get()->GetThisCtrlAddr() + "/" + file_dir, fs, + &fs::FileSystem::RecursivelyCreateDirIfNotExist, file_dir); fs->NewWritableFile(file_path, &file_); } diff --git a/oneflow/core/persistence/snapshot.cpp b/oneflow/core/persistence/snapshot.cpp index b9d2b7a53..b8ab84ca3 100644 --- a/oneflow/core/persistence/snapshot.cpp +++ b/oneflow/core/persistence/snapshot.cpp @@ -18,10 +18,10 @@ std::unique_ptr<PersistentOutStream> Snapshot::GetOutStream(const LogicalBlobId& int32_t part_id) { // op_name_dir std::string op_name_dir = JoinPath(root_path_, lbi.op_name()); - OF_CALL_ONCE(op_name_dir, GlobalFS()->CreateDir(op_name_dir)); + OfCallOnce(op_name_dir, GlobalFS(), &fs::FileSystem::CreateDir); // bn_in_op_tmp_dir std::string bn_in_op_tmp_dir = JoinPath(op_name_dir, lbi.blob_name() + "_tmp4a58"); - OF_CALL_ONCE(bn_in_op_tmp_dir, GlobalFS()->CreateDir(bn_in_op_tmp_dir)); + OfCallOnce(bn_in_op_tmp_dir, GlobalFS(), &fs::FileSystem::CreateDir); // part_file std::string part_file = JoinPath(bn_in_op_tmp_dir, "part_" + std::to_string(part_id)); return std::make_unique<PersistentOutStream>(GlobalFS(), part_file); diff --git a/oneflow/core/persistence/snapshot_manager.cpp b/oneflow/core/persistence/snapshot_manager.cpp index 45062391e..71afa3e94 100644 --- a/oneflow/core/persistence/snapshot_manager.cpp +++ b/oneflow/core/persistence/snapshot_manager.cpp @@ -7,7 +7,7 @@ namespace oneflow { SnapshotMgr::SnapshotMgr(const Plan& plan) { if (Global<JobDesc>::Get()->IsTrain()) { model_save_snapshots_path_ = Global<JobDesc>::Get()->MdSaveSnapshotsPath(); - OF_CALL_ONCE(model_save_snapshots_path_, GlobalFS()->MakeEmptyDir(model_save_snapshots_path_)); + OfCallOnce(model_save_snapshots_path_, GlobalFS(), &fs::FileSystem::MakeEmptyDir); } const std::string& load_path = Global<JobDesc>::Get()->MdLoadSnapshotPath(); if (load_path != "") { readable_snapshot_.reset(new Snapshot(load_path)); } @@ -20,7 +20,7 @@ Snapshot* SnapshotMgr::GetWriteableSnapshot(int64_t snapshot_id) { if (it == snapshot_id2writeable_snapshot_.end()) { std::string snapshot_root_path = JoinPath(model_save_snapshots_path_, "snapshot_" + std::to_string(snapshot_id)); - OF_CALL_ONCE(snapshot_root_path, GlobalFS()->CreateDirIfNotExist(snapshot_root_path)); + OfCallOnce(snapshot_root_path, GlobalFS(), &fs::FileSystem::CreateDirIfNotExist); std::unique_ptr<Snapshot> ret(new Snapshot(snapshot_root_path)); auto emplace_ret = snapshot_id2writeable_snapshot_.emplace(snapshot_id, std::move(ret)); it = emplace_ret.first; -- GitLab