Skip to content
Snippets Groups Projects
Unverified Commit 50e1c346 authored by Li Xinqi's avatar Li Xinqi Committed by GitHub
Browse files

Job pass maybe system (#5503)


* refactor job_pass by maybe_system

* remove useless files

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 810d8db5
No related branches found
No related tags found
No related merge requests found
...@@ -532,8 +532,7 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node ...@@ -532,8 +532,7 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node
CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size()); CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());
FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) { FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {
std::string regst_desc_name; std::string regst_desc_name;
RegstDesc* ctrl_regst_desc = src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
TaskEdge* edge = NewEdge(); TaskEdge* edge = NewEdge();
Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i)); Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i));
src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name); src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name);
......
...@@ -47,7 +47,7 @@ void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) { ...@@ -47,7 +47,7 @@ void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) {
void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
// Step1: ensure job is completed. // Step1: ensure job is completed.
if (need_job_complete) { JobCompleter().Complete(job); } if (need_job_complete) { CHECK_JUST(JobCompleter().Complete(job)); }
// Step2: new Global<OpGraph> and set log configs. // Step2: new Global<OpGraph> and set log configs.
Global<OpGraph>::New(*job); Global<OpGraph>::New(*job);
......
...@@ -171,21 +171,22 @@ Maybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick ...@@ -171,21 +171,22 @@ Maybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick
auto mut_helper = NewMutOpConTickInputHelper(op.op_conf()); auto mut_helper = NewMutOpConTickInputHelper(op.op_conf());
if (!mut_helper) { return Maybe<void>::Ok(); } if (!mut_helper) { return Maybe<void>::Ok(); }
if (mut_helper->IsTickInputBound() == true) { return Maybe<void>::Ok(); } if (mut_helper->IsTickInputBound() == true) { return Maybe<void>::Ok(); }
job_builder->MutOpsOnlyOnce({mut_helper->NewTickInputBoundOpConf(src_lbn)}); JUST(job_builder->MutOpOnlyOnce(mut_helper->NewTickInputBoundOpConf(src_lbn)));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
})); }));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
const OpNode* GetSrcSubsetTickOpNode(const OpGraph& op_graph) { Maybe<const OpNode*> GetSrcSubsetTickOpNode(const OpGraph& op_graph) {
const OpNode* src_subset_tick = nullptr; const OpNode* src_subset_tick = nullptr;
op_graph.ForEachNode([&](OpNode* op_node) { JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
if (op_node->op().op_conf().has_src_subset_tick_conf()) { if (op_node->op().op_conf().has_src_subset_tick_conf()) {
CHECK_ISNULL(src_subset_tick); CHECK_ISNULL_OR_RETURN(src_subset_tick);
src_subset_tick = op_node; src_subset_tick = op_node;
} }
}); return Maybe<void>::Ok();
CHECK_NOTNULL(src_subset_tick); }));
CHECK_NOTNULL_OR_RETURN(src_subset_tick);
return src_subset_tick; return src_subset_tick;
} }
...@@ -277,77 +278,81 @@ std::vector<std::string> GetOpNames(const HashSet<const OpNode*>& op_nodes) { ...@@ -277,77 +278,81 @@ std::vector<std::string> GetOpNames(const HashSet<const OpNode*>& op_nodes) {
return ret; return ret;
}; };
void InitOpTypeCase2OpNodes( Maybe<void> InitOpTypeCase2OpNodes(
const OpGraph& op_graph, const OpGraph& op_graph,
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>>* op_type_case2op_nodes) { HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>>* op_type_case2op_nodes) {
op_graph.ForEachNode([&](OpNode* op_node) { JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
const auto& op_conf = op_node->op().op_conf(); const auto& op_conf = op_node->op().op_conf();
if (IsInterfaceOpConf(op_conf)) { if (IsInterfaceOpConf(op_conf)) {
CHECK((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second); CHECK_OR_RETURN((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second);
} }
}); return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
} }
void ForEachInputCriticalSectionOpNodes( Maybe<void> ForEachInputCriticalSectionOpNodes(
const OpGraph& op_graph, const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>& const std::function<Maybe<void>(const HashSet<const OpNode*>&,
Handler) { const std::vector<std::string>&)>& Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes; HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes); JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));
OperatorConf::OpTypeCase op_type_case = OperatorConf::kInputConf; OperatorConf::OpTypeCase op_type_case = OperatorConf::kInputConf;
if (op_type_case2op_nodes[op_type_case].empty()) { return; } if (op_type_case2op_nodes[op_type_case].empty()) { return Maybe<void>::Ok(); }
HashSet<const OpNode*> op_nodes = op_type_case2op_nodes[op_type_case]; HashSet<const OpNode*> op_nodes = op_type_case2op_nodes[op_type_case];
for (const OpNode* op_node : op_type_case2op_nodes[op_type_case]) { for (const OpNode* op_node : op_type_case2op_nodes[op_type_case]) {
op_node->ForEachNodeOnOutEdge([&](OpNode* out_node) { op_nodes.insert(out_node); }); op_node->ForEachNodeOnOutEdge([&](OpNode* out_node) { op_nodes.insert(out_node); });
} }
Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case])); JUST(Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case])));
return Maybe<void>::Ok();
} }
void ForEachOutputCriticalSectionOpNodes( Maybe<void> ForEachOutputCriticalSectionOpNodes(
const OpGraph& op_graph, const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>& const std::function<Maybe<void>(const HashSet<const OpNode*>&,
Handler) { const std::vector<std::string>&)>& Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes; HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes); JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));
if (op_type_case2op_nodes[OperatorConf::kReturnConf].empty() == false) { if (op_type_case2op_nodes[OperatorConf::kReturnConf].empty() == false) {
Handler(op_type_case2op_nodes[OperatorConf::kReturnConf], JUST(Handler(op_type_case2op_nodes[OperatorConf::kReturnConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf])); GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf])));
} }
if (op_type_case2op_nodes[OperatorConf::kOutputConf].empty() == false) { if (op_type_case2op_nodes[OperatorConf::kOutputConf].empty() == false) {
Handler(op_type_case2op_nodes[OperatorConf::kOutputConf], JUST(Handler(op_type_case2op_nodes[OperatorConf::kOutputConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf])); GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf])));
} }
return Maybe<void>::Ok();
} }
std::vector<OperatorConf> AddTickForTimeShape(const Shape& src_time_shape, Maybe<std::vector<OperatorConf>> AddTickForTimeShape(const Shape& src_time_shape,
const HashSet<const OpNode*>& op_nodes, const HashSet<const OpNode*>& op_nodes,
JobBuilder* job_builder) { JobBuilder* job_builder) {
HashMap<std::pair<ParallelDesc, std::pair<Shape, Shape>>, std::list<const OpNode*>> HashMap<std::pair<ParallelDesc, std::pair<Shape, Shape>>, std::list<const OpNode*>>
pd7ts2op_nodes; pd7ts2op_nodes;
for (const OpNode* op_node : op_nodes) { for (const OpNode* op_node : op_nodes) {
auto ts = std::make_pair(*CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape()), auto ts = std::make_pair(*JUST(op_node->op().GetInputOutputFastestTimeShape()),
*CHECK_JUST(op_node->op().GetOpTimeShape())); *JUST(op_node->op().GetOpTimeShape()));
pd7ts2op_nodes[{op_node->parallel_desc(), ts}].push_back(op_node); pd7ts2op_nodes[{op_node->parallel_desc(), ts}].push_back(op_node);
} }
std::vector<OperatorConf> op_confs; std::vector<OperatorConf> op_confs;
for (const auto& pair : pd7ts2op_nodes) { for (const auto& pair : pd7ts2op_nodes) {
const std::pair<Shape, Shape>& ts = pair.first.second; const std::pair<Shape, Shape>& ts = pair.first.second;
if (ts.second.elem_cnt() == src_time_shape.elem_cnt()) { if (ts.second.elem_cnt() == src_time_shape.elem_cnt()) {
CHECK_GE(ts.first.elem_cnt(), ts.second.elem_cnt()); CHECK_GE_OR_RETURN(ts.first.elem_cnt(), ts.second.elem_cnt());
op_confs.push_back( op_confs.push_back(
AppendTick("Append", pair.second, std::make_shared<const Shape>(ts.second), job_builder)); AppendTick("Append", pair.second, std::make_shared<const Shape>(ts.second), job_builder));
} else if (ts.second.elem_cnt() > src_time_shape.elem_cnt()) { } else if (ts.second.elem_cnt() > src_time_shape.elem_cnt()) {
op_confs.push_back(AppendAccTick(src_time_shape, pair.second, job_builder)); op_confs.push_back(AppendAccTick(src_time_shape, pair.second, job_builder));
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED_THEN_RETURN();
} }
} }
return op_confs; return op_confs;
} }
void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes, Maybe<void> AddGlobalInputOutputCriticalSection(
const std::vector<std::string>& lbi_producer_op_names, const HashSet<const OpNode*>& op_nodes, const std::vector<std::string>& lbi_producer_op_names,
JobBuilder* job_builder) { JobBuilder* job_builder) {
auto* critical_section = auto* critical_section =
Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id()); Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
{ {
...@@ -358,22 +363,20 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes, ...@@ -358,22 +363,20 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
auto time_shape = std::make_unique<Shape>(DimVector{1, 1}); auto time_shape = std::make_unique<Shape>(DimVector{1, 1});
HashMap<ParallelDesc, HashSet<const OpNode*>> parallel_desc2op_nodes; HashMap<ParallelDesc, HashSet<const OpNode*>> parallel_desc2op_nodes;
for (const OpNode* op_node : op_nodes) { for (const OpNode* op_node : op_nodes) {
CHECK(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second); CHECK_OR_RETURN(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second);
} }
std::vector<OperatorConf> source_ticks; std::vector<OperatorConf> source_ticks;
std::vector<OperatorConf> sink_ticks; std::vector<OperatorConf> sink_ticks;
for (const auto& pair : parallel_desc2op_nodes) { for (const auto& pair : parallel_desc2op_nodes) {
source_ticks.push_back(PrependTick(pair.second, job_builder)); source_ticks.push_back(PrependTick(pair.second, job_builder));
for (const auto& sink_tick : AddTickForTimeShape(*time_shape, pair.second, job_builder)) { const auto& ops = JUST(AddTickForTimeShape(*time_shape, pair.second, job_builder));
sink_ticks.push_back(sink_tick); for (const auto& sink_tick : *ops) { sink_ticks.push_back(sink_tick); }
}
} }
OperatorConf src_subset_tick_op; OperatorConf src_subset_tick_op;
{ {
CHECK_EQ(source_ticks.empty(), false); CHECK_EQ_OR_RETURN(source_ticks.empty(), false);
CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder)); JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
CHECK_JUST( JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder));
CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder));
for (auto& op_conf : source_ticks) { for (auto& op_conf : source_ticks) {
op_conf.mutable_tick_conf()->add_tick(src_subset_tick_op.name() + "/" op_conf.mutable_tick_conf()->add_tick(src_subset_tick_op.name() + "/"
+ src_subset_tick_op.src_subset_tick_conf().out()); + src_subset_tick_op.src_subset_tick_conf().out());
...@@ -384,70 +387,86 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes, ...@@ -384,70 +387,86 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
for (const auto& op_conf : sink_ticks) { for (const auto& op_conf : sink_ticks) {
LogicalBlobId lbi; LogicalBlobId lbi;
lbi.set_op_name(op_conf.name()); lbi.set_op_name(op_conf.name());
CHECK(op_conf.has_device_tick_conf()); CHECK_OR_RETURN(op_conf.has_device_tick_conf());
lbi.set_blob_name(op_conf.device_tick_conf().out()); lbi.set_blob_name(op_conf.device_tick_conf().out());
CHECK(tick_lbis.insert(lbi).second); CHECK_OR_RETURN(tick_lbis.insert(lbi).second);
} }
CHECK_JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis, JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis,
job_builder)); job_builder));
return Maybe<void>::Ok();
} }
} // namespace } // namespace
void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) { Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {
PrependTickByParallelDesc(op_graph, job_builder); PrependTickByParallelDesc(op_graph, job_builder);
OperatorConf src_subset_tick_op; OperatorConf src_subset_tick_op;
CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder)); JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
CHECK_JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder)); JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder));
return Maybe<void>::Ok();
} }
void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) { Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) {
const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape()); const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));
const auto& src_time_shape = *JUST(op_node->op().GetOpTimeShape());
HashSet<const OpNode*> sink_op_nodes; HashSet<const OpNode*> sink_op_nodes;
op_graph.ForEachNode([&](OpNode* op_node) { JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
CHECK(!op_node->op().op_conf().has_sink_tick_conf()); CHECK_OR_RETURN(!op_node->op().op_conf().has_sink_tick_conf());
size_t out_cnt = 0; size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; }); op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt == 0) { sink_op_nodes.insert(op_node); } if (out_cnt == 0) { sink_op_nodes.insert(op_node); }
}); return Maybe<void>::Ok();
AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder); }));
JUST(AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder));
return Maybe<void>::Ok();
} }
void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) { Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
auto* critical_section = auto* critical_section =
Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id()); Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
critical_section->mutable_total_job_critical_section(); critical_section->mutable_total_job_critical_section();
op_graph.ForEachNode([&](OpNode* node) { CHECK(!node->op().op_conf().has_sink_tick_conf()); }); JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe<void> {
const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape()); CHECK_OR_RETURN(!node->op().op_conf().has_sink_tick_conf());
return Maybe<void>::Ok();
}));
const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));
const auto& src_time_shape = JUST(op_node->op().GetOpTimeShape());
HashSet<LogicalBlobId> tick_lbis; HashSet<LogicalBlobId> tick_lbis;
op_graph.ForEachNode([&](OpNode* op_node) { JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
size_t out_cnt = 0; size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; }); op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt > 0) { return; } if (out_cnt > 0) { return Maybe<void>::Ok(); }
CHECK(op_node->op().op_conf().has_device_tick_conf()); CHECK_OR_RETURN(op_node->op().op_conf().has_device_tick_conf());
CHECK(CHECK_JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape.elem_cnt()); CHECK_OR_RETURN(JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape->elem_cnt());
CHECK(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second); CHECK_OR_RETURN(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second);
}); return Maybe<void>::Ok();
OperatorConf src_subset_tick = CHECK_JUST(FindSrcSubsetTickOpConf(job_builder->job())); }));
CHECK_JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder)); OperatorConf src_subset_tick = JUST(FindSrcSubsetTickOpConf(job_builder->job()));
CHECK_JUST( JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder));
CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder)); JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder));
return Maybe<void>::Ok();
} }
void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
ForEachInputCriticalSectionOpNodes( JUST(ForEachInputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes, op_graph,
const std::vector<std::string>& lbi_producer_op_names) { [&](const HashSet<const OpNode*>& op_nodes,
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder); const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> {
}); JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
} }
void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
ForEachOutputCriticalSectionOpNodes( JUST(ForEachOutputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes, op_graph,
const std::vector<std::string>& lbi_producer_op_names) { [&](const HashSet<const OpNode*>& op_nodes,
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder); const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> {
}); JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
} }
} // namespace oneflow } // namespace oneflow
...@@ -22,11 +22,11 @@ limitations under the License. ...@@ -22,11 +22,11 @@ limitations under the License.
namespace oneflow { namespace oneflow {
void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder); Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder);
void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder); Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder);
void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder); Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder);
void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder); Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder); Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
class MutOpConTickInputHelper { class MutOpConTickInputHelper {
public: public:
......
...@@ -19,7 +19,7 @@ limitations under the License. ...@@ -19,7 +19,7 @@ limitations under the License.
namespace oneflow { namespace oneflow {
void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) { Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) {
HashMap<LogicalBlobId, HashMap<std::pair<ParallelDesc, cfg::ParallelDistribution>, HashMap<LogicalBlobId, HashMap<std::pair<ParallelDesc, cfg::ParallelDistribution>,
std::vector<std::pair<const OpNode*, std::string>>>> std::vector<std::pair<const OpNode*, std::string>>>>
lbi2consumer_grouped_by_parallel; lbi2consumer_grouped_by_parallel;
...@@ -76,13 +76,14 @@ void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) ...@@ -76,13 +76,14 @@ void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder)
OperatorConf& consumer_op_conf = op_node2op_conf[consumer]; OperatorConf& consumer_op_conf = op_node2op_conf[consumer];
const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn,
GenLogicalBlobName(grouped_lbi)); GenLogicalBlobName(grouped_lbi));
CHECK_EQ(GenLogicalBlobName(lbi), old_val); CHECK_EQ_OR_RETURN(GenLogicalBlobName(lbi), old_val);
} }
} }
} }
for (const auto& op_node7op_conf : op_node2op_conf) { for (const auto& op_node7op_conf : op_node2op_conf) {
job_builder->MutOpsOnlyOnce({op_node7op_conf.second}); JUST(job_builder->MutOpOnlyOnce(op_node7op_conf.second));
} }
return Maybe<void>::Ok();
} }
} // namespace oneflow } // namespace oneflow
...@@ -23,7 +23,7 @@ namespace oneflow { ...@@ -23,7 +23,7 @@ namespace oneflow {
class OpGraph; class OpGraph;
class Job; class Job;
void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder); Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder);
} // namespace oneflow } // namespace oneflow
......
...@@ -27,30 +27,35 @@ namespace oneflow { ...@@ -27,30 +27,35 @@ namespace oneflow {
namespace { namespace {
void CheckOpGraph(const OpGraph& op_graph) { Maybe<void> CheckOpGraph(const OpGraph& op_graph) {
op_graph.ForEachNode([&](OpNode* op_node) { JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
size_t in_cnt = 0; size_t in_cnt = 0;
op_graph.ForEachDataAndCtrlInNode(op_node, [&](OpNode*) { ++in_cnt; }); op_graph.ForEachDataAndCtrlInNode(op_node, [&](OpNode*) { ++in_cnt; });
if (in_cnt == 0) { CHECK(op_node->op().op_conf().has_source_tick_conf()); } if (in_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_source_tick_conf()); }
size_t out_cnt = 0; size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; }); op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt == 0) { CHECK(op_node->op().op_conf().has_sink_tick_conf()); } if (out_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_sink_tick_conf()); }
}); return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
} }
void WithOpGraphAndMutJob(Job* job, const std::function<void(const OpGraph&, Job*)>& Handler) { Maybe<void> WithOpGraphAndMutJob(Job* job,
const std::function<Maybe<void>(const OpGraph&, Job*)>& Handler) {
OpGraph op_graph(*job); OpGraph op_graph(*job);
Handler(op_graph, job); JUST(Handler(op_graph, job));
return Maybe<void>::Ok();
} }
void WithOpGraphAndMutJobBuilder(Job* job, Maybe<void> WithOpGraphAndMutJobBuilder(
const std::function<void(const OpGraph&, JobBuilder*)>& Handler) { Job* job, const std::function<Maybe<void>(const OpGraph&, JobBuilder*)>& Handler) {
OpGraph op_graph(*job); OpGraph op_graph(*job);
JobBuilder job_builder(job); JobBuilder job_builder(job);
Handler(op_graph, &job_builder); JUST(Handler(op_graph, &job_builder));
return Maybe<void>::Ok();
} }
void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) { Maybe<void> SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) {
auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool { auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool {
for (const std::string& bn : op.input_bns()) { for (const std::string& bn : op.input_bns()) {
if (op.BnInOp2Lbi(bn) == lbi && op.InputBlobModifier4Ibn(bn).is_mutable()) { return true; } if (op.BnInOp2Lbi(bn) == lbi && op.InputBlobModifier4Ibn(bn).is_mutable()) { return true; }
...@@ -59,9 +64,9 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder ...@@ -59,9 +64,9 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
}; };
auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();
HashMap<const OperatorConf*, HashSet<std::string>> op_conf2ctrl_in_op_names; HashMap<const OperatorConf*, HashSet<std::string>> op_conf2ctrl_in_op_names;
op_graph.ForEachNode([&](OpNode* op_node) { JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
if (op_node->op().op_conf().has_variable_conf() == false) { return; } if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe<void>::Ok(); }
if (op_node->out_edges().size() <= 1) { return; } if (op_node->out_edges().size() <= 1) { return Maybe<void>::Ok(); }
const Operator& variable_op = op_node->op(); const Operator& variable_op = op_node->op();
const LogicalBlobId& variable_lbi = variable_op.BnInOp2Lbi(variable_op.SoleObn()); const LogicalBlobId& variable_lbi = variable_op.BnInOp2Lbi(variable_op.SoleObn());
const OperatorConf* mutable_consumer = nullptr; const OperatorConf* mutable_consumer = nullptr;
...@@ -69,17 +74,18 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder ...@@ -69,17 +74,18 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
for (OpEdge* edge : op_node->out_edges()) { for (OpEdge* edge : op_node->out_edges()) {
const auto& op_conf = edge->dst_node()->op().op_conf(); const auto& op_conf = edge->dst_node()->op().op_conf();
if (IsMutableConsumedLbi(edge->dst_node()->op(), variable_lbi)) { if (IsMutableConsumedLbi(edge->dst_node()->op(), variable_lbi)) {
CHECK(mutable_consumer == nullptr); CHECK_OR_RETURN(mutable_consumer == nullptr);
mutable_consumer = &op_conf; mutable_consumer = &op_conf;
} else { } else {
naive_consumers.push_back(&op_conf); naive_consumers.push_back(&op_conf);
} }
} }
if (mutable_consumer == nullptr) { return; } if (mutable_consumer == nullptr) { return Maybe<void>::Ok(); }
for (const auto* fw_bw_op : naive_consumers) { for (const auto* fw_bw_op : naive_consumers) {
op_conf2ctrl_in_op_names[mutable_consumer].insert(fw_bw_op->name()); op_conf2ctrl_in_op_names[mutable_consumer].insert(fw_bw_op->name());
} }
}); return Maybe<void>::Ok();
}));
for (const auto& pair : op_conf2ctrl_in_op_names) { for (const auto& pair : op_conf2ctrl_in_op_names) {
OperatorConf mut_mutable_consumer_op_conf(*pair.first); OperatorConf mut_mutable_consumer_op_conf(*pair.first);
for (const auto& fw_bw_op_name : pair.second) { for (const auto& fw_bw_op_name : pair.second) {
...@@ -87,45 +93,46 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder ...@@ -87,45 +93,46 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
mut_mutable_consumer_op_conf.add_ctrl_in_op_name(fw_bw_op_name); mut_mutable_consumer_op_conf.add_ctrl_in_op_name(fw_bw_op_name);
} }
} }
job_builder->MutOpsOnlyOnce({mut_mutable_consumer_op_conf}); JUST(job_builder->MutOpOnlyOnce(mut_mutable_consumer_op_conf));
} }
return Maybe<void>::Ok();
} }
} // namespace } // namespace
void JobCompleter::Complete(Job* job) const { Maybe<void> JobCompleter::Complete(Job* job) const {
JobPassCtx job_pass_ctx(GlobalJobDesc()); JobPassCtx job_pass_ctx(GlobalJobDesc());
CHECK_JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
// NOTE(chengcheng): disable this pass for reduce boxing memory life cycle to memory cost. // NOTE(chengcheng): disable this pass for reduce boxing memory life cycle to memory cost.
if (!Global<ResourceDesc, ForSession>::Get()->resource().disable_group_boxing_by_dst_parallel()) { if (!Global<ResourceDesc, ForSession>::Get()->resource().disable_group_boxing_by_dst_parallel()) {
WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel); JUST(WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel));
} }
WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp); JUST(WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp));
// complete tick ops // complete tick ops
WithOpGraphAndMutJobBuilder(job, &AutoPrependTick); JUST(WithOpGraphAndMutJobBuilder(job, &AutoPrependTick));
WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape); JUST(WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape));
WithOpGraphAndMutJobBuilder(job, &AutoSourceAndSinkTick); JUST(WithOpGraphAndMutJobBuilder(job, &AutoSourceAndSinkTick));
WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections); JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections));
WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections); JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections));
CHECK_JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
if (XrtCompilationEnabled(GlobalJobDesc())) { if (XrtCompilationEnabled(GlobalJobDesc())) {
#ifdef OF_WITH_XRT #ifdef OF_WITH_XRT
WithOpGraphAndMutJob(job, &RebuildXrtCompiledJob); JUST(WithOpGraphAndMutJob(job, &RebuildXrtCompiledJob));
#else #else
LOG(WARNING) << "It will not use XLA or TensorRT since WITH_XLA or " LOG(WARNING) << "It will not use XLA or TensorRT since WITH_XLA or "
"WITH_TENSORRT was not enabled when compiling the project."; "WITH_TENSORRT was not enabled when compiling the project.";
#endif // OF_WITH_XRT #endif // OF_WITH_XRT
} }
#ifdef WITH_CUDA #ifdef WITH_CUDA
if (Global<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) { if (Global<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) {
// NOTE(chengcheng): this pass need as last pass for insert correct op with nccl boxing. // NOTE(chengcheng): this pass need as last pass for insert correct op with nccl boxing.
JobPass4Name("InsertNcclLogicalOpPass")(job, &job_pass_ctx); JUST(JobPass4Name("InsertNcclLogicalOpPass")(job, &job_pass_ctx));
// NOTE(chengcheng): Becasue insert new logical nccl op, MUST dump time shape, sbp again. // NOTE(chengcheng): Becasue insert new logical nccl op, MUST dump time shape, sbp again.
JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx); JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
} }
#endif // WITH_CUDA #endif // WITH_CUDA
CheckOpGraph(OpGraph(*job)); JUST(CheckOpGraph(OpGraph(*job)));
return Maybe<void>::Ok();
} }
} // namespace oneflow } // namespace oneflow
...@@ -28,7 +28,7 @@ class JobCompleter final { ...@@ -28,7 +28,7 @@ class JobCompleter final {
JobCompleter() = default; JobCompleter() = default;
~JobCompleter() = default; ~JobCompleter() = default;
void Complete(Job* job) const; Maybe<void> Complete(Job* job) const;
}; };
} // namespace oneflow } // namespace oneflow
......
...@@ -27,7 +27,7 @@ limitations under the License. ...@@ -27,7 +27,7 @@ limitations under the License.
namespace oneflow { namespace oneflow {
inline void RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) { inline Maybe<void> RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) {
#ifdef OF_WITH_XRT #ifdef OF_WITH_XRT
const auto& job_desc = GlobalJobDesc(); const auto& job_desc = GlobalJobDesc();
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
...@@ -43,6 +43,7 @@ inline void RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) { ...@@ -43,6 +43,7 @@ inline void RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) {
->Write(*job); ->Write(*job);
} }
#endif // OF_WITH_XRT #endif // OF_WITH_XRT
return Maybe<void>::Ok();
} }
inline bool XrtCompilationEnabled(const JobDesc& job_desc) { inline bool XrtCompilationEnabled(const JobDesc& job_desc) {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment