diff --git a/oneflow/core/graph/compute_task_node.cpp b/oneflow/core/graph/compute_task_node.cpp index 4b7550ccd4ef9f719072ad3c3a40b9b4c948a8e9..21e2bcc85cd5b9c1b3ba6a4e82c9e7fadc9ebc35 100644 --- a/oneflow/core/graph/compute_task_node.cpp +++ b/oneflow/core/graph/compute_task_node.cpp @@ -67,7 +67,7 @@ std::vector<CompTaskNode*> GetCompTaskNodesOnEdge( std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); } -void CompTaskNode::ToProto(TaskProto* task_proto) { +void CompTaskNode::ToProto(TaskProto* task_proto) const { TaskNode::ToProto(task_proto); *(task_proto->mutable_parallel_ctx()) = parallel_ctx_; } diff --git a/oneflow/core/graph/compute_task_node.h b/oneflow/core/graph/compute_task_node.h index dfb5c1bb7df031f9aecb416729ae4e910776d928..8e735f90d762c1f2d917a56687e11c482523de3e 100644 --- a/oneflow/core/graph/compute_task_node.h +++ b/oneflow/core/graph/compute_task_node.h @@ -36,7 +36,7 @@ class CompTaskNode : public TaskNode { UNIMPLEMENTED(); #endif } - virtual void ToProto(TaskProto*) override; + virtual void ToProto(TaskProto*) const override; // parallel_ctx_ int64_t parallel_id() const { return parallel_ctx_.parallel_id(); } diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 60645122b75dfdd3d8cae06998ee6f5c81678dc5..98a1f8d5548f8eeae6246cd7e1ef67eb892742e1 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -200,7 +200,7 @@ std::string TaskNode::VisualStr() const { bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); } -void TaskNode::ToProto(TaskProto* task_proto) { +void TaskNode::ToProto(TaskProto* task_proto) const { // Step1: process some scalar items. CHECK_NE(chain_id_, -1); task_proto->set_task_type(GetTaskType()); diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 969f0bd413b3c47b164c84ee2ebb9ee044d0ef97..4512f62ccae63db9cd9934b94bf9d35e4a798aca 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -94,7 +94,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> { virtual TaskType GetTaskType() const { return TaskType::kInvalid; } std::string VisualStr() const override; virtual bool IsMeaningLess(); - virtual void ToProto(TaskProto*); + virtual void ToProto(TaskProto*) const; virtual bool IsIndependent() const { return false; } void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); virtual int64_t MemZoneId121() const; diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 2462b4ded84fcd0ba0361f9432cce688539316ae..1749181ee22602fd5694c1909e28bab1d481c01a 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -20,6 +20,8 @@ limitations under the License. #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job_rewriter/job_completer.h" +#include "oneflow/core/thread/thread_pool.h" +#include "oneflow/core/common/blocking_counter.h" namespace oneflow { @@ -114,16 +116,30 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); }); // Step4: put infomation from task_gph into plan. + const int64_t node_num = task_gph->node_num(); + const int64_t cpu_num = std::thread::hardware_concurrency(); + const int64_t thread_pool_size = std::min(node_num, cpu_num); + BlockingCounter counter(node_num); + std::mutex mtx; + ThreadPool thread_pool(thread_pool_size); task_gph->ForEachNode([&](TaskNode* task_node) { - if (task_node->IsMeaningLess()) { return; } - TaskProto task_proto; - task_node->ToProto(&task_proto); - if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat - || task_node->GetTaskType() == kAcc) { - CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); - } - plan->mutable_task()->Add(std::move(task_proto)); - }); + thread_pool.AddWork([task_node, plan, &job_desc, &counter, &mtx]() { + if (!task_node->IsMeaningLess()) { + TaskProto task_proto; + task_node->ToProto(&task_proto); + { + std::unique_lock<std::mutex> guard(mtx); + if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat + || task_node->GetTaskType() == kAcc) { + CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto); + } + plan->mutable_task()->Add(std::move(task_proto)); + } // guard(mtx) + } + counter.Decrease(); + } /* thread_pool.AddWork */); + } /* task_gph->ForEachNode */); + counter.WaitUntilCntEqualZero(); // NOTE(levi): release task_gph here to decrise memory peak. task_gph.reset(); diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index a05e1874ac2001cbc803d0dc7cb0e69148b57182..e36104850cac7d8a329e72e707e0daa61d9beae4 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/vm/symbol_storage.h" +#include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/job/mirrored_sig_infer_hint.h" @@ -1200,10 +1201,19 @@ Maybe<void> Operator::ToOpAttribute(OpAttribute* op_attribute) const { if (*pair.second == *op_parallel_desc_) { (*symbol_map)[pair.first] = parallel_desc_symbol_id; } else { - (*symbol_map)[pair.first] = - (*Global<std::shared_ptr<ForeignCallback>>::Get()) - ->MakeParallelDescSymbol( - std::make_shared<cfg::ParallelConf>(pair.second->parallel_conf())); + const auto parallel_conf = + std::make_shared<cfg::ParallelConf>(pair.second->parallel_conf()); + const auto MakeParallelDescSymbol = [¶llel_conf]() -> int64_t { + int64_t symbol_id; + const auto BuildInstruction = + [&symbol_id, ¶llel_conf](InstructionsBuilder* builder) -> Maybe<void> { + symbol_id = JUST(JUST(builder->GetParallelDescSymbol(parallel_conf))->symbol_id()); + return Maybe<void>::Ok(); + }; + LogicalRun(BuildInstruction); + return symbol_id; + }; + (*symbol_map)[pair.first] = MakeParallelDescSymbol(); } } for (const auto& tbn : tmp_bns()) { (*symbol_map)[tbn] = parallel_desc_symbol_id; }