Skip to content
Snippets Groups Projects
Unverified Commit 14099cc2 authored by Shenghang Tsai's avatar Shenghang Tsai Committed by GitHub
Browse files

Use multi core to run TaskNode::ToProto (#4820)


* Serialize proto in binary rather than text

* move del ops out from loop

* refine

* Skip GenCollectiveBoxingPlan if no CollectiveBoxingTaskNode

* multi core to proto

* copy pointers explicitly

* make toproto const method

* reorder

* larger tol

* Update test_layers_conv1d.py

* fix deadlock

* remove ForeignCallBack in Operator::ToOpAttribute

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: default avatarliujuncheng <liujuncheng1022@gmail.com>
Co-authored-by: default avatarclackhan <han_binbin@163.com>
parent a274c76a
No related branches found
No related tags found
No related merge requests found
......@@ -67,7 +67,7 @@ std::vector<CompTaskNode*> GetCompTaskNodesOnEdge(
std::string CompTaskNode::VisualStr() const { return op_node_->op().op_name(); }
void CompTaskNode::ToProto(TaskProto* task_proto) {
void CompTaskNode::ToProto(TaskProto* task_proto) const {
TaskNode::ToProto(task_proto);
*(task_proto->mutable_parallel_ctx()) = parallel_ctx_;
}
......
......@@ -36,7 +36,7 @@ class CompTaskNode : public TaskNode {
UNIMPLEMENTED();
#endif
}
virtual void ToProto(TaskProto*) override;
virtual void ToProto(TaskProto*) const override;
// parallel_ctx_
int64_t parallel_id() const { return parallel_ctx_.parallel_id(); }
......
......@@ -200,7 +200,7 @@ std::string TaskNode::VisualStr() const {
bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_regsts_.empty(); }
void TaskNode::ToProto(TaskProto* task_proto) {
void TaskNode::ToProto(TaskProto* task_proto) const {
// Step1: process some scalar items.
CHECK_NE(chain_id_, -1);
task_proto->set_task_type(GetTaskType());
......
......@@ -94,7 +94,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
virtual TaskType GetTaskType() const { return TaskType::kInvalid; }
std::string VisualStr() const override;
virtual bool IsMeaningLess();
virtual void ToProto(TaskProto*);
virtual void ToProto(TaskProto*) const;
virtual bool IsIndependent() const { return false; }
void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name);
virtual int64_t MemZoneId121() const;
......
......@@ -20,6 +20,8 @@ limitations under the License.
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job_rewriter/job_completer.h"
#include "oneflow/core/thread/thread_pool.h"
#include "oneflow/core/common/blocking_counter.h"
namespace oneflow {
......@@ -114,16 +116,30 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); });
// Step4: put infomation from task_gph into plan.
const int64_t node_num = task_gph->node_num();
const int64_t cpu_num = std::thread::hardware_concurrency();
const int64_t thread_pool_size = std::min(node_num, cpu_num);
BlockingCounter counter(node_num);
std::mutex mtx;
ThreadPool thread_pool(thread_pool_size);
task_gph->ForEachNode([&](TaskNode* task_node) {
if (task_node->IsMeaningLess()) { return; }
TaskProto task_proto;
task_node->ToProto(&task_proto);
if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat
|| task_node->GetTaskType() == kAcc) {
CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);
}
plan->mutable_task()->Add(std::move(task_proto));
});
thread_pool.AddWork([task_node, plan, &job_desc, &counter, &mtx]() {
if (!task_node->IsMeaningLess()) {
TaskProto task_proto;
task_node->ToProto(&task_proto);
{
std::unique_lock<std::mutex> guard(mtx);
if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat
|| task_node->GetTaskType() == kAcc) {
CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);
}
plan->mutable_task()->Add(std::move(task_proto));
} // guard(mtx)
}
counter.Decrease();
} /* thread_pool.AddWork */);
} /* task_gph->ForEachNode */);
counter.WaitUntilCntEqualZero();
// NOTE(levi): release task_gph here to decrise memory peak.
task_gph.reset();
......
......@@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/job/mirrored_sig_infer_hint.h"
......@@ -1200,10 +1201,19 @@ Maybe<void> Operator::ToOpAttribute(OpAttribute* op_attribute) const {
if (*pair.second == *op_parallel_desc_) {
(*symbol_map)[pair.first] = parallel_desc_symbol_id;
} else {
(*symbol_map)[pair.first] =
(*Global<std::shared_ptr<ForeignCallback>>::Get())
->MakeParallelDescSymbol(
std::make_shared<cfg::ParallelConf>(pair.second->parallel_conf()));
const auto parallel_conf =
std::make_shared<cfg::ParallelConf>(pair.second->parallel_conf());
const auto MakeParallelDescSymbol = [&parallel_conf]() -> int64_t {
int64_t symbol_id;
const auto BuildInstruction =
[&symbol_id, &parallel_conf](InstructionsBuilder* builder) -> Maybe<void> {
symbol_id = JUST(JUST(builder->GetParallelDescSymbol(parallel_conf))->symbol_id());
return Maybe<void>::Ok();
};
LogicalRun(BuildInstruction);
return symbol_id;
};
(*symbol_map)[pair.first] = MakeParallelDescSymbol();
}
}
for (const auto& tbn : tmp_bns()) { (*symbol_map)[tbn] = parallel_desc_symbol_id; }
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment