Skip to content
Snippets Groups Projects
Unverified Commit 3dbf6ae5 authored by cheng cheng's avatar cheng cheng Committed by GitHub
Browse files

Register ForeignCallback and Watcher in Multi-Client (#5591)

parent fc546b93
No related branches found
No related tags found
No related merge requests found
......@@ -22,8 +22,10 @@ limitations under the License.
namespace py = pybind11;
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("RegisterForeignCallbackOnlyOnce", &RegisterForeignCallbackOnlyOnce);
m.def("RegisterWatcherOnlyOnce", &RegisterWatcherOnlyOnce);
m.def("RegisterGlobalForeignCallback", &RegisterGlobalForeignCallback);
m.def("DestroyGlobalForeignCallback", &DestroyGlobalForeignCallback);
m.def("RegisterGlobalWatcher", &RegisterGlobalWatcher);
m.def("DestroyGlobalWatcher", &DestroyGlobalWatcher);
m.def("LaunchJob", &LaunchJob, py::call_guard<py::gil_scoped_release>());
m.def("GetSerializedInterUserJobInfo",
......
......@@ -37,8 +37,7 @@ limitations under the License.
namespace oneflow {
inline Maybe<void> RegisterForeignCallbackOnlyOnce(
const std::shared_ptr<ForeignCallback>& callback) {
inline Maybe<void> RegisterGlobalForeignCallback(const std::shared_ptr<ForeignCallback>& callback) {
CHECK_ISNULL_OR_RETURN(Global<std::shared_ptr<ForeignCallback>>::Get())
<< "foreign callback registered";
// Global<T>::SetAllocated is preferred since Global<T>::New will output logs but
......@@ -48,7 +47,14 @@ inline Maybe<void> RegisterForeignCallbackOnlyOnce(
return Maybe<void>::Ok();
}
inline Maybe<void> RegisterWatcherOnlyOnce(const std::shared_ptr<ForeignWatcher>& watcher) {
inline Maybe<void> DestroyGlobalForeignCallback() {
if (Global<std::shared_ptr<ForeignCallback>>::Get()) {
Global<std::shared_ptr<ForeignCallback>>::Delete();
}
return Maybe<void>::Ok();
}
inline Maybe<void> RegisterGlobalWatcher(const std::shared_ptr<ForeignWatcher>& watcher) {
CHECK_ISNULL_OR_RETURN(Global<std::shared_ptr<ForeignWatcher>>::Get())
<< "foreign watcher registered";
// Global<T>::SetAllocated is preferred since Global<T>::New will output logs but
......@@ -58,6 +64,13 @@ inline Maybe<void> RegisterWatcherOnlyOnce(const std::shared_ptr<ForeignWatcher>
return Maybe<void>::Ok();
}
inline Maybe<void> DestroyGlobalWatcher() {
if (Global<std::shared_ptr<ForeignWatcher>>::Get()) {
Global<std::shared_ptr<ForeignWatcher>>::Delete();
}
return Maybe<void>::Ok();
}
inline Maybe<void> LaunchJob(const std::shared_ptr<oneflow::JobInstance>& cb) {
CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());
CHECK_NOTNULL_OR_RETURN(Global<Oneflow>::Get());
......
......@@ -19,15 +19,21 @@ limitations under the License.
#include "oneflow/api/python/framework/framework.h"
#include "oneflow/core/serving/saved_model.cfg.h"
inline void RegisterForeignCallbackOnlyOnce(
inline void RegisterGlobalForeignCallback(
const std::shared_ptr<oneflow::ForeignCallback>& callback) {
return oneflow::RegisterForeignCallbackOnlyOnce(callback).GetOrThrow();
return oneflow::RegisterGlobalForeignCallback(callback).GetOrThrow();
}
inline void RegisterWatcherOnlyOnce(const std::shared_ptr<oneflow::ForeignWatcher>& watcher) {
return oneflow::RegisterWatcherOnlyOnce(watcher).GetOrThrow();
inline void DestroyGlobalForeignCallback() {
return oneflow::DestroyGlobalForeignCallback().GetOrThrow();
}
inline void RegisterGlobalWatcher(const std::shared_ptr<oneflow::ForeignWatcher>& watcher) {
return oneflow::RegisterGlobalWatcher(watcher).GetOrThrow();
}
inline void DestroyGlobalWatcher() { return oneflow::DestroyGlobalWatcher().GetOrThrow(); }
inline void LaunchJob(const std::shared_ptr<oneflow::JobInstance>& cb) {
return oneflow::LaunchJob(cb).GetOrThrow();
}
......
......@@ -82,6 +82,21 @@ oneflow._oneflow_internal.EnableEagerEnvironment(True)
del env_util
# NOTE(chengcheng): register ForeignCallback and Watcher used by nn.Graph train job compelete
from oneflow.python.framework import register_python_callback
from oneflow.python.framework import python_callback
oneflow._oneflow_internal.RegisterGlobalForeignCallback(
python_callback.global_python_callback
)
del python_callback
del register_python_callback
from oneflow.python.framework import watcher
oneflow._oneflow_internal.RegisterGlobalWatcher(watcher._global_watcher)
del watcher
def _SyncOnMasterFn():
import oneflow
......
......@@ -50,6 +50,9 @@ from oneflow.compatible.single_client.python.framework import session_context
from oneflow.compatible.single_client.python.framework import env_util
# NOTE(chengcheng): Destroy ForeignCallback and Watcher for created by Multi-Client init.
oneflow._oneflow_internal.DestroyGlobalWatcher()
oneflow._oneflow_internal.DestroyGlobalForeignCallback()
oneflow._oneflow_internal.DestroyEnv()
import time
......@@ -76,7 +79,7 @@ import oneflow.compatible.single_client.python.framework.c_api_util
from oneflow.compatible.single_client.python.framework import register_python_callback
from oneflow.compatible.single_client.python.framework import python_callback
oneflow._oneflow_internal.RegisterForeignCallbackOnlyOnce(
oneflow._oneflow_internal.RegisterGlobalForeignCallback(
python_callback.global_python_callback
)
del python_callback
......@@ -85,7 +88,7 @@ del register_python_callback
# register Watcher
from oneflow.compatible.single_client.python.framework import watcher
oneflow._oneflow_internal.RegisterWatcherOnlyOnce(watcher._global_watcher)
oneflow._oneflow_internal.RegisterGlobalWatcher(watcher._global_watcher)
del watcher
# register BoxingUtil
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment