Skip to content
Snippets Groups Projects
Unverified Commit 4804c301 authored by Li Xinqi's avatar Li Xinqi Committed by GitHub
Browse files

add job_pass MultiClientAutoSourceAndSinkTick (#5507)


* refactor job_pass by maybe_system

* refactor AutoSourceAndSinkTick to SingleClientAutoSourceAndSinkTick

* remove useless files

* add job_pass MultiClientAutoSourceAndSinkTick

* address review

* fix check

* fix OpConf4OpName return

Co-authored-by: default avatarleaves-zwx <kunta0932@gmail.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent e5589c0a
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/operator/operator.h"
namespace oneflow {
......@@ -157,14 +158,14 @@ JobBuilder::JobBuilder(Job* job) : job_(job) {
}
}
OperatorConf* JobBuilder::MutableOpConf4OpName(const std::string& op_name) {
Maybe<OperatorConf*> JobBuilder::MutableOpConf4OpName(const std::string& op_name) {
const auto& it = op_name2op_conf_.find(op_name);
CHECK(it != op_name2op_conf_.end());
CHECK_OR_RETURN(it != op_name2op_conf_.end());
return it->second;
}
const OperatorConf& JobBuilder::OpConf4OpName(const std::string& op_name) const {
return *op_name2op_conf_.at(op_name);
Maybe<const OperatorConf&> JobBuilder::OpConf4OpName(const std::string& op_name) const {
return *JUST(MapAt(op_name2op_conf_, op_name));
}
const ParallelConf& JobBuilder::ParallelConf4Lbi(const LogicalBlobId& lbi) const {
......
......@@ -50,8 +50,8 @@ class JobBuilder final {
return job_->mutable_job_parallel_view_conf();
}
const OperatorConf& OpConf4OpName(const std::string& op_name) const;
OperatorConf* MutableOpConf4OpName(const std::string& op_name);
Maybe<const OperatorConf&> OpConf4OpName(const std::string& op_name) const;
Maybe<OperatorConf*> MutableOpConf4OpName(const std::string& op_name);
Maybe<void> AddOp(const ParallelConf& parallel_conf, const OperatorConf& op_conf);
void AddOps(const ParallelConf& parallel_conf, const std::vector<OperatorConf>& op_confs);
......
......@@ -400,7 +400,7 @@ void GetMemSharingOpBlobInfo(const JobBuilder& job_builder, const std::string& o
std::string obn = "out";
std::string lbn;
{
const auto& op_conf = job_builder.OpConf4OpName(op_name);
const auto& op_conf = CHECK_JUST(job_builder.OpConf4OpName(op_name));
if (op_conf.has_variable_conf()) {
lbn = op_name + "/" + op_conf.variable_conf().out();
} else if (op_conf.has_input_conf()) {
......
......@@ -48,7 +48,7 @@ Maybe<void> AddLbiDiffWatcherOpConfs::Apply(Job* job) const {
for (const LbiAndDiffWatcherUuidPair& pair : pair_list) {
if (lbi2diff_lbi.find(pair.lbi()) == lbi2diff_lbi.end()) { continue; }
const auto& diff_lbi = lbi2diff_lbi.at(pair.lbi());
const auto& diff_lbi_op_conf = job_builder.OpConf4OpName(diff_lbi.op_name());
const auto& diff_lbi_op_conf = JUST(job_builder.OpConf4OpName(diff_lbi.op_name()));
OperatorConf foreign_watcher_op;
foreign_watcher_op.set_name("System-LbiDiffWatcher-ForeignWatcher-" + NewUniqueId());
foreign_watcher_op.set_scope_symbol_id(diff_lbi_op_conf.scope_symbol_id());
......
......@@ -420,6 +420,54 @@ Maybe<void> AddGlobalInputOutputCriticalSection(
return Maybe<void>::Ok();
}
Maybe<void> MultiClientAddWaitAndSendIds(JobBuilder* job_builder, int64_t machine_id,
const std::string& src_op_name) {
ParallelConf parallel_conf;
{
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0");
}
OperatorConf wait_and_send_ids_op_conf;
{
wait_and_send_ids_op_conf.set_name(std::string("System-Src-WaitAndSendIds_") + NewUniqueId());
wait_and_send_ids_op_conf.set_pass_tag(kMainOp);
auto* wait_and_send_ids_conf = wait_and_send_ids_op_conf.mutable_wait_and_send_ids_conf();
wait_and_send_ids_conf->set_out("out");
wait_and_send_ids_conf->set_wait_buffer_name("UnimplementedBufferName");
wait_and_send_ids_conf->set_data_type(DataType::kInt32);
// wait_and_send_ids_conf->id_list() is unused in multi-client mode.
}
JUST(job_builder->AddOp(parallel_conf, wait_and_send_ids_op_conf));
OperatorConf source_tick_op = JUST(job_builder->OpConf4OpName(src_op_name));
{
CHECK_OR_RETURN(source_tick_op.has_source_tick_conf());
auto* source_tick_op_conf = source_tick_op.mutable_source_tick_conf();
CHECK_OR_RETURN(!source_tick_op_conf->has_wait_in());
source_tick_op_conf->set_wait_in(GenLogicalBlobName(wait_and_send_ids_op_conf.name(), "out"));
}
JUST(job_builder->MutOpOnlyOnce(source_tick_op));
return Maybe<void>::Ok();
}
Maybe<void> MultiClientAddCallbackNotifier(JobBuilder* job_builder, int64_t machine_id,
const std::string& sink_op_name) {
ParallelConf parallel_conf;
{
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0");
}
OperatorConf callback_notify_op_conf;
{
callback_notify_op_conf.set_name(std::string("System-Sink-CallbackNotify_") + NewUniqueId());
callback_notify_op_conf.set_pass_tag(kMainOp);
auto* callback_notify_conf = callback_notify_op_conf.mutable_callback_notify_conf();
callback_notify_conf->set_in(GenLogicalBlobName(sink_op_name, "/out"));
// callback_notify_conf->callback_buffer_name() is unused in multi-client mode.
}
JUST(job_builder->AddOp(parallel_conf, callback_notify_op_conf));
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {
......@@ -490,6 +538,34 @@ Maybe<void> SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilde
return Maybe<void>::Ok();
}
Maybe<void> MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job) {
if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
HashMap<int64_t, std::string> machine_id2src_op_name;
HashMap<int64_t, std::string> machine_id2sink_op_name;
{
JobBuilder job_builder(job);
const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {
CHECK_OR_RETURN(machine_id2src_op_name.emplace(machine_id, op_name).second);
return Maybe<void>::Ok();
};
const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {
CHECK_OR_RETURN(machine_id2sink_op_name.emplace(machine_id, op_name).second);
return Maybe<void>::Ok();
};
JUST(AutoSourceAndSinkTick(op_graph, &job_builder, DoEachSrc, DoEachSink));
}
{
JobBuilder job_builder(job);
for (const auto& pair : machine_id2src_op_name) {
JUST(MultiClientAddWaitAndSendIds(&job_builder, pair.first, pair.second));
}
for (const auto& pair : machine_id2sink_op_name) {
JUST(MultiClientAddCallbackNotifier(&job_builder, pair.first, pair.second));
}
}
return Maybe<void>::Ok();
}
Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph,
JobBuilder* job_builder) {
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
......
......@@ -29,6 +29,7 @@ Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph,
JobBuilder* job_builder);
Maybe<void> SingleClientAddGlobalOutputCriticalSections(const OpGraph& op_graph,
JobBuilder* job_builder);
Maybe<void> MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job);
class MutOpConTickInputHelper {
public:
......
......@@ -55,10 +55,9 @@ void GenerateCloneGradOpIfNeed(const OpNode& op_node, JobBuilder* job_builder,
add_op_builder.Input("in", GenLogicalBlobName(lbis_to_add.at(i)));
}
lbis_to_add.resize(start);
const auto& op_conf = CHECK_JUST(job_builder->OpConf4OpName(lbi.op_name()));
const auto add_op =
add_op_builder.Output("out")
.ScopeSymbolId(job_builder->OpConf4OpName(lbi.op_name()).scope_symbol_id())
.Build();
add_op_builder.Output("out").ScopeSymbolId(op_conf.scope_symbol_id()).Build();
job_builder->AddOps(job_builder->ParallelConf4Lbi(lbi), {add_op.op_conf()});
lbis_to_add.push_back(GenLogicalBlobId(add_op.output("out", 0)));
}
......
......@@ -114,6 +114,7 @@ Maybe<void> JobCompleter::Complete(Job* job) const {
JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAutoSourceAndSinkTick));
JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalInputCriticalSections));
JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalOutputCriticalSections));
JUST(WithOpGraphAndMutJob(job, &MultiClientAutoSourceAndSinkTick));
JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
if (XrtCompilationEnabled(GlobalJobDesc())) {
#ifdef OF_WITH_XRT
......
......@@ -15,18 +15,24 @@ limitations under the License.
*/
#include "oneflow/core/kernel/callback_notify_kernel.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/common/buffer_manager.h"
namespace oneflow {
template<typename T>
void CallbackNotifyKernel<T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
T job_id = *BnInOp2Blob("in")->dptr<T>();
const auto& buffer_name = this->op_conf().callback_notify_conf().callback_buffer_name(job_id);
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
std::string buffer_name;
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
buffer_name = GetCallbackNotifierBufferName(this->job_desc().job_name());
} else {
T job_id = *BnInOp2Blob("in")->dptr<T>();
buffer_name = this->op_conf().callback_notify_conf().callback_buffer_name(job_id);
}
std::shared_ptr<JobInstance> foreign_job_instance;
BufferStatus buffer_status = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get()
->Get(buffer_name)
->TryReceive(&foreign_job_instance);
BufferStatus buffer_status = buffer_mgr->Get(buffer_name)->TryReceive(&foreign_job_instance);
CHECK_NE(buffer_status, kBufferStatusEmpty);
if (buffer_status == kBufferStatusSuccess) { foreign_job_instance->Finish(); }
}
......
......@@ -220,7 +220,8 @@ message DstSubsetTickOpConf {
}
message SourceTickOpConf {
required string out = 1;
optional string wait_in = 1;
required string out = 2;
}
message SinkTickOpConf {
......
......@@ -21,6 +21,7 @@ namespace oneflow {
Maybe<void> SourceTickOp::InitFromOpConf() {
CHECK(op_conf().has_source_tick_conf());
CHECK(op_conf().ctrl_in_op_name().empty());
if (op_conf().source_tick_conf().has_wait_in()) { EnrollInputBn("wait_in", false); }
EnrollOutputBn("out", false);
return Maybe<void>::Ok();
}
......@@ -45,7 +46,8 @@ Maybe<void> SourceTickOp::InferOutBlobDescs(
}
Maybe<void> SourceTickOp::GetSbpSignatures(cfg::SbpSignatureList* sbp_sig_list) const {
SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add());
auto* sbp_signature = sbp_sig_list->mutable_sbp_signature()->Add();
SbpSignatureBuilder().Broadcast(input_bns()).Broadcast(output_bns()).Build(sbp_signature);
return Maybe<void>::Ok();
}
......
......@@ -311,7 +311,7 @@ void FoldSubgraphBuilder::FixupControlInOpNames() {
};
for (const XrtNode* node : graph_.Nodes()) {
auto* op_conf = builder_->MutableOpConf4OpName(node->name());
auto* op_conf = CHECK_JUST(builder_->MutableOpConf4OpName(node->name()));
if (node->sub_graph() == nullptr) {
auto ctrl_in_op_names = op_conf->ctrl_in_op_name();
op_conf->clear_ctrl_in_op_name();
......@@ -319,7 +319,7 @@ void FoldSubgraphBuilder::FixupControlInOpNames() {
} else {
for (const XrtNode* sub_node : node->sub_graph()->Nodes()) {
if (sub_node->IsArgumentNode()) { continue; }
const auto& folded_op_conf = builder_->OpConf4OpName(sub_node->name());
const auto& folded_op_conf = CHECK_JUST(builder_->OpConf4OpName(sub_node->name()));
for (const auto& op_name : folded_op_conf.ctrl_in_op_name()) {
AddControlInOpName(op_conf, op_name);
}
......@@ -368,7 +368,7 @@ void FoldSubgraphBuilder::FixupInOutBlobNames() {
// Fix end input blob name
const XrtNode* end = edge->end();
if (end->type() != _XrtLaunchOpType) {
auto* op_conf = builder_->MutableOpConf4OpName(end->name());
auto* op_conf = CHECK_JUST(builder_->MutableOpConf4OpName(end->name()));
const std::string& consume_key = arg.meta_data().consume_key;
SetOpInputBlobName(op_conf, consume_key, arg.name(), fixed_blob_name);
}
......@@ -399,7 +399,7 @@ void FoldSubgraphBuilder::FixupSbpSignatures() {
builder_->AddSbpSignature4OpName(node->name(), sbp_conf);
// Add function node sbp signatures.
auto* op_conf = builder_->MutableOpConf4OpName(node->name());
auto* op_conf = CHECK_JUST(builder_->MutableOpConf4OpName(node->name()));
auto* launch_conf = op_conf->mutable_xrt_launch_conf();
auto* sbp_signatures = launch_conf->mutable_sbp_signatures();
for (const auto& node_conf : launch_conf->function().node()) {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment