Skip to content
Snippets Groups Projects
Unverified Commit 50e1c346 authored by Li Xinqi's avatar Li Xinqi Committed by GitHub
Browse files

Job pass maybe system (#5503)


* refactor job_pass by maybe_system

* remove useless files

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 810d8db5
No related branches found
No related tags found
No related merge requests found
......@@ -532,8 +532,7 @@ void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_node
CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());
FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {
std::string regst_desc_name;
RegstDesc* ctrl_regst_desc =
src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
src_task_nodes.at(i)->BuildCtrlRegstDesc(dst_task_nodes.at(i), &regst_desc_name);
TaskEdge* edge = NewEdge();
Connect<TaskNode>(src_task_nodes.at(i), edge, dst_task_nodes.at(i));
src_task_nodes.at(i)->BindEdgeWithProducedRegst(edge, regst_desc_name);
......
......@@ -47,7 +47,7 @@ void CreateOpAttributeRef(Plan* plan, int64_t job_id, TaskProto* task_proto) {
void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
// Step1: ensure job is completed.
if (need_job_complete) { JobCompleter().Complete(job); }
if (need_job_complete) { CHECK_JUST(JobCompleter().Complete(job)); }
// Step2: new Global<OpGraph> and set log configs.
Global<OpGraph>::New(*job);
......
......@@ -171,21 +171,22 @@ Maybe<void> ConnectSrcSubsetTickAndOtherTick(const OperatorConf& src_subset_tick
auto mut_helper = NewMutOpConTickInputHelper(op.op_conf());
if (!mut_helper) { return Maybe<void>::Ok(); }
if (mut_helper->IsTickInputBound() == true) { return Maybe<void>::Ok(); }
job_builder->MutOpsOnlyOnce({mut_helper->NewTickInputBoundOpConf(src_lbn)});
JUST(job_builder->MutOpOnlyOnce(mut_helper->NewTickInputBoundOpConf(src_lbn)));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
const OpNode* GetSrcSubsetTickOpNode(const OpGraph& op_graph) {
Maybe<const OpNode*> GetSrcSubsetTickOpNode(const OpGraph& op_graph) {
const OpNode* src_subset_tick = nullptr;
op_graph.ForEachNode([&](OpNode* op_node) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
if (op_node->op().op_conf().has_src_subset_tick_conf()) {
CHECK_ISNULL(src_subset_tick);
CHECK_ISNULL_OR_RETURN(src_subset_tick);
src_subset_tick = op_node;
}
});
CHECK_NOTNULL(src_subset_tick);
return Maybe<void>::Ok();
}));
CHECK_NOTNULL_OR_RETURN(src_subset_tick);
return src_subset_tick;
}
......@@ -277,77 +278,81 @@ std::vector<std::string> GetOpNames(const HashSet<const OpNode*>& op_nodes) {
return ret;
};
void InitOpTypeCase2OpNodes(
Maybe<void> InitOpTypeCase2OpNodes(
const OpGraph& op_graph,
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>>* op_type_case2op_nodes) {
op_graph.ForEachNode([&](OpNode* op_node) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
const auto& op_conf = op_node->op().op_conf();
if (IsInterfaceOpConf(op_conf)) {
CHECK((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second);
CHECK_OR_RETURN((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second);
}
});
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
void ForEachInputCriticalSectionOpNodes(
Maybe<void> ForEachInputCriticalSectionOpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
const std::function<Maybe<void>(const HashSet<const OpNode*>&,
const std::vector<std::string>&)>& Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes);
JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));
OperatorConf::OpTypeCase op_type_case = OperatorConf::kInputConf;
if (op_type_case2op_nodes[op_type_case].empty()) { return; }
if (op_type_case2op_nodes[op_type_case].empty()) { return Maybe<void>::Ok(); }
HashSet<const OpNode*> op_nodes = op_type_case2op_nodes[op_type_case];
for (const OpNode* op_node : op_type_case2op_nodes[op_type_case]) {
op_node->ForEachNodeOnOutEdge([&](OpNode* out_node) { op_nodes.insert(out_node); });
}
Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case]));
JUST(Handler(op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case])));
return Maybe<void>::Ok();
}
void ForEachOutputCriticalSectionOpNodes(
Maybe<void> ForEachOutputCriticalSectionOpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
const std::function<Maybe<void>(const HashSet<const OpNode*>&,
const std::vector<std::string>&)>& Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes);
JUST(InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes));
if (op_type_case2op_nodes[OperatorConf::kReturnConf].empty() == false) {
Handler(op_type_case2op_nodes[OperatorConf::kReturnConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf]));
JUST(Handler(op_type_case2op_nodes[OperatorConf::kReturnConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kReturnConf])));
}
if (op_type_case2op_nodes[OperatorConf::kOutputConf].empty() == false) {
Handler(op_type_case2op_nodes[OperatorConf::kOutputConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf]));
JUST(Handler(op_type_case2op_nodes[OperatorConf::kOutputConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf])));
}
return Maybe<void>::Ok();
}
std::vector<OperatorConf> AddTickForTimeShape(const Shape& src_time_shape,
const HashSet<const OpNode*>& op_nodes,
JobBuilder* job_builder) {
Maybe<std::vector<OperatorConf>> AddTickForTimeShape(const Shape& src_time_shape,
const HashSet<const OpNode*>& op_nodes,
JobBuilder* job_builder) {
HashMap<std::pair<ParallelDesc, std::pair<Shape, Shape>>, std::list<const OpNode*>>
pd7ts2op_nodes;
for (const OpNode* op_node : op_nodes) {
auto ts = std::make_pair(*CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape()),
*CHECK_JUST(op_node->op().GetOpTimeShape()));
auto ts = std::make_pair(*JUST(op_node->op().GetInputOutputFastestTimeShape()),
*JUST(op_node->op().GetOpTimeShape()));
pd7ts2op_nodes[{op_node->parallel_desc(), ts}].push_back(op_node);
}
std::vector<OperatorConf> op_confs;
for (const auto& pair : pd7ts2op_nodes) {
const std::pair<Shape, Shape>& ts = pair.first.second;
if (ts.second.elem_cnt() == src_time_shape.elem_cnt()) {
CHECK_GE(ts.first.elem_cnt(), ts.second.elem_cnt());
CHECK_GE_OR_RETURN(ts.first.elem_cnt(), ts.second.elem_cnt());
op_confs.push_back(
AppendTick("Append", pair.second, std::make_shared<const Shape>(ts.second), job_builder));
} else if (ts.second.elem_cnt() > src_time_shape.elem_cnt()) {
op_confs.push_back(AppendAccTick(src_time_shape, pair.second, job_builder));
} else {
UNIMPLEMENTED();
UNIMPLEMENTED_THEN_RETURN();
}
}
return op_confs;
}
void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names,
JobBuilder* job_builder) {
Maybe<void> AddGlobalInputOutputCriticalSection(
const HashSet<const OpNode*>& op_nodes, const std::vector<std::string>& lbi_producer_op_names,
JobBuilder* job_builder) {
auto* critical_section =
Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
{
......@@ -358,22 +363,20 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
auto time_shape = std::make_unique<Shape>(DimVector{1, 1});
HashMap<ParallelDesc, HashSet<const OpNode*>> parallel_desc2op_nodes;
for (const OpNode* op_node : op_nodes) {
CHECK(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second);
CHECK_OR_RETURN(parallel_desc2op_nodes[op_node->parallel_desc()].insert(op_node).second);
}
std::vector<OperatorConf> source_ticks;
std::vector<OperatorConf> sink_ticks;
for (const auto& pair : parallel_desc2op_nodes) {
source_ticks.push_back(PrependTick(pair.second, job_builder));
for (const auto& sink_tick : AddTickForTimeShape(*time_shape, pair.second, job_builder)) {
sink_ticks.push_back(sink_tick);
}
const auto& ops = JUST(AddTickForTimeShape(*time_shape, pair.second, job_builder));
for (const auto& sink_tick : *ops) { sink_ticks.push_back(sink_tick); }
}
OperatorConf src_subset_tick_op;
{
CHECK_EQ(source_ticks.empty(), false);
CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
CHECK_JUST(
CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder));
CHECK_EQ_OR_RETURN(source_ticks.empty(), false);
JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick_op, job_builder));
for (auto& op_conf : source_ticks) {
op_conf.mutable_tick_conf()->add_tick(src_subset_tick_op.name() + "/"
+ src_subset_tick_op.src_subset_tick_conf().out());
......@@ -384,70 +387,86 @@ void AddGlobalInputOutputCriticalSection(const HashSet<const OpNode*>& op_nodes,
for (const auto& op_conf : sink_ticks) {
LogicalBlobId lbi;
lbi.set_op_name(op_conf.name());
CHECK(op_conf.has_device_tick_conf());
CHECK_OR_RETURN(op_conf.has_device_tick_conf());
lbi.set_blob_name(op_conf.device_tick_conf().out());
CHECK(tick_lbis.insert(lbi).second);
CHECK_OR_RETURN(tick_lbis.insert(lbi).second);
}
CHECK_JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis,
job_builder));
JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick_op, tick_lbis,
job_builder));
return Maybe<void>::Ok();
}
} // namespace
void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder) {
PrependTickByParallelDesc(op_graph, job_builder);
OperatorConf src_subset_tick_op;
CHECK_JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
CHECK_JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder));
JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick_op, job_builder));
JUST(ConnectSrcSubsetTickAndOtherTick(src_subset_tick_op, job_builder));
return Maybe<void>::Ok();
}
void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) {
const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape());
Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder) {
const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));
const auto& src_time_shape = *JUST(op_node->op().GetOpTimeShape());
HashSet<const OpNode*> sink_op_nodes;
op_graph.ForEachNode([&](OpNode* op_node) {
CHECK(!op_node->op().op_conf().has_sink_tick_conf());
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
CHECK_OR_RETURN(!op_node->op().op_conf().has_sink_tick_conf());
size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt == 0) { sink_op_nodes.insert(op_node); }
});
AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder);
return Maybe<void>::Ok();
}));
JUST(AddTickForTimeShape(src_time_shape, sink_op_nodes, job_builder));
return Maybe<void>::Ok();
}
void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder) {
auto* critical_section =
Global<CriticalSectionDesc>::Get()->AddCriticalSection(GlobalJobDesc().job_id());
critical_section->mutable_total_job_critical_section();
op_graph.ForEachNode([&](OpNode* node) { CHECK(!node->op().op_conf().has_sink_tick_conf()); });
const auto& src_time_shape = *CHECK_JUST(GetSrcSubsetTickOpNode(op_graph)->op().GetOpTimeShape());
JUST(op_graph.MaybeForEachNode([&](OpNode* node) -> Maybe<void> {
CHECK_OR_RETURN(!node->op().op_conf().has_sink_tick_conf());
return Maybe<void>::Ok();
}));
const auto* op_node = JUST(GetSrcSubsetTickOpNode(op_graph));
const auto& src_time_shape = JUST(op_node->op().GetOpTimeShape());
HashSet<LogicalBlobId> tick_lbis;
op_graph.ForEachNode([&](OpNode* op_node) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt > 0) { return; }
CHECK(op_node->op().op_conf().has_device_tick_conf());
CHECK(CHECK_JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape.elem_cnt());
CHECK(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second);
});
OperatorConf src_subset_tick = CHECK_JUST(FindSrcSubsetTickOpConf(job_builder->job()));
CHECK_JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder));
CHECK_JUST(
CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder));
if (out_cnt > 0) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(op_node->op().op_conf().has_device_tick_conf());
CHECK_OR_RETURN(JUST(op_node->op().GetOpTimeShape())->elem_cnt() == src_time_shape->elem_cnt());
CHECK_OR_RETURN(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second);
return Maybe<void>::Ok();
}));
OperatorConf src_subset_tick = JUST(FindSrcSubsetTickOpConf(job_builder->job()));
JUST(CreateSourceTicksAndSrcSubsetTick(critical_section, &src_subset_tick, job_builder));
JUST(CreateDstSubsetTickAndSinkTicks(critical_section, src_subset_tick, tick_lbis, job_builder));
return Maybe<void>::Ok();
}
void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
ForEachInputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) {
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder);
});
Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
JUST(ForEachInputCriticalSectionOpNodes(
op_graph,
[&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> {
JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
ForEachOutputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) {
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder);
});
Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) {
JUST(ForEachOutputCriticalSectionOpNodes(
op_graph,
[&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) -> Maybe<void> {
JUST(AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job_builder));
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
} // namespace oneflow
......@@ -22,11 +22,11 @@ limitations under the License.
namespace oneflow {
void AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder);
void AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder);
void AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder);
void AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
void AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AutoPrependTick(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AddTickForTimeShape(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AutoSourceAndSinkTick(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> AddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder);
class MutOpConTickInputHelper {
public:
......
......@@ -19,7 +19,7 @@ limitations under the License.
namespace oneflow {
void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder) {
HashMap<LogicalBlobId, HashMap<std::pair<ParallelDesc, cfg::ParallelDistribution>,
std::vector<std::pair<const OpNode*, std::string>>>>
lbi2consumer_grouped_by_parallel;
......@@ -76,13 +76,14 @@ void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder)
OperatorConf& consumer_op_conf = op_node2op_conf[consumer];
const auto& old_val = ReplaceInputLbnInOpCustomizedConf(&consumer_op_conf, ibn,
GenLogicalBlobName(grouped_lbi));
CHECK_EQ(GenLogicalBlobName(lbi), old_val);
CHECK_EQ_OR_RETURN(GenLogicalBlobName(lbi), old_val);
}
}
}
for (const auto& op_node7op_conf : op_node2op_conf) {
job_builder->MutOpsOnlyOnce({op_node7op_conf.second});
JUST(job_builder->MutOpOnlyOnce(op_node7op_conf.second));
}
return Maybe<void>::Ok();
}
} // namespace oneflow
......@@ -23,7 +23,7 @@ namespace oneflow {
class OpGraph;
class Job;
void GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder);
Maybe<void> GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_builder);
} // namespace oneflow
......
......@@ -27,30 +27,35 @@ namespace oneflow {
namespace {
void CheckOpGraph(const OpGraph& op_graph) {
op_graph.ForEachNode([&](OpNode* op_node) {
Maybe<void> CheckOpGraph(const OpGraph& op_graph) {
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
size_t in_cnt = 0;
op_graph.ForEachDataAndCtrlInNode(op_node, [&](OpNode*) { ++in_cnt; });
if (in_cnt == 0) { CHECK(op_node->op().op_conf().has_source_tick_conf()); }
if (in_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_source_tick_conf()); }
size_t out_cnt = 0;
op_graph.ForEachDataAndCtrlOutNode(op_node, [&](OpNode*) { ++out_cnt; });
if (out_cnt == 0) { CHECK(op_node->op().op_conf().has_sink_tick_conf()); }
});
if (out_cnt == 0) { CHECK_OR_RETURN(op_node->op().op_conf().has_sink_tick_conf()); }
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
}
void WithOpGraphAndMutJob(Job* job, const std::function<void(const OpGraph&, Job*)>& Handler) {
Maybe<void> WithOpGraphAndMutJob(Job* job,
const std::function<Maybe<void>(const OpGraph&, Job*)>& Handler) {
OpGraph op_graph(*job);
Handler(op_graph, job);
JUST(Handler(op_graph, job));
return Maybe<void>::Ok();
}
void WithOpGraphAndMutJobBuilder(Job* job,
const std::function<void(const OpGraph&, JobBuilder*)>& Handler) {
Maybe<void> WithOpGraphAndMutJobBuilder(
Job* job, const std::function<Maybe<void>(const OpGraph&, JobBuilder*)>& Handler) {
OpGraph op_graph(*job);
JobBuilder job_builder(job);
Handler(op_graph, &job_builder);
JUST(Handler(op_graph, &job_builder));
return Maybe<void>::Ok();
}
void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) {
Maybe<void> SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder) {
auto IsMutableConsumedLbi = [](const Operator& op, const LogicalBlobId& lbi) -> bool {
for (const std::string& bn : op.input_bns()) {
if (op.BnInOp2Lbi(bn) == lbi && op.InputBlobModifier4Ibn(bn).is_mutable()) { return true; }
......@@ -59,9 +64,9 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
};
auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();
HashMap<const OperatorConf*, HashSet<std::string>> op_conf2ctrl_in_op_names;
op_graph.ForEachNode([&](OpNode* op_node) {
if (op_node->op().op_conf().has_variable_conf() == false) { return; }
if (op_node->out_edges().size() <= 1) { return; }
JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe<void> {
if (op_node->op().op_conf().has_variable_conf() == false) { return Maybe<void>::Ok(); }
if (op_node->out_edges().size() <= 1) { return Maybe<void>::Ok(); }
const Operator& variable_op = op_node->op();
const LogicalBlobId& variable_lbi = variable_op.BnInOp2Lbi(variable_op.SoleObn());
const OperatorConf* mutable_consumer = nullptr;
......@@ -69,17 +74,18 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
for (OpEdge* edge : op_node->out_edges()) {
const auto& op_conf = edge->dst_node()->op().op_conf();
if (IsMutableConsumedLbi(edge->dst_node()->op(), variable_lbi)) {
CHECK(mutable_consumer == nullptr);
CHECK_OR_RETURN(mutable_consumer == nullptr);
mutable_consumer = &op_conf;
} else {
naive_consumers.push_back(&op_conf);
}
}
if (mutable_consumer == nullptr) { return; }
if (mutable_consumer == nullptr) { return Maybe<void>::Ok(); }
for (const auto* fw_bw_op : naive_consumers) {
op_conf2ctrl_in_op_names[mutable_consumer].insert(fw_bw_op->name());
}
});
return Maybe<void>::Ok();
}));
for (const auto& pair : op_conf2ctrl_in_op_names) {
OperatorConf mut_mutable_consumer_op_conf(*pair.first);
for (const auto& fw_bw_op_name : pair.second) {
......@@ -87,45 +93,46 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
mut_mutable_consumer_op_conf.add_ctrl_in_op_name(fw_bw_op_name);
}
}
job_builder->MutOpsOnlyOnce({mut_mutable_consumer_op_conf});
JUST(job_builder->MutOpOnlyOnce(mut_mutable_consumer_op_conf));
}
return Maybe<void>::Ok();
}
} // namespace
void JobCompleter::Complete(Job* job) const {
Maybe<void> JobCompleter::Complete(Job* job) const {
JobPassCtx job_pass_ctx(GlobalJobDesc());
CHECK_JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
// NOTE(chengcheng): disable this pass for reduce boxing memory life cycle to memory cost.
if (!Global<ResourceDesc, ForSession>::Get()->resource().disable_group_boxing_by_dst_parallel()) {
WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel);
JUST(WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel));
}
WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp);
JUST(WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp));
// complete tick ops
WithOpGraphAndMutJobBuilder(job, &AutoPrependTick);
WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape);
WithOpGraphAndMutJobBuilder(job, &AutoSourceAndSinkTick);
WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections);
WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections);
CHECK_JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
JUST(WithOpGraphAndMutJobBuilder(job, &AutoPrependTick));
JUST(WithOpGraphAndMutJobBuilder(job, &AddTickForTimeShape));
JUST(WithOpGraphAndMutJobBuilder(job, &AutoSourceAndSinkTick));
JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections));
JUST(WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections));
JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
if (XrtCompilationEnabled(GlobalJobDesc())) {
#ifdef OF_WITH_XRT
WithOpGraphAndMutJob(job, &RebuildXrtCompiledJob);
JUST(WithOpGraphAndMutJob(job, &RebuildXrtCompiledJob));
#else
LOG(WARNING) << "It will not use XLA or TensorRT since WITH_XLA or "
"WITH_TENSORRT was not enabled when compiling the project.";
#endif // OF_WITH_XRT
}
#ifdef WITH_CUDA
if (Global<ResourceDesc, ForSession>::Get()->nccl_use_compute_stream()) {
// NOTE(chengcheng): this pass need as last pass for insert correct op with nccl boxing.
JobPass4Name("InsertNcclLogicalOpPass")(job, &job_pass_ctx);
JUST(JobPass4Name("InsertNcclLogicalOpPass")(job, &job_pass_ctx));
// NOTE(chengcheng): Becasue insert new logical nccl op, MUST dump time shape, sbp again.
JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx);
JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
}
#endif // WITH_CUDA
CheckOpGraph(OpGraph(*job));
JUST(CheckOpGraph(OpGraph(*job)));
return Maybe<void>::Ok();
}
} // namespace oneflow
......@@ -28,7 +28,7 @@ class JobCompleter final {
JobCompleter() = default;
~JobCompleter() = default;
void Complete(Job* job) const;
Maybe<void> Complete(Job* job) const;
};
} // namespace oneflow
......
......@@ -27,7 +27,7 @@ limitations under the License.
namespace oneflow {
inline void RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) {
inline Maybe<void> RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) {
#ifdef OF_WITH_XRT
const auto& job_desc = GlobalJobDesc();
if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
......@@ -43,6 +43,7 @@ inline void RebuildXrtCompiledJob(const OpGraph& op_graph, Job* job) {
->Write(*job);
}
#endif // OF_WITH_XRT
return Maybe<void>::Ok();
}
inline bool XrtCompilationEnabled(const JobDesc& job_desc) {
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment