diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 744754156941a78838db74bda27016edd1a70d8e..eb44f7b846a536c29f84fb7f458b1f3477e17018 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -532,8 +532,7 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size()); FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) { std::string regst_desc_name; - RegstDesc* ctrl_regst_desc = - src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), ®st_desc_name); + src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), ®st_desc_name); TaskEdge* edge = NewEdge(); Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i)); src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name); diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index d918bd41c4fae5d9e0a55cc028231a4630d92250..92df9f1c6a0c9bae77e8a74574f06025f2f84bc4 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -47,7 +47,7 @@ void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) { void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { // Step1: ensure job is completed. - if (need_job_complete) { JobCompleter().Complete(job); } + if (need_job_complete) { CHECK_JUST(JobCompleter().Complete(job)); } // Step2: new Global<OpGraph> and set log configs. Global<OpGraph>::New(*job); diff --git a/oneflow/core/job_rewriter/autotick.cpp b/oneflow/core/job_rewriter/autotick.cpp index 92fcb12efec680c792d1078f45bf2bf954379435..b7efa49665ec85b950e034d7d41278b7661c1203 100644 --- a/oneflow/core/job_rewriter/autotick.cpp +++ b/oneflow/core/job_rewriter/autotick.cpp @@ -171,21 +171,22 @@ Maybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick auto mut_helper = NewMutOpConTickInputHelper(op.op_conf()); if (!mut_helper) { return Maybe<void>::Ok(); } if (mut_helper->IsTickInputBound() == true) { return Maybe<void>::Ok(); } - job_builder->MutOpsOnlyOnce({mut_helper->NewTickInputBoundOpConf(src_lbn)}); + JUST(job_builder->MutOpOnlyOnce(mut_helper->NewTickInputBoundOpConf(src_lbn))); return Maybe<void>::Ok(); })); return Maybe<void>::Ok(); } -const OpNode* GetSrcSubsetTickOpNode(const OpGraph& op_graph) { +Maybe<const OpNode*> GetSrcSubsetTickOpNode(const OpGraph& op_graph) { const OpNode* src_subset_tick = nullptr; - op_graph.ForEachNode([&](OpNode* op_node) { + JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> { if (op_node->op().op_conf().has_src_subset_tick_conf()) { - CHECK_ISNULL(src_subset_tick); + CHECK_ISNULL_OR_RETURN(src_subset_tick); src_subset_tick = op_node; } - }); - CHECK_NOTNULL(src_subset_tick); + return Maybe<void>::Ok(); + })); + CHECK_NOTNULL_OR_RETURN(src_subset_tick); return src_subset_tick; } @@ -277,77 +278,81 @@ std::vector<std::string> GetOpNames(const HashSet<const OpNode*>& op_nodes) { return ret; }; -void InitOpTypeCase2OpNodes( +Maybe<void> InitOpTypeCase2OpNodes( const OpGraph& op_graph, HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>>* op_type_case2op_nodes) { - op_graph.ForEachNode([&](OpNode* op_node) { + JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> { const auto& op_conf = op_node->op().op_conf(); if (IsInterfaceOpConf(op_conf)) { - CHECK((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second); + CHECK_OR_RETURN((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second); } - }); + return Maybe<void>::Ok(); + })); + return Maybe<void>::Ok(); } -void ForEachInputCriticalSectionOpNodes( +Maybe<void> ForEachInputCriticalSectionOpNodes( const OpGraph& op_graph, - const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>& - Handler) { + const std::function<Maybe<void>(const HashSet<const OpNode*>&, + const std::vector<std::string>&)>& Handler) { HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes; - InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes); + JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes)); OperatorConf::OpTypeCase op_type_case = OperatorConf::kInputConf; - if (op_type_case2op_nodes[op_type_case].empty()) { return; } + if (op_type_case2op_nodes[op_type_case].empty()) { return Maybe<void>::Ok(); } HashSet<const OpNode*> op_nodes = op_type_case2op_nodes[op_type_case]; for (const OpNode* op_node : op_type_case2op_nodes[op_type_case]) { op_node->ForEachNodeOnOutEdge([&](OpNode* out_node) { op_nodes.insert(out_node); }); } - Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case])); + JUST(Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case]))); + return Maybe<void>::Ok(); } -void ForEachOutputCriticalSectionOpNodes( +Maybe<void> ForEachOutputCriticalSectionOpNodes( const OpGraph& op_graph, - const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>& - Handler) { + const std::function<Maybe<void>(const HashSet<const OpNode*>&, + const std::vector<std::string>&)>& Handler) { HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes; - InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes); + JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes)); if (op_type_case2op_nodes[OperatorConf::kReturnConf].empty() == false) { - Handler(op_type_case2op_nodes[OperatorConf::kReturnConf], - GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf])); + JUST(Handler(op_type_case2op_nodes[OperatorConf::kReturnConf], + GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf]))); } if (op_type_case2op_nodes[OperatorConf::kOutputConf].empty() == false) { - Handler(op_type_case2op_nodes[OperatorConf::kOutputConf], - GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf])); + JUST(Handler(op_type_case2op_nodes[OperatorConf::kOutputConf], + GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf]))); } + return Maybe<void>::Ok(); } -std::vector<OperatorConf> AddTickForTimeShape(const Shape& src_time_shape, - const HashSet<const OpNode*>& op_nodes, - JobBuilder* job_builder) { +Maybe<std::vector<OperatorConf>> AddTickForTimeShape(const Shape& src_time_shape, + const HashSet<const OpNode*>& op_nodes, + JobBuilder* job_builder) { HashMap<std::pair<ParallelDesc, std::pair<Shape, Shape>>, std::list<const OpNode*>> pd7ts2op_nodes; for (const OpNode* op_node : op_nodes) { - auto ts = std::make_pair(*CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape()), - *CHECK_JUST(op_node->op().GetOpTimeShape())); + auto ts = std::make_pair(*JUST(op_node->op().GetInputOutputFastestTimeShape()), + *JUST(op_node->op().GetOpTimeShape())); pd7ts2op_nodes[{op_node->parallel_desc(), ts}].push_back(op_node); } std::vector<OperatorConf> op_confs; for (const auto& pair : pd7ts2op_nodes) { const std::pair<Shape, Shape>& ts = pair.first.second; if (ts.second.elem_cnt() == src_time_shape.elem_cnt()) { - CHECK_GE(ts.first.elem_cnt(), ts.second.elem_cnt()); + CHECK_GE_OR_RETURN(ts.first.elem_cnt(), ts.second.elem_cnt()); op_confs.push_back( AppendTick("Append", pair.second, std::make_shared<const Shape>(ts.second), job_builder)); } else if (ts.second.elem_cnt() > src_time_shape.elem_cnt()) { op_confs.push_back(AppendAccTick(src_time_shape, pair.second, job_builder)); } else { - UNIMPLEMENTED(); + UNIMPLEMENTED_THEN_RETURN(); } } return op_confs; } -void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes, - const std::vector<std::string>& lbi_producer_op_names, - JobBuilder* job_builder) { +Maybe<void> AddGlobalInputOutputCriticalSection( + const HashSet<const OpNode*>& op_nodes, const std::vector<std::string>& lbi_producer_op_names, + JobBuilder* job_builder) { auto* critical_section = Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id()); { @@ -358,22 +363,20 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes, auto time_shape = std::make_unique<Shape>(DimVector{1, 1}); HashMap<ParallelDesc, HashSet<const OpNode*>> parallel_desc2op_nodes; for (const OpNode* op_node : op_nodes) { - CHECK(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second); + CHECK_OR_RETURN(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second); } std::vector<OperatorConf> source_ticks; std::vector<OperatorConf> sink_ticks; for (const auto& pair : parallel_desc2op_nodes) { source_ticks.push_back(PrependTick(pair.second, job_builder)); - for (const auto& sink_tick : AddTickForTimeShape(*time_shape, pair.second, job_builder)) { - sink_ticks.push_back(sink_tick); - } + const auto& ops = JUST(AddTickForTimeShape(*time_shape, pair.second, job_builder)); + for (const auto& sink_tick : *ops) { sink_ticks.push_back(sink_tick); } } OperatorConf src_subset_tick_op; { - CHECK_EQ(source_ticks.empty(), false); - CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder)); - CHECK_JUST( - CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder)); + CHECK_EQ_OR_RETURN(source_ticks.empty(), false); + JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder)); + JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder)); for (auto& op_conf : source_ticks) { op_conf.mutable_tick_conf()->add_tick(src_subset_tick_op.name() + "/" + src_subset_tick_op.src_subset_tick_conf().out()); @@ -384,70 +387,86 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes, for (const auto& op_conf : sink_ticks) { LogicalBlobId lbi; lbi.set_op_name(op_conf.name()); - CHECK(op_conf.has_device_tick_conf()); + CHECK_OR_RETURN(op_conf.has_device_tick_conf()); lbi.set_blob_name(op_conf.device_tick_conf().out()); - CHECK(tick_lbis.insert(lbi).second); + CHECK_OR_RETURN(tick_lbis.insert(lbi).second); } - CHECK_JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis, - job_builder)); + JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis, + job_builder)); + return Maybe<void>::Ok(); } } // namespace -void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) { +Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) { PrependTickByParallelDesc(op_graph, job_builder); OperatorConf src_subset_tick_op; - CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder)); - CHECK_JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder)); + JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder)); + JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder)); + return Maybe<void>::Ok(); } -void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) { - const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape()); +Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) { + const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph)); + const auto& src_time_shape = *JUST(op_node->op().GetOpTimeShape()); HashSet<const OpNode*> sink_op_nodes; - op_graph.ForEachNode([&](OpNode* op_node) { - CHECK(!op_node->op().op_conf().has_sink_tick_conf()); + JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> { + CHECK_OR_RETURN(!op_node->op().op_conf().has_sink_tick_conf()); size_t out_cnt = 0; op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; }); if (out_cnt == 0) { sink_op_nodes.insert(op_node); } - }); - AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder); + return Maybe<void>::Ok(); + })); + JUST(AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder)); + return Maybe<void>::Ok(); } -void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) { +Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) { auto* critical_section = Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id()); critical_section->mutable_total_job_critical_section(); - op_graph.ForEachNode([&](OpNode* node) { CHECK(!node->op().op_conf().has_sink_tick_conf()); }); - const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape()); + JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe<void> { + CHECK_OR_RETURN(!node->op().op_conf().has_sink_tick_conf()); + return Maybe<void>::Ok(); + })); + const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph)); + const auto& src_time_shape = JUST(op_node->op().GetOpTimeShape()); HashSet<LogicalBlobId> tick_lbis; - op_graph.ForEachNode([&](OpNode* op_node) { + JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> { size_t out_cnt = 0; op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; }); - if (out_cnt > 0) { return; } - CHECK(op_node->op().op_conf().has_device_tick_conf()); - CHECK(CHECK_JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape.elem_cnt()); - CHECK(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second); - }); - OperatorConf src_subset_tick = CHECK_JUST(FindSrcSubsetTickOpConf(job_builder->job())); - CHECK_JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder)); - CHECK_JUST( - CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder)); + if (out_cnt > 0) { return Maybe<void>::Ok(); } + CHECK_OR_RETURN(op_node->op().op_conf().has_device_tick_conf()); + CHECK_OR_RETURN(JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape->elem_cnt()); + CHECK_OR_RETURN(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second); + return Maybe<void>::Ok(); + })); + OperatorConf src_subset_tick = JUST(FindSrcSubsetTickOpConf(job_builder->job())); + JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder)); + JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder)); + return Maybe<void>::Ok(); } -void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { - ForEachInputCriticalSectionOpNodes( - op_graph, [&](const HashSet<const OpNode*>& op_nodes, - const std::vector<std::string>& lbi_producer_op_names) { - AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder); - }); +Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { + JUST(ForEachInputCriticalSectionOpNodes( + op_graph, + [&](const HashSet<const OpNode*>& op_nodes, + const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> { + JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder)); + return Maybe<void>::Ok(); + })); + return Maybe<void>::Ok(); } -void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { - ForEachOutputCriticalSectionOpNodes( - op_graph, [&](const HashSet<const OpNode*>& op_nodes, - const std::vector<std::string>& lbi_producer_op_names) { - AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder); - }); +Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { + JUST(ForEachOutputCriticalSectionOpNodes( + op_graph, + [&](const HashSet<const OpNode*>& op_nodes, + const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> { + JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder)); + return Maybe<void>::Ok(); + })); + return Maybe<void>::Ok(); } } // namespace oneflow diff --git a/oneflow/core/job_rewriter/autotick.h b/oneflow/core/job_rewriter/autotick.h index ffb1ae432240e73e276f10995c30e2ca343155ec..2760edda5bec86fe16626fbc01d2978c5da2195e 100644 --- a/oneflow/core/job_rewriter/autotick.h +++ b/oneflow/core/job_rewriter/autotick.h @@ -22,11 +22,11 @@ limitations under the License. namespace oneflow { -void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder); -void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder); -void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder); -void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder); -void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder); +Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder); +Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder); +Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder); +Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder); +Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder); class MutOpConTickInputHelper { public: diff --git a/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp b/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp index 5957d4316dbbc6f3fa511e4ca0dc36b6d1ea9da9..a7f144cd61dc94de43010e09c5b525a499c7a47d 100644 --- a/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp +++ b/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp @@ -19,7 +19,7 @@ limitations under the License. namespace oneflow { -void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) { +Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) { HashMap<LogicalBlobId, HashMap<std::pair<ParallelDesc, cfg::ParallelDistribution>, std::vector<std::pair<const OpNode*, std::string>>>> lbi2consumer_grouped_by_parallel; @@ -76,13 +76,14 @@ void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) OperatorConf& consumer_op_conf = op_node2op_conf[consumer]; const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn, GenLogicalBlobName(grouped_lbi)); - CHECK_EQ(GenLogicalBlobName(lbi), old_val); + CHECK_EQ_OR_RETURN(GenLogicalBlobName(lbi), old_val); } } } for (const auto& op_node7op_conf : op_node2op_conf) { - job_builder->MutOpsOnlyOnce({op_node7op_conf.second}); + JUST(job_builder->MutOpOnlyOnce(op_node7op_conf.second)); } + return Maybe<void>::Ok(); } } // namespace oneflow diff --git a/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h b/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h index 910a891b229c8b318c5d13cbeda4fdc0dd601777..761eba7f912c23fd2de8068f7be8d032acd02e19 100644 --- a/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h +++ b/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.h @@ -23,7 +23,7 @@ namespace oneflow { class OpGraph; class Job; -void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder); +Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder); } // namespace oneflow diff --git a/oneflow/core/job_rewriter/job_completer.cpp b/oneflow/core/job_rewriter/job_completer.cpp index f388af817744dd2915fe275ce5b4392e9c3c0c02..c3833ab332c86862bc34b944c56d18fcd8951026 100644 --- a/oneflow/core/job_rewriter/job_completer.cpp +++ b/oneflow/core/job_rewriter/job_completer.cpp @@ -27,30 +27,35 @@ namespace oneflow { namespace { -void CheckOpGraph(const OpGraph& op_graph) { - op_graph.ForEachNode([&](OpNode* op_node) { +Maybe<void> CheckOpGraph(const OpGraph& op_graph) { + JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> { size_t in_cnt = 0; op_graph.ForEachDataAndCtrlInNode(op_node, [&](OpNode*) { ++in_cnt; }); - if (in_cnt == 0) { CHECK(op_node->op().op_conf().has_source_tick_conf()); } + if (in_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_source_tick_conf()); } size_t out_cnt = 0; op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; }); - if (out_cnt == 0) { CHECK(op_node->op().op_conf().has_sink_tick_conf()); } - }); + if (out_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_sink_tick_conf()); } + return Maybe<void>::Ok(); + })); + return Maybe<void>::Ok(); } -void WithOpGraphAndMutJob(Job* job, const std::function<void(const OpGraph&, Job*)>& Handler) { +Maybe<void> WithOpGraphAndMutJob(Job* job, + const std::function<Maybe<void>(const OpGraph&, Job*)>& Handler) { OpGraph op_graph(*job); - Handler(op_graph, job); + JUST(Handler(op_graph, job)); + return Maybe<void>::Ok(); } -void WithOpGraphAndMutJobBuilder(Job* job, - const std::function<void(const OpGraph&, JobBuilder*)>& Handler) { +Maybe<void> WithOpGraphAndMutJobBuilder( + Job* job, const std::function<Maybe<void>(const OpGraph&, JobBuilder*)>& Handler) { OpGraph op_graph(*job); JobBuilder job_builder(job); - Handler(op_graph, &job_builder); + JUST(Handler(op_graph, &job_builder)); + return Maybe<void>::Ok(); } -void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) { +Maybe<void> SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) { auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool { for (const std::string& bn : op.input_bns()) { if (op.BnInOp2Lbi(bn) == lbi && op.InputBlobModifier4Ibn(bn).is_mutable()) { return true; } @@ -59,9 +64,9 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder }; auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); HashMap<const OperatorConf*, HashSet<std::string>> op_conf2ctrl_in_op_names; - op_graph.ForEachNode([&](OpNode* op_node) { - if (op_node->op().op_conf().has_variable_conf() == false) { return; } - if (op_node->out_edges().size() <= 1) { return; } + JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> { + if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe<void>::Ok(); } + if (op_node->out_edges().size() <= 1) { return Maybe<void>::Ok(); } const Operator& variable_op = op_node->op(); const LogicalBlobId& variable_lbi = variable_op.BnInOp2Lbi(variable_op.SoleObn()); const OperatorConf* mutable_consumer = nullptr; @@ -69,17 +74,18 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder for (OpEdge* edge : op_node->out_edges()) { const auto& op_conf = edge->dst_node()->op().op_conf(); if (IsMutableConsumedLbi(edge->dst_node()->op(), variable_lbi)) { - CHECK(mutable_consumer == nullptr); + CHECK_OR_RETURN(mutable_consumer == nullptr); mutable_consumer = &op_conf; } else { naive_consumers.push_back(&op_conf); } } - if (mutable_consumer == nullptr) { return; } + if (mutable_consumer == nullptr) { return Maybe<void>::Ok(); } for (const auto* fw_bw_op : naive_consumers) { op_conf2ctrl_in_op_names[mutable_consumer].insert(fw_bw_op->name()); } - }); + return Maybe<void>::Ok(); + })); for (const auto& pair : op_conf2ctrl_in_op_names) { OperatorConf mut_mutable_consumer_op_conf(*pair.first); for (const auto& fw_bw_op_name : pair.second) { @@ -87,45 +93,46 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder mut_mutable_consumer_op_conf.add_ctrl_in_op_name(fw_bw_op_name); } } - job_builder->MutOpsOnlyOnce({mut_mutable_consumer_op_conf}); + JUST(job_builder->MutOpOnlyOnce(mut_mutable_consumer_op_conf)); } + return Maybe<void>::Ok(); } } // namespace -void JobCompleter::Complete(Job* job) const { +Maybe<void> JobCompleter::Complete(Job* job) const { JobPassCtx job_pass_ctx(GlobalJobDesc()); - CHECK_JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); + JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); // NOTE(chengcheng): disable this pass for reduce boxing memory life cycle to memory cost. if (!Global<ResourceDesc, ForSession>::Get()->resource().disable_group_boxing_by_dst_parallel()) { - WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel); + JUST(WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel)); } - WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp); + JUST(WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp)); // complete tick ops - WithOpGraphAndMutJobBuilder(job, &AutoPrependTick); - WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape); - WithOpGraphAndMutJobBuilder(job, &AutoSourceAndSinkTick); - WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections); - WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections); - CHECK_JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); + JUST(WithOpGraphAndMutJobBuilder(job, &AutoPrependTick)); + JUST(WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape)); + JUST(WithOpGraphAndMutJobBuilder(job, &AutoSourceAndSinkTick)); + JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections)); + JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections)); + JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); if (XrtCompilationEnabled(GlobalJobDesc())) { #ifdef OF_WITH_XRT - WithOpGraphAndMutJob(job, &RebuildXrtCompiledJob); + JUST(WithOpGraphAndMutJob(job, &RebuildXrtCompiledJob)); #else LOG(WARNING) << "It will not use XLA or TensorRT since WITH_XLA or " "WITH_TENSORRT was not enabled when compiling the project."; #endif // OF_WITH_XRT } - #ifdef WITH_CUDA if (Global<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) { // NOTE(chengcheng): this pass need as last pass for insert correct op with nccl boxing. - JobPass4Name("InsertNcclLogicalOpPass")(job, &job_pass_ctx); + JUST(JobPass4Name("InsertNcclLogicalOpPass")(job, &job_pass_ctx)); // NOTE(chengcheng): Becasue insert new logical nccl op, MUST dump time shape, sbp again. - JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx); + JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); } #endif // WITH_CUDA - CheckOpGraph(OpGraph(*job)); + JUST(CheckOpGraph(OpGraph(*job))); + return Maybe<void>::Ok(); } } // namespace oneflow diff --git a/oneflow/core/job_rewriter/job_completer.h b/oneflow/core/job_rewriter/job_completer.h index b63952645ec2f27b8867cf85de7733981a2cb20e..56763e83a46343b6d444c54bfb1679bd976cad80 100644 --- a/oneflow/core/job_rewriter/job_completer.h +++ b/oneflow/core/job_rewriter/job_completer.h @@ -28,7 +28,7 @@ class JobCompleter final { JobCompleter() = default; ~JobCompleter() = default; - void Complete(Job* job) const; + Maybe<void> Complete(Job* job) const; }; } // namespace oneflow diff --git a/oneflow/core/job_rewriter/xrt_compilation.h b/oneflow/core/job_rewriter/xrt_compilation.h index 7bd06bc4191398c50cf88586bf44393079f2ef84..4f03342db8cb5a3c2ea7aee75a537776882cfe15 100644 --- a/oneflow/core/job_rewriter/xrt_compilation.h +++ b/oneflow/core/job_rewriter/xrt_compilation.h @@ -27,7 +27,7 @@ limitations under the License. namespace oneflow { -inline void RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) { +inline Maybe<void> RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) { #ifdef OF_WITH_XRT const auto& job_desc = GlobalJobDesc(); if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { @@ -43,6 +43,7 @@ inline void RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) { ->Write(*job); } #endif // OF_WITH_XRT + return Maybe<void>::Ok(); } inline bool XrtCompilationEnabled(const JobDesc& job_desc) {