diff --git a/oneflow/core/job/job_builder.cpp b/oneflow/core/job/job_builder.cpp
index 39c5b0539dfd108be1d5ad49a13d5b71328cfae3..9aaa81a4b2d1685a4a2392fd7cc65cd7e83a3f02 100644
--- a/oneflow/core/job/job_builder.cpp
+++ b/oneflow/core/job/job_builder.cpp
@@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/common/util.h"
+#include "oneflow/core/common/container_util.h"
#include "oneflow/core/operator/operator.h"
namespace oneflow {
@@ -157,14 +158,14 @@ JobBuilder::JobBuilder(Job* job) : job_(job) {
}
}
-OperatorConf* JobBuilder::MutableOpConf4OpName(const std::string& op_name) {
+Maybe<OperatorConf*> JobBuilder::MutableOpConf4OpName(const std::string& op_name) {
const auto& it = op_name2op_conf_.find(op_name);
- CHECK(it != op_name2op_conf_.end());
+ CHECK_OR_RETURN(it != op_name2op_conf_.end());
return it->second;
}
-const OperatorConf& JobBuilder::OpConf4OpName(const std::string& op_name) const {
- return *op_name2op_conf_.at(op_name);
+Maybe<const OperatorConf&> JobBuilder::OpConf4OpName(const std::string& op_name) const {
+ return *JUST(MapAt(op_name2op_conf_, op_name));
}
const ParallelConf& JobBuilder::ParallelConf4Lbi(const LogicalBlobId& lbi) const {
diff --git a/oneflow/core/job/job_builder.h b/oneflow/core/job/job_builder.h
index b641f46b3f47e5124d3036ed1af18bf1c85d3ffc..8e07abfb81aa553c7016d6a2a25072413c7b9eae 100644
--- a/oneflow/core/job/job_builder.h
+++ b/oneflow/core/job/job_builder.h
@@ -50,8 +50,8 @@ class JobBuilder final {
return job_->mutable_job_parallel_view_conf();
}
- const OperatorConf& OpConf4OpName(const std::string& op_name) const;
- OperatorConf* MutableOpConf4OpName(const std::string& op_name);
+ Maybe<const OperatorConf&> OpConf4OpName(const std::string& op_name) const;
+ Maybe<OperatorConf*> MutableOpConf4OpName(const std::string& op_name);
Maybe<void> AddOp(const ParallelConf& parallel_conf, const OperatorConf& op_conf);
void AddOps(const ParallelConf& parallel_conf, const std::vector<OperatorConf>& op_confs);
diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp
index 93d440009b25a2e886ec8d8d8a33fafcf19bdd91..c186f902021cc22eefc3e45ebf90d26c0d55e7c6 100644
--- a/oneflow/core/job/oneflow.cpp
+++ b/oneflow/core/job/oneflow.cpp
@@ -400,7 +400,7 @@ void GetMemSharingOpBlobInfo(const JobBuilder& job_builder, const std::string& o
std::string obn = "out";
std::string lbn;
{
- const auto& op_conf = job_builder.OpConf4OpName(op_name);
+ const auto& op_conf = CHECK_JUST(job_builder.OpConf4OpName(op_name));
if (op_conf.has_variable_conf()) {
lbn = op_name + "/" + op_conf.variable_conf().out();
} else if (op_conf.has_input_conf()) {
diff --git a/oneflow/core/job_rewriter/add_lbi_diff_watcher.cpp b/oneflow/core/job_rewriter/add_lbi_diff_watcher.cpp
index e1c540d4a13cbd8b7ceea02dec84d3ca282f0770..16987f3d96b054d2a65e501ea14c57a599876b8c 100644
--- a/oneflow/core/job_rewriter/add_lbi_diff_watcher.cpp
+++ b/oneflow/core/job_rewriter/add_lbi_diff_watcher.cpp
@@ -48,7 +48,7 @@ Maybe<void> AddLbiDiffWatcherOpConfs::Apply(Job* job) const {
for (const LbiAndDiffWatcherUuidPair& pair : pair_list) {
if (lbi2diff_lbi.find(pair.lbi()) == lbi2diff_lbi.end()) { continue; }
const auto& diff_lbi = lbi2diff_lbi.at(pair.lbi());
- const auto& diff_lbi_op_conf = job_builder.OpConf4OpName(diff_lbi.op_name());
+ const auto& diff_lbi_op_conf = JUST(job_builder.OpConf4OpName(diff_lbi.op_name()));
OperatorConf foreign_watcher_op;
foreign_watcher_op.set_name("System-LbiDiffWatcher-ForeignWatcher-" + NewUniqueId());
foreign_watcher_op.set_scope_symbol_id(diff_lbi_op_conf.scope_symbol_id());
diff --git a/oneflow/core/job_rewriter/autotick.cpp b/oneflow/core/job_rewriter/autotick.cpp
index 1eeb3a358d917d618f5215494ace28d3caf671ff..042c8f6669ec2b3b2cbd4563ca8e7f9d16e05abd 100644
--- a/oneflow/core/job_rewriter/autotick.cpp
+++ b/oneflow/core/job_rewriter/autotick.cpp
@@ -420,6 +420,54 @@ Maybe<void> AddGlobalInputOutputCriticalSection(
return Maybe<void>::Ok();
}
+Maybe<void> MultiClientAddWaitAndSendIds(JobBuilder* job_builder, int64_t machine_id,
+ const std::string& src_op_name) {
+ ParallelConf parallel_conf;
+ {
+ parallel_conf.set_device_tag("cpu");
+ parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0");
+ }
+ OperatorConf wait_and_send_ids_op_conf;
+ {
+ wait_and_send_ids_op_conf.set_name(std::string("System-Src-WaitAndSendIds_") + NewUniqueId());
+ wait_and_send_ids_op_conf.set_pass_tag(kMainOp);
+ auto* wait_and_send_ids_conf = wait_and_send_ids_op_conf.mutable_wait_and_send_ids_conf();
+ wait_and_send_ids_conf->set_out("out");
+ wait_and_send_ids_conf->set_wait_buffer_name("UnimplementedBufferName");
+ wait_and_send_ids_conf->set_data_type(DataType::kInt32);
+ // wait_and_send_ids_conf->id_list() is unused in multi-client mode.
+ }
+ JUST(job_builder->AddOp(parallel_conf, wait_and_send_ids_op_conf));
+ OperatorConf source_tick_op = JUST(job_builder->OpConf4OpName(src_op_name));
+ {
+ CHECK_OR_RETURN(source_tick_op.has_source_tick_conf());
+ auto* source_tick_op_conf = source_tick_op.mutable_source_tick_conf();
+ CHECK_OR_RETURN(!source_tick_op_conf->has_wait_in());
+ source_tick_op_conf->set_wait_in(GenLogicalBlobName(wait_and_send_ids_op_conf.name(), "out"));
+ }
+ JUST(job_builder->MutOpOnlyOnce(source_tick_op));
+ return Maybe<void>::Ok();
+}
+
+Maybe<void> MultiClientAddCallbackNotifier(JobBuilder* job_builder, int64_t machine_id,
+ const std::string& sink_op_name) {
+ ParallelConf parallel_conf;
+ {
+ parallel_conf.set_device_tag("cpu");
+ parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0");
+ }
+ OperatorConf callback_notify_op_conf;
+ {
+ callback_notify_op_conf.set_name(std::string("System-Sink-CallbackNotify_") + NewUniqueId());
+ callback_notify_op_conf.set_pass_tag(kMainOp);
+ auto* callback_notify_conf = callback_notify_op_conf.mutable_callback_notify_conf();
+ callback_notify_conf->set_in(GenLogicalBlobName(sink_op_name, "/out"));
+ // callback_notify_conf->callback_buffer_name() is unused in multi-client mode.
+ }
+ JUST(job_builder->AddOp(parallel_conf, callback_notify_op_conf));
+ return Maybe<void>::Ok();
+}
+
} // namespace
Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {
@@ -490,6 +538,34 @@ Maybe<void> SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilde
return Maybe<void>::Ok();
}
+Maybe<void> MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job) {
+ if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
+ HashMap<int64_t, std::string> machine_id2src_op_name;
+ HashMap<int64_t, std::string> machine_id2sink_op_name;
+ {
+ JobBuilder job_builder(job);
+ const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {
+ CHECK_OR_RETURN(machine_id2src_op_name.emplace(machine_id, op_name).second);
+ return Maybe<void>::Ok();
+ };
+ const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {
+ CHECK_OR_RETURN(machine_id2sink_op_name.emplace(machine_id, op_name).second);
+ return Maybe<void>::Ok();
+ };
+ JUST(AutoSourceAndSinkTick(op_graph, &job_builder, DoEachSrc, DoEachSink));
+ }
+ {
+ JobBuilder job_builder(job);
+ for (const auto& pair : machine_id2src_op_name) {
+ JUST(MultiClientAddWaitAndSendIds(&job_builder, pair.first, pair.second));
+ }
+ for (const auto& pair : machine_id2sink_op_name) {
+ JUST(MultiClientAddCallbackNotifier(&job_builder, pair.first, pair.second));
+ }
+ }
+ return Maybe<void>::Ok();
+}
+
Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph,
JobBuilder* job_builder) {
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
diff --git a/oneflow/core/job_rewriter/autotick.h b/oneflow/core/job_rewriter/autotick.h
index 1935ab945b9f05629610c672195ae40ffea9f5d6..66579d8a81dd4296c9fbde416dbef8bc79f792bb 100644
--- a/oneflow/core/job_rewriter/autotick.h
+++ b/oneflow/core/job_rewriter/autotick.h
@@ -29,6 +29,7 @@ Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph,
JobBuilder* job_builder);
Maybe<void> SingleClientAddGlobalOutputCriticalSections(const OpGraph& op_graph,
JobBuilder* job_builder);
+Maybe<void> MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job);
class MutOpConTickInputHelper {
public:
diff --git a/oneflow/core/job_rewriter/clone_grad.cpp b/oneflow/core/job_rewriter/clone_grad.cpp
index e2d677d51586256891f2ff8a7df8ac11c5a91799..a3c52f616cb381d4024e6ed0a961e41d548740d0 100644
--- a/oneflow/core/job_rewriter/clone_grad.cpp
+++ b/oneflow/core/job_rewriter/clone_grad.cpp
@@ -55,10 +55,9 @@ void GenerateCloneGradOpIfNeed(const OpNode& op_node, JobBuilder* job_builder,
add_op_builder.Input("in", GenLogicalBlobName(lbis_to_add.at(i)));
}
lbis_to_add.resize(start);
+ const auto& op_conf = CHECK_JUST(job_builder->OpConf4OpName(lbi.op_name()));
const auto add_op =
- add_op_builder.Output("out")
- .ScopeSymbolId(job_builder->OpConf4OpName(lbi.op_name()).scope_symbol_id())
- .Build();
+ add_op_builder.Output("out").ScopeSymbolId(op_conf.scope_symbol_id()).Build();
job_builder->AddOps(job_builder->ParallelConf4Lbi(lbi), {add_op.op_conf()});
lbis_to_add.push_back(GenLogicalBlobId(add_op.output("out", 0)));
}
diff --git a/oneflow/core/job_rewriter/job_completer.cpp b/oneflow/core/job_rewriter/job_completer.cpp
index 151a5bddc857fb6b155d80c84d5897661f7879b1..bb38e9a595eb5017564bbb7909f26f08c54daab5 100644
--- a/oneflow/core/job_rewriter/job_completer.cpp
+++ b/oneflow/core/job_rewriter/job_completer.cpp
@@ -114,6 +114,7 @@ Maybe<void> JobCompleter::Complete(Job* job) const {
JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAutoSourceAndSinkTick));
JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalInputCriticalSections));
JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalOutputCriticalSections));
+ JUST(WithOpGraphAndMutJob(job, &MultiClientAutoSourceAndSinkTick));
JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
if (XrtCompilationEnabled(GlobalJobDesc())) {
#ifdef OF_WITH_XRT
diff --git a/oneflow/core/kernel/callback_notify_kernel.cpp b/oneflow/core/kernel/callback_notify_kernel.cpp
index 5268f351a7185f0dd27327c95e4153accb5fbab0..4a8ce90a8e64a5f08f044bc7a2d2f1ff1acdd52e 100644
--- a/oneflow/core/kernel/callback_notify_kernel.cpp
+++ b/oneflow/core/kernel/callback_notify_kernel.cpp
@@ -15,18 +15,24 @@ limitations under the License.
*/
#include "oneflow/core/kernel/callback_notify_kernel.h"
#include "oneflow/core/job/job_instance.h"
+#include "oneflow/core/job/global_for.h"
+#include "oneflow/core/common/buffer_manager.h"
namespace oneflow {
template<typename T>
void CallbackNotifyKernel<T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
- T job_id = *BnInOp2Blob("in")->dptr<T>();
- const auto& buffer_name = this->op_conf().callback_notify_conf().callback_buffer_name(job_id);
+ auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
+ std::string buffer_name;
+ if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
+ buffer_name = GetCallbackNotifierBufferName(this->job_desc().job_name());
+ } else {
+ T job_id = *BnInOp2Blob("in")->dptr<T>();
+ buffer_name = this->op_conf().callback_notify_conf().callback_buffer_name(job_id);
+ }
std::shared_ptr<JobInstance> foreign_job_instance;
- BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get()
- ->Get(buffer_name)
- ->TryReceive(&foreign_job_instance);
+ BufferStatus buffer_status = buffer_mgr->Get(buffer_name)->TryReceive(&foreign_job_instance);
CHECK_NE(buffer_status, kBufferStatusEmpty);
if (buffer_status == kBufferStatusSuccess) { foreign_job_instance->Finish(); }
}
diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto
index 589004526822e7bbb2152350927a0cd6d2168a16..5e85292dfe6404af75bf66e4cf4d0b6be7c26209 100644
--- a/oneflow/core/operator/op_conf.proto
+++ b/oneflow/core/operator/op_conf.proto
@@ -220,7 +220,8 @@ message DstSubsetTickOpConf {
}
message SourceTickOpConf {
- required string out = 1;
+ optional string wait_in = 1;
+ required string out = 2;
}
message SinkTickOpConf {
diff --git a/oneflow/core/operator/source_tick_op.cpp b/oneflow/core/operator/source_tick_op.cpp
index 52a1f26b50392c08fa94ee810c5e837aa0d17397..da645d9610b0364b56161a1b5a177e5dea5cb17c 100644
--- a/oneflow/core/operator/source_tick_op.cpp
+++ b/oneflow/core/operator/source_tick_op.cpp
@@ -21,6 +21,7 @@ namespace oneflow {
Maybe<void> SourceTickOp::InitFromOpConf() {
CHECK(op_conf().has_source_tick_conf());
CHECK(op_conf().ctrl_in_op_name().empty());
+ if (op_conf().source_tick_conf().has_wait_in()) { EnrollInputBn("wait_in", false); }
EnrollOutputBn("out", false);
return Maybe<void>::Ok();
}
@@ -45,7 +46,8 @@ Maybe<void> SourceTickOp::InferOutBlobDescs(
}
Maybe<void> SourceTickOp::GetSbpSignatures(cfg::SbpSignatureList* sbp_sig_list) const {
- SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add());
+ auto* sbp_signature = sbp_sig_list->mutable_sbp_signature()->Add();
+ SbpSignatureBuilder().Broadcast(input_bns()).Broadcast(output_bns()).Build(sbp_signature);
return Maybe<void>::Ok();
}
diff --git a/oneflow/xrt/passes/rebuild_job_pass.cpp b/oneflow/xrt/passes/rebuild_job_pass.cpp
index 78dbfb4d05f02daef4871f0e27126e860d17b810..439116abb54b3e2150269e8e2eb014030429b7c6 100644
--- a/oneflow/xrt/passes/rebuild_job_pass.cpp
+++ b/oneflow/xrt/passes/rebuild_job_pass.cpp
@@ -311,7 +311,7 @@ void FoldSubgraphBuilder::FixupControlInOpNames() {
};
for (const XrtNode* node : graph_.Nodes()) {
- auto* op_conf = builder_->MutableOpConf4OpName(node->name());
+ auto* op_conf = CHECK_JUST(builder_->MutableOpConf4OpName(node->name()));
if (node->sub_graph() == nullptr) {
auto ctrl_in_op_names = op_conf->ctrl_in_op_name();
op_conf->clear_ctrl_in_op_name();
@@ -319,7 +319,7 @@ void FoldSubgraphBuilder::FixupControlInOpNames() {
} else {
for (const XrtNode* sub_node : node->sub_graph()->Nodes()) {
if (sub_node->IsArgumentNode()) { continue; }
- const auto& folded_op_conf = builder_->OpConf4OpName(sub_node->name());
+ const auto& folded_op_conf = CHECK_JUST(builder_->OpConf4OpName(sub_node->name()));
for (const auto& op_name : folded_op_conf.ctrl_in_op_name()) {
AddControlInOpName(op_conf, op_name);
}
@@ -368,7 +368,7 @@ void FoldSubgraphBuilder::FixupInOutBlobNames() {
// Fix end input blob name
const XrtNode* end = edge->end();
if (end->type() != _XrtLaunchOpType) {
- auto* op_conf = builder_->MutableOpConf4OpName(end->name());
+ auto* op_conf = CHECK_JUST(builder_->MutableOpConf4OpName(end->name()));
const std::string& consume_key = arg.meta_data().consume_key;
SetOpInputBlobName(op_conf, consume_key, arg.name(), fixed_blob_name);
}
@@ -399,7 +399,7 @@ void FoldSubgraphBuilder::FixupSbpSignatures() {
builder_->AddSbpSignature4OpName(node->name(), sbp_conf);
// Add function node sbp signatures.
- auto* op_conf = builder_->MutableOpConf4OpName(node->name());
+ auto* op_conf = CHECK_JUST(builder_->MutableOpConf4OpName(node->name()));
auto* launch_conf = op_conf->mutable_xrt_launch_conf();
auto* sbp_signatures = launch_conf->mutable_sbp_signatures();
for (const auto& node_conf : launch_conf->function().node()) {