diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index 186f2efe9aeae1fa2cb8faf03cce32ebd908d436..089e52940d651bd951d564d144976c90c734339e 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -77,7 +77,7 @@ bool IsBreakpointOpNode(const OpNode* node) { if (op_conf.has_user_conf()) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); if (user_type_name == "repeat" || user_type_name == "acc" || user_type_name == "pack" - || user_type_name == "unpack") { + || user_type_name == "unpack" || user_type_name == "identity_buffer") { return true; } } @@ -102,7 +102,8 @@ bool SharedPtrShapeEqual(const std::shared_ptr<const Shape>& lhs, return (*lhs) == (*rhs); } -void FindMaxConnectedSubgraphForGpuExecOrder(HashSet<const OpNode*>* ret, const OpGraph& op_graph, +void FindAllConnectedSubgraphForGpuExecOrder(std::vector<HashSet<const OpNode*>>* ret, + const OpGraph& op_graph, const std::vector<const OpNode*>& order) { HashSet<const OpNode*> visited; @@ -128,7 +129,7 @@ void FindMaxConnectedSubgraphForGpuExecOrder(HashSet<const OpNode*>* ret, const CHECK(this_subgraph.insert(cur_node).second); cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) { - if (visited.find(next_node) == visited.end() + if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) && SharedPtrShapeEqual(GetOpNodeTimeShape(next_node), seed_time_shape)) { CHECK(visited.insert(next_node).second); @@ -137,8 +138,16 @@ void FindMaxConnectedSubgraphForGpuExecOrder(HashSet<const OpNode*>* ret, const }); } - if (this_subgraph.size() > ret->size()) { ret->swap(this_subgraph); } + if (this_subgraph.size() > 1) { + ret->push_back(HashSet<const OpNode*>()); + ret->back().swap(this_subgraph); + } } + + std::sort(ret->begin(), ret->end(), + [](const HashSet<const OpNode*>& lhs, const HashSet<const OpNode*>& rhs) { + return lhs.size() > rhs.size(); + }); } bool ParallelDistributionAllSameSplitParallel(const ParallelDistribution& parallel_distribution) { @@ -390,7 +399,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode( HashMap<std::string, OperatorConf>* subgraph_op_name2conf, HashSet<std::string>* mut_op_names, std::vector<OperatorConf>* nccl_op_confs, std::vector<ParallelConf>* nccl_op_parallel_confs, const std::vector<const OpNode*>& subgraph_order, - const HashMap<const OpNode*, int64_t>& node2order) { + const HashMap<const OpNode*, int64_t>& node2subgraph_order) { for (const OpNode* src_node : subgraph_order) { const std::string& src_op_name = src_node->op().op_name(); for (const OpEdge* op_edge : src_node->out_edges()) { @@ -419,7 +428,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode( } // NOTE(chengcheng): src_node MUST not the last node in subgraph, find the next op - int64_t src_order = node2order.at(src_node); + int64_t src_order = node2subgraph_order.at(src_node); CHECK(src_order + 1 < subgraph_order.size()); const std::string& next_op_name = subgraph_order.at(src_order + 1)->op().op_name(); if (dst_op_name != next_op_name) { @@ -432,7 +441,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode( LOG(INFO) << " insert nccl op: " << nccl_op.name() << " from: [" << src_op_name << "](order=" << src_order << ", sbp_parallel_dis=" << ParallelDistributionToString(src_node->ParallelDistribution4Lbi(lbi)) - << ")->[" << dst_op_name << "](order=" << node2order.at(dst_node) + << ")->[" << dst_op_name << "](order=" << node2subgraph_order.at(dst_node) << ", sbp_parallel_dis=" << ParallelDistributionToString(dst_node->ParallelDistribution4Lbi(lbi)) << ") and before: [" << next_op_name << "](order=" << src_order + 1 << ")\n"; @@ -448,7 +457,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode( HashMap<std::string, OperatorConf>* subgraph_op_name2conf, HashSet<std::string>* mut_op_names, std::vector<OperatorConf>* nccl_op_confs, std::vector<ParallelConf>* nccl_op_parallel_confs, const std::vector<const OpNode*>& subgraph_order, - const HashMap<const OpNode*, int64_t>& node2order) { + const HashMap<const OpNode*, int64_t>& node2subgraph_order) { for (const OpNode* dst_node : subgraph_order) { const std::string& dst_op_name = dst_node->op().op_name(); for (const OpEdge* op_edge : dst_node->in_edges()) { @@ -481,7 +490,7 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode( // NOTE(chengcheng): dst_node MUST not the first node in subgraph, find the Immediately // previous op of dst_node. - int64_t dst_order = node2order.at(dst_node); + int64_t dst_order = node2subgraph_order.at(dst_node); CHECK_GT(dst_order, 0); const std::string& pre_op_name = subgraph_order.at(dst_order - 1)->op().op_name(); if (src_op_name != pre_op_name) { @@ -491,8 +500,9 @@ void InsertNcclLogicalOpsAsCloseAsPossibleToDstNode( if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { LOG(INFO) << " insert nccl op: " << nccl_op.name() << " from: [" << src_op_name << "](" - << node2order.at(src_node) << ")->[" << dst_op_name << "](" << dst_order - << ") and after: [" << pre_op_name << "](" << dst_order - 1 << ")\n"; + << node2subgraph_order.at(src_node) << ")->[" << dst_op_name << "](" + << dst_order << ") and after: [" << pre_op_name << "](" << dst_order - 1 + << ")\n"; } nccl_op_confs->push_back(nccl_op); // NOTE(chengcheng, guoran): set nccl op as src_node parallel_conf (hierarchy) may check @@ -522,18 +532,12 @@ struct InsertedNcclInfo { }; void InsertNcclLogicalOpsAfterAcc(const OpGraph& op_graph, - const std::vector<const OpNode*>& ordered_op_nodes, + const HashMap<const OpNode*, int64_t>& op_node2global_order, const std::vector<const OpNode*>& ordered_acc_op_nodes, const std::string& bw_sink_tick_op_name, HashMap<std::string, OperatorConf>* mut_consumer_name2op, std::vector<OperatorConf>* nccl_op_confs, std::vector<ParallelConf>* nccl_op_parallel_confs) { - HashMap<const OpNode*, int64_t> op_node2global_order; - op_node2global_order.reserve(ordered_op_nodes.size()); - for (int64_t i = 0; i < ordered_op_nodes.size(); ++i) { - CHECK(op_node2global_order.emplace(ordered_op_nodes.at(i), i).second); - } - HashSet<const OpEdge*> visited; std::shared_ptr<const Shape> seed_time_shape = GetOpNodeTimeShape(ordered_acc_op_nodes.front()); std::vector<InsertedNcclInfo> nccl_op_infos; @@ -615,156 +619,311 @@ void InsertNcclLogicalOpsAfterAcc(const OpGraph& op_graph, } } -Maybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { - auto OpGraphForEachInDataAndCtrlNode = [&](OpNode* node, - const std::function<void(OpNode*)>& Handler) { - op_graph.ForEachDataAndCtrlInNode(node, Handler); - }; - auto OpGraphForEachOutDataAndCtrlNode = [&](OpNode* node, - const std::function<void(OpNode*)>& Handler) { - op_graph.ForEachDataAndCtrlOutNode(node, Handler); - }; +std::string GenParallelConfKey(const ParallelConf& conf) { + std::string ret = conf.device_tag(); + for (const auto& name : conf.device_name()) { ret += ("-" + name); } + return ret; +} +struct InsertNcclSubGraph { std::vector<const OpNode*> ordered_op_nodes; - op_graph.TopoForEachNode(op_graph.DataOrCtrlSourceNodes(), OpGraphForEachInDataAndCtrlNode, - OpGraphForEachOutDataAndCtrlNode, - [&](const OpNode* node) { ordered_op_nodes.push_back(node); }); - - HashSet<const OpNode*> subgraph; - FindMaxConnectedSubgraphForGpuExecOrder(&subgraph, op_graph, ordered_op_nodes); - if (subgraph.size() <= 1) { return Maybe<void>::Ok(); } + int64_t begin_op_global_order; + int64_t end_op_global_order; + const OpNode* begin_op; + const OpNode* end_op; +}; - std::vector<const OpNode*> subgraph_order; - HashMap<const OpNode*, int64_t> node2order; +struct PlacementNcclSubGraghsInfo { + std::vector<std::shared_ptr<InsertNcclSubGraph>> ordered_subgraph; std::vector<const OpNode*> ordered_acc_op_nodes; - for (const OpNode* this_node : ordered_op_nodes) { - if (subgraph.find(this_node) != subgraph.end()) { - subgraph_order.push_back(this_node); - node2order.emplace(this_node, subgraph_order.size() - 1); - } else if (IsAccOpNode(this_node)) { - ordered_acc_op_nodes.push_back(this_node); - } + const ParallelDesc* seed_parallel_desc; + std::shared_ptr<const Shape> seed_time_shape; +}; + +void InitInsertNcclSubGraphInfoFromSet( + std::shared_ptr<InsertNcclSubGraph> nccl_subgraph_info, const HashSet<const OpNode*>& subgraph, + const HashMap<const OpNode*, int64_t>& op_node2global_order, + const std::function<bool(const OpNode*, const OpNode*)>& CmpOpNodeOrder) { + auto* subgraph_ordered_nodes = &nccl_subgraph_info->ordered_op_nodes; + subgraph_ordered_nodes->assign(subgraph.begin(), subgraph.end()); + std::sort(subgraph_ordered_nodes->begin(), subgraph_ordered_nodes->end(), CmpOpNodeOrder); + nccl_subgraph_info->begin_op = subgraph_ordered_nodes->front(); + nccl_subgraph_info->end_op = subgraph_ordered_nodes->back(); + nccl_subgraph_info->begin_op_global_order = op_node2global_order.at(nccl_subgraph_info->begin_op); + nccl_subgraph_info->end_op_global_order = op_node2global_order.at(nccl_subgraph_info->end_op); + CHECK(nccl_subgraph_info->begin_op != nccl_subgraph_info->end_op); + CHECK_LT(nccl_subgraph_info->begin_op_global_order, nccl_subgraph_info->end_op_global_order); +} + +void InsertNcclLogicalOpsInSubGraph( + const OpGraph& op_graph, JobBuilder* job_builder, + const std::vector<const OpNode*>& subgraph_order, + const std::function<bool(const std::string&, const std::string&)>& IsReachable) { + HashMap<const OpNode*, int64_t> node2subgraph_order; + node2subgraph_order.reserve(subgraph_order.size()); + for (int64_t i = 0; i < subgraph_order.size(); ++i) { + CHECK(node2subgraph_order.emplace(subgraph_order.at(i), i).second); } - CHECK_EQ(subgraph.size(), subgraph_order.size()); + if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { + LOG(INFO) << " Try insert nccl logical ops into job: " + << job_builder->job().job_conf().job_name() << ". Begin...\n"; + } HashSet<std::string> mut_op_names; const OpNode* first_node = subgraph_order.at(0); HashMap<std::string, OperatorConf> subgraph_op_name2conf; subgraph_op_name2conf.emplace(first_node->op().op_name(), first_node->op().op_conf()); - auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); + + // add ctrl for strict order. for (int64_t i = 1; i < subgraph_order.size(); ++i) { const OpNode* this_node = subgraph_order.at(i); const OpNode* pre_node = subgraph_order.at(i - 1); const std::string& this_op_name = this_node->op().op_name(); const std::string& pre_op_name = pre_node->op().op_name(); CHECK(subgraph_op_name2conf.emplace(this_op_name, this_node->op().op_conf()).second); - // build control edge if need. + // build ctrl edge if need. if (!IsReachable(pre_op_name, this_op_name)) { subgraph_op_name2conf.at(this_op_name).add_ctrl_in_op_name(pre_op_name); mut_op_names.insert(this_op_name); } } - if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { - LOG(INFO) << " Try insert nccl logical ops into job: " - << job_builder->job().job_conf().job_name() << ". Begin...\n"; - } - std::vector<OperatorConf> nccl_op_confs; std::vector<ParallelConf> nccl_op_parallel_confs; if (ReverseOrderInsertNcclLogicalOps()) { InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(&subgraph_op_name2conf, &mut_op_names, &nccl_op_confs, &nccl_op_parallel_confs, - subgraph_order, node2order); + subgraph_order, node2subgraph_order); } else { InsertNcclLogicalOpsAsCloseAsPossibleToSrcNode(&subgraph_op_name2conf, &mut_op_names, &nccl_op_confs, &nccl_op_parallel_confs, - subgraph_order, node2order); + subgraph_order, node2subgraph_order); } - if (!ordered_acc_op_nodes.empty()) { - const OpNode* bw_sink_op = subgraph_order.back(); - const OpNode* first_acc_op = ordered_acc_op_nodes.front(); - std::shared_ptr<const Shape> time_shape_before_acc = GetOpNodeTimeShape(bw_sink_op); - std::shared_ptr<const Shape> time_shape_after_acc = GetOpNodeTimeShape(first_acc_op); - LOG(WARNING) << " Find acc op in Job: " << job_builder->job().job_conf().job_name() - << ", we will try insert special identity and ctrl for " - << " UNSAFE handle ALL nccl ops between different time shape: " - << time_shape_before_acc->DebugStr() << "->acc->" - << time_shape_after_acc->DebugStr() << "\n\n"; - CHECK_GT(time_shape_before_acc->elem_cnt(), time_shape_after_acc->elem_cnt()); - CHECK_EQ(time_shape_before_acc->elem_cnt() % time_shape_after_acc->elem_cnt(), 0); - - for (const OpNode* acc : ordered_acc_op_nodes) { - CHECK(SharedPtrShapeEqual(time_shape_before_acc, GetOpNodeInputTimeShape(acc))); - CHECK(SharedPtrShapeEqual(time_shape_after_acc, GetOpNodeTimeShape(acc))); - } + if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { + LOG(INFO) << " Try insert nccl logical ops into job: " + << job_builder->job().job_conf().job_name() << ". ...End\n\n"; + } - // NOTE(chengcheng): insert acc_tick after bw_sink_op, and this tick op conf will control - // after_acc_nccl_ops start. - const auto& obns = bw_sink_op->op().output_bns(); - CHECK(!obns.empty()); - const std::string bw_sink_op_out_lbn = - GenLogicalBlobName(bw_sink_op->op().BnInOp2Lbi(obns.Get(0))); - LOG(INFO) << " bw_sink_op : " << bw_sink_op->op().op_conf().DebugString(); - - user_op::UserOpConfWrapper cast_to_tick_op = - user_op::UserOpConfWrapperBuilder("System-CastToTick-" + NewUniqueId()) - .OpTypeName("cast_to_tick") - .Input("in", bw_sink_op_out_lbn) - .Output("out") - .Build(); - job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), cast_to_tick_op.op_conf()); + std::vector<OperatorConf> mut_op_confs; + mut_op_confs.reserve(mut_op_names.size()); + for (const std::string& mut_op_name : mut_op_names) { + mut_op_confs.push_back(subgraph_op_name2conf.at(mut_op_name)); + } + job_builder->MutOpsOnlyOnce(mut_op_confs); + + CHECK_EQ(nccl_op_confs.size(), nccl_op_parallel_confs.size()); + for (int64_t i = 0; i < nccl_op_confs.size(); ++i) { + CHECK_JUST(job_builder->AddOp(nccl_op_parallel_confs.at(i), nccl_op_confs.at(i))); + } +} + +void InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( + const OpGraph& op_graph, JobBuilder* job_builder, + const std::vector<const OpNode*>& ordered_acc_op_nodes, + const HashMap<const OpNode*, int64_t>& op_node2global_order, const OpNode* bw_sink_op) { + const OpNode* first_acc_op = ordered_acc_op_nodes.front(); + std::shared_ptr<const Shape> time_shape_before_acc = GetOpNodeTimeShape(bw_sink_op); + std::shared_ptr<const Shape> time_shape_after_acc = GetOpNodeTimeShape(first_acc_op); + LOG(INFO) << " Find acc ops (num=" << ordered_acc_op_nodes.size() + << ") in Job: " << job_builder->job().job_conf().job_name() + << ", we will try insert special identity and ctrl for " + << " UNSAFE handle ALL nccl ops between different time shape: " + << time_shape_before_acc->DebugStr() << "->acc->" << time_shape_after_acc->DebugStr() + << "\n\n"; + CHECK_GT(time_shape_before_acc->elem_cnt(), time_shape_after_acc->elem_cnt()); + CHECK_EQ(time_shape_before_acc->elem_cnt() % time_shape_after_acc->elem_cnt(), 0); + + for (const OpNode* acc : ordered_acc_op_nodes) { + CHECK(SharedPtrShapeEqual(time_shape_before_acc, GetOpNodeInputTimeShape(acc))); + CHECK(SharedPtrShapeEqual(time_shape_after_acc, GetOpNodeTimeShape(acc))); + } + + // NOTE(chengcheng): insert acc_tick after bw_sink_op, and this tick op conf will control + // after_acc_nccl_ops start. + const auto& obns = bw_sink_op->op().output_bns(); + CHECK(!obns.empty()); + const std::string bw_sink_op_out_lbn = + GenLogicalBlobName(bw_sink_op->op().BnInOp2Lbi(obns.Get(0))); + LOG(INFO) << " bw_sink_op : " << bw_sink_op->op().op_conf().DebugString(); + + user_op::UserOpConfWrapper cast_to_tick_op = + user_op::UserOpConfWrapperBuilder("System-CastToTick-" + NewUniqueId()) + .OpTypeName("cast_to_tick") + .Input("in", bw_sink_op_out_lbn) + .Output("out") + .Build(); + + OperatorConf bw_sink_acc_tick_conf; + bw_sink_acc_tick_conf.set_name(std::string("System-BwSinkTick-AccTick_") + NewUniqueId()); + auto* acc_conf = bw_sink_acc_tick_conf.mutable_acc_tick_conf(); + acc_conf->set_one(cast_to_tick_op.output("out", 0)); + acc_conf->set_acc("acc"); + acc_conf->set_max_acc_num(time_shape_before_acc->elem_cnt() / time_shape_after_acc->elem_cnt()); + + OperatorConf bw_sink_final_tick_conf; + bw_sink_final_tick_conf.set_name(std::string("System-BwSinkFinalTick-Tick_") + NewUniqueId()); + auto* tick_conf = bw_sink_final_tick_conf.mutable_tick_conf(); + tick_conf->add_tick(GenLogicalBlobName(bw_sink_acc_tick_conf.name(), "acc")); + tick_conf->set_out("out"); + + // insert nccl ops after acc + std::vector<OperatorConf> after_acc_nccl_op_confs; + std::vector<ParallelConf> after_acc_nccl_parallel_confs; + HashMap<std::string, OperatorConf> mut_consumer_name2op; + + InsertNcclLogicalOpsAfterAcc(op_graph, op_node2global_order, ordered_acc_op_nodes, + bw_sink_final_tick_conf.name(), &mut_consumer_name2op, + &after_acc_nccl_op_confs, &after_acc_nccl_parallel_confs); + + if (after_acc_nccl_op_confs.empty()) { + CHECK(after_acc_nccl_parallel_confs.empty()); + CHECK(mut_consumer_name2op.empty()); + } else { + // insert bw sink acc tick ops + CHECK_JUST( + job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), cast_to_tick_op.op_conf())); LOG(INFO) << " Insert cast_to_tick_op : " << cast_to_tick_op.op_conf().DebugString(); - OperatorConf bw_sink_acc_tick_conf; - bw_sink_acc_tick_conf.set_name(std::string("System-BwSinkTick-AccTick_") + NewUniqueId()); - auto* acc_conf = bw_sink_acc_tick_conf.mutable_acc_tick_conf(); - acc_conf->set_one(cast_to_tick_op.output("out", 0)); - acc_conf->set_acc("acc"); - acc_conf->set_max_acc_num(time_shape_before_acc->elem_cnt() / time_shape_after_acc->elem_cnt()); - job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_acc_tick_conf); + CHECK_JUST( + job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_acc_tick_conf)); LOG(INFO) << " Insert bw_sink_acc_tick_op : " << bw_sink_acc_tick_conf.DebugString(); - OperatorConf bw_sink_final_tick_conf; - bw_sink_final_tick_conf.set_name(std::string("System-BwSinkFinalTick-Tick_") + NewUniqueId()); - auto* tick_conf = bw_sink_final_tick_conf.mutable_tick_conf(); - tick_conf->add_tick(GenLogicalBlobName(bw_sink_acc_tick_conf.name(), "acc")); - tick_conf->set_out("out"); - job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_final_tick_conf); + CHECK_JUST( + job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_final_tick_conf)); LOG(INFO) << " Insert bw_sink_final_tick_op : " << bw_sink_final_tick_conf.DebugString(); // insert nccl ops after acc - std::vector<OperatorConf> after_acc_nccl_op_confs; - std::vector<ParallelConf> after_acc_nccl_parallel_confs; - HashMap<std::string, OperatorConf> mut_consumer_name2op; - - InsertNcclLogicalOpsAfterAcc(op_graph, ordered_op_nodes, ordered_acc_op_nodes, - bw_sink_final_tick_conf.name(), &mut_consumer_name2op, - &after_acc_nccl_op_confs, &after_acc_nccl_parallel_confs); - - for (const auto& pair : mut_consumer_name2op) { JUST(job_builder->MutOpOnlyOnce(pair.second)); } + for (const auto& pair : mut_consumer_name2op) { + CHECK_JUST(job_builder->MutOpOnlyOnce(pair.second)); + } CHECK_EQ(after_acc_nccl_op_confs.size(), after_acc_nccl_parallel_confs.size()); for (int64_t i = 0; i < after_acc_nccl_op_confs.size(); ++i) { - job_builder->AddOp(after_acc_nccl_parallel_confs.at(i), after_acc_nccl_op_confs.at(i)); + CHECK_JUST( + job_builder->AddOp(after_acc_nccl_parallel_confs.at(i), after_acc_nccl_op_confs.at(i))); } } +} - if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { - LOG(INFO) << " Try insert nccl logical ops into job: " - << job_builder->job().job_conf().job_name() << ". ...End\n\n"; +Maybe<void> InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { + auto OpGraphForEachInDataAndCtrlNode = [&](OpNode* node, + const std::function<void(OpNode*)>& Handler) { + op_graph.ForEachDataAndCtrlInNode(node, Handler); + }; + auto OpGraphForEachOutDataAndCtrlNode = [&](OpNode* node, + const std::function<void(OpNode*)>& Handler) { + op_graph.ForEachDataAndCtrlOutNode(node, Handler); + }; + + std::vector<const OpNode*> ordered_op_nodes; + HashMap<const OpNode*, int64_t> op_node2global_order; + op_graph.TopoForEachNode(op_graph.DataOrCtrlSourceNodes(), OpGraphForEachInDataAndCtrlNode, + OpGraphForEachOutDataAndCtrlNode, [&](const OpNode* node) { + ordered_op_nodes.push_back(node); + op_node2global_order.emplace(node, ordered_op_nodes.size() - 1); + }); + + std::vector<HashSet<const OpNode*>> subgraph_list; + FindAllConnectedSubgraphForGpuExecOrder(&subgraph_list, op_graph, ordered_op_nodes); + if (subgraph_list.size() == 0) { return Maybe<void>::Ok(); } + + auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) { + return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); + }; + + auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); + + HashMap<std::string, PlacementNcclSubGraghsInfo> placement2subgraphs; + for (const auto& subgraph : subgraph_list) { + const OpNode* rand_node = *subgraph.begin(); + const ParallelDesc& this_parallel_desc = rand_node->parallel_desc(); + std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); + const std::shared_ptr<const Shape>& this_time_shape = GetOpNodeTimeShape(rand_node); + auto it = placement2subgraphs.find(key); + if (it == placement2subgraphs.end()) { + it = placement2subgraphs.emplace(key, PlacementNcclSubGraghsInfo()).first; + auto& info = it->second; + info.seed_parallel_desc = &this_parallel_desc; + info.seed_time_shape = this_time_shape; + info.ordered_subgraph.push_back(std::make_shared<InsertNcclSubGraph>()); + InitInsertNcclSubGraphInfoFromSet(info.ordered_subgraph.back(), subgraph, + op_node2global_order, CmpOpNodeOrder); + } else { + auto& info = it->second; + if (SharedPtrShapeEqual(info.seed_time_shape, this_time_shape)) { + CHECK(this_parallel_desc.EqualsIgnoringHierarchy(*info.seed_parallel_desc)); + std::shared_ptr<InsertNcclSubGraph> nccl_subgraph_info = + std::make_shared<InsertNcclSubGraph>(); + InitInsertNcclSubGraphInfoFromSet(nccl_subgraph_info, subgraph, op_node2global_order, + CmpOpNodeOrder); + CHECK_GT(info.ordered_subgraph.size(), 0); + const auto& first_graph = info.ordered_subgraph.front(); + const auto& last_graph = info.ordered_subgraph.back(); + int64_t first_order = first_graph->begin_op_global_order; + int64_t last_order = last_graph->end_op_global_order; + if (nccl_subgraph_info->end_op_global_order < first_order) { + if (IsReachable(nccl_subgraph_info->end_op->op().op_name(), + first_graph->begin_op->op().op_name())) { + info.ordered_subgraph.insert(info.ordered_subgraph.begin(), nccl_subgraph_info); + } + } else if (nccl_subgraph_info->begin_op_global_order > last_order) { + if (IsReachable(last_graph->end_op->op().op_name(), + nccl_subgraph_info->begin_op->op().op_name())) { + info.ordered_subgraph.push_back(nccl_subgraph_info); + } + } else { + auto before = info.ordered_subgraph.begin(); + auto next = before + 1; + while (next != info.ordered_subgraph.end()) { + if ((*before)->end_op_global_order < nccl_subgraph_info->begin_op_global_order + && nccl_subgraph_info->end_op_global_order < (*next)->begin_op_global_order) { + if (IsReachable((*before)->end_op->op().op_name(), + nccl_subgraph_info->begin_op->op().op_name()) + && IsReachable(nccl_subgraph_info->end_op->op().op_name(), + (*next)->begin_op->op().op_name())) { + info.ordered_subgraph.insert(next, nccl_subgraph_info); + } + break; + } + before = next; + next++; + } + } + } + } } - std::vector<OperatorConf> mut_op_confs; - for (const std::string& mut_op_name : mut_op_names) { - mut_op_confs.push_back(subgraph_op_name2conf.at(mut_op_name)); + for (const OpNode* this_node : ordered_op_nodes) { + if (IsAccOpNode(this_node)) { + const ParallelDesc& this_parallel_desc = this_node->parallel_desc(); + std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); + auto it = placement2subgraphs.find(key); + if (it != placement2subgraphs.end()) { it->second.ordered_acc_op_nodes.push_back(this_node); } + } } - job_builder->MutOpsOnlyOnce(mut_op_confs); - CHECK_EQ(nccl_op_confs.size(), nccl_op_parallel_confs.size()); - for (int64_t i = 0; i < nccl_op_confs.size(); ++i) { - JUST(job_builder->AddOp(nccl_op_parallel_confs.at(i), nccl_op_confs.at(i))); + for (auto& pair : placement2subgraphs) { + PlacementNcclSubGraghsInfo& info = pair.second; + for (int i = 0; i < info.ordered_subgraph.size() - 1; i++) { + CHECK_LT(info.ordered_subgraph.at(i)->end_op_global_order, + info.ordered_subgraph.at(i + 1)->begin_op_global_order); + } + + // NOTE(chengcheng): insert nccl ops for each subgraph + for (const auto& subgraph_ptr : info.ordered_subgraph) { + auto& ordered_op_nodes = subgraph_ptr->ordered_op_nodes; + InsertNcclLogicalOpsInSubGraph(op_graph, job_builder, ordered_op_nodes, IsReachable); + } + + // NOTE(chengcheng): insert acc for all subgraph with same placement group + const OpNode* bw_sink_op = info.ordered_subgraph.back()->end_op; + const std::vector<const OpNode*>& ordered_acc_op_nodes = info.ordered_acc_op_nodes; + + if (!ordered_acc_op_nodes.empty()) { + InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( + op_graph, job_builder, ordered_acc_op_nodes, op_node2global_order, bw_sink_op); + } } return Maybe<void>::Ok();