From 809ffef1359cc3f6e91f29ecdab22acee7b62a03 Mon Sep 17 00:00:00 2001 From: Li Xinqi <lixinqi2010@gmail.com> Date: Wed, 21 Jul 2021 06:38:34 +0800 Subject: [PATCH] Refactor single client autotick (#5506) * refactor job_pass by maybe_system * refactor AutoSourceAndSinkTick to SingleClientAutoSourceAndSinkTick * remove useless files Co-authored-by: leaves-zwx <kunta0932@gmail.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/core/job_rewriter/autotick.cpp | 85 ++++++++++++++++----- oneflow/core/job_rewriter/autotick.h | 8 +- oneflow/core/job_rewriter/job_completer.cpp | 6 +- 3 files changed, 74 insertions(+), 25 deletions(-) diff --git a/oneflow/core/job_rewriter/autotick.cpp b/oneflow/core/job_rewriter/autotick.cpp index b7efa4966..1eeb3a358 100644 --- a/oneflow/core/job_rewriter/autotick.cpp +++ b/oneflow/core/job_rewriter/autotick.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/critical_section_desc.h" #include "oneflow/core/common/protobuf.h" +#include "oneflow/core/job/global_for.h" namespace oneflow { @@ -83,15 +84,14 @@ Maybe<void> BuildDstSubsetTickOpAndParallelConf(const HashSet<LogicalBlobId>& ti return Maybe<void>::Ok(); } -Maybe<void> CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section, - const OperatorConf& src_subset_tick, - const HashSet<LogicalBlobId>& tick_lbis, - JobBuilder* job_builder) { +Maybe<void> CreateDstSubsetTickAndSinkTicks( + const OperatorConf& src_subset_tick, const HashSet<LogicalBlobId>& tick_lbis, + JobBuilder* job_builder, + const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSink) { OperatorConf dst_subset_tick; dst_subset_tick.mutable_dst_subset_tick_conf()->add_in( src_subset_tick.name() + "/" + src_subset_tick.src_subset_tick_conf().out()); JUST(BuildDstSubsetTickOpAndParallelConf(tick_lbis, &dst_subset_tick, job_builder)); - auto* map = critical_section->mutable_machine_id2sink_tick_op_name(); for (int64_t machine_id : Global<ResourceDesc, ForSession>::Get()->process_ranks()) { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); @@ -113,11 +113,24 @@ Maybe<void> CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section, sink_tick_conf->set_out("out"); JUST(job_builder->AddOp(parallel_conf, sink_tick_op)); } - (*map)[machine_id] = sink_tick_op.name(); + JUST(DoEachSink(machine_id, sink_tick_op.name())); } return Maybe<void>::Ok(); } +Maybe<void> CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section, + const OperatorConf& src_subset_tick, + const HashSet<LogicalBlobId>& tick_lbis, + JobBuilder* job_builder) { + auto* map = critical_section->mutable_machine_id2sink_tick_op_name(); + const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> { + (*map)[machine_id] = op_name; + return Maybe<void>::Ok(); + }; + JUST(CreateDstSubsetTickAndSinkTicks(src_subset_tick, tick_lbis, job_builder, DoEachSink)); + return Maybe<void>::Ok(); +} + Maybe<void> BuildSrcSubsetTickOpAndParallelConf(OperatorConf* src_subset_tick_op, JobBuilder* job_builder) { src_subset_tick_op->set_name("System-AutoTick-SrcSubsetTick_" + NewUniqueId()); @@ -131,10 +144,9 @@ Maybe<void> BuildSrcSubsetTickOpAndParallelConf(OperatorConf* src_subset_tick_op return Maybe<void>::Ok(); } -Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section, - OperatorConf* src_subset_tick_op, - JobBuilder* job_builder) { - auto* map = critical_section->mutable_machine_id2source_tick_op_name(); +Maybe<void> CreateSourceTicksAndSrcSubsetTick( + OperatorConf* src_subset_tick_op, JobBuilder* job_builder, + const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSrc) { for (int64_t machine_id : Global<ResourceDesc, ForSession>::Get()->process_ranks()) { ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); @@ -145,7 +157,7 @@ Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section, src_tick_op.mutable_source_tick_conf()->set_out("out"); JUST(job_builder->AddOp(parallel_conf, src_tick_op)); } - (*map)[machine_id] = src_tick_op.name(); + JUST(DoEachSrc(machine_id, src_tick_op.name())); OperatorConf tick_op; { tick_op.set_name("System-AutoTick-Tick_" + NewUniqueId()); @@ -159,6 +171,18 @@ Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section, return Maybe<void>::Ok(); } +Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section, + OperatorConf* src_subset_tick_op, + JobBuilder* job_builder) { + auto* map = critical_section->mutable_machine_id2source_tick_op_name(); + const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> { + (*map)[machine_id] = op_name; + return Maybe<void>::Ok(); + }; + JUST(CreateSourceTicksAndSrcSubsetTick(src_subset_tick_op, job_builder, DoEachSrc)); + return Maybe<void>::Ok(); +} + Maybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick_op, JobBuilder* job_builder) { CHECK_OR_RETURN(src_subset_tick_op.has_src_subset_tick_conf()); @@ -421,10 +445,10 @@ Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder return Maybe<void>::Ok(); } -Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) { - auto* critical_section = - Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id()); - critical_section->mutable_total_job_critical_section(); +Maybe<void> AutoSourceAndSinkTick( + const OpGraph& op_graph, JobBuilder* job_builder, + const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSrc, + const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSink) { JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe<void> { CHECK_OR_RETURN(!node->op().op_conf().has_sink_tick_conf()); return Maybe<void>::Ok(); @@ -442,12 +466,33 @@ Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_build return Maybe<void>::Ok(); })); OperatorConf src_subset_tick = JUST(FindSrcSubsetTickOpConf(job_builder->job())); - JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder)); - JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder)); + JUST(CreateSourceTicksAndSrcSubsetTick(&src_subset_tick, job_builder, DoEachSrc)); + JUST(CreateDstSubsetTickAndSinkTicks(src_subset_tick, tick_lbis, job_builder, DoEachSink)); + return Maybe<void>::Ok(); +} + +Maybe<void> SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) { + if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); } + auto* critical_section = + Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id()); + critical_section->mutable_total_job_critical_section(); + auto* src_map = critical_section->mutable_machine_id2source_tick_op_name(); + const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> { + (*src_map)[machine_id] = op_name; + return Maybe<void>::Ok(); + }; + auto* sink_map = critical_section->mutable_machine_id2sink_tick_op_name(); + const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> { + (*sink_map)[machine_id] = op_name; + return Maybe<void>::Ok(); + }; + JUST(AutoSourceAndSinkTick(op_graph, job_builder, DoEachSrc, DoEachSink)); return Maybe<void>::Ok(); } -Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { +Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph, + JobBuilder* job_builder) { + if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); } JUST(ForEachInputCriticalSectionOpNodes( op_graph, [&](const HashSet<const OpNode*>& op_nodes, @@ -458,7 +503,9 @@ Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* return Maybe<void>::Ok(); } -Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { +Maybe<void> SingleClientAddGlobalOutputCriticalSections(const OpGraph& op_graph, + JobBuilder* job_builder) { + if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); } JUST(ForEachOutputCriticalSectionOpNodes( op_graph, [&](const HashSet<const OpNode*>& op_nodes, diff --git a/oneflow/core/job_rewriter/autotick.h b/oneflow/core/job_rewriter/autotick.h index 2760edda5..1935ab945 100644 --- a/oneflow/core/job_rewriter/autotick.h +++ b/oneflow/core/job_rewriter/autotick.h @@ -24,9 +24,11 @@ namespace oneflow { Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder); Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder); -Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder); -Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder); -Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder); +Maybe<void> SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder); +Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph, + JobBuilder* job_builder); +Maybe<void> SingleClientAddGlobalOutputCriticalSections(const OpGraph& op_graph, + JobBuilder* job_builder); class MutOpConTickInputHelper { public: diff --git a/oneflow/core/job_rewriter/job_completer.cpp b/oneflow/core/job_rewriter/job_completer.cpp index c3833ab33..151a5bddc 100644 --- a/oneflow/core/job_rewriter/job_completer.cpp +++ b/oneflow/core/job_rewriter/job_completer.cpp @@ -111,9 +111,9 @@ Maybe<void> JobCompleter::Complete(Job* job) const { // complete tick ops JUST(WithOpGraphAndMutJobBuilder(job, &AutoPrependTick)); JUST(WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape)); - JUST(WithOpGraphAndMutJobBuilder(job, &AutoSourceAndSinkTick)); - JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections)); - JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections)); + JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAutoSourceAndSinkTick)); + JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalInputCriticalSections)); + JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalOutputCriticalSections)); JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); if (XrtCompilationEnabled(GlobalJobDesc())) { #ifdef OF_WITH_XRT -- GitLab