Skip to content
Snippets Groups Projects
Unverified Commit 94ffe857 authored by liufengwei0103's avatar liufengwei0103 Committed by GitHub
Browse files

replace ForeignJobInstance using JobInstance (#5374)


Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 5806e2ba
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,7 @@ limitations under the License.
#include "oneflow/core/job/inter_user_job_info.pb.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/job/foreign_watcher.h"
#include "oneflow/core/job/foreign_job_instance.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/oneflow.h"
#include "oneflow/core/job/placement.pb.h"
#include "oneflow/core/framework/config_def.h"
......@@ -58,11 +58,11 @@ inline Maybe<void> RegisterWatcherOnlyOnce(const std::shared_ptr<ForeignWatcher>
return Maybe<void>::Ok();
}
inline Maybe<void> LaunchJob(const std::shared_ptr<oneflow::ForeignJobInstance>& cb) {
inline Maybe<void> LaunchJob(const std::shared_ptr<oneflow::JobInstance>& cb) {
CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());
CHECK_NOTNULL_OR_RETURN(Global<Oneflow>::Get());
const auto& job_name = cb->job_name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<ForeignJobInstance>>>::Get();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
int64_t job_id = Global<JobName2JobId>::Get()->at(job_name);
if (IsPullJob(job_name, *Global<InterUserJobInfo>::Get())) {
buffer_mgr->Get(GetForeignOutputBufferName(job_name))->Send(cb);
......
......@@ -28,7 +28,7 @@ inline void RegisterWatcherOnlyOnce(const std::shared_ptr<oneflow::ForeignWatche
return oneflow::RegisterWatcherOnlyOnce(watcher).GetOrThrow();
}
inline void LaunchJob(const std::shared_ptr<oneflow::ForeignJobInstance>& cb) {
inline void LaunchJob(const std::shared_ptr<oneflow::JobInstance>& cb) {
return oneflow::LaunchJob(cb).GetOrThrow();
}
......
......@@ -18,42 +18,42 @@ limitations under the License.
#include <memory>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/foreign_job_instance.h"
#include "oneflow/core/job/job_instance.h"
namespace py = pybind11;
namespace oneflow {
class PyForeignJobInstance : public ForeignJobInstance {
class PyJobInstance : public JobInstance {
public:
// Inherit the constructors
using ForeignJobInstance::ForeignJobInstance;
using JobInstance::JobInstance;
// Trampoline (need one for each virtual function)
std::string job_name() const override {
PYBIND11_OVERRIDE(std::string, /* Return type */
ForeignJobInstance, /* Parent class */
job_name, /* Name of function in C++ (must match Python name) */
PYBIND11_OVERRIDE(std::string, /* Return type */
JobInstance, /* Parent class */
job_name, /* Name of function in C++ (must match Python name) */
);
}
std::string sole_input_op_name_in_user_job() const override {
PYBIND11_OVERRIDE(std::string, ForeignJobInstance, sole_input_op_name_in_user_job, );
PYBIND11_OVERRIDE(std::string, JobInstance, sole_input_op_name_in_user_job, );
}
std::string sole_output_op_name_in_user_job() const override {
PYBIND11_OVERRIDE(std::string, ForeignJobInstance, sole_output_op_name_in_user_job, );
PYBIND11_OVERRIDE(std::string, JobInstance, sole_output_op_name_in_user_job, );
}
void PushBlob(uint64_t ofblob_ptr) const override {
PYBIND11_OVERRIDE(void, ForeignJobInstance, PushBlob, ofblob_ptr);
PYBIND11_OVERRIDE(void, JobInstance, PushBlob, ofblob_ptr);
}
void PullBlob(uint64_t ofblob_ptr) const override {
PYBIND11_OVERRIDE(void, ForeignJobInstance, PullBlob, ofblob_ptr);
PYBIND11_OVERRIDE(void, JobInstance, PullBlob, ofblob_ptr);
}
void Finish() const override { PYBIND11_OVERRIDE(void, ForeignJobInstance, Finish, ); }
void Finish() const override { PYBIND11_OVERRIDE(void, JobInstance, Finish, ); }
};
} // namespace oneflow
......@@ -61,13 +61,12 @@ class PyForeignJobInstance : public ForeignJobInstance {
ONEFLOW_API_PYBIND11_MODULE("", m) {
using namespace oneflow;
py::class_<ForeignJobInstance, PyForeignJobInstance, std::shared_ptr<ForeignJobInstance>>(
m, "ForeignJobInstance")
py::class_<JobInstance, PyJobInstance, std::shared_ptr<JobInstance>>(m, "JobInstance")
.def(py::init<>())
.def("job_name", &ForeignJobInstance::job_name)
.def("sole_input_op_name_in_user_job", &ForeignJobInstance::sole_input_op_name_in_user_job)
.def("sole_output_op_name_in_user_job", &ForeignJobInstance::sole_output_op_name_in_user_job)
.def("PushBlob", &ForeignJobInstance::PushBlob)
.def("PullBlob", &ForeignJobInstance::PullBlob)
.def("Finish", &ForeignJobInstance::Finish);
.def("job_name", &JobInstance::job_name)
.def("sole_input_op_name_in_user_job", &JobInstance::sole_input_op_name_in_user_job)
.def("sole_output_op_name_in_user_job", &JobInstance::sole_output_op_name_in_user_job)
.def("PushBlob", &JobInstance::PushBlob)
.def("PullBlob", &JobInstance::PullBlob)
.def("Finish", &JobInstance::Finish);
}
......@@ -20,11 +20,11 @@ limitations under the License.
namespace oneflow {
class ForeignJobInstance {
class JobInstance {
public:
ForeignJobInstance() = default;
JobInstance() = default;
virtual ~ForeignJobInstance() = default;
virtual ~JobInstance() = default;
virtual std::string job_name() const { UNIMPLEMENTED(); }
virtual std::string sole_input_op_name_in_user_job() const { UNIMPLEMENTED(); }
......
......@@ -15,17 +15,17 @@ limitations under the License.
*/
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/runtime_buffer_managers_scope.h"
#include "oneflow/core/job/foreign_job_instance.h"
#include "oneflow/core/job/job_instance.h"
namespace oneflow {
RuntimeBufferManagersScope::RuntimeBufferManagersScope() {
Global<BufferMgr<int64_t>>::New();
Global<BufferMgr<std::shared_ptr<ForeignJobInstance>>>::New();
Global<BufferMgr<std::shared_ptr<JobInstance>>>::New();
}
RuntimeBufferManagersScope::~RuntimeBufferManagersScope() {
Global<BufferMgr<std::shared_ptr<ForeignJobInstance>>>::Delete();
Global<BufferMgr<std::shared_ptr<JobInstance>>>::Delete();
Global<BufferMgr<int64_t>>::Delete();
}
......
......@@ -16,14 +16,14 @@ limitations under the License.
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/runtime_buffers_scope.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/foreign_job_instance.h"
#include "oneflow/core/job/job_instance.h"
namespace oneflow {
RuntimeBuffersScope::RuntimeBuffersScope(const JobConfs& job_confs) {
size_t job_size = Global<JobName2JobId>::Get()->size();
Global<BufferMgr<int64_t>>::Get()->NewBuffer(kBufferNameGlobalWaitJobId, job_size);
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<ForeignJobInstance>>>::Get();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
for (const auto& pair : job_confs.job_id2job_conf()) {
const auto& job_name = pair.second.job_name();
CHECK_EQ(pair.first, Global<JobName2JobId>::Get()->at(job_name));
......@@ -35,7 +35,7 @@ RuntimeBuffersScope::RuntimeBuffersScope(const JobConfs& job_confs) {
}
RuntimeBuffersScope::~RuntimeBuffersScope() {
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<ForeignJobInstance>>>::Get();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
for (const auto& pair : *Global<JobName2JobId>::Get()) {
const auto& job_name = pair.first;
buffer_mgr->Get(GetCallbackNotifierBufferName(job_name))->Close();
......
......@@ -21,7 +21,7 @@ limitations under the License.
#include "oneflow/core/job/available_memory_desc.pb.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/profiler.h"
#include "oneflow/core/job/foreign_job_instance.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/inter_user_job_info.pb.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/critical_section_desc.h"
......
......@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/callback_notify_kernel.h"
#include "oneflow/core/job/foreign_job_instance.h"
#include "oneflow/core/job/job_instance.h"
namespace oneflow {
......@@ -23,8 +23,8 @@ void CallbackNotifyKernel<T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
T job_id = *BnInOp2Blob("in")->dptr<T>();
const auto& buffer_name = this->op_conf().callback_notify_conf().callback_buffer_name(job_id);
std::shared_ptr<ForeignJobInstance> foreign_job_instance;
BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<ForeignJobInstance>>>::Get()
std::shared_ptr<JobInstance> foreign_job_instance;
BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get()
->Get(buffer_name)
->TryReceive(&foreign_job_instance);
CHECK_NE(buffer_status, kBufferStatusEmpty);
......
......@@ -16,15 +16,15 @@ limitations under the License.
#include "oneflow/core/kernel/foreign_input_kernel.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/job/foreign_job_instance.h"
#include "oneflow/core/job/job_instance.h"
namespace oneflow {
void ForeignInputKernel::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const auto& buffer_name = op_conf().foreign_input_conf().ofblob_buffer_name();
std::shared_ptr<ForeignJobInstance> foreign_job_instance;
BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<ForeignJobInstance>>>::Get()
std::shared_ptr<JobInstance> foreign_job_instance;
BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get()
->Get(buffer_name)
->TryReceive(&foreign_job_instance);
CHECK_NE(buffer_status, kBufferStatusEmpty);
......
......@@ -16,15 +16,15 @@ limitations under the License.
#include "oneflow/core/kernel/foreign_output_kernel.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/job/foreign_job_instance.h"
#include "oneflow/core/job/job_instance.h"
namespace oneflow {
void ForeignOutputKernel::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const auto& buffer_name = op_conf().foreign_output_conf().ofblob_buffer_name();
std::shared_ptr<ForeignJobInstance> foreign_job_instance;
BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<ForeignJobInstance>>>::Get()
std::shared_ptr<JobInstance> foreign_job_instance;
BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get()
->Get(buffer_name)
->TryReceive(&foreign_job_instance);
CHECK_NE(buffer_status, kBufferStatusEmpty);
......
......@@ -73,7 +73,7 @@ def MakeJobInstance(*arg, **kw):
return job_instance
class JobInstance(oneflow._oneflow_internal.ForeignJobInstance):
class JobInstance(oneflow._oneflow_internal.JobInstance):
def __init__(
self,
job_name,
......@@ -83,7 +83,7 @@ class JobInstance(oneflow._oneflow_internal.ForeignJobInstance):
pull_cb=None,
finish_cb=None,
):
oneflow._oneflow_internal.ForeignJobInstance.__init__(self)
oneflow._oneflow_internal.JobInstance.__init__(self)
self.thisown = 0
self.job_name_ = str(job_name)
self.sole_input_op_name_in_user_job_ = str(sole_input_op_name_in_user_job)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment