From 809ffef1359cc3f6e91f29ecdab22acee7b62a03 Mon Sep 17 00:00:00 2001
From: Li Xinqi <lixinqi2010@gmail.com>
Date: Wed, 21 Jul 2021 06:38:34 +0800
Subject: [PATCH] Refactor single client autotick (#5506)

* refactor job_pass by maybe_system

* refactor AutoSourceAndSinkTick to SingleClientAutoSourceAndSinkTick

* remove useless files

Co-authored-by: leaves-zwx <kunta0932@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/core/job_rewriter/autotick.cpp      | 85 ++++++++++++++++-----
 oneflow/core/job_rewriter/autotick.h        |  8 +-
 oneflow/core/job_rewriter/job_completer.cpp |  6 +-
 3 files changed, 74 insertions(+), 25 deletions(-)

diff --git a/oneflow/core/job_rewriter/autotick.cpp b/oneflow/core/job_rewriter/autotick.cpp
index b7efa4966..1eeb3a358 100644
--- a/oneflow/core/job_rewriter/autotick.cpp
+++ b/oneflow/core/job_rewriter/autotick.cpp
@@ -17,6 +17,7 @@ limitations under the License.
 #include "oneflow/core/job/job_builder.h"
 #include "oneflow/core/job/critical_section_desc.h"
 #include "oneflow/core/common/protobuf.h"
+#include "oneflow/core/job/global_for.h"
 
 namespace oneflow {
 
@@ -83,15 +84,14 @@ Maybe<void> BuildDstSubsetTickOpAndParallelConf(const HashSet<LogicalBlobId>& ti
   return Maybe<void>::Ok();
 }
 
-Maybe<void> CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section,
-                                            const OperatorConf& src_subset_tick,
-                                            const HashSet<LogicalBlobId>& tick_lbis,
-                                            JobBuilder* job_builder) {
+Maybe<void> CreateDstSubsetTickAndSinkTicks(
+    const OperatorConf& src_subset_tick, const HashSet<LogicalBlobId>& tick_lbis,
+    JobBuilder* job_builder,
+    const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSink) {
   OperatorConf dst_subset_tick;
   dst_subset_tick.mutable_dst_subset_tick_conf()->add_in(
       src_subset_tick.name() + "/" + src_subset_tick.src_subset_tick_conf().out());
   JUST(BuildDstSubsetTickOpAndParallelConf(tick_lbis, &dst_subset_tick, job_builder));
-  auto* map = critical_section->mutable_machine_id2sink_tick_op_name();
   for (int64_t machine_id : Global<ResourceDesc, ForSession>::Get()->process_ranks()) {
     ParallelConf parallel_conf;
     parallel_conf.set_device_tag("cpu");
@@ -113,11 +113,24 @@ Maybe<void> CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section,
       sink_tick_conf->set_out("out");
       JUST(job_builder->AddOp(parallel_conf, sink_tick_op));
     }
-    (*map)[machine_id] = sink_tick_op.name();
+    JUST(DoEachSink(machine_id, sink_tick_op.name()));
   }
   return Maybe<void>::Ok();
 }
 
+Maybe<void> CreateDstSubsetTickAndSinkTicks(CriticalSection* critical_section,
+                                            const OperatorConf& src_subset_tick,
+                                            const HashSet<LogicalBlobId>& tick_lbis,
+                                            JobBuilder* job_builder) {
+  auto* map = critical_section->mutable_machine_id2sink_tick_op_name();
+  const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {
+    (*map)[machine_id] = op_name;
+    return Maybe<void>::Ok();
+  };
+  JUST(CreateDstSubsetTickAndSinkTicks(src_subset_tick, tick_lbis, job_builder, DoEachSink));
+  return Maybe<void>::Ok();
+}
+
 Maybe<void> BuildSrcSubsetTickOpAndParallelConf(OperatorConf* src_subset_tick_op,
                                                 JobBuilder* job_builder) {
   src_subset_tick_op->set_name("System-AutoTick-SrcSubsetTick_" + NewUniqueId());
@@ -131,10 +144,9 @@ Maybe<void> BuildSrcSubsetTickOpAndParallelConf(OperatorConf* src_subset_tick_op
   return Maybe<void>::Ok();
 }
 
-Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section,
-                                              OperatorConf* src_subset_tick_op,
-                                              JobBuilder* job_builder) {
-  auto* map = critical_section->mutable_machine_id2source_tick_op_name();
+Maybe<void> CreateSourceTicksAndSrcSubsetTick(
+    OperatorConf* src_subset_tick_op, JobBuilder* job_builder,
+    const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSrc) {
   for (int64_t machine_id : Global<ResourceDesc, ForSession>::Get()->process_ranks()) {
     ParallelConf parallel_conf;
     parallel_conf.set_device_tag("cpu");
@@ -145,7 +157,7 @@ Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section,
       src_tick_op.mutable_source_tick_conf()->set_out("out");
       JUST(job_builder->AddOp(parallel_conf, src_tick_op));
     }
-    (*map)[machine_id] = src_tick_op.name();
+    JUST(DoEachSrc(machine_id, src_tick_op.name()));
     OperatorConf tick_op;
     {
       tick_op.set_name("System-AutoTick-Tick_" + NewUniqueId());
@@ -159,6 +171,18 @@ Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section,
   return Maybe<void>::Ok();
 }
 
+Maybe<void> CreateSourceTicksAndSrcSubsetTick(CriticalSection* critical_section,
+                                              OperatorConf* src_subset_tick_op,
+                                              JobBuilder* job_builder) {
+  auto* map = critical_section->mutable_machine_id2source_tick_op_name();
+  const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {
+    (*map)[machine_id] = op_name;
+    return Maybe<void>::Ok();
+  };
+  JUST(CreateSourceTicksAndSrcSubsetTick(src_subset_tick_op, job_builder, DoEachSrc));
+  return Maybe<void>::Ok();
+}
+
 Maybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick_op,
                                              JobBuilder* job_builder) {
   CHECK_OR_RETURN(src_subset_tick_op.has_src_subset_tick_conf());
@@ -421,10 +445,10 @@ Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder
   return Maybe<void>::Ok();
 }
 
-Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
-  auto* critical_section =
-      Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
-  critical_section->mutable_total_job_critical_section();
+Maybe<void> AutoSourceAndSinkTick(
+    const OpGraph& op_graph, JobBuilder* job_builder,
+    const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSrc,
+    const std::function<Maybe<void>(int64_t machine_id, const std::string& op_name)>& DoEachSink) {
   JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe<void> {
     CHECK_OR_RETURN(!node->op().op_conf().has_sink_tick_conf());
     return Maybe<void>::Ok();
@@ -442,12 +466,33 @@ Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_build
     return Maybe<void>::Ok();
   }));
   OperatorConf src_subset_tick = JUST(FindSrcSubsetTickOpConf(job_builder->job()));
-  JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder));
-  JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder));
+  JUST(CreateSourceTicksAndSrcSubsetTick(&src_subset_tick, job_builder, DoEachSrc));
+  JUST(CreateDstSubsetTickAndSinkTicks(src_subset_tick, tick_lbis, job_builder, DoEachSink));
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
+  if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
+  auto* critical_section =
+      Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
+  critical_section->mutable_total_job_critical_section();
+  auto* src_map = critical_section->mutable_machine_id2source_tick_op_name();
+  const auto& DoEachSrc = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {
+    (*src_map)[machine_id] = op_name;
+    return Maybe<void>::Ok();
+  };
+  auto* sink_map = critical_section->mutable_machine_id2sink_tick_op_name();
+  const auto& DoEachSink = [&](int64_t machine_id, const std::string& op_name) -> Maybe<void> {
+    (*sink_map)[machine_id] = op_name;
+    return Maybe<void>::Ok();
+  };
+  JUST(AutoSourceAndSinkTick(op_graph, job_builder, DoEachSrc, DoEachSink));
   return Maybe<void>::Ok();
 }
 
-Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
+Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph,
+                                                       JobBuilder* job_builder) {
+  if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
   JUST(ForEachInputCriticalSectionOpNodes(
       op_graph,
       [&](const HashSet<const OpNode*>& op_nodes,
@@ -458,7 +503,9 @@ Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder*
   return Maybe<void>::Ok();
 }
 
-Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
+Maybe<void> SingleClientAddGlobalOutputCriticalSections(const OpGraph& op_graph,
+                                                        JobBuilder* job_builder) {
+  if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) { return Maybe<void>::Ok(); }
   JUST(ForEachOutputCriticalSectionOpNodes(
       op_graph,
       [&](const HashSet<const OpNode*>& op_nodes,
diff --git a/oneflow/core/job_rewriter/autotick.h b/oneflow/core/job_rewriter/autotick.h
index 2760edda5..1935ab945 100644
--- a/oneflow/core/job_rewriter/autotick.h
+++ b/oneflow/core/job_rewriter/autotick.h
@@ -24,9 +24,11 @@ namespace oneflow {
 
 Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder);
 Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder);
-Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder);
-Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
-Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
+Maybe<void> SingleClientAutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder);
+Maybe<void> SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph,
+                                                       JobBuilder* job_builder);
+Maybe<void> SingleClientAddGlobalOutputCriticalSections(const OpGraph& op_graph,
+                                                        JobBuilder* job_builder);
 
 class MutOpConTickInputHelper {
  public:
diff --git a/oneflow/core/job_rewriter/job_completer.cpp b/oneflow/core/job_rewriter/job_completer.cpp
index c3833ab33..151a5bddc 100644
--- a/oneflow/core/job_rewriter/job_completer.cpp
+++ b/oneflow/core/job_rewriter/job_completer.cpp
@@ -111,9 +111,9 @@ Maybe<void> JobCompleter::Complete(Job* job) const {
   // complete tick ops
   JUST(WithOpGraphAndMutJobBuilder(job, &AutoPrependTick));
   JUST(WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape));
-  JUST(WithOpGraphAndMutJobBuilder(job, &AutoSourceAndSinkTick));
-  JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections));
-  JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections));
+  JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAutoSourceAndSinkTick));
+  JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalInputCriticalSections));
+  JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalOutputCriticalSections));
   JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
   if (XrtCompilationEnabled(GlobalJobDesc())) {
 #ifdef OF_WITH_XRT
-- 
GitLab