Skip to content
Snippets Groups Projects
Commit 5fa0b3fa authored by qicosmos's avatar qicosmos Committed by Jinhui Yuan
Browse files

Oneflow qiyu (#928)

* remove cplusplus_17.h, it's no need now.

* improve make_unique

* make macros together

* remove OF_CALL_ONCE macro

* remove macros in ctrl_server

* -Wno-unused-function

* remove useless header
parent 210ead4b
No related branches found
No related tags found
No related merge requests found
Showing
with 308 additions and 192 deletions
......@@ -40,7 +40,7 @@ if (WIN32)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0")
else()
list(APPEND CUDA_NVCC_FLAGS -std=c++11 -w -Wno-deprecated-gpu-targets)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare -Wno-unused-function")
if (RELEASE_VERSION)
list(APPEND CUDA_NVCC_FLAGS -O3)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
......
......@@ -4,12 +4,35 @@
#if __cplusplus < 201402L
namespace std {
template<class T>
struct unique_if {
typedef unique_ptr<T> single_object;
};
template<class T>
struct unique_if<T[]> {
typedef unique_ptr<T[]> unknown_bound;
};
template<class T, size_t N>
struct unique_if<T[N]> {
typedef void known_bound;
};
template<typename T, typename... Args>
unique_ptr<T> make_unique(Args&&... args) {
template<class T, class... Args>
typename unique_if<T>::single_object make_unique(Args&&... args) {
return unique_ptr<T>(new T(forward<Args>(args)...));
}
template<class T>
typename unique_if<T>::unknown_bound make_unique(size_t n) {
typedef typename remove_extent<T>::type U;
return unique_ptr<T>(new U[n]());
}
template<class T, class... Args>
typename unique_if<T>::known_bound make_unique(Args&&...) = delete;
template<typename T, T... Ints>
struct integer_sequence {
static_assert(is_integral<T>::value, "");
......
#ifndef ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
#define ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
#if __cplusplus < 201703L
namespace std {}
#endif
#endif // ONEFLOW_CORE_COMMON_CPLUSPLUS_17_H_
#ifndef ONEFLOW_META_UTIL_HPP
#define ONEFLOW_META_UTIL_HPP
#include "oneflow/core/common/cplusplus_14.h"
namespace oneflow{
template <typename... Args, typename Func, std::size_t... Idx>
void for_each(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) {
(void)std::initializer_list<int> { (f(std::get<Idx>(t)), void(), 0)...};
}
template <typename... Args, typename Func, std::size_t... Idx>
void for_each_i(const std::tuple<Args...>& t, Func&& f, std::index_sequence<Idx...>) {
(void)std::initializer_list<int> { (f(std::get<Idx>(t), std::integral_constant<size_t, Idx>{}), void(), 0)...};
}
template<typename T>
struct function_traits;
template<typename Ret, typename... Args>
struct function_traits<Ret(Args...)>
{
public:
enum { arity = sizeof...(Args) };
typedef Ret function_type(Args...);
typedef Ret return_type;
using stl_function_type = std::function<function_type>;
typedef Ret(*pointer)(Args...);
typedef std::tuple<Args...> tuple_type;
};
template<typename Ret, typename... Args>
struct function_traits<Ret(*)(Args...)> : function_traits<Ret(Args...)>{};
template <typename Ret, typename... Args>
struct function_traits<std::function<Ret(Args...)>> : function_traits<Ret(Args...)>{};
template <typename ReturnType, typename ClassType, typename... Args>
struct function_traits<ReturnType(ClassType::*)(Args...)> : function_traits<ReturnType(Args...)>{};
template <typename ReturnType, typename ClassType, typename... Args>
struct function_traits<ReturnType(ClassType::*)(Args...) const> : function_traits<ReturnType(Args...)>{};
template<typename Callable>
struct function_traits : function_traits<decltype(&Callable::operator())>{};
}
#endif //ONEFLOW_META_UTIL_HPP
......@@ -24,8 +24,7 @@
#include <utility>
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/common/cplusplus_14.h"
#include "oneflow/core/common/cplusplus_17.h"
#include "oneflow/core/common/meta_util.hpp"
DECLARE_string(log_dir);
......
......@@ -26,6 +26,8 @@ class CtrlCall final : public CtrlCallIf {
CtrlCall() : status_(Status::kBeforeHandleRequest), responder_(&server_ctx_) {}
~CtrlCall() = default;
static constexpr const size_t value = (size_t)ctrl_method;
const CtrlRequest<ctrl_method>& request() const { return request_; }
CtrlRequest<ctrl_method>* mut_request() { return &request_; }
CtrlResponse<ctrl_method>* mut_response() { return &response_; }
......
......@@ -68,19 +68,37 @@ class CtrlClient final {
#define OF_BARRIER() Global<CtrlClient>::Get()->Barrier(FILE_LINE_STR)
#define OF_CALL_ONCE(name, ...) \
do { \
TryLockResult lock_ret = Global<CtrlClient>::Get()->TryLock(name); \
if (lock_ret == TryLockResult::kLocked) { \
__VA_ARGS__; \
Global<CtrlClient>::Get()->NotifyDone(name); \
} else if (lock_ret == TryLockResult::kDone) { \
} else if (lock_ret == TryLockResult::kDoing) { \
Global<CtrlClient>::Get()->WaitUntilDone(name); \
} else { \
UNIMPLEMENTED(); \
} \
} while (0)
static void OfCallOnce(const std::string& name, std::function<void()> f) {
TryLockResult lock_ret = Global<CtrlClient>::Get()->TryLock(name);
if (lock_ret == TryLockResult::kLocked) {
f();
Global<CtrlClient>::Get()->NotifyDone(name);
} else if (lock_ret == TryLockResult::kDone) {
} else if (lock_ret == TryLockResult::kDoing) {
Global<CtrlClient>::Get()->WaitUntilDone(name);
} else {
UNIMPLEMENTED();
}
}
template<typename Self, typename F, typename Arg, typename... Args>
static void OfCallOnce(const std::string& name, Self self, F f, Arg&& arg, Args&&... args) {
std::function<void()> fn =
std::bind(f, self, std::forward<Arg>(arg), std::forward<Args>(args)...);
OfCallOnce(name, std::move(fn));
}
template<typename Self, typename F>
static void OfCallOnce(const std::string& name, Self self, F f) {
std::function<void()> fn = std::bind(f, self, name);
OfCallOnce(name, std::move(fn));
}
template<typename F, typename Arg, typename... Args>
static void OfCallOnce(const std::string& name, F f, Arg&& arg, Args&&... args) {
std::function<void()> fn = std::bind(f, std::forward<Arg>(arg), std::forward<Args>(args)...);
OfCallOnce(name, std::move(fn));
}
} // namespace oneflow
......
......@@ -24,6 +24,8 @@ CtrlServer::~CtrlServer() {
}
CtrlServer::CtrlServer(const std::string& server_addr) {
Init();
if (FLAGS_grpc_use_no_signal) { grpc_use_signal(-1); }
int port = ExtractPortFromAddr(server_addr);
grpc::ServerBuilder server_builder;
......@@ -38,17 +40,8 @@ CtrlServer::CtrlServer(const std::string& server_addr) {
loop_thread_ = std::thread(&CtrlServer::HandleRpcs, this);
}
#define ENQUEUE_REQUEST(method) \
do { \
auto call = new CtrlCall<CtrlMethod::k##method>(); \
call->set_request_handler(std::bind(&CtrlServer::method##Handler, this, call)); \
grpc_service_->RequestAsyncUnary(static_cast<int32_t>(CtrlMethod::k##method), \
call->mut_server_ctx(), call->mut_request(), \
call->mut_responder(), cq_.get(), cq_.get(), call); \
} while (0);
void CtrlServer::HandleRpcs() {
OF_PP_FOR_EACH_TUPLE(ENQUEUE_REQUEST, CTRL_METHOD_SEQ);
EnqueueRequests();
void* tag = nullptr;
bool ok = false;
......@@ -64,141 +57,147 @@ void CtrlServer::HandleRpcs() {
}
}
void CtrlServer::LoadServerHandler(CtrlCall<CtrlMethod::kLoadServer>* call) {
call->SendResponse();
ENQUEUE_REQUEST(LoadServer);
}
void CtrlServer::Init() {
Add([this](CtrlCall<CtrlMethod::kLoadServer>* call) {
call->SendResponse();
EnqueueRequest<CtrlMethod::kLoadServer>();
});
Add([this](CtrlCall<CtrlMethod::kBarrier>* call) {
const std::string& barrier_name = call->request().name();
int32_t barrier_num = call->request().num();
auto barrier_call_it = barrier_calls_.find(barrier_name);
if (barrier_call_it == barrier_calls_.end()) {
barrier_call_it =
barrier_calls_
.emplace(barrier_name, std::make_pair(std::list<CtrlCallIf*>{}, barrier_num))
.first;
}
CHECK_EQ(barrier_num, barrier_call_it->second.second);
barrier_call_it->second.first.push_back(call);
if (barrier_call_it->second.first.size() == barrier_call_it->second.second) {
for (CtrlCallIf* pending_call : barrier_call_it->second.first) {
pending_call->SendResponse();
}
barrier_calls_.erase(barrier_call_it);
}
void CtrlServer::BarrierHandler(CtrlCall<CtrlMethod::kBarrier>* call) {
const std::string& barrier_name = call->request().name();
int32_t barrier_num = call->request().num();
auto barrier_call_it = barrier_calls_.find(barrier_name);
if (barrier_call_it == barrier_calls_.end()) {
barrier_call_it =
barrier_calls_.emplace(barrier_name, std::make_pair(std::list<CtrlCallIf*>{}, barrier_num))
.first;
}
CHECK_EQ(barrier_num, barrier_call_it->second.second);
barrier_call_it->second.first.push_back(call);
if (barrier_call_it->second.first.size() == barrier_call_it->second.second) {
for (CtrlCallIf* pending_call : barrier_call_it->second.first) { pending_call->SendResponse(); }
barrier_calls_.erase(barrier_call_it);
}
ENQUEUE_REQUEST(Barrier);
}
EnqueueRequest<CtrlMethod::kBarrier>();
});
void CtrlServer::TryLockHandler(CtrlCall<CtrlMethod::kTryLock>* call) {
const std::string& lock_name = call->request().name();
auto name2lock_status_it = name2lock_status_.find(lock_name);
if (name2lock_status_it == name2lock_status_.end()) {
call->mut_response()->set_result(TryLockResult::kLocked);
auto waiting_until_done_calls = new std::list<CtrlCallIf*>;
CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second);
} else {
if (name2lock_status_it->second) {
call->mut_response()->set_result(TryLockResult::kDoing);
Add([this](CtrlCall<CtrlMethod::kTryLock>* call) {
const std::string& lock_name = call->request().name();
auto name2lock_status_it = name2lock_status_.find(lock_name);
if (name2lock_status_it == name2lock_status_.end()) {
call->mut_response()->set_result(TryLockResult::kLocked);
auto waiting_until_done_calls = new std::list<CtrlCallIf*>;
CHECK(name2lock_status_.emplace(lock_name, waiting_until_done_calls).second);
} else {
call->mut_response()->set_result(TryLockResult::kDone);
if (name2lock_status_it->second) {
call->mut_response()->set_result(TryLockResult::kDoing);
} else {
call->mut_response()->set_result(TryLockResult::kDone);
}
}
}
call->SendResponse();
ENQUEUE_REQUEST(TryLock);
}
void CtrlServer::NotifyDoneHandler(CtrlCall<CtrlMethod::kNotifyDone>* call) {
const std::string& lock_name = call->request().name();
auto name2lock_status_it = name2lock_status_.find(lock_name);
auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(name2lock_status_it->second);
for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); }
delete waiting_calls;
name2lock_status_it->second = nullptr;
call->SendResponse();
ENQUEUE_REQUEST(NotifyDone);
}
void CtrlServer::WaitUntilDoneHandler(CtrlCall<CtrlMethod::kWaitUntilDone>* call) {
const std::string& lock_name = call->request().name();
void* lock_status = name2lock_status_.at(lock_name);
if (lock_status) {
auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(lock_status);
waiting_calls->push_back(call);
} else {
call->SendResponse();
}
ENQUEUE_REQUEST(WaitUntilDone);
}
void CtrlServer::PushKVHandler(CtrlCall<CtrlMethod::kPushKV>* call) {
const std::string& k = call->request().key();
const std::string& v = call->request().val();
CHECK(kv_.emplace(k, v).second);
auto pending_kv_calls_it = pending_kv_calls_.find(k);
if (pending_kv_calls_it != pending_kv_calls_.end()) {
for (auto pending_call : pending_kv_calls_it->second) {
pending_call->mut_response()->set_val(v);
pending_call->SendResponse();
EnqueueRequest<CtrlMethod::kTryLock>();
});
Add([this](CtrlCall<CtrlMethod::kNotifyDone>* call) {
const std::string& lock_name = call->request().name();
auto name2lock_status_it = name2lock_status_.find(lock_name);
auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(name2lock_status_it->second);
for (CtrlCallIf* waiting_call : *waiting_calls) { waiting_call->SendResponse(); }
delete waiting_calls;
name2lock_status_it->second = nullptr;
call->SendResponse();
EnqueueRequest<CtrlMethod::kNotifyDone>();
});
Add([this](CtrlCall<CtrlMethod::kWaitUntilDone>* call) {
const std::string& lock_name = call->request().name();
void* lock_status = name2lock_status_.at(lock_name);
if (lock_status) {
auto waiting_calls = static_cast<std::list<CtrlCallIf*>*>(lock_status);
waiting_calls->push_back(call);
} else {
call->SendResponse();
}
EnqueueRequest<CtrlMethod::kWaitUntilDone>();
});
Add([this](CtrlCall<CtrlMethod::kPushKV>* call) {
const std::string& k = call->request().key();
const std::string& v = call->request().val();
CHECK(kv_.emplace(k, v).second);
auto pending_kv_calls_it = pending_kv_calls_.find(k);
if (pending_kv_calls_it != pending_kv_calls_.end()) {
for (auto pending_call : pending_kv_calls_it->second) {
pending_call->mut_response()->set_val(v);
pending_call->SendResponse();
}
pending_kv_calls_.erase(pending_kv_calls_it);
}
pending_kv_calls_.erase(pending_kv_calls_it);
}
call->SendResponse();
ENQUEUE_REQUEST(PushKV);
}
void CtrlServer::ClearKVHandler(CtrlCall<CtrlMethod::kClearKV>* call) {
const std::string& k = call->request().key();
CHECK_EQ(kv_.erase(k), 1);
CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end());
call->SendResponse();
ENQUEUE_REQUEST(ClearKV);
}
void CtrlServer::PullKVHandler(CtrlCall<CtrlMethod::kPullKV>* call) {
const std::string& k = call->request().key();
auto kv_it = kv_.find(k);
if (kv_it != kv_.end()) {
call->mut_response()->set_val(kv_it->second);
call->SendResponse();
} else {
pending_kv_calls_[k].push_back(call);
}
ENQUEUE_REQUEST(PullKV);
}
EnqueueRequest<CtrlMethod::kPushKV>();
});
void CtrlServer::PushActEventHandler(CtrlCall<CtrlMethod::kPushActEvent>* call) {
ActEvent act_event = call->request().act_event();
call->SendResponse();
Global<ActEventLogger>::Get()->PrintActEventToLogDir(act_event);
ENQUEUE_REQUEST(PushActEvent);
}
Add([this](CtrlCall<CtrlMethod::kClearKV>* call) {
const std::string& k = call->request().key();
CHECK_EQ(kv_.erase(k), 1);
CHECK(pending_kv_calls_.find(k) == pending_kv_calls_.end());
call->SendResponse();
EnqueueRequest<CtrlMethod::kClearKV>();
});
Add([this](CtrlCall<CtrlMethod::kPullKV>* call) {
const std::string& k = call->request().key();
auto kv_it = kv_.find(k);
if (kv_it != kv_.end()) {
call->mut_response()->set_val(kv_it->second);
call->SendResponse();
} else {
pending_kv_calls_[k].push_back(call);
}
EnqueueRequest<CtrlMethod::kPullKV>();
});
void CtrlServer::ClearHandler(CtrlCall<CtrlMethod::kClear>* call) {
name2lock_status_.clear();
kv_.clear();
CHECK(pending_kv_calls_.empty());
call->SendResponse();
ENQUEUE_REQUEST(Clear);
}
Add([this](CtrlCall<CtrlMethod::kPushActEvent>* call) {
ActEvent act_event = call->request().act_event();
call->SendResponse();
Global<ActEventLogger>::Get()->PrintActEventToLogDir(act_event);
EnqueueRequest<CtrlMethod::kPushActEvent>();
});
Add([this](CtrlCall<CtrlMethod::kClear>* call) {
name2lock_status_.clear();
kv_.clear();
CHECK(pending_kv_calls_.empty());
call->SendResponse();
EnqueueRequest<CtrlMethod::kClear>();
});
void CtrlServer::IncreaseCountHandler(CtrlCall<CtrlMethod::kIncreaseCount>* call) {
int32_t& count = count_[call->request().key()];
count += call->request().val();
call->mut_response()->set_val(count);
call->SendResponse();
ENQUEUE_REQUEST(IncreaseCount);
}
Add([this](CtrlCall<CtrlMethod::kIncreaseCount>* call) {
int32_t& count = count_[call->request().key()];
count += call->request().val();
call->mut_response()->set_val(count);
call->SendResponse();
EnqueueRequest<CtrlMethod::kIncreaseCount>();
});
void CtrlServer::EraseCountHandler(CtrlCall<CtrlMethod::kEraseCount>* call) {
CHECK_EQ(count_.erase(call->request().key()), 1);
call->SendResponse();
ENQUEUE_REQUEST(EraseCount);
}
Add([this](CtrlCall<CtrlMethod::kEraseCount>* call) {
CHECK_EQ(count_.erase(call->request().key()), 1);
call->SendResponse();
EnqueueRequest<CtrlMethod::kEraseCount>();
});
void CtrlServer::PushAvgActIntervalHandler(CtrlCall<CtrlMethod::kPushAvgActInterval>* call) {
Global<Profiler>::Get()->PushAvgActInterval(call->request().actor_id(),
call->request().avg_act_interval());
call->SendResponse();
ENQUEUE_REQUEST(PushAvgActInterval);
Add([this](CtrlCall<CtrlMethod::kPushAvgActInterval>* call) {
Global<Profiler>::Get()->PushAvgActInterval(call->request().actor_id(),
call->request().avg_act_interval());
call->SendResponse();
EnqueueRequest<CtrlMethod::kPushAvgActInterval>();
});
}
} // namespace oneflow
......@@ -7,6 +7,15 @@
namespace oneflow {
namespace {
template<size_t... Idx>
static std::tuple<std::function<void(CtrlCall<(CtrlMethod)Idx>*)>...> GetHandlerTuple(
std::index_sequence<Idx...>) {
return {};
}
} // namespace
class CtrlServer final {
public:
OF_DISALLOW_COPY_AND_MOVE(CtrlServer);
......@@ -17,14 +26,44 @@ class CtrlServer final {
private:
void HandleRpcs();
void Init();
void EnqueueRequests() {
for_each_i(handlers_, helper{this}, std::make_index_sequence<kCtrlMethodNum>{});
}
template<CtrlMethod kMethod>
void EnqueueRequest() {
constexpr const size_t I = (size_t)kMethod;
auto handler = std::get<I>(handlers_);
auto call = new CtrlCall<(CtrlMethod)I>();
call->set_request_handler(std::bind(handler, call));
grpc_service_->RequestAsyncUnary(I, call->mut_server_ctx(), call->mut_request(),
call->mut_responder(), cq_.get(), cq_.get(), call);
}
template<typename F>
void Add(F f) {
using tuple_type = typename function_traits<F>::tuple_type;
using arg_type =
typename std::remove_pointer<typename std::tuple_element<0, tuple_type>::type>::type;
std::get<arg_type::value>(handlers_) = std::move(f);
}
#define DECLARE_CTRL_METHOD_HANDLE(method) \
void method##Handler(CtrlCall<CtrlMethod::k##method>* call);
struct helper {
helper(CtrlServer* s) : s_(s) {}
template<typename T, typename V>
void operator()(const T& t, V) {
s_->EnqueueRequest<(CtrlMethod)V::value>();
}
OF_PP_FOR_EACH_TUPLE(DECLARE_CTRL_METHOD_HANDLE, CTRL_METHOD_SEQ);
CtrlServer* s_;
};
#undef DECLARE_CTRL_METHOD_HANDLE
using HandlerTuple = decltype(GetHandlerTuple(std::make_index_sequence<kCtrlMethodNum>{}));
HandlerTuple handlers_;
std::unique_ptr<CtrlService::AsyncService> grpc_service_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> grpc_server_;
......
......@@ -4,12 +4,6 @@ namespace oneflow {
namespace {
const char* g_method_name[] = {
#define DEFINE_METHOD_NAME(method) "/oneflow.CtrlService/" OF_PP_STRINGIZE(method),
OF_PP_FOR_EACH_TUPLE(DEFINE_METHOD_NAME, CTRL_METHOD_SEQ)};
const char* GetMethodName(CtrlMethod method) { return g_method_name[static_cast<int32_t>(method)]; }
template<size_t method_index>
const grpc::RpcMethod BuildOneRpcMethod(std::shared_ptr<grpc::ChannelInterface> channel) {
return grpc::RpcMethod(GetMethodName(static_cast<CtrlMethod>(method_index)),
......
......@@ -32,31 +32,33 @@ namespace oneflow {
OF_PP_MAKE_TUPLE_SEQ(EraseCount) \
OF_PP_MAKE_TUPLE_SEQ(PushAvgActInterval)
enum class CtrlMethod {
#define MAKE_ENTRY(method) k##method,
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ)
};
#undef MAKE_ENTRY
#define CatRequest(method) method##Request,
#define CatReqponse(method) method##Response,
#define CatEnum(method) k##method,
#define CatName(method) "/oneflow.CtrlService/" OF_PP_STRINGIZE(method),
const int32_t kCtrlMethodNum = OF_PP_SEQ_SIZE(CTRL_METHOD_SEQ);
#define MAKE_META_DATA() \
enum class CtrlMethod { OF_PP_FOR_EACH_TUPLE(CatEnum, CTRL_METHOD_SEQ) }; \
static const char* g_method_name[] = {OF_PP_FOR_EACH_TUPLE(CatName, CTRL_METHOD_SEQ)}; \
using CtrlRequestTuple = std::tuple<OF_PP_FOR_EACH_TUPLE(CatRequest, CTRL_METHOD_SEQ) void>; \
using CtrlResponseTuple = std::tuple<OF_PP_FOR_EACH_TUPLE(CatReqponse, CTRL_METHOD_SEQ) void>;
using CtrlRequestTuple = std::tuple<
#define MAKE_ENTRY(method) method##Request,
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) void>;
#undef MAKE_ENTRY
MAKE_META_DATA()
using CtrlResponseTuple = std::tuple<
#define MAKE_ENTRY(method) method##Response,
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, CTRL_METHOD_SEQ) void>;
#undef MAKE_ENTRY
constexpr const size_t kCtrlMethodNum = OF_PP_SEQ_SIZE(CTRL_METHOD_SEQ);
template<CtrlMethod ctrl_method>
using CtrlRequest =
typename std::tuple_element<static_cast<size_t>(ctrl_method), CtrlRequestTuple>::type;
template<CtrlMethod ctrl_method>
using CtrlResponse =
typename std::tuple_element<static_cast<size_t>(ctrl_method), CtrlResponseTuple>::type;
inline const char* GetMethodName(CtrlMethod method) {
return g_method_name[static_cast<int32_t>(method)];
}
class CtrlService final {
public:
class Stub final {
......
......@@ -7,7 +7,7 @@ namespace oneflow {
void PrintKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) {
const auto& conf = op_conf().print_conf();
const std::string& root_path = conf.print_dir();
OF_CALL_ONCE(root_path, GlobalFS()->RecursivelyCreateDir(root_path));
OfCallOnce(root_path, GlobalFS(), &fs::FileSystem::RecursivelyCreateDir);
int32_t part_name_suffix_length = conf.part_name_suffix_length();
std::string num = std::to_string(parallel_ctx->parallel_id());
int32_t zero_count = std::max(part_name_suffix_length - static_cast<int32_t>(num.length()), 0);
......
......@@ -7,8 +7,8 @@ namespace oneflow {
PersistentOutStream::PersistentOutStream(fs::FileSystem* fs, const std::string& file_path) {
std::string file_dir = Dirname(file_path);
OF_CALL_ONCE(Global<MachineCtx>::Get()->GetThisCtrlAddr() + "/" + file_dir,
fs->RecursivelyCreateDirIfNotExist(file_dir));
OfCallOnce(Global<MachineCtx>::Get()->GetThisCtrlAddr() + "/" + file_dir, fs,
&fs::FileSystem::RecursivelyCreateDirIfNotExist, file_dir);
fs->NewWritableFile(file_path, &file_);
}
......
......@@ -18,10 +18,10 @@ std::unique_ptr<PersistentOutStream> Snapshot::GetOutStream(const LogicalBlobId&
int32_t part_id) {
// op_name_dir
std::string op_name_dir = JoinPath(root_path_, lbi.op_name());
OF_CALL_ONCE(op_name_dir, GlobalFS()->CreateDir(op_name_dir));
OfCallOnce(op_name_dir, GlobalFS(), &fs::FileSystem::CreateDir);
// bn_in_op_tmp_dir
std::string bn_in_op_tmp_dir = JoinPath(op_name_dir, lbi.blob_name() + "_tmp4a58");
OF_CALL_ONCE(bn_in_op_tmp_dir, GlobalFS()->CreateDir(bn_in_op_tmp_dir));
OfCallOnce(bn_in_op_tmp_dir, GlobalFS(), &fs::FileSystem::CreateDir);
// part_file
std::string part_file = JoinPath(bn_in_op_tmp_dir, "part_" + std::to_string(part_id));
return std::make_unique<PersistentOutStream>(GlobalFS(), part_file);
......
......@@ -7,7 +7,7 @@ namespace oneflow {
SnapshotMgr::SnapshotMgr(const Plan& plan) {
if (Global<JobDesc>::Get()->IsTrain()) {
model_save_snapshots_path_ = Global<JobDesc>::Get()->MdSaveSnapshotsPath();
OF_CALL_ONCE(model_save_snapshots_path_, GlobalFS()->MakeEmptyDir(model_save_snapshots_path_));
OfCallOnce(model_save_snapshots_path_, GlobalFS(), &fs::FileSystem::MakeEmptyDir);
}
const std::string& load_path = Global<JobDesc>::Get()->MdLoadSnapshotPath();
if (load_path != "") { readable_snapshot_.reset(new Snapshot(load_path)); }
......@@ -20,7 +20,7 @@ Snapshot* SnapshotMgr::GetWriteableSnapshot(int64_t snapshot_id) {
if (it == snapshot_id2writeable_snapshot_.end()) {
std::string snapshot_root_path =
JoinPath(model_save_snapshots_path_, "snapshot_" + std::to_string(snapshot_id));
OF_CALL_ONCE(snapshot_root_path, GlobalFS()->CreateDirIfNotExist(snapshot_root_path));
OfCallOnce(snapshot_root_path, GlobalFS(), &fs::FileSystem::CreateDirIfNotExist);
std::unique_ptr<Snapshot> ret(new Snapshot(snapshot_root_path));
auto emplace_ret = snapshot_id2writeable_snapshot_.emplace(snapshot_id, std::move(ret));
it = emplace_ret.first;
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment