diff --git a/oneflow/core/common/cplusplus_14.h b/oneflow/core/common/cplusplus_14.h index 969006f5ce40aad133a5ebed199ce3be897680df..041a334effe1b698a11b20d4783e75cee39e8ebe 100644 --- a/oneflow/core/common/cplusplus_14.h +++ b/oneflow/core/common/cplusplus_14.h @@ -6,10 +6,41 @@ namespace std { template<typename T, typename... Args> -std::unique_ptr<T> make_unique(Args&&... args) { - return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); +unique_ptr<T> make_unique(Args&&... args) { + return unique_ptr<T>(new T(forward<Args>(args)...)); } +template<typename T, T... Ints> +struct integer_sequence { + static_assert(is_integral<T>::value, ""); + using value_type = T; + static constexpr size_t size() { return sizeof...(Ints); } +}; + +template<typename T, T index, typename U, T... Ints> +struct make_integer_sequence_impl; + +template<typename T, T index, T... Ints> +struct make_integer_sequence_impl<T, index, typename enable_if<index != 0>::type, Ints...> + : make_integer_sequence_impl<T, index - 1, void, index, Ints...> {}; + +template<typename T, T index, T... Ints> +struct make_integer_sequence_impl<T, index, typename enable_if<index == 0>::type, Ints...> { + using type = integer_sequence<T, index, Ints...>; +}; + +template<typename T, T n> +using make_integer_sequence = typename make_integer_sequence_impl<T, n - 1, void>::type; + +template<size_t... Ints> +using index_sequence = integer_sequence<size_t, Ints...>; + +template<size_t n> +using make_index_sequence = make_integer_sequence<size_t, n>; + +template<typename... T> +using index_sequence_for = make_index_sequence<sizeof...(T)>; + } // namespace std #endif diff --git a/oneflow/core/control/ctrl_client.cpp b/oneflow/core/control/ctrl_client.cpp index 3e5a39e96f41d86e798958561d89b59fefef7a54..df3e3f2d20466901279c2887b32ffa5f71ed5fe0 100644 --- a/oneflow/core/control/ctrl_client.cpp +++ b/oneflow/core/control/ctrl_client.cpp @@ -10,25 +10,24 @@ const int64_t sleep_seconds = 10; #define GRPC_CHECK(x) CHECK_EQ(x.error_code(), grpc::StatusCode::OK) -#define DEFINE_CLIENT_CALL(method) \ - class method##ClientCall final { \ - public: \ - OF_DISALLOW_COPY_AND_MOVE(method##ClientCall); \ - method##ClientCall() = default; \ - ~method##ClientCall() = default; \ - method##Request* mut_request() { return &request_; } \ - const method##Response& response() const { return response_; } \ - void operator()(CtrlService::Stub* stub) { \ - grpc::ClientContext client_ctx; \ - GRPC_CHECK(stub->method(&client_ctx, request_, &response_)); \ - } \ - \ - private: \ - method##Request request_; \ - method##Response response_; \ - }; - -OF_PP_FOR_EACH_TUPLE(DEFINE_CLIENT_CALL, CTRL_METHOD_SEQ); +template<CtrlMethod ctrl_method> +class ClientCall final { + public: + OF_DISALLOW_COPY_AND_MOVE(ClientCall); + ClientCall() = default; + ~ClientCall() = default; + + CtrlRequest<ctrl_method>* mut_request() { return &request_; } + const CtrlResponse<ctrl_method>& response() const { return response_; } + void operator()(CtrlService::Stub* stub) { + grpc::ClientContext client_ctx; + GRPC_CHECK(stub->CallMethod<ctrl_method>(&client_ctx, request_, &response_)); + } + + private: + CtrlRequest<ctrl_method> request_; + CtrlResponse<ctrl_method> response_; +}; } // namespace @@ -46,7 +45,7 @@ void CtrlClient::Barrier(const std::string& barrier_name) { } void CtrlClient::Barrier(const std::string& barrier_name, int32_t barrier_num) { - BarrierClientCall call; + ClientCall<CtrlMethod::kBarrier> call; call.mut_request()->set_name(barrier_name); call.mut_request()->set_num(barrier_num); call(GetMasterStub()); @@ -57,7 +56,7 @@ TryLockResult CtrlClient::TryLock(const std::string& name) { std::unique_lock<std::mutex> lck(done_names_mtx_); if (done_names_.find(name) != done_names_.end()) { return TryLockResult::kDone; } } - TryLockClientCall call; + ClientCall<CtrlMethod::kTryLock> call; call.mut_request()->set_name(name); call(GetResponsibleStub(name)); if (call.response().result() == TryLockResult::kDone) { @@ -68,19 +67,19 @@ TryLockResult CtrlClient::TryLock(const std::string& name) { } void CtrlClient::NotifyDone(const std::string& name) { - NotifyDoneClientCall call; + ClientCall<CtrlMethod::kNotifyDone> call; call.mut_request()->set_name(name); call(GetResponsibleStub(name)); } void CtrlClient::WaitUntilDone(const std::string& name) { - WaitUntilDoneClientCall call; + ClientCall<CtrlMethod::kWaitUntilDone> call; call.mut_request()->set_name(name); call(GetResponsibleStub(name)); } void CtrlClient::PushKV(const std::string& k, std::function<void(std::string*)> VSetter) { - PushKVClientCall call; + ClientCall<CtrlMethod::kPushKV> call; call.mut_request()->set_key(k); VSetter(call.mut_request()->mutable_val()); call(GetResponsibleStub(k)); @@ -95,13 +94,13 @@ void CtrlClient::PushKV(const std::string& k, const PbMessage& msg) { } void CtrlClient::ClearKV(const std::string& k) { - ClearKVClientCall call; + ClientCall<CtrlMethod::kClearKV> call; call.mut_request()->set_key(k); call(GetResponsibleStub(k)); } void CtrlClient::PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) { - PullKVClientCall call; + ClientCall<CtrlMethod::kPullKV> call; call.mut_request()->set_key(k); call(GetResponsibleStub(k)); VGetter(call.response().val()); @@ -116,20 +115,20 @@ void CtrlClient::PullKV(const std::string& k, PbMessage* msg) { } void CtrlClient::PushActEvent(const ActEvent& act_event) { - PushActEventClientCall call; + ClientCall<CtrlMethod::kPushActEvent> call; *(call.mut_request()->mutable_act_event()) = act_event; call(GetMasterStub()); } void CtrlClient::Clear() { - ClearClientCall call; + ClientCall<CtrlMethod::kClear> call; call(GetThisStub()); std::unique_lock<std::mutex> lck(done_names_mtx_); done_names_.clear(); } int32_t CtrlClient::IncreaseCount(const std::string& k, int32_t v) { - IncreaseCountClientCall call; + ClientCall<CtrlMethod::kIncreaseCount> call; call.mut_request()->set_key(k); call.mut_request()->set_val(v); call(GetResponsibleStub(k)); @@ -137,13 +136,13 @@ int32_t CtrlClient::IncreaseCount(const std::string& k, int32_t v) { } void CtrlClient::EraseCount(const std::string& k) { - EraseCountClientCall call; + ClientCall<CtrlMethod::kEraseCount> call; call.mut_request()->set_key(k); call(GetResponsibleStub(k)); } void CtrlClient::PushAvgActInterval(int64_t actor_id, double avg_act_interval) { - PushAvgActIntervalClientCall call; + ClientCall<CtrlMethod::kPushAvgActInterval> call; call.mut_request()->set_actor_id(actor_id); call.mut_request()->set_avg_act_interval(avg_act_interval); call(GetMasterStub()); @@ -169,7 +168,7 @@ CtrlClient::CtrlClient() { } for (size_t i = 0; i < stubs_.size(); ++i) { grpc::ClientContext client_ctx; - GRPC_CHECK(stubs_[i]->LoadServer(&client_ctx, request, &response)) + GRPC_CHECK(stubs_[i]->CallMethod<CtrlMethod::kLoadServer>(&client_ctx, request, &response)) << "Machine " << i << " lost"; } std::this_thread::sleep_for(std::chrono::seconds(sleep_second_dis(gen))); @@ -183,7 +182,7 @@ void CtrlClient::LoadServer(const std::string& server_addr, CtrlService::Stub* s grpc::ClientContext client_ctx; LoadServerRequest request; LoadServerResponse response; - grpc::Status st = stub->LoadServer(&client_ctx, request, &response); + grpc::Status st = stub->CallMethod<CtrlMethod::kLoadServer>(&client_ctx, request, &response); if (st.error_code() == grpc::StatusCode::OK) { LOG(INFO) << "LoadServer " << server_addr << " Successful at " << retry_idx << " times"; break; diff --git a/oneflow/core/control/ctrl_service.cpp b/oneflow/core/control/ctrl_service.cpp index 3b109c8a8f995b692625bfbae9ea1676549aefb7..14843bc83bb96efc73b703182ff51b2288688288 100644 --- a/oneflow/core/control/ctrl_service.cpp +++ b/oneflow/core/control/ctrl_service.cpp @@ -1,5 +1,4 @@ #include "oneflow/core/control/ctrl_service.h" -#include <grpc++/impl/codegen/client_unary_call.h> namespace oneflow { @@ -11,23 +10,23 @@ const char* g_method_name[] = { const char* GetMethodName(CtrlMethod method) { return g_method_name[static_cast<int32_t>(method)]; } -} // namespace +template<size_t method_index> +const grpc::RpcMethod BuildOneRpcMethod(std::shared_ptr<grpc::ChannelInterface> channel) { + return grpc::RpcMethod(GetMethodName(static_cast<CtrlMethod>(method_index)), + grpc::RpcMethod::NORMAL_RPC, channel); +} -CtrlService::Stub::Stub(std::shared_ptr<grpc::ChannelInterface> channel) - : -#define INIT_RPC_METHOD_OBJ(method) \ - rpcmethod_##method##_(GetMethodName(CtrlMethod::k##method), grpc::RpcMethod::NORMAL_RPC, channel), - OF_PP_FOR_EACH_TUPLE(INIT_RPC_METHOD_OBJ, CTRL_METHOD_SEQ) channel_(channel) { +template<size_t... method_indices> +std::array<const grpc::RpcMethod, kCtrlMethodNum> BuildRpcMethods( + std::index_sequence<method_indices...>, std::shared_ptr<grpc::ChannelInterface> channel) { + return {BuildOneRpcMethod<method_indices>(channel)...}; } -#define DEFINE_STUB_METHOD(method) \ - grpc::Status CtrlService::Stub::method( \ - grpc::ClientContext* context, const method##Request& request, method##Response* response) { \ - return grpc::BlockingUnaryCall(channel_.get(), rpcmethod_##method##_, context, request, \ - response); \ - } +} // namespace -OF_PP_FOR_EACH_TUPLE(DEFINE_STUB_METHOD, CTRL_METHOD_SEQ) +CtrlService::Stub::Stub(std::shared_ptr<grpc::ChannelInterface> channel) + : rpcmethods_(BuildRpcMethods(std::make_index_sequence<kCtrlMethodNum>{}, channel)), + channel_(channel) {} std::unique_ptr<CtrlService::Stub> CtrlService::NewStub(const std::string& addr) { return std::make_unique<Stub>(grpc::CreateChannel(addr, grpc::InsecureChannelCredentials())); diff --git a/oneflow/core/control/ctrl_service.h b/oneflow/core/control/ctrl_service.h index ef89a7ba3acf84ba0fe132ad3a684c4d3b28bb4c..198e3fb476179d8a02b35cb92b8eaaf5697f8673 100644 --- a/oneflow/core/control/ctrl_service.h +++ b/oneflow/core/control/ctrl_service.h @@ -10,6 +10,7 @@ #include <grpc++/impl/codegen/status.h> #include <grpc++/impl/codegen/stub_options.h> #include <grpc++/impl/codegen/sync_stream.h> +#include <grpc++/impl/codegen/client_unary_call.h> #include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/util.h" #include "oneflow/core/control/control.pb.h" @@ -34,22 +35,20 @@ namespace oneflow { enum class CtrlMethod { #define MAKE_ENTRY(method) k##method, OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) -#undef MAKE_ENTRY }; +#undef MAKE_ENTRY const int32_t kCtrlMethodNum = OF_PP_SEQ_SIZE(CTRL_METHOD_SEQ); using CtrlRequestTuple = std::tuple< #define MAKE_ENTRY(method) method##Request, - OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) + OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) void>; #undef MAKE_ENTRY - void>; using CtrlResponseTuple = std::tuple< #define MAKE_ENTRY(method) method##Response, - OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) + OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) void>; #undef MAKE_ENTRY - void>; template<CtrlMethod ctrl_method> using CtrlRequest = @@ -63,18 +62,17 @@ class CtrlService final { class Stub final { public: Stub(std::shared_ptr<grpc::ChannelInterface> channel); -#define DECLARE_STUB_METHOD(method) \ - grpc::Status method(grpc::ClientContext* context, const method##Request& request, \ - method##Response* response); - - OF_PP_FOR_EACH_TUPLE(DECLARE_STUB_METHOD, CTRL_METHOD_SEQ); -#undef DECLARE_STUB_METHOD + template<CtrlMethod ctrl_method> + grpc::Status CallMethod(grpc::ClientContext* context, const CtrlRequest<ctrl_method>& request, + CtrlResponse<ctrl_method>* response) { + return grpc::BlockingUnaryCall(channel_.get(), + rpcmethods_.at(static_cast<size_t>(ctrl_method)), context, + request, response); + } private: -#define DECLARE_RPC_METHOD(method) const grpc::RpcMethod rpcmethod_##method##_; - OF_PP_FOR_EACH_TUPLE(DECLARE_RPC_METHOD, CTRL_METHOD_SEQ); -#undef DECLARE_RPC_METHOD + std::array<const grpc::RpcMethod, kCtrlMethodNum> rpcmethods_; std::shared_ptr<grpc::ChannelInterface> channel_; };