From 14099cc2923ea906b84027bcc61344ae0dfbc4fe Mon Sep 17 00:00:00 2001 From: Shenghang Tsai <jackalcooper@gmail.com> Date: Sat, 8 May 2021 20:49:23 +0800 Subject: [PATCH] Use multi core to run TaskNode::ToProto (#4820) * Serialize proto in binary rather than text * move del ops out from loop * refine * Skip GenCollectiveBoxingPlan if no CollectiveBoxingTaskNode * multi core to proto * copy pointers explicitly * make toproto const method * reorder * larger tol * Update test_layers_conv1d.py * fix deadlock * remove ForeignCallBack in Operator::ToOpAttribute Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: liujuncheng <liujuncheng1022@gmail.com> Co-authored-by: clackhan <han_binbin@163.com> --- oneflow/core/graph/compute_task_node.cpp | 2 +- oneflow/core/graph/compute_task_node.h | 2 +- oneflow/core/graph/task_node.cpp | 2 +- oneflow/core/graph/task_node.h | 2 +- oneflow/core/job/compiler.cpp | 34 +++++++++++++++++------- oneflow/core/operator/operator.cpp | 18 ++++++++++--- 6 files changed, 43 insertions(+), 17 deletions(-) diff --git a/oneflow/core/graph/compute_task_node.cpp b/oneflow/core/graph/compute_task_node.cpp index 4b7550ccd..21e2bcc85 100644 --- a/oneflow/core/graph/compute_task_node.cpp +++ b/oneflow/core/graph/compute_task_node.cpp @@ -67,7 +67,7 @@ std::vector<CompTaskNode*> GetCompTaskNodesOnEdge( std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); } -void CompTaskNode::ToProto(TaskProto* task_proto) { +void CompTaskNode::ToProto(TaskProto* task_proto) const { TaskNode::ToProto(task_proto); *(task_proto->mutable_parallel_ctx()) = parallel_ctx_; } diff --git a/oneflow/core/graph/compute_task_node.h b/oneflow/core/graph/compute_task_node.h index dfb5c1bb7..8e735f90d 100644 --- a/oneflow/core/graph/compute_task_node.h +++ b/oneflow/core/graph/compute_task_node.h @@ -36,7 +36,7 @@ class CompTaskNode : public TaskNode { UNIMPLEMENTED(); #endif } - virtual void ToProto(TaskProto*) override; + virtual void ToProto(TaskProto*) const override; // parallel_ctx_ int64_t parallel_id() const { return parallel_ctx_.parallel_id(); } diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 60645122b..98a1f8d55 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -200,7 +200,7 @@ std::string TaskNode::VisualStr() const { bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); } -void TaskNode::ToProto(TaskProto* task_proto) { +void TaskNode::ToProto(TaskProto* task_proto) const { // Step1: process some scalar items. CHECK_NE(chain_id_, -1); task_proto->set_task_type(GetTaskType()); diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 969f0bd41..4512f62cc 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -94,7 +94,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> { virtual TaskType GetTaskType() const { return TaskType::kInvalid; } std::string VisualStr() const override; virtual bool IsMeaningLess(); - virtual void ToProto(TaskProto*); + virtual void ToProto(TaskProto*) const; virtual bool IsIndependent() const { return false; } void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); virtual int64_t MemZoneId121() const; diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 2462b4ded..1749181ee 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -20,6 +20,8 @@ limitations under the License. #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job_rewriter/job_completer.h" +#include "oneflow/core/thread/thread_pool.h" +#include "oneflow/core/common/blocking_counter.h" namespace oneflow { @@ -114,16 +116,30 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); }); // Step4: put infomation from task_gph into plan. + const int64_t node_num = task_gph->node_num(); + const int64_t cpu_num = std::thread::hardware_concurrency(); + const int64_t thread_pool_size = std::min(node_num, cpu_num); + BlockingCounter counter(node_num); + std::mutex mtx; + ThreadPool thread_pool(thread_pool_size); task_gph->ForEachNode([&](TaskNode* task_node) { - if (task_node->IsMeaningLess()) { return; } - TaskProto task_proto; - task_node->ToProto(&task_proto); - if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat - || task_node->GetTaskType() == kAcc) { - CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); - } - plan->mutable_task()->Add(std::move(task_proto)); - }); + thread_pool.AddWork([task_node, plan, &job_desc, &counter, &mtx]() { + if (!task_node->IsMeaningLess()) { + TaskProto task_proto; + task_node->ToProto(&task_proto); + { + std::unique_lock<std::mutex> guard(mtx); + if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat + || task_node->GetTaskType() == kAcc) { + CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); + } + plan->mutable_task()->Add(std::move(task_proto)); + } // guard(mtx) + } + counter.Decrease(); + } /* thread_pool.AddWork */); + } /* task_gph->ForEachNode */); + counter.WaitUntilCntEqualZero(); // NOTE(levi): release task_gph here to decrise memory peak. task_gph.reset(); diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index a05e1874a..e36104850 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/vm/symbol_storage.h" +#include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/job/mirrored_sig_infer_hint.h" @@ -1200,10 +1201,19 @@ Maybe<void> Operator::ToOpAttribute(OpAttribute* op_attribute) const { if (*pair.second == *op_parallel_desc_) { (*symbol_map)[pair.first] = parallel_desc_symbol_id; } else { - (*symbol_map)[pair.first] = - (*Global<std::shared_ptr<ForeignCallback>>::Get()) - ->MakeParallelDescSymbol( - std::make_shared<cfg::ParallelConf>(pair.second->parallel_conf())); + const auto parallel_conf = + std::make_shared<cfg::ParallelConf>(pair.second->parallel_conf()); + const auto MakeParallelDescSymbol = [¶llel_conf]() -> int64_t { + int64_t symbol_id; + const auto BuildInstruction = + [&symbol_id, ¶llel_conf](InstructionsBuilder* builder) -> Maybe<void> { + symbol_id = JUST(JUST(builder->GetParallelDescSymbol(parallel_conf))->symbol_id()); + return Maybe<void>::Ok(); + }; + LogicalRun(BuildInstruction); + return symbol_id; + }; + (*symbol_map)[pair.first] = MakeParallelDescSymbol(); } } for (const auto& tbn : tmp_bns()) { (*symbol_map)[tbn] = parallel_desc_symbol_id; } -- GitLab