From 14099cc2923ea906b84027bcc61344ae0dfbc4fe Mon Sep 17 00:00:00 2001
From: Shenghang Tsai <jackalcooper@gmail.com>
Date: Sat, 8 May 2021 20:49:23 +0800
Subject: [PATCH] Use multi core to run TaskNode::ToProto (#4820)

* Serialize proto in binary rather than text

* move del ops out from loop

* refine

* Skip GenCollectiveBoxingPlan if no CollectiveBoxingTaskNode

* multi core to proto

* copy pointers explicitly

* make toproto const method

* reorder

* larger tol

* Update test_layers_conv1d.py

* fix deadlock

* remove ForeignCallBack in Operator::ToOpAttribute

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: liujuncheng <liujuncheng1022@gmail.com>
Co-authored-by: clackhan <han_binbin@163.com>
---
 oneflow/core/graph/compute_task_node.cpp |  2 +-
 oneflow/core/graph/compute_task_node.h   |  2 +-
 oneflow/core/graph/task_node.cpp         |  2 +-
 oneflow/core/graph/task_node.h           |  2 +-
 oneflow/core/job/compiler.cpp            | 34 +++++++++++++++++-------
 oneflow/core/operator/operator.cpp       | 18 ++++++++++---
 6 files changed, 43 insertions(+), 17 deletions(-)

diff --git a/oneflow/core/graph/compute_task_node.cpp b/oneflow/core/graph/compute_task_node.cpp
index 4b7550ccd..21e2bcc85 100644
--- a/oneflow/core/graph/compute_task_node.cpp
+++ b/oneflow/core/graph/compute_task_node.cpp
@@ -67,7 +67,7 @@ std::vector<CompTaskNode*> GetCompTaskNodesOnEdge(
 
 std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); }
 
-void CompTaskNode::ToProto(TaskProto* task_proto) {
+void CompTaskNode::ToProto(TaskProto* task_proto) const {
   TaskNode::ToProto(task_proto);
   *(task_proto->mutable_parallel_ctx()) = parallel_ctx_;
 }
diff --git a/oneflow/core/graph/compute_task_node.h b/oneflow/core/graph/compute_task_node.h
index dfb5c1bb7..8e735f90d 100644
--- a/oneflow/core/graph/compute_task_node.h
+++ b/oneflow/core/graph/compute_task_node.h
@@ -36,7 +36,7 @@ class CompTaskNode : public TaskNode {
     UNIMPLEMENTED();
 #endif
   }
-  virtual void ToProto(TaskProto*) override;
+  virtual void ToProto(TaskProto*) const override;
 
   // parallel_ctx_
   int64_t parallel_id() const { return parallel_ctx_.parallel_id(); }
diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp
index 60645122b..98a1f8d55 100644
--- a/oneflow/core/graph/task_node.cpp
+++ b/oneflow/core/graph/task_node.cpp
@@ -200,7 +200,7 @@ std::string TaskNode::VisualStr() const {
 
 bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); }
 
-void TaskNode::ToProto(TaskProto* task_proto) {
+void TaskNode::ToProto(TaskProto* task_proto) const {
   // Step1: process some scalar items.
   CHECK_NE(chain_id_, -1);
   task_proto->set_task_type(GetTaskType());
diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h
index 969f0bd41..4512f62cc 100644
--- a/oneflow/core/graph/task_node.h
+++ b/oneflow/core/graph/task_node.h
@@ -94,7 +94,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
   virtual TaskType GetTaskType() const { return TaskType::kInvalid; }
   std::string VisualStr() const override;
   virtual bool IsMeaningLess();
-  virtual void ToProto(TaskProto*);
+  virtual void ToProto(TaskProto*) const;
   virtual bool IsIndependent() const { return false; }
   void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name);
   virtual int64_t MemZoneId121() const;
diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp
index 2462b4ded..1749181ee 100644
--- a/oneflow/core/job/compiler.cpp
+++ b/oneflow/core/job/compiler.cpp
@@ -20,6 +20,8 @@ limitations under the License.
 #include "oneflow/core/persistence/tee_persistent_log_stream.h"
 #include "oneflow/core/graph/op_graph.h"
 #include "oneflow/core/job_rewriter/job_completer.h"
+#include "oneflow/core/thread/thread_pool.h"
+#include "oneflow/core/common/blocking_counter.h"
 
 namespace oneflow {
 
@@ -114,16 +116,30 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
   task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); });
 
   // Step4: put infomation from task_gph into plan.
+  const int64_t node_num = task_gph->node_num();
+  const int64_t cpu_num = std::thread::hardware_concurrency();
+  const int64_t thread_pool_size = std::min(node_num, cpu_num);
+  BlockingCounter counter(node_num);
+  std::mutex mtx;
+  ThreadPool thread_pool(thread_pool_size);
   task_gph->ForEachNode([&](TaskNode* task_node) {
-    if (task_node->IsMeaningLess()) { return; }
-    TaskProto task_proto;
-    task_node->ToProto(&task_proto);
-    if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat
-        || task_node->GetTaskType() == kAcc) {
-      CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);
-    }
-    plan->mutable_task()->Add(std::move(task_proto));
-  });
+    thread_pool.AddWork([task_node, plan, &job_desc, &counter, &mtx]() {
+      if (!task_node->IsMeaningLess()) {
+        TaskProto task_proto;
+        task_node->ToProto(&task_proto);
+        {
+          std::unique_lock<std::mutex> guard(mtx);
+          if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat
+              || task_node->GetTaskType() == kAcc) {
+            CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);
+          }
+          plan->mutable_task()->Add(std::move(task_proto));
+        }  // guard(mtx)
+      }
+      counter.Decrease();
+    } /* thread_pool.AddWork */);
+  } /* task_gph->ForEachNode */);
+  counter.WaitUntilCntEqualZero();
   // NOTE(levi): release task_gph here to decrise memory peak.
   task_gph.reset();
 
diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp
index a05e1874a..e36104850 100644
--- a/oneflow/core/operator/operator.cpp
+++ b/oneflow/core/operator/operator.cpp
@@ -15,6 +15,7 @@ limitations under the License.
 */
 #include "oneflow/core/common/balanced_splitter.h"
 #include "oneflow/core/vm/symbol_storage.h"
+#include "oneflow/core/framework/instructions_builder.h"
 #include "oneflow/core/framework/to_string.h"
 #include "oneflow/core/framework/user_op_registry_manager.h"
 #include "oneflow/core/job/mirrored_sig_infer_hint.h"
@@ -1200,10 +1201,19 @@ Maybe<void> Operator::ToOpAttribute(OpAttribute* op_attribute) const {
         if (*pair.second == *op_parallel_desc_) {
           (*symbol_map)[pair.first] = parallel_desc_symbol_id;
         } else {
-          (*symbol_map)[pair.first] =
-              (*Global<std::shared_ptr<ForeignCallback>>::Get())
-                  ->MakeParallelDescSymbol(
-                      std::make_shared<cfg::ParallelConf>(pair.second->parallel_conf()));
+          const auto parallel_conf =
+              std::make_shared<cfg::ParallelConf>(pair.second->parallel_conf());
+          const auto MakeParallelDescSymbol = [&parallel_conf]() -> int64_t {
+            int64_t symbol_id;
+            const auto BuildInstruction =
+                [&symbol_id, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {
+              symbol_id = JUST(JUST(builder->GetParallelDescSymbol(parallel_conf))->symbol_id());
+              return Maybe<void>::Ok();
+            };
+            LogicalRun(BuildInstruction);
+            return symbol_id;
+          };
+          (*symbol_map)[pair.first] = MakeParallelDescSymbol();
         }
       }
       for (const auto& tbn : tmp_bns()) { (*symbol_map)[tbn] = parallel_desc_symbol_id; }
-- 
GitLab