Skip to content
Snippets Groups Projects
Commit ce8bca8b authored by lixinqi's avatar lixinqi
Browse files

fix the "input_op -> output_op" bug

parent f7623e9f
No related branches found
No related tags found
No related merge requests found
......@@ -126,8 +126,9 @@ OperatorConf PrependTick(const std::list<const OpNode*>& op_nodes, JobBuilder* j
OperatorConf tick_op_conf = MakeTickOpConf();
std::vector<OperatorConf> op_confs;
for (const OpNode* op_node : op_nodes) {
op_confs.push_back(op_node->op().op_conf());
op_confs.back().add_ctrl_in_op_name(tick_op_conf.name());
OperatorConf op_conf(op_node->op().op_conf());
op_conf.add_ctrl_in_op_name(tick_op_conf.name());
op_confs.push_back(op_conf);
}
job_builder->MutOps({op_confs});
job_builder->AddOps(op_nodes.front()->parallel_desc().parallel_conf(), {tick_op_conf});
......@@ -176,21 +177,27 @@ std::vector<std::string> GetOpNames(const HashSet<const OpNode*>& op_nodes) {
return ret;
};
void ForEachInputOutputCriticalSectionOpNodes(
void InitOpTypeCase2OpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>>* op_type_case2op_nodes) {
HashSet<std::string> arg_op_names;
for (const auto& name : GlobalJobDesc().arg_op_name()) { arg_op_names.insert(name); }
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
op_graph.ForEachNode([&](OpNode* op_node) {
const auto& op_name = op_node->op().op_name();
const auto& op_conf = op_node->op().op_conf();
if (IsInterfaceOpConf(op_conf)
&& (op_conf.has_variable_conf() || arg_op_names.find(op_name) != arg_op_names.end())) {
CHECK(op_type_case2op_nodes[op_conf.op_type_case()].emplace(op_node).second);
CHECK((*op_type_case2op_nodes)[op_conf.op_type_case()].emplace(op_node).second);
}
});
}
void ForEachInputCriticalSectionOpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes);
for (OperatorConf::OpTypeCase op_type_case :
{OperatorConf::kVariableConf, OperatorConf::kInputConf}) {
if (op_type_case2op_nodes[op_type_case].empty()) { continue; }
......@@ -200,10 +207,21 @@ void ForEachInputOutputCriticalSectionOpNodes(
}
Handler(consumer_op_nodes, GetOpNames(op_type_case2op_nodes[op_type_case]));
}
}
void ForEachOutputCriticalSectionOpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
InitOpTypeCase2OpNodes(op_graph, &op_type_case2op_nodes);
if (op_type_case2op_nodes[OperatorConf::kOutputConf].empty() == false) {
Handler(op_type_case2op_nodes[OperatorConf::kOutputConf],
GetOpNames(op_type_case2op_nodes[OperatorConf::kOutputConf]));
}
for (const auto* op_node : op_type_case2op_nodes[OperatorConf::kSwitchOutputConf]) {
Handler({op_node}, GetOpNames({op_node}));
}
}
std::vector<OperatorConf> AddTickForTimeShape(const Shape& src_time_shape,
......@@ -334,8 +352,16 @@ void AddGlobalTotalJobCriticalSection(const Job& job) {
->mutable_total_job_critical_section();
}
void AddGlobalInputOutputCriticalSections(const OpGraph& op_graph, Job* job) {
ForEachInputOutputCriticalSectionOpNodes(
void AddGlobalInputCriticalSections(const OpGraph& op_graph, Job* job) {
ForEachInputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) {
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job);
});
}
void AddGlobalOutputCriticalSections(const OpGraph& op_graph, Job* job) {
ForEachOutputCriticalSectionOpNodes(
op_graph, [&](const HashSet<const OpNode*>& op_nodes,
const std::vector<std::string>& lbi_producer_op_names) {
AddGlobalInputOutputCriticalSection(op_nodes, lbi_producer_op_names, job);
......
......@@ -11,7 +11,8 @@ void AutoSourceTick(const OpGraph& op_graph, Job* job);
void AddTickForTimeShape(const OpGraph& op_graph, Job* job);
void AutoSinkTick(const OpGraph& op_graph, Job* job);
void AddGlobalTotalJobCriticalSection(const Job& job);
void AddGlobalInputOutputCriticalSections(const OpGraph& op_graph, Job* job);
void AddGlobalInputCriticalSections(const OpGraph& op_graph, Job* job);
void AddGlobalOutputCriticalSections(const OpGraph& op_graph, Job* job);
class MutOpConTickInputHelper {
public:
......
......@@ -484,7 +484,8 @@ void JobCompleter::Complete(Job* job) const {
WithOpGraphAndMutJob(job, &AddTickForTimeShape);
WithOpGraphAndMutJob(job, &AutoSinkTick);
AddGlobalTotalJobCriticalSection(*job);
WithOpGraphAndMutJob(job, &AddGlobalInputOutputCriticalSections);
WithOpGraphAndMutJob(job, &AddGlobalInputCriticalSections);
WithOpGraphAndMutJob(job, &AddGlobalOutputCriticalSections);
WithOpGraphAndMutJob(job, &OpGraph::DumpLogicalBlobDescAndSbpSignature);
WithOpGraphAndMutJob(job, &SetOpTimeShape7ModelLbis);
CheckOpGraph(OpGraph(*job));
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment