diff --git a/oneflow/core/actor/naive_actor.cpp b/oneflow/core/actor/naive_actor.cpp index 461444ceaefd097e081fcdd6ab054d28377c0813..8d0eaed3ea337eb7f4dbc6300621959e002a8d0c 100644 --- a/oneflow/core/actor/naive_actor.cpp +++ b/oneflow/core/actor/naive_actor.cpp @@ -29,8 +29,8 @@ void NaiveActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { REGISTER_ACTOR(TaskType::kSliceBoxing, NaiveActor); REGISTER_ACTOR(TaskType::kBoxingIdentity, NaiveActor); -REGISTER_ACTOR(TaskType::kBoxingS2SAll2AllPack, NaiveActor); -REGISTER_ACTOR(TaskType::kBoxingS2SAll2AllUnpack, NaiveActor); +REGISTER_ACTOR(TaskType::kCollectiveBoxingPack, NaiveActor); +REGISTER_ACTOR(TaskType::kCollectiveBoxingUnpack, NaiveActor); REGISTER_ACTOR(TaskType::kDecodeH2D, NaiveActor); } // namespace oneflow diff --git a/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp index 47b77cc99ded16b490f6f2d51f55600371e04706..02b4268c7cbd153fc86fbd2b70fe9dbfe6e600f5 100644 --- a/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp @@ -19,24 +19,22 @@ limitations under the License. namespace oneflow { Maybe<SubTskGphBuilderStatus> B21SubTskGphBuilder::Build( - SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi, - const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const { - if ((src_parallel_desc.parallel_num() == 1 || src_sbp_parallel.has_broadcast_parallel()) - && dst_parallel_desc.parallel_num() == 1) { - CompTaskNode* dst_node = sorted_dst_comp_tasks.front(); - CompTaskNode* nearest_src_node = - SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node); - CHECK_NOTNULL(nearest_src_node); - TaskNode* proxy = ctx->GetProxyNode(nearest_src_node, nearest_src_node->MemZoneId121(), - dst_node->machine_id(), dst_node->MemZoneId121()); - Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node); - return TRY(BuildSubTskGphBuilderStatus(sorted_src_comp_tasks.front(), - sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, - lbi, logical_blob_desc, "B21SubTskGphBuilder", "")); + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { + if ((in_parallel_desc.parallel_num() == 1 || in_sbp_parallel.has_broadcast_parallel()) + && out_parallel_desc.parallel_num() == 1) { + const int64_t out_parallel_id = 0; + const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( + in_parallel_desc, out_parallel_desc, out_parallel_id); + TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id); + TaskNode* proxy = ctx->GetProxyNode(nearest_in_node, nearest_in_node->MemZoneId121(), + out_parallel_desc, out_parallel_id); + sorted_out_tasks->push_back(proxy); + return TRY(BuildSubTskGphBuilderStatus("B21SubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } diff --git a/oneflow/core/graph/boxing/b21_sub_task_graph_builder.h b/oneflow/core/graph/boxing/b21_sub_task_graph_builder.h index 139d3422382fa4c7fedce7f8472028d5dc4f174e..e6894f7da7c3a83bc61722c06128d5a8f5efebc9 100644 --- a/oneflow/core/graph/boxing/b21_sub_task_graph_builder.h +++ b/oneflow/core/graph/boxing/b21_sub_task_graph_builder.h @@ -26,14 +26,13 @@ class B21SubTskGphBuilder final : public SubTskGphBuilder { B21SubTskGphBuilder() = default; ~B21SubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override; + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow diff --git a/oneflow/core/graph/boxing/boxing_logger.cpp b/oneflow/core/graph/boxing/boxing_logger.cpp index 42e47cc4ddbcde870f271bf1aae8cd86af96b5f9..844459c6adeaeb2fa21aa5d46d097c8d2a798558 100644 --- a/oneflow/core/graph/boxing/boxing_logger.cpp +++ b/oneflow/core/graph/boxing/boxing_logger.cpp @@ -58,17 +58,23 @@ std::string ShapeToString(const Shape& shape) { return shape_ss.str(); } -std::string SubTskGphBuilderStatusToCsvLine(const SubTskGphBuilderStatus& status) { +std::string MakeBoxingLoggerCsvRow(const SubTskGphBuilderStatus& status, + const std::string& src_op_name, const std::string& dst_op_name, + const ParallelDesc& src_parallel_desc, + const ParallelDesc& dst_parallel_desc, + const SbpParallel& src_sbp_parallel, + const SbpParallel& dst_sbp_parallel, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc) { std::string serialized_status; - serialized_status += status.src_op_name() + ","; - serialized_status += status.dst_op_name() + ","; - serialized_status += ParallelDescToString(status.src_parallel_desc()) + ","; - serialized_status += ParallelDescToString(status.dst_parallel_desc()) + ","; - serialized_status += SbpParallelToString(status.src_sbp_parallel()) + ","; - serialized_status += SbpParallelToString(status.dst_sbp_parallel()) + ","; - serialized_status += GenLogicalBlobName(status.lbi()) + ","; - serialized_status += DataType_Name(status.logical_blob_desc().data_type()) + ","; - serialized_status += ShapeToString(status.logical_blob_desc().shape()) + ","; + serialized_status += src_op_name + ","; + serialized_status += dst_op_name + ","; + serialized_status += ParallelDescToString(src_parallel_desc) + ","; + serialized_status += ParallelDescToString(dst_parallel_desc) + ","; + serialized_status += SbpParallelToString(src_sbp_parallel) + ","; + serialized_status += SbpParallelToString(dst_sbp_parallel) + ","; + serialized_status += GenLogicalBlobName(lbi) + ","; + serialized_status += DataType_Name(logical_blob_desc.data_type()) + ","; + serialized_status += ShapeToString(logical_blob_desc.shape()) + ","; serialized_status += status.builder_name() + ","; if (status.comment().empty()) { serialized_status += "-"; @@ -88,8 +94,14 @@ CsvBoxingLogger::CsvBoxingLogger(std::string path) { CsvBoxingLogger::~CsvBoxingLogger() { log_stream_->Flush(); } -void CsvBoxingLogger::Log(const SubTskGphBuilderStatus& status) { - log_stream_ << SubTskGphBuilderStatusToCsvLine(status); +void CsvBoxingLogger::Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name, + const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, + const ParallelDesc& dst_parallel_desc, + const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel, + const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc) { + log_stream_ << MakeBoxingLoggerCsvRow(status, src_op_name, dst_op_name, src_parallel_desc, + dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, + logical_blob_desc); } } // namespace oneflow diff --git a/oneflow/core/graph/boxing/boxing_logger.h b/oneflow/core/graph/boxing/boxing_logger.h index cf9c7880b34d0fa7e316697d5c7498e92a0a16b6..ba3c9678192469c75434b4161f9f270932331d4a 100644 --- a/oneflow/core/graph/boxing/boxing_logger.h +++ b/oneflow/core/graph/boxing/boxing_logger.h @@ -27,7 +27,11 @@ class BoxingLogger { BoxingLogger() = default; virtual ~BoxingLogger() = default; - virtual void Log(const SubTskGphBuilderStatus& status) = 0; + virtual void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name, + const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, + const ParallelDesc& dst_parallel_desc, const SbpParallel& src_sbp_parallel, + const SbpParallel& dst_sbp_parallel, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc) = 0; }; class NullBoxingLogger final : public BoxingLogger { @@ -36,7 +40,11 @@ class NullBoxingLogger final : public BoxingLogger { NullBoxingLogger() = default; ~NullBoxingLogger() override = default; - void Log(const SubTskGphBuilderStatus& status) override{}; + void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name, + const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, + const ParallelDesc& dst_parallel_desc, const SbpParallel& src_sbp_parallel, + const SbpParallel& dst_sbp_parallel, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc) override{}; }; class CsvBoxingLogger final : public BoxingLogger { @@ -46,7 +54,11 @@ class CsvBoxingLogger final : public BoxingLogger { CsvBoxingLogger(std::string path); ~CsvBoxingLogger() override; - void Log(const SubTskGphBuilderStatus& status) override; + void Log(const SubTskGphBuilderStatus& status, const std::string& src_op_name, + const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, + const ParallelDesc& dst_parallel_desc, const SbpParallel& src_sbp_parallel, + const SbpParallel& dst_sbp_parallel, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc) override; private: std::unique_ptr<TeePersistentLogStream> log_stream_; diff --git a/oneflow/core/graph/boxing/chain_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/chain_sub_task_graph_builder.cpp index 57a68b096e4aa7b66efc7dcba2ba7eb56dfcb6d2..8ea54f918b8ca6b2abaf09d9177621553e4c685f 100644 --- a/oneflow/core/graph/boxing/chain_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/chain_sub_task_graph_builder.cpp @@ -19,15 +19,16 @@ limitations under the License. namespace oneflow { Maybe<SubTskGphBuilderStatus> ChainSubTskGphBuilder::Build( - SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi, - const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const { + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { for (const auto& builder : builders_) { Maybe<SubTskGphBuilderStatus> boxing_builder_status = TRY(builder->Build( - ctx, sorted_src_comp_tasks, sorted_dst_comp_tasks, src_parallel_desc, dst_parallel_desc, - lbi, logical_blob_desc, src_sbp_parallel, dst_sbp_parallel)); + ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, in_parallel_desc, + out_parallel_desc, lbi, logical_blob_desc, in_sbp_parallel, out_sbp_parallel, time_shape)); if (!boxing_builder_status.IsOk() && SubTskGphBuilderUtil::IsErrorBoxingNotSupported(*boxing_builder_status.error())) { continue; diff --git a/oneflow/core/graph/boxing/chain_sub_task_graph_builder.h b/oneflow/core/graph/boxing/chain_sub_task_graph_builder.h index 6cac4da80dc4d7b52087636a4e63d82cccc66eda..f3132d3d1268819b6350e6dcbdb392c973074dbb 100644 --- a/oneflow/core/graph/boxing/chain_sub_task_graph_builder.h +++ b/oneflow/core/graph/boxing/chain_sub_task_graph_builder.h @@ -27,14 +27,13 @@ class ChainSubTskGphBuilder final : public SubTskGphBuilder { : builders_(std::move(builders)) {} ~ChainSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override; + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: std::vector<std::shared_ptr<SubTskGphBuilder>> builders_; diff --git a/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp index d9a71d3c2c759d44cbf803a95142ce8f0d381519..a95f39152c6129bbd700194076ef7ef4eb42df2f 100644 --- a/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp @@ -19,8 +19,8 @@ limitations under the License. #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" #include "oneflow/core/graph/collective_boxing_task_node.h" #include "oneflow/core/graph/slice_boxing_task_node.h" -#include "oneflow/core/graph/boxing_s2s_all2all_pack_compute_task_node.h" -#include "oneflow/core/graph/boxing_s2s_all2all_unpack_compute_task_node.h" +#include "oneflow/core/graph/collective_boxing_pack_task_node.h" +#include "oneflow/core/graph/collective_boxing_unpack_task_node.h" #ifdef WITH_CUDA #include <nccl.h> #endif @@ -92,33 +92,28 @@ class NcclCollectiveBoxingAllReduceSubTskGphBuilder final : public SubTskGphBuil NcclCollectiveBoxingAllReduceSubTskGphBuilder() = default; ~NcclCollectiveBoxingAllReduceSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override { - if (dst_parallel_desc.Equals(src_parallel_desc) + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override { + if (out_parallel_desc.Equals(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) - && dst_parallel_desc.device_type() == DeviceType::kGPU - && dst_parallel_desc.parallel_num() > 1 - && SubTskGphBuilderUtil::IsBoxingP2B(src_sbp_parallel, dst_sbp_parallel)) { + && out_parallel_desc.device_type() == DeviceType::kGPU + && out_parallel_desc.parallel_num() > 1 + && SubTskGphBuilderUtil::IsBoxingP2B(in_sbp_parallel, out_sbp_parallel)) { const std::string op_name = "System-Boxing-NcclCollectiveBoxingAllReduce-" + NewUniqueId(); - FOR_RANGE(int64_t, i, 0, src_parallel_desc.parallel_num()) { - CompTaskNode* src_node = sorted_src_comp_tasks.at(i); - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(i); + FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { + TaskNode* in_node = sorted_in_tasks.at(i); auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>(); - NcclInitCollectiveNode(collective_node, src_parallel_desc, i, op_name, lbi, + NcclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeAllReduce, -1); - Connect<TaskNode>(src_node, ctx->task_graph()->NewEdge(), collective_node); - Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node); + Connect<TaskNode>(in_node, ctx->task_graph()->NewEdge(), collective_node); + sorted_out_tasks->push_back(collective_node); } - return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - "NcclCollectiveBoxingAllReduceSubTskGphBuilder", "")); + return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingAllReduceSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } @@ -131,36 +126,32 @@ class NcclCollectiveBoxingReduceScatterSubTskGphBuilder final : public SubTskGph NcclCollectiveBoxingReduceScatterSubTskGphBuilder() = default; ~NcclCollectiveBoxingReduceScatterSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override { - if (dst_parallel_desc.Equals(src_parallel_desc) + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override { + if (out_parallel_desc.Equals(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) - && dst_parallel_desc.device_type() == DeviceType::kGPU - && dst_parallel_desc.parallel_num() > 1 - && logical_blob_desc.shape().At(0) % dst_parallel_desc.parallel_num() == 0 - && SubTskGphBuilderUtil::IsBoxingP2S(src_sbp_parallel, dst_sbp_parallel) - && dst_sbp_parallel.split_parallel().axis() == 0) { + && out_parallel_desc.device_type() == DeviceType::kGPU + && out_parallel_desc.parallel_num() > 1 + && logical_blob_desc.shape().At(0) % out_parallel_desc.parallel_num() == 0 + && SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel) + && out_sbp_parallel.split_parallel().axis() == 0) { const std::string op_name = "System-Boxing-NcclCollectiveBoxingReduceScatter-" + NewUniqueId(); - FOR_RANGE(int64_t, i, 0, src_parallel_desc.parallel_num()) { - CompTaskNode* src_node = sorted_src_comp_tasks.at(i); - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(i); + FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { + TaskNode* in_node = sorted_in_tasks.at(i); auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>(); - NcclInitCollectiveNode(collective_node, src_parallel_desc, i, op_name, lbi, + NcclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeReduceScatter, -1); - Connect<TaskNode>(src_node, ctx->task_graph()->NewEdge(), collective_node); - Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node); + Connect<TaskNode>(in_node, ctx->task_graph()->NewEdge(), collective_node); + sorted_out_tasks->push_back(collective_node); } - return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - "NcclCollectiveBoxingReduceScatterSubTskGphBuilder", "")); + return TRY( + BuildSubTskGphBuilderStatus("NcclCollectiveBoxingReduceScatterSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } @@ -173,38 +164,33 @@ class NcclCollectiveBoxingAllGatherSubTskGphBuilder final : public SubTskGphBuil NcclCollectiveBoxingAllGatherSubTskGphBuilder() = default; ~NcclCollectiveBoxingAllGatherSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override { - if (dst_parallel_desc.EqualsIgnoringDeviceType(src_parallel_desc) + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override { + if (out_parallel_desc.EqualsIgnoringDeviceType(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) - && SubTskGphBuilderUtil::IsDeviceTypeCPUOrGPU(src_parallel_desc) - && dst_parallel_desc.device_type() == DeviceType::kGPU - && dst_parallel_desc.parallel_num() > 1 - && logical_blob_desc.shape().At(0) % dst_parallel_desc.parallel_num() == 0 - && SubTskGphBuilderUtil::IsBoxingS2B(src_sbp_parallel, dst_sbp_parallel) - && src_sbp_parallel.split_parallel().axis() == 0) { + && SubTskGphBuilderUtil::IsDeviceTypeCPUOrGPU(in_parallel_desc) + && out_parallel_desc.device_type() == DeviceType::kGPU + && out_parallel_desc.parallel_num() > 1 + && logical_blob_desc.shape().At(0) % out_parallel_desc.parallel_num() == 0 + && SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel) + && in_sbp_parallel.split_parallel().axis() == 0) { const std::string op_name = "System-Boxing-NcclCollectiveBoxingAllGather-" + NewUniqueId(); - FOR_RANGE(int64_t, i, 0, src_parallel_desc.parallel_num()) { - CompTaskNode* src_comp_task = sorted_src_comp_tasks.at(i); - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(i); - TaskNode* src_node = ctx->GetProxyNode(src_comp_task, src_comp_task->MemZoneId121(), - dst_node->machine_id(), dst_node->MemZoneId121()); + FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { + TaskNode* in_node = sorted_in_tasks.at(i); + TaskNode* in_node_proxy = + ctx->GetProxyNode(in_node, in_node->MemZoneId121(), out_parallel_desc, i); auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>(); - NcclInitCollectiveNode(collective_node, dst_parallel_desc, i, op_name, lbi, + NcclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeAllGather, -1); - Connect<TaskNode>(src_node, ctx->task_graph()->NewEdge(), collective_node); - Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node); + Connect<TaskNode>(in_node_proxy, ctx->task_graph()->NewEdge(), collective_node); + sorted_out_tasks->push_back(collective_node); } - return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - "NcclCollectiveBoxingAllGatherSubTskGphBuilder", "")); + return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingAllGatherSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } @@ -217,41 +203,36 @@ class NcclCollectiveBoxingReduceSubTskGphBuilder final : public SubTskGphBuilder NcclCollectiveBoxingReduceSubTskGphBuilder() = default; ~NcclCollectiveBoxingReduceSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override { - if (src_parallel_desc.parallel_num() > 1 && dst_parallel_desc.parallel_num() == 1 - && src_parallel_desc.device_type() == DeviceType::kGPU - && dst_parallel_desc.device_type() == DeviceType::kGPU + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override { + if (in_parallel_desc.parallel_num() > 1 && out_parallel_desc.parallel_num() == 1 + && in_parallel_desc.device_type() == DeviceType::kGPU + && out_parallel_desc.device_type() == DeviceType::kGPU && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) - && src_sbp_parallel.has_partial_sum_parallel()) { - const int64_t root_parallel_id = FindRootParallelId(src_parallel_desc, dst_parallel_desc); + && in_sbp_parallel.has_partial_sum_parallel()) { + const int64_t root_parallel_id = FindRootParallelId(in_parallel_desc, out_parallel_desc); if (root_parallel_id == -1) { return Error::BoxingNotSupportedError(); } const std::string op_name = "System-Boxing-NcclCollectiveBoxingReduce-" + NewUniqueId(); - FOR_RANGE(int64_t, i, 0, src_parallel_desc.parallel_num()) { - CompTaskNode* src_node = sorted_src_comp_tasks.at(i); + sorted_ctrl_tasks->resize(out_parallel_desc.parallel_num()); + FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { + TaskNode* in_node = sorted_in_tasks.at(i); auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>(); - NcclInitCollectiveNode(collective_node, src_parallel_desc, i, op_name, lbi, + NcclInitCollectiveNode(collective_node, in_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeReduce, root_parallel_id); - Connect<TaskNode>(src_node, ctx->task_graph()->NewEdge(), collective_node); - CompTaskNode* dst_node = sorted_dst_comp_tasks.front(); + Connect<TaskNode>(in_node, ctx->task_graph()->NewEdge(), collective_node); if (i == root_parallel_id) { - Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node); + sorted_out_tasks->push_back(collective_node); } else { - collective_node->BuildCtrlRegstDesc(dst_node); - Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node); + sorted_ctrl_tasks->at(0).push_back(collective_node); } } - return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - "NcclCollectiveBoxingReduceSubTskGphBuilder", "")); + return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingReduceSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } @@ -264,54 +245,51 @@ class CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder final : public Su CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder() = default; ~CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override { - if (src_parallel_desc.parallel_num() == 1 && dst_parallel_desc.parallel_num() > 1 - && src_parallel_desc.device_type() == DeviceType::kCPU - && dst_parallel_desc.device_type() == DeviceType::kGPU + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override { + if (in_parallel_desc.parallel_num() == 1 && out_parallel_desc.parallel_num() > 1 + && in_parallel_desc.device_type() == DeviceType::kCPU + && out_parallel_desc.device_type() == DeviceType::kGPU && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) && logical_blob_desc.shape().elem_cnt() >= 1024 - && dst_sbp_parallel.has_broadcast_parallel() + && out_sbp_parallel.has_broadcast_parallel() // a potential optimization: flat the blob and then relax this requirement - && logical_blob_desc.shape().At(0) % dst_parallel_desc.parallel_num() == 0) { + && logical_blob_desc.shape().At(0) % out_parallel_desc.parallel_num() == 0) { const TensorSliceView in_slice = SubTskGphBuilderUtil::GetBroadcastTensorSliceView(logical_blob_desc); SbpParallel split_sbp_parallel; split_sbp_parallel.mutable_split_parallel()->set_axis(0); std::vector<TensorSliceView> out_slices = SubTskGphBuilderUtil::GetTensorSliceView( - dst_parallel_desc.parallel_num(), split_sbp_parallel, logical_blob_desc); + out_parallel_desc.parallel_num(), split_sbp_parallel, logical_blob_desc); const std::string op_name = "System-Boxing-NcclCollectiveBoxingAllGather-" + NewUniqueId(); - FOR_RANGE(int64_t, out_id, 0, dst_parallel_desc.parallel_num()) { + FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) { const TensorSliceView& out_slice = out_slices.at(out_id); - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(out_id); - CompTaskNode* src_node = - SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node); + const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( + in_parallel_desc, out_parallel_desc, out_id); + + TaskNode* in_node = sorted_in_tasks.at(nearest_in_parallel_id); SliceBoxingTaskNode* slice_node = ctx->task_graph()->NewNode<SliceBoxingTaskNode>(); // slice on cpu - const auto src_machine_id = CHECK_JUST(src_parallel_desc.MachineId4ParallelId(0)); - slice_node->Init(lbi, out_slice, kSliceBoxingTaskModeCopy, src_machine_id, - Global<IDMgr>::Get()->PickCpuThrdIdEvenly(src_machine_id)); - slice_node->ConnectToSrcNodeWithSlice(src_node, ctx->task_graph()->NewEdge(), in_slice); + const auto in_machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(0)); + slice_node->Init(lbi, out_slice, kSliceBoxingTaskModeCopy, in_machine_id, + Global<IDMgr>::Get()->PickCpuThrdIdEvenly(in_machine_id)); + slice_node->ConnectToSrcNodeWithSlice(in_node, ctx->task_graph()->NewEdge(), in_slice); // copy to dst gpu TaskNode* slice_node_proxy = - ctx->GetProxyNode(slice_node, slice_node->MemZoneId121(), dst_node->machine_id(), - dst_node->MemZoneId121()); + ctx->GetProxyNode(slice_node, slice_node->MemZoneId121(), out_parallel_desc, out_id); // allgather auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>(); - NcclInitCollectiveNode(collective_node, dst_parallel_desc, out_id, op_name, lbi, + NcclInitCollectiveNode(collective_node, out_parallel_desc, out_id, op_name, lbi, logical_blob_desc, OpType::kOpTypeAllGather, -1); Connect<TaskNode>(slice_node_proxy, ctx->task_graph()->NewEdge(), collective_node); - Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node); + sorted_out_tasks->push_back(collective_node); } return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, "CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); @@ -325,57 +303,51 @@ class NcclCollectiveBoxingBroadcastSubTskGphBuilder final : public SubTskGphBuil NcclCollectiveBoxingBroadcastSubTskGphBuilder() = default; ~NcclCollectiveBoxingBroadcastSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override { - if (src_parallel_desc.parallel_num() == 1 && dst_parallel_desc.parallel_num() > 1 - && (src_parallel_desc.device_type() == DeviceType::kGPU - || (src_parallel_desc.device_type() == DeviceType::kCPU + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override { + if (in_parallel_desc.parallel_num() == 1 && out_parallel_desc.parallel_num() > 1 + && (in_parallel_desc.device_type() == DeviceType::kGPU + || (in_parallel_desc.device_type() == DeviceType::kCPU && logical_blob_desc.shape().elem_cnt() >= 1024)) - && dst_parallel_desc.device_type() == DeviceType::kGPU + && out_parallel_desc.device_type() == DeviceType::kGPU && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) - && dst_sbp_parallel.has_broadcast_parallel()) { - TaskNode* gpu_src_node = nullptr; + && out_sbp_parallel.has_broadcast_parallel()) { + TaskNode* gpu_in_node = nullptr; int64_t root_parallel_id = -1; - if (src_parallel_desc.device_type() == DeviceType::kCPU) { - auto* cpu_src_node = sorted_src_comp_tasks.front(); + if (in_parallel_desc.device_type() == DeviceType::kCPU) { + auto* cpu_in_node = sorted_in_tasks.front(); root_parallel_id = - SubTskGphBuilderUtil::FindNearestNodeIndex(sorted_dst_comp_tasks, cpu_src_node); - auto* nearest_dst_node = sorted_dst_comp_tasks.at(root_parallel_id); - gpu_src_node = - ctx->GetProxyNode(cpu_src_node, cpu_src_node->MemZoneId121(), - nearest_dst_node->machine_id(), nearest_dst_node->MemZoneId121()); - } else if (src_parallel_desc.device_type() == DeviceType::kGPU) { - root_parallel_id = FindRootParallelId(dst_parallel_desc, src_parallel_desc); - gpu_src_node = sorted_src_comp_tasks.front(); + SubTskGphBuilderUtil::FindNearestSrcParallelId(out_parallel_desc, in_parallel_desc, 0); + gpu_in_node = ctx->GetProxyNode(cpu_in_node, cpu_in_node->MemZoneId121(), out_parallel_desc, + root_parallel_id); + + } else if (in_parallel_desc.device_type() == DeviceType::kGPU) { + root_parallel_id = FindRootParallelId(out_parallel_desc, in_parallel_desc); + gpu_in_node = sorted_in_tasks.front(); } else { return Error::BoxingNotSupportedError(); } if (root_parallel_id == -1) { return Error::BoxingNotSupportedError(); } const std::string op_name = "System-Boxing-NcclCollectiveBoxingBroadcast-" + NewUniqueId(); - FOR_RANGE(int64_t, i, 0, dst_parallel_desc.parallel_num()) { - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(i); + FOR_RANGE(int64_t, i, 0, out_parallel_desc.parallel_num()) { auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>(); - NcclInitCollectiveNode(collective_node, dst_parallel_desc, i, op_name, lbi, + NcclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeBroadcast, root_parallel_id); if (i == root_parallel_id) { - Connect<TaskNode>(gpu_src_node, ctx->task_graph()->NewEdge(), collective_node); + Connect<TaskNode>(gpu_in_node, ctx->task_graph()->NewEdge(), collective_node); } else { - gpu_src_node->BuildCtrlRegstDesc(collective_node); - Connect<TaskNode>(gpu_src_node, ctx->task_graph()->NewEdge(), collective_node); + gpu_in_node->BuildCtrlRegstDesc(collective_node); + Connect<TaskNode>(gpu_in_node, ctx->task_graph()->NewEdge(), collective_node); } - Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), dst_node); + sorted_out_tasks->push_back(collective_node); } - return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - "NcclCollectiveBoxingBroadcastSubTskGphBuilder", "")); + return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingBroadcastSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } @@ -388,55 +360,51 @@ class NcclCollectiveBoxingAll2AllSubTskGphBuilder final : public SubTskGphBuilde NcclCollectiveBoxingAll2AllSubTskGphBuilder() = default; ~NcclCollectiveBoxingAll2AllSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override { - if (dst_parallel_desc.EqualsIgnoringDeviceType(src_parallel_desc) + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override { + if (out_parallel_desc.EqualsIgnoringDeviceType(in_parallel_desc) && !SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc) - && src_parallel_desc.device_type() == DeviceType::kGPU - && dst_parallel_desc.device_type() == DeviceType::kGPU - && dst_parallel_desc.parallel_num() > 1 - && logical_blob_desc.shape().At(src_sbp_parallel.split_parallel().axis()) - % src_parallel_desc.parallel_num() + && in_parallel_desc.device_type() == DeviceType::kGPU + && out_parallel_desc.device_type() == DeviceType::kGPU + && out_parallel_desc.parallel_num() > 1 + && logical_blob_desc.shape().At(in_sbp_parallel.split_parallel().axis()) + % in_parallel_desc.parallel_num() == 0 - && logical_blob_desc.shape().At(dst_sbp_parallel.split_parallel().axis()) - % dst_parallel_desc.parallel_num() + && logical_blob_desc.shape().At(out_sbp_parallel.split_parallel().axis()) + % out_parallel_desc.parallel_num() == 0 - && src_sbp_parallel.split_parallel().axis() != dst_sbp_parallel.split_parallel().axis() - && SubTskGphBuilderUtil::IsBoxingS2S(src_sbp_parallel, dst_sbp_parallel)) { + && in_sbp_parallel.split_parallel().axis() != out_sbp_parallel.split_parallel().axis() + && SubTskGphBuilderUtil::IsBoxingS2S(in_sbp_parallel, out_sbp_parallel)) { const std::string op_name = "System-Boxing-NcclCollectiveBoxingAll2All-" + NewUniqueId(); - FOR_RANGE(int64_t, i, 0, src_parallel_desc.parallel_num()) { - CompTaskNode* src_node = sorted_src_comp_tasks.at(i); - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(i); - - BoxingS2SAll2AllPackCompTaskNode* pack_node = - ctx->task_graph()->NewNode<BoxingS2SAll2AllPackCompTaskNode>(); - pack_node->Init(src_node, lbi, dst_sbp_parallel.split_parallel().axis()); - Connect<TaskNode>(src_node, ctx->task_graph()->NewEdge(), pack_node); + FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { + const int64_t machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(i)); + const int64_t device_id = CHECK_JUST(in_parallel_desc.DeviceId4ParallelId(i)); + const int64_t thrd_id = Global<IDMgr>::Get()->GetGpuComputeThrdId(device_id); + TaskNode* in_node = sorted_in_tasks.at(i); + CollectiveBoxingPackTaskNode* pack_node = + ctx->task_graph()->NewNode<CollectiveBoxingPackTaskNode>(); + pack_node->Init(machine_id, thrd_id, NewAreaId(), lbi, logical_blob_desc.shape(), + in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); + Connect<TaskNode>(in_node, ctx->task_graph()->NewEdge(), pack_node); auto* collective_node = ctx->task_graph()->NewNode<CollectiveBoxingGenericTaskNode>(); - NcclInitCollectiveNode(collective_node, dst_parallel_desc, i, op_name, lbi, + NcclInitCollectiveNode(collective_node, out_parallel_desc, i, op_name, lbi, logical_blob_desc, OpType::kOpTypeAll2All, -1); Connect<TaskNode>(pack_node, ctx->task_graph()->NewEdge(), collective_node); - BoxingS2SAll2AllUnpackCompTaskNode* unpack_node = - ctx->task_graph()->NewNode<BoxingS2SAll2AllUnpackCompTaskNode>(); - unpack_node->Init(src_node, lbi, logical_blob_desc.shape(), - src_sbp_parallel.split_parallel().axis(), - dst_sbp_parallel.split_parallel().axis()); - + CollectiveBoxingUnpackTaskNode* unpack_node = + ctx->task_graph()->NewNode<CollectiveBoxingUnpackTaskNode>(); + unpack_node->Init(machine_id, thrd_id, NewAreaId(), lbi, logical_blob_desc.shape(), + in_sbp_parallel, out_sbp_parallel, in_parallel_desc.parallel_num()); Connect<TaskNode>(collective_node, ctx->task_graph()->NewEdge(), unpack_node); - Connect<TaskNode>(unpack_node, ctx->task_graph()->NewEdge(), dst_node); + sorted_out_tasks->push_back(unpack_node); } - return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - "NcclCollectiveBoxingAll2AllSubTskGphBuilder", "")); + return TRY(BuildSubTskGphBuilderStatus("NcclCollectiveBoxingAll2AllSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } @@ -466,18 +434,17 @@ CollectiveBoxingSubTskGphBuilder::CollectiveBoxingSubTskGphBuilder() { } Maybe<SubTskGphBuilderStatus> CollectiveBoxingSubTskGphBuilder::Build( - SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi, - const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const { + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (!GlobalJobDesc().Bool("__is_user_function__")) { return Error::BoxingNotSupportedError(); } - if (!IsSourceTimeShape(*sorted_src_comp_tasks.front()->logical_node()->out_blob_time_shape())) { - return Error::BoxingNotSupportedError(); - } - return chain_builder_->Build(ctx, sorted_src_comp_tasks, sorted_dst_comp_tasks, src_parallel_desc, - dst_parallel_desc, lbi, logical_blob_desc, src_sbp_parallel, - dst_sbp_parallel); + if (!IsSourceTimeShape(time_shape)) { return Error::BoxingNotSupportedError(); } + return chain_builder_->Build(ctx, sorted_in_tasks, sorted_out_tasks, sorted_ctrl_tasks, + in_parallel_desc, out_parallel_desc, lbi, logical_blob_desc, + in_sbp_parallel, out_sbp_parallel, time_shape); } } // namespace oneflow diff --git a/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h b/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h index 3f55c86c52e274826232679e24b6540aa0a161b6..d18a926f5de6883ce963856c171ee624c83013fa 100644 --- a/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h +++ b/oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.h @@ -26,14 +26,13 @@ class CollectiveBoxingSubTskGphBuilder final : public SubTskGphBuilder { CollectiveBoxingSubTskGphBuilder(); ~CollectiveBoxingSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override; + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; private: std::unique_ptr<SubTskGphBuilder> chain_builder_; diff --git a/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp index b349515d5486a0da31bd5ce579565e54e5d0ae04..a483ae9b22fb06c8135a846eb093d9933b12e771 100644 --- a/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.cpp @@ -19,26 +19,23 @@ limitations under the License. namespace oneflow { Maybe<SubTskGphBuilderStatus> NaiveB2BSubTskGphBuilder::Build( - SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi, - const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const { - if ((src_parallel_desc.parallel_num() == 1 || src_sbp_parallel.has_broadcast_parallel()) - && (dst_parallel_desc.parallel_num() == 1 || dst_sbp_parallel.has_broadcast_parallel())) { - std::vector<CompTaskNode*> nearest_src_comp_tasks; - for (CompTaskNode* dst_node : sorted_dst_comp_tasks) { - CompTaskNode* nearest_src_node = - SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node); - CHECK_NOTNULL(nearest_src_node); - TaskNode* proxy = ctx->GetProxyNode(nearest_src_node, nearest_src_node->MemZoneId121(), - dst_node->machine_id(), dst_node->MemZoneId121()); - Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node); + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { + if ((in_parallel_desc.parallel_num() == 1 || in_sbp_parallel.has_broadcast_parallel()) + && (out_parallel_desc.parallel_num() == 1 || out_sbp_parallel.has_broadcast_parallel())) { + FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) { + const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( + in_parallel_desc, out_parallel_desc, out_id); + TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id); + TaskNode* proxy = ctx->GetProxyNode(nearest_in_node, nearest_in_node->MemZoneId121(), + out_parallel_desc, out_id); + sorted_out_tasks->push_back(proxy); } - return TRY(BuildSubTskGphBuilderStatus(sorted_src_comp_tasks.front(), - sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, - lbi, logical_blob_desc, "NaiveB2BSubTskGphBuilder", "")); + return TRY(BuildSubTskGphBuilderStatus("NaiveB2BSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } diff --git a/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h b/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h index 71356bdb232a90e4c9268d55749c55343e16e4cd..f54e5fd48c5992728a2269a25f1327f0bfa0e0db 100644 --- a/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h +++ b/oneflow/core/graph/boxing/naive_b2b_sub_task_graph_builder.h @@ -26,14 +26,13 @@ class NaiveB2BSubTskGphBuilder final : public SubTskGphBuilder { NaiveB2BSubTskGphBuilder() = default; ~NaiveB2BSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override; + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow diff --git a/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp index d8f94084f00e5b9e4fe1437c561df363d91e63af..d70b9dcc51467618496d29010d81dc3e34e52610 100644 --- a/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp @@ -20,51 +20,61 @@ limitations under the License. namespace oneflow { Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build( - SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi, - const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const { - if ((src_parallel_desc.parallel_num() == 1 || src_sbp_parallel.has_broadcast_parallel()) - && dst_parallel_desc.parallel_num() != 1 && dst_sbp_parallel.has_partial_sum_parallel()) { - HashMap<CompTaskNode*, CompTaskNode*> dst_node2nearest_src_node; - int64_t nearest_dst_node_idx = -1; - int64_t nearest_dst_node_distance = -1; - std::vector<CompTaskNode*> nearest_src_comp_tasks; - for (int64_t dst_node_idx = 0; dst_node_idx < sorted_dst_comp_tasks.size(); ++dst_node_idx) { - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(dst_node_idx); - const int64_t nearest_src_node_idx = - SubTskGphBuilderUtil::FindNearestNodeIndex(sorted_src_comp_tasks, dst_node); - CHECK_NE_OR_RETURN(nearest_src_node_idx, -1); - CompTaskNode* nearest_src_node = sorted_src_comp_tasks.at(nearest_src_node_idx); - CHECK_OR_RETURN(dst_node2nearest_src_node.emplace(dst_node, nearest_src_node).second); - const int64_t distance = SubTskGphBuilderUtil::GetDistance(nearest_src_node, dst_node); - if (nearest_dst_node_idx == -1 || distance < nearest_dst_node_distance) { - nearest_dst_node_idx = dst_node_idx; - nearest_dst_node_distance = distance; + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { + if ((in_parallel_desc.parallel_num() == 1 || in_sbp_parallel.has_broadcast_parallel()) + && out_parallel_desc.parallel_num() != 1 && out_sbp_parallel.has_partial_sum_parallel()) { + HashMap<int64_t, int64_t> out_id2nearest_in_id; + int64_t nearest_out_node_idx = -1; + int64_t nearest_out_node_distance = -1; + + FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) { + const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( + in_parallel_desc, out_parallel_desc, out_id); + out_id2nearest_in_id.emplace(out_id, nearest_in_parallel_id); + const int64_t distance = SubTskGphBuilderUtil::GetDistance( + in_parallel_desc, nearest_in_parallel_id, out_parallel_desc, out_id); + if (nearest_out_node_idx == -1 || distance < nearest_out_node_distance) { + nearest_out_node_idx = out_id; + nearest_out_node_distance = distance; } } - for (int64_t dst_node_idx = 0; dst_node_idx < sorted_dst_comp_tasks.size(); ++dst_node_idx) { - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(dst_node_idx); - CompTaskNode* nearest_src_node = dst_node2nearest_src_node.at(dst_node); - if (dst_node_idx == nearest_dst_node_idx) { - TaskNode* proxy = ctx->GetProxyNode(nearest_src_node, nearest_src_node->MemZoneId121(), - dst_node->machine_id(), dst_node->MemZoneId121()); - Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node); + FOR_RANGE(int64_t, out_id, 0, out_parallel_desc.parallel_num()) { + const int64_t nearest_in_id = out_id2nearest_in_id.at(out_id); + TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_id); + if (out_id == nearest_out_node_idx) { + TaskNode* proxy = ctx->GetProxyNode(nearest_in_node, nearest_in_node->MemZoneId121(), + out_parallel_desc, out_id); + + sorted_out_tasks->push_back(proxy); } else { + const int64_t out_machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(out_id)); + const int64_t out_dev_phy_id = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(out_id)); + int64_t thrd_id; + if (out_parallel_desc.device_type() == DeviceType::kGPU) { +#ifdef WITH_CUDA + thrd_id = Global<IDMgr>::Get()->GetGpuComputeThrdId(out_dev_phy_id); +#else + UNIMPLEMENTED(); +#endif + } else if (out_parallel_desc.device_type() == DeviceType::kCPU) { + thrd_id = Global<IDMgr>::Get()->PickCpuThrdIdEvenly(out_machine_id); + } else { + UNIMPLEMENTED(); + } auto* zeros_node = ctx->task_graph()->NewNode<BoxingZerosTaskNode>(); - zeros_node->Init(dst_node->machine_id(), dst_node->thrd_id(), dst_node->area_id(), lbi, - logical_blob_desc.shape(), logical_blob_desc.data_type(), - *nearest_src_node->logical_node()->out_blob_time_shape()); - nearest_src_node->BuildCtrlRegstDesc(zeros_node); - Connect<TaskNode>(nearest_src_node, ctx->task_graph()->NewEdge(), zeros_node); - Connect<TaskNode>(zeros_node, ctx->task_graph()->NewEdge(), dst_node); + zeros_node->Init(out_machine_id, thrd_id, NewAreaId(), lbi, logical_blob_desc.shape(), + logical_blob_desc.data_type(), time_shape); + nearest_in_node->BuildCtrlRegstDesc(zeros_node); + Connect<TaskNode>(nearest_in_node, ctx->task_graph()->NewEdge(), zeros_node); + sorted_out_tasks->push_back(zeros_node); } } - return TRY(BuildSubTskGphBuilderStatus(sorted_src_comp_tasks.front(), - sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, - lbi, logical_blob_desc, "NaiveB2PSubTskGphBuilder", "")); + return TRY(BuildSubTskGphBuilderStatus("NaiveB2PSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } diff --git a/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h b/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h index dc5caac9bb2acd3f3c2c5327c384422f38676695..8133ed2209bd813037a0373db3413f9a69644521 100644 --- a/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h +++ b/oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h @@ -26,14 +26,13 @@ class NaiveB2PSubTskGphBuilder final : public SubTskGphBuilder { NaiveB2PSubTskGphBuilder() = default; ~NaiveB2PSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override; + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow diff --git a/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp index 77ebb6e8af2c1e4548627a44ec00e7c845558296..f40a71cf05e1c67dd094bd22a9c161b78a7b5265 100644 --- a/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.cpp @@ -19,26 +19,22 @@ limitations under the License. namespace oneflow { Maybe<SubTskGphBuilderStatus> OneToOneSubTskGphBuilder::Build( - SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi, - const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const { - if ((src_parallel_desc.parallel_num() == 1 && dst_parallel_desc.parallel_num() == 1) - || (src_parallel_desc.parallel_num() == dst_parallel_desc.parallel_num() - && src_sbp_parallel == dst_sbp_parallel)) { - for (int64_t i = 0; i < src_parallel_desc.parallel_num(); ++i) { - CompTaskNode* src_node = sorted_src_comp_tasks.at(i); - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(i); + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { + if ((in_parallel_desc.parallel_num() == 1 && out_parallel_desc.parallel_num() == 1) + || (in_parallel_desc.parallel_num() == out_parallel_desc.parallel_num() + && in_sbp_parallel == out_sbp_parallel)) { + for (int64_t i = 0; i < in_parallel_desc.parallel_num(); ++i) { + TaskNode* in_node = sorted_in_tasks.at(i); // TODO(liujuncheng): use lbi - TaskNode* proxy = ctx->GetProxyNode(src_node, src_node->MemZoneId121(), - dst_node->machine_id(), dst_node->MemZoneId121()); - Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node); + TaskNode* proxy = ctx->GetProxyNode(in_node, in_node->MemZoneId121(), out_parallel_desc, i); + sorted_out_tasks->push_back(proxy); } - return TRY(BuildSubTskGphBuilderStatus(sorted_src_comp_tasks.front(), - sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, - lbi, logical_blob_desc, "OneToOneSubTskGphBuilder", "")); + return TRY(BuildSubTskGphBuilderStatus("OneToOneSubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError(); } diff --git a/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h b/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h index 457822eb6184dc925d5f197bb1d609695d0ad280..acf475c7221f44912b00d30d8af966a16d8f3192 100644 --- a/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h +++ b/oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h @@ -26,14 +26,13 @@ class OneToOneSubTskGphBuilder final : public SubTskGphBuilder { OneToOneSubTskGphBuilder() = default; ~OneToOneSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override; + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow diff --git a/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp index cfa9939baa58c24d73dc2a62507ee24c498c5294..07e9535a90ae3a9cf90335e6c0d7e81767afcf73 100644 --- a/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp @@ -54,33 +54,34 @@ bool IsSameDevice(const ParallelDesc& in_pd, const ParallelDesc& out_pd, } // namespace Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build( - SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi, - const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const { + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const { if (SubTskGphBuilderUtil::BlobHasDynamicShape(logical_blob_desc)) { return Error::BoxingNotSupportedError(); } - if (!SubTskGphBuilderUtil::IsDeviceTypeCPUOrGPU(src_parallel_desc)) { + if (!SubTskGphBuilderUtil::IsDeviceTypeCPUOrGPU(in_parallel_desc)) { return Error::BoxingNotSupportedError(); } - if (!SubTskGphBuilderUtil::IsDeviceTypeCPUOrGPU(dst_parallel_desc)) { + if (!SubTskGphBuilderUtil::IsDeviceTypeCPUOrGPU(out_parallel_desc)) { return Error::BoxingNotSupportedError(); } - if (SubTskGphBuilderUtil::HasEmptySliceIfSplit(src_parallel_desc.parallel_num(), src_sbp_parallel, + if (SubTskGphBuilderUtil::HasEmptySliceIfSplit(in_parallel_desc.parallel_num(), in_sbp_parallel, logical_blob_desc)) { return Error::BoxingNotSupportedError(); } - if (SubTskGphBuilderUtil::HasEmptySliceIfSplit(dst_parallel_desc.parallel_num(), dst_sbp_parallel, + if (SubTskGphBuilderUtil::HasEmptySliceIfSplit(out_parallel_desc.parallel_num(), out_sbp_parallel, logical_blob_desc)) { return Error::BoxingNotSupportedError(); } - if (!(SubTskGphBuilderUtil::IsBoxingS2B(src_sbp_parallel, dst_sbp_parallel) - || SubTskGphBuilderUtil::IsBoxingS2S(src_sbp_parallel, dst_sbp_parallel) - || SubTskGphBuilderUtil::IsBoxingP2S(src_sbp_parallel, dst_sbp_parallel) - || SubTskGphBuilderUtil::IsBoxingP2B(src_sbp_parallel, dst_sbp_parallel) - || SubTskGphBuilderUtil::IsBoxingB2S(src_sbp_parallel, dst_sbp_parallel))) { + if (!(SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel) + || SubTskGphBuilderUtil::IsBoxingS2S(in_sbp_parallel, out_sbp_parallel) + || SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel) + || SubTskGphBuilderUtil::IsBoxingP2B(in_sbp_parallel, out_sbp_parallel) + || SubTskGphBuilderUtil::IsBoxingB2S(in_sbp_parallel, out_sbp_parallel))) { return Error::BoxingNotSupportedError(); } const auto GetBoxingGpuThrdId = [](const int64_t dev_id, CudaWorkType work_type) -> int64_t { @@ -419,8 +420,7 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build( }; const auto BuildSubTaskGphB2S = - [&ctx, &lbi, &CreateBoxingNode121, &CreateBoxingNodeToHost, &GetBoxingGpuThrdId, &NewEdge, - &sorted_src_comp_tasks, &sorted_dst_comp_tasks]( + [&ctx, &lbi, &CreateBoxingNode121, &CreateBoxingNodeToHost, &GetBoxingGpuThrdId, &NewEdge]( const ParallelDesc& in_pd, const ParallelDesc& out_pd, const SbpParallel& in_sbp, const SbpParallel& out_sbp, const BlobDesc& blob_desc, const std::vector<TaskNode*>& in_nodes, std::vector<TaskNode*>* out_nodes) { @@ -432,39 +432,36 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build( CHECK(!ContainsEmptySlice(out_slices)); FOR_RANGE(int64_t, out_id, 0, out_pd.parallel_num()) { const TensorSliceView& out_slice = out_slices.at(out_id); - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(out_id); const int64_t nearest_idx = - SubTskGphBuilderUtil::FindNearestNodeIndex(sorted_src_comp_tasks, dst_node); - CompTaskNode* src_node = sorted_src_comp_tasks.at(nearest_idx); + SubTskGphBuilderUtil::FindNearestSrcParallelId(in_pd, out_pd, out_id); + TaskNode* in_node = in_nodes.at(nearest_idx); SliceBoxingTaskNode* slice_node = CreateBoxingNode121(in_pd, nearest_idx, out_slice, kSliceBoxingTaskModeCopy); - slice_node->ConnectToSrcNodeWithSlice(src_node, NewEdge(), in_slice); - TaskNode* out_node = ctx->GetProxyNode(slice_node, slice_node->MemZoneId121(), - dst_node->machine_id(), dst_node->MemZoneId121()); + slice_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice); + TaskNode* out_node = + ctx->GetProxyNode(slice_node, slice_node->MemZoneId121(), out_pd, out_id); + out_nodes->push_back(out_node); } }; - std::vector<TaskNode*> in_nodes; - in_nodes.assign(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end()); - std::vector<TaskNode*> out_nodes; std::string comment; - if (SubTskGphBuilderUtil::IsBoxingS2B(src_sbp_parallel, dst_sbp_parallel)) { - BuildSubTaskGphS2B(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, - logical_blob_desc, in_nodes, &out_nodes); + if (SubTskGphBuilderUtil::IsBoxingS2B(in_sbp_parallel, out_sbp_parallel)) { + BuildSubTaskGphS2B(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, + logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphS2B"; - } else if (SubTskGphBuilderUtil::IsBoxingS2S(src_sbp_parallel, dst_sbp_parallel)) { - BuildSubTaskGphS2S(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, - logical_blob_desc, in_nodes, &out_nodes); + } else if (SubTskGphBuilderUtil::IsBoxingS2S(in_sbp_parallel, out_sbp_parallel)) { + BuildSubTaskGphS2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, + logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphS2S"; - } else if (SubTskGphBuilderUtil::IsBoxingP2S(src_sbp_parallel, dst_sbp_parallel)) { - BuildSubTaskGphP2S(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, - logical_blob_desc, in_nodes, &out_nodes); + } else if (SubTskGphBuilderUtil::IsBoxingP2S(in_sbp_parallel, out_sbp_parallel)) { + BuildSubTaskGphP2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, + logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphP2S"; - } else if (SubTskGphBuilderUtil::IsBoxingP2B(src_sbp_parallel, dst_sbp_parallel)) { - if (logical_blob_desc.shape().elem_cnt() < dst_parallel_desc.parallel_num()) { - BuildSubTaskGphP2B(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, - logical_blob_desc, in_nodes, &out_nodes); + } else if (SubTskGphBuilderUtil::IsBoxingP2B(in_sbp_parallel, out_sbp_parallel)) { + if (logical_blob_desc.shape().elem_cnt() < out_parallel_desc.parallel_num()) { + BuildSubTaskGphP2B(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, + logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphP2B"; } else { BlobDesc flat_blob_desc(logical_blob_desc.data_type()); @@ -472,30 +469,26 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build( std::vector<TaskNode*> middle_nodes; SbpParallel middle_sbp; middle_sbp.mutable_split_parallel()->set_axis(0); - BuildSubTaskGphP2S(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, middle_sbp, - flat_blob_desc, in_nodes, &middle_nodes); - BuildSubTaskGphS2B(dst_parallel_desc, dst_parallel_desc, middle_sbp, dst_sbp_parallel, - flat_blob_desc, middle_nodes, &out_nodes); + BuildSubTaskGphP2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, middle_sbp, + flat_blob_desc, sorted_in_tasks, &middle_nodes); + BuildSubTaskGphS2B(out_parallel_desc, out_parallel_desc, middle_sbp, out_sbp_parallel, + flat_blob_desc, middle_nodes, sorted_out_tasks); comment = "BuildSubTaskGphP2S->BuildSubTaskGphS2B"; - for (TaskNode* out_node : out_nodes) { + for (TaskNode* out_node : *sorted_out_tasks) { auto* slice_boxing_node = dynamic_cast<SliceBoxingTaskNode*>(out_node); CHECK_NOTNULL(slice_boxing_node); slice_boxing_node->SetOutShape(logical_blob_desc.shape()); } } - } else if (SubTskGphBuilderUtil::IsBoxingB2S(src_sbp_parallel, dst_sbp_parallel)) { - BuildSubTaskGphB2S(src_parallel_desc, dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, - logical_blob_desc, in_nodes, &out_nodes); + } else if (SubTskGphBuilderUtil::IsBoxingB2S(in_sbp_parallel, out_sbp_parallel)) { + BuildSubTaskGphB2S(in_parallel_desc, out_parallel_desc, in_sbp_parallel, out_sbp_parallel, + logical_blob_desc, sorted_in_tasks, sorted_out_tasks); comment = "BuildSubTaskGphB2S"; } else { UNIMPLEMENTED(); } - ctx->ConnectAll121(out_nodes, sorted_dst_comp_tasks); - return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - "SliceBoxingSubTskGphBuilder", comment)); + return TRY(BuildSubTskGphBuilderStatus("SliceBoxingSubTskGphBuilder", comment)); } } // namespace oneflow diff --git a/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h b/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h index 3ac0a6c2dd4d6b9a3f56eefaa07bfbcaef8ba733..0129f57ef8ab7a59450b5d4405e0bb5be887b3d8 100644 --- a/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h +++ b/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.h @@ -26,14 +26,13 @@ class SliceBoxingSubTskGphBuilder final : public SubTskGphBuilder { SliceBoxingSubTskGphBuilder() = default; ~SliceBoxingSubTskGphBuilder() override = default; - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override; + Maybe<SubTskGphBuilderStatus> Build( + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const override; }; } // namespace oneflow diff --git a/oneflow/core/graph/boxing/sub_task_graph_builder.h b/oneflow/core/graph/boxing/sub_task_graph_builder.h index 56f378a88717e4160900df8de11d49e2c172b0cd..c9c70bcd1fb1aa226b5a4044591d1bd28f7f4ca8 100644 --- a/oneflow/core/graph/boxing/sub_task_graph_builder.h +++ b/oneflow/core/graph/boxing/sub_task_graph_builder.h @@ -29,11 +29,12 @@ class SubTskGphBuilder { virtual ~SubTskGphBuilder() = default; virtual Maybe<SubTskGphBuilderStatus> Build( - SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel) const = 0; + SubTskGphBuilderCtx* ctx, const std::vector<TaskNode*>& sorted_in_tasks, + std::vector<TaskNode*>* sorted_out_tasks, + std::vector<std::vector<TaskNode*>>* sorted_ctrl_tasks, const ParallelDesc& in_parallel_desc, + const ParallelDesc& out_parallel_desc, const LogicalBlobId& lbi, + const BlobDesc& logical_blob_desc, const SbpParallel& in_sbp_parallel, + const SbpParallel& out_sbp_parallel, const Shape& time_shape) const = 0; }; } // namespace oneflow diff --git a/oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp b/oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp index 78b3a0303ba61a1b9beb68eb3a2c63ff526cf609..e0e690888c3cf03c906e91688fede226c9f67602 100644 --- a/oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp +++ b/oneflow/core/graph/boxing/sub_task_graph_builder_context.cpp @@ -68,4 +68,23 @@ TaskNode* SubTskGphBuilderCtx::GetProxyNode(TaskNode* src_node, int64_t src_mem_ } } +TaskNode* SubTskGphBuilderCtx::GetProxyNode(TaskNode* src_node, const int64_t src_mem_zone_id, + const ParallelDesc& dst_parallel_desc, + const int64_t dst_parallel_id) { + const int64_t dst_machine_id = + CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id)); + int64_t dst_mem_zone_id; + const IDMgr* id_mgr = Global<IDMgr>::Get(); + if (dst_parallel_desc.device_type() == DeviceType::kCPU) { + dst_mem_zone_id = id_mgr->CpuMemZoneId(); + } else if (dst_parallel_desc.device_type() == DeviceType::kGPU) { + const int64_t dst_dev_phy_id = + CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id)); + dst_mem_zone_id = id_mgr->GpuMemZoneId(dst_dev_phy_id); + } else { + UNIMPLEMENTED(); + } + return GetProxyNode(src_node, src_mem_zone_id, dst_machine_id, dst_mem_zone_id); +} + } // namespace oneflow diff --git a/oneflow/core/graph/boxing/sub_task_graph_builder_context.h b/oneflow/core/graph/boxing/sub_task_graph_builder_context.h index a49dd5dbca22b60a9163dfd8ea1fe6443e5e2c1d..62056c2a08cd6d083a08d8c73ad75770d964fb9c 100644 --- a/oneflow/core/graph/boxing/sub_task_graph_builder_context.h +++ b/oneflow/core/graph/boxing/sub_task_graph_builder_context.h @@ -34,6 +34,8 @@ class SubTskGphBuilderCtx final { virtual TaskGraph* task_graph(); TaskNode* GetProxyNode(TaskNode* src_node, int64_t src_mem_zone_id, int64_t dst_machine_id, int64_t dst_mem_zone_id); + TaskNode* GetProxyNode(TaskNode* src_node, int64_t src_mem_zone_id, + const ParallelDesc& dst_parallel_desc, const int64_t dst_parallel_id); template<typename T1, typename T2> void ConnectAll121(const std::vector<T1*>& src_nodes, const std::vector<T2*>& dst_nodes) { CHECK_EQ(src_nodes.size(), dst_nodes.size()); diff --git a/oneflow/core/graph/boxing/sub_task_graph_builder_status_util.cpp b/oneflow/core/graph/boxing/sub_task_graph_builder_status_util.cpp index daef939c2b24a414b7d16ea4cc691f4d0080e7b4..497714bc7d550756e47084d8cce820561f8a90ce 100644 --- a/oneflow/core/graph/boxing/sub_task_graph_builder_status_util.cpp +++ b/oneflow/core/graph/boxing/sub_task_graph_builder_status_util.cpp @@ -18,19 +18,28 @@ limitations under the License. namespace oneflow { -Maybe<SubTskGphBuilderStatus> BuildSubTskGphBuilderStatus( - const CompTaskNode* src_node, const CompTaskNode* dst_node, - const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, - const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const std::string& builder_name, - const std::string& comment) { - std::string src_op_name = src_node->logical_node()->op_vec().at(0)->op_name(); - std::string dst_op_name = dst_node->logical_node()->op_vec().at(0)->op_name(); - SubTskGphBuilderStatus status(src_op_name, dst_op_name, src_parallel_desc, dst_parallel_desc, - src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - builder_name, comment); - +Maybe<SubTskGphBuilderStatus> BuildSubTskGphBuilderStatus(const std::string& builder_name, + const std::string& comment) { + SubTskGphBuilderStatus status(builder_name, comment); return status; } +Maybe<SubTskGphBuilderStatus> MakeComposedSubTskGphBuilderStatus( + const std::vector<SubTskGphBuilderStatus>& status_vec) { + std::string builder_name = "ComposedBuilder:"; + std::string comment = "ComposedComment:"; + for (auto status : status_vec) { + builder_name += " "; + builder_name += status.builder_name(); + comment += " "; + if (status.comment().empty()) { + comment += "None"; + } else { + comment += status.comment(); + } + } + SubTskGphBuilderStatus composed_status(builder_name, comment); + return composed_status; +} + } // namespace oneflow diff --git a/oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h b/oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h index 6b8dcc553d8d1bf9997668f7fad081ae9bcec74f..b01bbac332ab3c23e662e0f6e9323e83422e1992 100644 --- a/oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h +++ b/oneflow/core/graph/boxing/sub_task_graph_builder_status_util.h @@ -22,54 +22,23 @@ namespace oneflow { class SubTskGphBuilderStatus; -Maybe<SubTskGphBuilderStatus> BuildSubTskGphBuilderStatus( - const CompTaskNode* src_node, const CompTaskNode* dst_node, - const ParallelDesc& src_parallel_desc, const ParallelDesc& dst_parallel_desc, - const SbpParallel& src_sbp_parallel, const SbpParallel& dst_sbp_parallel, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, const std::string& builder_name, - const std::string& comment); +Maybe<SubTskGphBuilderStatus> BuildSubTskGphBuilderStatus(const std::string& builder_name, + const std::string& comment); + +Maybe<SubTskGphBuilderStatus> MakeComposedSubTskGphBuilderStatus( + const std::vector<SubTskGphBuilderStatus>& status); class SubTskGphBuilderStatus final { public: - SubTskGphBuilderStatus(const std::string& src_op_name, const std::string& dst_op_name, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const SbpParallel& src_sbp_parallel_, const SbpParallel& dst_sbp_parallel, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const std::string& builder_name, const std::string& comment) - : src_op_name_(src_op_name), - dst_op_name_(dst_op_name), - src_parallel_desc_(src_parallel_desc), - dst_parallel_desc_(dst_parallel_desc), - src_sbp_parallel_(src_sbp_parallel_), - dst_sbp_parallel_(dst_sbp_parallel), - lbi_(lbi), - logical_blob_desc_(logical_blob_desc), - builder_name_(builder_name), - comment_(comment){}; + SubTskGphBuilderStatus(const std::string& builder_name, const std::string& comment) + : builder_name_(builder_name), comment_(comment){}; ~SubTskGphBuilderStatus() = default; // Getters - const std::string& src_op_name() const { return src_op_name_; } - const std::string& dst_op_name() const { return dst_op_name_; } - const ParallelDesc& src_parallel_desc() const { return src_parallel_desc_; } - const ParallelDesc& dst_parallel_desc() const { return dst_parallel_desc_; } - const SbpParallel& src_sbp_parallel() const { return src_sbp_parallel_; } - const SbpParallel& dst_sbp_parallel() const { return dst_sbp_parallel_; } - const LogicalBlobId& lbi() const { return lbi_; } - const BlobDesc& logical_blob_desc() const { return logical_blob_desc_; } const std::string& builder_name() const { return builder_name_; } const std::string& comment() const { return comment_; } private: - std::string src_op_name_; - std::string dst_op_name_; - ParallelDesc src_parallel_desc_; - ParallelDesc dst_parallel_desc_; - SbpParallel src_sbp_parallel_; - SbpParallel dst_sbp_parallel_; - LogicalBlobId lbi_; - BlobDesc logical_blob_desc_; std::string builder_name_; std::string comment_; }; diff --git a/oneflow/core/graph/boxing/sub_task_graph_builder_util.cpp b/oneflow/core/graph/boxing/sub_task_graph_builder_util.cpp index 704f9fd132eacf665ab7f6b5bbad44d829a13df0..a285797d89976d894f529fb31e206d5056da47b0 100644 --- a/oneflow/core/graph/boxing/sub_task_graph_builder_util.cpp +++ b/oneflow/core/graph/boxing/sub_task_graph_builder_util.cpp @@ -101,16 +101,17 @@ bool SubTskGphBuilderUtil::IsErrorBoxingNotSupported(const cfg::ErrorProto& erro return error.has_boxing_not_supported_error(); } -int64_t SubTskGphBuilderUtil::GetDistance(const TaskNode* src, const TaskNode* dst) { - if (src->machine_id() != dst->machine_id()) { +int64_t SubTskGphBuilderUtil::GetDistance( + const int64_t src_machine_id, const int64_t src_dev_phy_id, const DeviceType src_device_type, + const int64_t dst_machine_id, const int64_t dst_dev_phy_id, const DeviceType dst_device_type) { + if (src_machine_id != dst_machine_id) { return kDistanceDiffMachine; - } else if (src->device_type() != dst->device_type()) { + } else if (src_device_type != dst_device_type) { return kDistanceSameMachine; - } else if (src->device_type() == DeviceType::kCPU) { + } else if (src_device_type == DeviceType::kCPU) { return kDistanceSameDevice; } else { - if (Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(src->thrd_id()) - == Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(dst->thrd_id())) { + if (src_dev_phy_id == dst_dev_phy_id) { return kDistanceSameDevice; } else { return kDistanceSameMachine; @@ -118,4 +119,53 @@ int64_t SubTskGphBuilderUtil::GetDistance(const TaskNode* src, const TaskNode* d } } +int64_t SubTskGphBuilderUtil::GetDistance(const ParallelDesc& src_parallel_desc, + const int64_t src_parallel_id, + const ParallelDesc& dst_parallel_desc, + const int64_t dst_parallel_id) { + const int64_t src_machine_id = + CHECK_JUST(src_parallel_desc.MachineId4ParallelId(src_parallel_id)); + const int64_t src_dev_phy_id = CHECK_JUST(src_parallel_desc.DeviceId4ParallelId(src_parallel_id)); + const int64_t dst_machine_id = + CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id)); + const int64_t dst_dev_phy_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id)); + return GetDistance(src_machine_id, src_dev_phy_id, src_parallel_desc.device_type(), + dst_machine_id, dst_dev_phy_id, dst_parallel_desc.device_type()); +} + +int64_t SubTskGphBuilderUtil::GetDistance(const TaskNode* src, const TaskNode* dst) { + const auto GetDevPhyId = [](const DeviceType device_type, const int64_t thrd_id) -> int64_t { + if (device_type == DeviceType::kGPU) { + return Global<IDMgr>::Get()->GetGpuPhyIdFromThrdId(thrd_id); + } else if (device_type == DeviceType::kCPU) { + return 0; + } else { + UNIMPLEMENTED(); + } + }; + const DeviceType src_device_type = src->device_type(); + const int64_t src_dev_phy_id = GetDevPhyId(src_device_type, src->thrd_id()); + const DeviceType dst_device_type = dst->device_type(); + const int64_t dst_dev_phy_id = GetDevPhyId(dst_device_type, dst->thrd_id()); + return GetDistance(src->machine_id(), src_dev_phy_id, src_device_type, dst->machine_id(), + dst_dev_phy_id, dst_device_type); +} + +int64_t SubTskGphBuilderUtil::FindNearestSrcParallelId(const ParallelDesc& from_parallel_desc, + const ParallelDesc& to_parallel_desc, + const int64_t to_parallel_id) { + int64_t nearest_from_parallel_idx = -1; + int64_t nearest_distance = SubTskGphBuilderUtil::kDistanceMax; + for (int64_t i = 0; i < from_parallel_desc.parallel_num(); ++i) { + const int64_t distance = + SubTskGphBuilderUtil::GetDistance(from_parallel_desc, i, to_parallel_desc, to_parallel_id); + if (distance < nearest_distance) { + nearest_from_parallel_idx = i; + nearest_distance = distance; + } + } + CHECK_NE(nearest_from_parallel_idx, -1); + return nearest_from_parallel_idx; +} + } // namespace oneflow diff --git a/oneflow/core/graph/boxing/sub_task_graph_builder_util.h b/oneflow/core/graph/boxing/sub_task_graph_builder_util.h index a82dc6effcb731d4bd88fe163d4b29970091a28d..e02caf8a33e085363d34b16f14d9c7a2ee534105 100644 --- a/oneflow/core/graph/boxing/sub_task_graph_builder_util.h +++ b/oneflow/core/graph/boxing/sub_task_graph_builder_util.h @@ -45,7 +45,13 @@ struct SubTskGphBuilderUtil { static bool IsBoxingB2S(const SbpParallel& src, const SbpParallel& dst); static bool BlobHasDynamicShape(const BlobDesc& blob_desc); static bool IsErrorBoxingNotSupported(const cfg::ErrorProto& error); + static int64_t GetDistance(int64_t src_machine_id, int64_t src_dev_phy_id, + DeviceType src_device_type, int64_t dst_machine_id, + int64_t dst_dev_phy_id, DeviceType dst_device_type); + static int64_t GetDistance(const ParallelDesc& src_parallel_desc, int64_t src_parallel_id, + const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id); static int64_t GetDistance(const TaskNode* src, const TaskNode* dst); + template<typename NodeType> static int64_t FindNearestNodeIndex(const std::vector<NodeType*> from_nodes, const NodeType* to_node) { @@ -69,6 +75,10 @@ struct SubTskGphBuilderUtil { const int64_t idx = FindNearestNodeIndex<NodeType>(from_nodes, to_node); return from_nodes.at(idx); } + + static int64_t FindNearestSrcParallelId(const ParallelDesc& from_parallel_desc, + const ParallelDesc& to_parallel_desc, + int64_t to_parallel_id); }; } // namespace oneflow diff --git a/oneflow/core/graph/boxing/to_interface_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/to_interface_sub_task_graph_builder.cpp deleted file mode 100644 index 894185284a65904c4e065d94c33a7de059b9577c..0000000000000000000000000000000000000000 --- a/oneflow/core/graph/boxing/to_interface_sub_task_graph_builder.cpp +++ /dev/null @@ -1,93 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/graph/boxing/to_interface_sub_task_graph_builder.h" -#include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" -#include "oneflow/core/graph/slice_boxing_task_node.h" - -namespace oneflow { - -Maybe<SubTskGphBuilderStatus> ToInterfaceSubTskGphBuilder::Build( - SubTskGphBuilderCtx* ctx, const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, const LogicalBlobId& lbi, - const BlobDesc& logical_blob_desc, const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const { - const LogicalNode* dst_logical_node = sorted_dst_comp_tasks.front()->logical_node(); - if (dst_logical_node->op_vec().size() != 1) { return Error::BoxingNotSupportedError(); } - if (!IsClassRegistered<int32_t, IsInterfaceOpConf4OpTypeCase>( - dst_logical_node->SoleOp()->op_conf().op_type_case())) { - return Error::BoxingNotSupportedError(); - } - if ((src_parallel_desc.parallel_num() == 1 || src_sbp_parallel.has_broadcast_parallel()) - && (dst_parallel_desc.parallel_num() == 1 || dst_sbp_parallel.has_broadcast_parallel())) { - std::vector<CompTaskNode*> nearest_src_comp_tasks; - for (CompTaskNode* dst_node : sorted_dst_comp_tasks) { - CompTaskNode* nearest_src_node = - SubTskGphBuilderUtil::FindNearestNode(sorted_src_comp_tasks, dst_node); - CHECK_NOTNULL(nearest_src_node); - if (SubTskGphBuilderUtil::IsOnSameGPU(nearest_src_node, dst_node)) { - Connect<TaskNode>(nearest_src_node, ctx->task_graph()->NewEdge(), dst_node); - } else { - TaskNode* proxy = - ctx->GetProxyNode(nearest_src_node, nearest_src_node->MemZoneId121(), - dst_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId()); - Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node); - } - } - return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - "ToInterfaceSubTskGphBuilder", "BuildSubTaskGphB2B")); - } else if ((src_parallel_desc.parallel_num() == 1 || src_sbp_parallel.has_broadcast_parallel()) - && (dst_parallel_desc.parallel_num() > 1 || dst_sbp_parallel.has_split_parallel())) { - const TensorSliceView in_slice = - SubTskGphBuilderUtil::GetBroadcastTensorSliceView(logical_blob_desc); - const std::vector<TensorSliceView> out_slices = SubTskGphBuilderUtil::GetTensorSliceView( - dst_parallel_desc.parallel_num(), dst_sbp_parallel, logical_blob_desc); - FOR_RANGE(int64_t, out_id, 0, dst_parallel_desc.parallel_num()) { - const TensorSliceView& out_slice = out_slices.at(out_id); - CompTaskNode* dst_node = sorted_dst_comp_tasks.at(out_id); - const int64_t nearest_idx = - SubTskGphBuilderUtil::FindNearestNodeIndex(sorted_src_comp_tasks, dst_node); - CompTaskNode* src_node = sorted_src_comp_tasks.at(nearest_idx); - SliceBoxingTaskNode* slice_node = ctx->task_graph()->NewNode<SliceBoxingTaskNode>(); - const auto src_machine_id = CHECK_JUST(src_parallel_desc.MachineId4ParallelId(0)); - if (src_parallel_desc.device_type() == DeviceType::kCPU) { - slice_node->Init(lbi, out_slice, kSliceBoxingTaskModeCopy, src_machine_id, - Global<IDMgr>::Get()->PickCpuThrdIdEvenly(src_machine_id)); - } else if (src_parallel_desc.device_type() == DeviceType::kGPU) { - slice_node->Init(lbi, out_slice, kSliceBoxingTaskModeCopy, src_machine_id, - Global<IDMgr>::Get()->GetGpuD2HThrdId(src_node->GpuPhyId()), - Global<IDMgr>::Get()->CpuMemZoneId()); - } else { - UNIMPLEMENTED(); - } - slice_node->ConnectToSrcNodeWithSlice(src_node, ctx->task_graph()->NewEdge(), in_slice); - TaskNode* proxy = - ctx->GetProxyNode(slice_node, slice_node->MemZoneId121(), dst_node->machine_id(), - Global<IDMgr>::Get()->CpuMemZoneId()); - Connect<TaskNode>(proxy, ctx->task_graph()->NewEdge(), dst_node); - } - return TRY(BuildSubTskGphBuilderStatus( - sorted_src_comp_tasks.front(), sorted_dst_comp_tasks.front(), src_parallel_desc, - dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, lbi, logical_blob_desc, - "ToInterfaceSubTskGphBuilder", "BuildSubTaskGphB2S")); - } else { - return Error::BoxingNotSupportedError(); - } -} - -} // namespace oneflow diff --git a/oneflow/core/graph/boxing/to_interface_sub_task_graph_builder.h b/oneflow/core/graph/boxing/to_interface_sub_task_graph_builder.h deleted file mode 100644 index 2d27b20c7955483f68a0a818631139f372646d0c..0000000000000000000000000000000000000000 --- a/oneflow/core/graph/boxing/to_interface_sub_task_graph_builder.h +++ /dev/null @@ -1,41 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_GRAPH_BOXING_TO_INTERFACE_SUB_TASK_GRAPH_BUILDER_H_ -#define ONEFLOW_CORE_GRAPH_BOXING_TO_INTERFACE_SUB_TASK_GRAPH_BUILDER_H_ - -#include "oneflow/core/graph/boxing/sub_task_graph_builder.h" - -namespace oneflow { - -class ToInterfaceSubTskGphBuilder final : public SubTskGphBuilder { - public: - OF_DISALLOW_COPY_AND_MOVE(ToInterfaceSubTskGphBuilder); - ToInterfaceSubTskGphBuilder() = default; - ~ToInterfaceSubTskGphBuilder() override = default; - - Maybe<SubTskGphBuilderStatus> Build(SubTskGphBuilderCtx* ctx, - const std::vector<CompTaskNode*>& sorted_src_comp_tasks, - const std::vector<CompTaskNode*>& sorted_dst_comp_tasks, - const ParallelDesc& src_parallel_desc, - const ParallelDesc& dst_parallel_desc, - const LogicalBlobId& lbi, const BlobDesc& logical_blob_desc, - const SbpParallel& src_sbp_parallel, - const SbpParallel& dst_sbp_parallel) const override; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_GRAPH_BOXING_TO_INTERFACE_SUB_TASK_GRAPH_BUILDER_H_ diff --git a/oneflow/core/graph/boxing_identity_compute_task_node.cpp b/oneflow/core/graph/boxing_identity_task_node.cpp similarity index 70% rename from oneflow/core/graph/boxing_identity_compute_task_node.cpp rename to oneflow/core/graph/boxing_identity_task_node.cpp index 035700f09d395ad97abd3fd497d1892f321fd3c6..6a66c234df567cf8cf8ac40a4b00a0b5d0b14b90 100644 --- a/oneflow/core/graph/boxing_identity_compute_task_node.cpp +++ b/oneflow/core/graph/boxing_identity_task_node.cpp @@ -14,31 +14,29 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" -#include "oneflow/core/graph/boxing_identity_compute_task_node.h" -#include "oneflow/core/graph/logical_node.h" +#include "oneflow/core/graph/boxing_identity_task_node.h" namespace oneflow { -void BoxingIdentityCompTaskNode::Init(const CompTaskNode* src_node, const LogicalBlobId& lbi) { +void BoxingIdentityTaskNode::Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, + const LogicalBlobId& lbi) { lbi_ = lbi; - set_logical_node(src_node->logical_node()); - *mut_parallel_ctx() = *src_node->parallel_ctx(); - set_machine_id(src_node->machine_id()); - set_thrd_id(src_node->thrd_id()); - set_area_id(src_node->area_id()); + set_machine_id(machine_id); + set_thrd_id(thrd_id); + set_area_id(area_id); } -void BoxingIdentityCompTaskNode::ProduceAllRegstsAndBindEdges() { +void BoxingIdentityTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr<RegstDesc> out_regst = ProduceRegst("out", true, 1, 1); this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); } -void BoxingIdentityCompTaskNode::ConsumeAllRegsts() { +void BoxingIdentityTaskNode::ConsumeAllRegsts() { this->ForEachInDataEdge( [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); } -void BoxingIdentityCompTaskNode::BuildExecGphAndRegst() { +void BoxingIdentityTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); OperatorConf op_conf; op_conf.set_name("System-Boxing-Identity-" + NewUniqueId()); @@ -50,10 +48,10 @@ void BoxingIdentityCompTaskNode::BuildExecGphAndRegst() { std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); - node->InferBlobDescs(parallel_ctx()); + node->InferBlobDescs(nullptr); } -void BoxingIdentityCompTaskNode::InferProducedDataRegstTimeShape() { +void BoxingIdentityTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } diff --git a/oneflow/core/graph/boxing_identity_compute_task_node.h b/oneflow/core/graph/boxing_identity_task_node.h similarity index 64% rename from oneflow/core/graph/boxing_identity_compute_task_node.h rename to oneflow/core/graph/boxing_identity_task_node.h index edd756a788d6edad8b3130bc57d1b25f4614ce37..a8fcf502396e48a2dc3bee5672a4cad84062e079 100644 --- a/oneflow/core/graph/boxing_identity_compute_task_node.h +++ b/oneflow/core/graph/boxing_identity_task_node.h @@ -13,20 +13,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_COMPUTE_TASK_NODE_H_ -#define ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_COMPUTE_TASK_NODE_H_ - -#include "oneflow/core/graph/compute_task_node.h" +#ifndef ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_ +#include "oneflow/core/graph/task_node.h" namespace oneflow { -class BoxingIdentityCompTaskNode : public CompTaskNode { +class BoxingIdentityTaskNode : public TaskNode { public: - OF_DISALLOW_COPY_AND_MOVE(BoxingIdentityCompTaskNode); - BoxingIdentityCompTaskNode() = default; - ~BoxingIdentityCompTaskNode() override = default; + OF_DISALLOW_COPY_AND_MOVE(BoxingIdentityTaskNode); + BoxingIdentityTaskNode() = default; + ~BoxingIdentityTaskNode() override = default; - void Init(const CompTaskNode* src_node, const LogicalBlobId& lbi); + void Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, const LogicalBlobId& lbi); TaskType GetTaskType() const override { return TaskType::kBoxingIdentity; } private: @@ -40,4 +39,4 @@ class BoxingIdentityCompTaskNode : public CompTaskNode { } // namespace oneflow -#endif // ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_COMPUTE_TASK_NODE_H_ +#endif // ONEFLOW_CORE_GRAPH_BOXING_IDENTITY_TASK_NODE_H_ diff --git a/oneflow/core/graph/boxing_s2s_all2all_pack_compute_task_node.h b/oneflow/core/graph/boxing_s2s_all2all_pack_compute_task_node.h deleted file mode 100644 index fb77a2425801bf710589bbd0a6bce7cb1333aad0..0000000000000000000000000000000000000000 --- a/oneflow/core/graph/boxing_s2s_all2all_pack_compute_task_node.h +++ /dev/null @@ -1,43 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_GRAPH_BOXING_S2S_ALL2ALL_PACK_COMPUTE_TASK_NODE_H_ -#define ONEFLOW_CORE_GRAPH_BOXING_S2S_ALL2ALL_PACK_COMPUTE_TASK_NODE_H_ - -#include "oneflow/core/graph/compute_task_node.h" - -namespace oneflow { - -class BoxingS2SAll2AllPackCompTaskNode : public CompTaskNode { - public: - OF_DISALLOW_COPY_AND_MOVE(BoxingS2SAll2AllPackCompTaskNode); - BoxingS2SAll2AllPackCompTaskNode() = default; - ~BoxingS2SAll2AllPackCompTaskNode() override = default; - - void Init(const CompTaskNode* src_node, const LogicalBlobId& lbi, const int64_t dst_split_axis); - TaskType GetTaskType() const override { return TaskType::kBoxingS2SAll2AllPack; } - - private: - void BuildExecGphAndRegst() override; - void ProduceAllRegstsAndBindEdges() override; - void ConsumeAllRegsts() final; - void InferProducedDataRegstTimeShape() final; - int64_t dst_split_axis_; - LogicalBlobId lbi_; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_GRAPH_BOXING_S2S_ALL2ALL_PACK_COMPUTE_TASK_NODE_H_ diff --git a/oneflow/core/graph/boxing_s2s_all2all_pack_compute_task_node.cpp b/oneflow/core/graph/collective_boxing_pack_task_node.cpp similarity index 50% rename from oneflow/core/graph/boxing_s2s_all2all_pack_compute_task_node.cpp rename to oneflow/core/graph/collective_boxing_pack_task_node.cpp index a0a11e97f8a10e7396b0bda7f1090949355d45da..0acfc972628a4c55be01583a7d35e32ad6929efc 100644 --- a/oneflow/core/graph/boxing_s2s_all2all_pack_compute_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_pack_task_node.cpp @@ -14,50 +14,56 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" -#include "oneflow/core/graph/boxing_s2s_all2all_pack_compute_task_node.h" -#include "oneflow/core/graph/logical_node.h" +#include "oneflow/core/graph/collective_boxing_pack_task_node.h" namespace oneflow { -void BoxingS2SAll2AllPackCompTaskNode::Init(const CompTaskNode* src_node, const LogicalBlobId& lbi, - const int64_t dst_split_axis) { +void CollectiveBoxingPackTaskNode::Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, + const LogicalBlobId& lbi, const Shape& logical_shape, + const SbpParallel& src_sbp_parallel, + const SbpParallel& dst_sbp_parallel, + const int64_t parallel_num) { lbi_ = lbi; - set_logical_node(src_node->logical_node()); - *mut_parallel_ctx() = *src_node->parallel_ctx(); - set_machine_id(src_node->machine_id()); - set_thrd_id(src_node->thrd_id()); - set_area_id(src_node->area_id()); - dst_split_axis_ = dst_split_axis; + set_machine_id(machine_id); + set_thrd_id(thrd_id); + set_area_id(area_id); + logical_shape_ = logical_shape; + parallel_num_ = parallel_num; + src_sbp_parallel_ = src_sbp_parallel; + dst_sbp_parallel_ = dst_sbp_parallel; } -void BoxingS2SAll2AllPackCompTaskNode::ProduceAllRegstsAndBindEdges() { +void CollectiveBoxingPackTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr<RegstDesc> out_regst = ProduceRegst("out", true, 1, 1); this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); } -void BoxingS2SAll2AllPackCompTaskNode::ConsumeAllRegsts() { +void CollectiveBoxingPackTaskNode::ConsumeAllRegsts() { this->ForEachInDataEdge( [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); } -void BoxingS2SAll2AllPackCompTaskNode::BuildExecGphAndRegst() { +void CollectiveBoxingPackTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); OperatorConf op_conf; - op_conf.set_name("System-Boxing-S2S-All2All-Pack-" + NewUniqueId()); + op_conf.set_name("System-Collective-Boxing-Pack-" + NewUniqueId()); op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); - *op_conf.mutable_boxing_s2s_all2all_pack_conf()->mutable_lbi() = lbi_; - op_conf.mutable_boxing_s2s_all2all_pack_conf()->set_dst_split_axis(dst_split_axis_); - op_conf.mutable_boxing_s2s_all2all_pack_conf()->set_num_ranks(parallel_ctx()->parallel_num()); + auto* collective_boxing_pack_conf = op_conf.mutable_collective_boxing_pack_conf(); + *collective_boxing_pack_conf->mutable_lbi() = lbi_; + logical_shape_.ToProto(collective_boxing_pack_conf->mutable_logical_shape()); + *collective_boxing_pack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_; + *collective_boxing_pack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_; + collective_boxing_pack_conf->set_num_ranks(parallel_num_); std::shared_ptr<Operator> sole_op = ConstructOp(op_conf, &GlobalJobDesc()); node->mut_op() = sole_op; node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); - node->InferBlobDescs(parallel_ctx()); + node->InferBlobDescs(nullptr); } -void BoxingS2SAll2AllPackCompTaskNode::InferProducedDataRegstTimeShape() { +void CollectiveBoxingPackTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } diff --git a/oneflow/core/graph/boxing_s2s_all2all_unpack_compute_task_node.h b/oneflow/core/graph/collective_boxing_pack_task_node.h similarity index 50% rename from oneflow/core/graph/boxing_s2s_all2all_unpack_compute_task_node.h rename to oneflow/core/graph/collective_boxing_pack_task_node.h index 03805142d3d0ce656e5cb752598030110472e81b..992d4cd68a8d41789b5bed0c53f9e72bd5c903d1 100644 --- a/oneflow/core/graph/boxing_s2s_all2all_unpack_compute_task_node.h +++ b/oneflow/core/graph/collective_boxing_pack_task_node.h @@ -13,23 +13,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_GRAPH_BOXING_S2S_ALL2ALL_UNPACK_COMPUTE_TASK_NODE_H_ -#define ONEFLOW_CORE_GRAPH_BOXING_S2S_ALL2ALL_UNPACK_COMPUTE_TASK_NODE_H_ - -#include "oneflow/core/graph/compute_task_node.h" +#ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_ +#include "oneflow/core/graph/task_node.h" namespace oneflow { -class BoxingS2SAll2AllUnpackCompTaskNode : public CompTaskNode { +class CollectiveBoxingPackTaskNode : public TaskNode { public: - OF_DISALLOW_COPY_AND_MOVE(BoxingS2SAll2AllUnpackCompTaskNode); - BoxingS2SAll2AllUnpackCompTaskNode() = default; - ~BoxingS2SAll2AllUnpackCompTaskNode() override = default; - - void Init(const CompTaskNode* src_node, const LogicalBlobId& lbi, const Shape& logical_shape, - const int64_t src_split_axis, const int64_t dst_split_axis); + OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackTaskNode); + CollectiveBoxingPackTaskNode() = default; + ~CollectiveBoxingPackTaskNode() override = default; - TaskType GetTaskType() const override { return TaskType::kBoxingS2SAll2AllUnpack; } + void Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, const LogicalBlobId& lbi, + const Shape& logical_shape, const SbpParallel& src_sbp_parallel, + const SbpParallel& dst_sbp_parallel, const int64_t parallel_num); + TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingPack; } private: void BuildExecGphAndRegst() override; @@ -39,10 +38,11 @@ class BoxingS2SAll2AllUnpackCompTaskNode : public CompTaskNode { LogicalBlobId lbi_; Shape logical_shape_; - int64_t src_split_axis_; - int64_t dst_split_axis_; + SbpParallel src_sbp_parallel_; + SbpParallel dst_sbp_parallel_; + int64_t parallel_num_; }; } // namespace oneflow -#endif // ONEFLOW_CORE_GRAPH_BOXING_S2S_ALL2ALL_UNPACK_COMPUTE_TASK_NODE_H_ +#endif // ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_PACK_TASK_NODE_H_ diff --git a/oneflow/core/graph/boxing_s2s_all2all_unpack_compute_task_node.cpp b/oneflow/core/graph/collective_boxing_unpack_task_node.cpp similarity index 50% rename from oneflow/core/graph/boxing_s2s_all2all_unpack_compute_task_node.cpp rename to oneflow/core/graph/collective_boxing_unpack_task_node.cpp index 3ffc12fe4acd4078ace489b55f9073d78530e5b7..12bdfe41506bdfbf2b67ecfdd8975935f33efcdb 100644 --- a/oneflow/core/graph/boxing_s2s_all2all_unpack_compute_task_node.cpp +++ b/oneflow/core/graph/collective_boxing_unpack_task_node.cpp @@ -14,57 +14,56 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/to_string.h" -#include "oneflow/core/graph/boxing_s2s_all2all_unpack_compute_task_node.h" -#include "oneflow/core/graph/logical_node.h" +#include "oneflow/core/graph/collective_boxing_unpack_task_node.h" namespace oneflow { -void BoxingS2SAll2AllUnpackCompTaskNode::Init(const CompTaskNode* src_node, - const LogicalBlobId& lbi, const Shape& logical_shape, - const int64_t src_split_axis, - const int64_t dst_split_axis) { +void CollectiveBoxingUnpackTaskNode::Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, + const LogicalBlobId& lbi, const Shape& logical_shape, + const SbpParallel& src_sbp_parallel, + const SbpParallel& dst_sbp_parallel, + const int64_t parallel_num) { lbi_ = lbi; - set_logical_node(src_node->logical_node()); - *mut_parallel_ctx() = *src_node->parallel_ctx(); - set_machine_id(src_node->machine_id()); - set_thrd_id(src_node->thrd_id()); - set_area_id(src_node->area_id()); + set_machine_id(machine_id); + set_thrd_id(thrd_id); + set_area_id(area_id); logical_shape_ = logical_shape; - src_split_axis_ = src_split_axis; - dst_split_axis_ = dst_split_axis; + parallel_num_ = parallel_num; + src_sbp_parallel_ = src_sbp_parallel; + dst_sbp_parallel_ = dst_sbp_parallel; } -void BoxingS2SAll2AllUnpackCompTaskNode::ProduceAllRegstsAndBindEdges() { +void CollectiveBoxingUnpackTaskNode::ProduceAllRegstsAndBindEdges() { std::shared_ptr<RegstDesc> out_regst = ProduceRegst("out", true, 1, 1); this->ForEachOutDataEdge([&](TaskEdge* out_dege) { out_dege->AddRegst("out", out_regst); }); } -void BoxingS2SAll2AllUnpackCompTaskNode::ConsumeAllRegsts() { +void CollectiveBoxingUnpackTaskNode::ConsumeAllRegsts() { this->ForEachInDataEdge( [&](TaskEdge* in_edge) { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }); } -void BoxingS2SAll2AllUnpackCompTaskNode::BuildExecGphAndRegst() { +void CollectiveBoxingUnpackTaskNode::BuildExecGphAndRegst() { ExecNode* node = mut_exec_gph().NewNode(); OperatorConf op_conf; - op_conf.set_name("System-Boxing-S2S-All2All-Unpack-" + NewUniqueId()); + op_conf.set_name("System-Collective-Boxing-Unpack-" + NewUniqueId()); op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(this->device_type()))); - *op_conf.mutable_boxing_s2s_all2all_unpack_conf()->mutable_lbi() = lbi_; - logical_shape_.ToProto(op_conf.mutable_boxing_s2s_all2all_unpack_conf()->mutable_logical_shape()); - op_conf.mutable_boxing_s2s_all2all_unpack_conf()->set_src_split_axis(src_split_axis_); - op_conf.mutable_boxing_s2s_all2all_unpack_conf()->set_dst_split_axis(dst_split_axis_); - op_conf.mutable_boxing_s2s_all2all_unpack_conf()->set_num_ranks(parallel_ctx()->parallel_num()); - + auto* collective_boxing_unpack_conf = op_conf.mutable_collective_boxing_unpack_conf(); + *collective_boxing_unpack_conf->mutable_lbi() = lbi_; + logical_shape_.ToProto(collective_boxing_unpack_conf->mutable_logical_shape()); + *collective_boxing_unpack_conf->mutable_src_sbp_parallel() = src_sbp_parallel_; + *collective_boxing_unpack_conf->mutable_dst_sbp_parallel() = dst_sbp_parallel_; + collective_boxing_unpack_conf->set_num_ranks(parallel_num_); std::shared_ptr<Operator> sole_op = ConstructOp(op_conf, &GlobalJobDesc()); node->mut_op() = sole_op; node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); - node->InferBlobDescs(parallel_ctx()); + node->InferBlobDescs(nullptr); } -void BoxingS2SAll2AllUnpackCompTaskNode::InferProducedDataRegstTimeShape() { +void CollectiveBoxingUnpackTaskNode::InferProducedDataRegstTimeShape() { NaiveInferProducedDataRegstTimeShape(); } diff --git a/oneflow/core/graph/collective_boxing_unpack_task_node.h b/oneflow/core/graph/collective_boxing_unpack_task_node.h new file mode 100644 index 0000000000000000000000000000000000000000..8e2be08373aae18207b1b941d37be79af4a69b64 --- /dev/null +++ b/oneflow/core/graph/collective_boxing_unpack_task_node.h @@ -0,0 +1,50 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_ +#define ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_ + +#include "oneflow/core/graph/task_node.h" + +namespace oneflow { + +class CollectiveBoxingUnpackTaskNode : public TaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackTaskNode); + CollectiveBoxingUnpackTaskNode() = default; + ~CollectiveBoxingUnpackTaskNode() override = default; + + void Init(int64_t machine_id, int64_t thrd_id, int64_t area_id, const LogicalBlobId& lbi, + const Shape& logical_shape, const SbpParallel& src_sbp_parallel, + const SbpParallel& dst_sbp_parallel, const int64_t parallel_num); + + TaskType GetTaskType() const override { return TaskType::kCollectiveBoxingUnpack; } + + private: + void BuildExecGphAndRegst() override; + void ProduceAllRegstsAndBindEdges() override; + void ConsumeAllRegsts() final; + void InferProducedDataRegstTimeShape() final; + + LogicalBlobId lbi_; + Shape logical_shape_; + SbpParallel src_sbp_parallel_; + SbpParallel dst_sbp_parallel_; + int64_t parallel_num_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_GRAPH_COLLECTIVE_BOXING_UNPACK_TASK_NODE_H_ diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 436204b5d6158619b1f296b8f9218105a91f831c..0b55bb23bd40036ae1546d1ef1624f56afb776bf 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -33,9 +33,8 @@ limitations under the License. #include "oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/b21_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/one_to_one_sub_task_graph_builder.h" -#include "oneflow/core/graph/boxing/to_interface_sub_task_graph_builder.h" #include "oneflow/core/graph/boxing/sub_task_graph_builder_util.h" -#include "oneflow/core/graph/boxing_identity_compute_task_node.h" +#include "oneflow/core/graph/boxing_identity_task_node.h" namespace oneflow { @@ -219,7 +218,6 @@ TaskGraph::TaskGraph(std::unique_ptr<const LogicalGraph>&& logical_gph) { sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); std::vector<std::shared_ptr<SubTskGphBuilder>> builders; - builders.emplace_back(new ToInterfaceSubTskGphBuilder()); builders.emplace_back(new OneToOneSubTskGphBuilder()); builders.emplace_back(new B21SubTskGphBuilder()); builders.emplace_back(new CollectiveBoxingSubTskGphBuilder()); @@ -531,17 +529,20 @@ void TaskGraph::SetAreaIdForNewNodes(const LogicalNode* src_logical, DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { const std::vector<LogicalBlobId> lbis = src_logical->GetLbisTo(dst_logical); for (const LogicalBlobId& lbi : lbis) { - std::vector<CompTaskNode*> src_nodes; + std::vector<TaskNode*> in_nodes; if (lbis.size() == 1) { - src_nodes = sorted_src_comp_tasks; + in_nodes.assign(sorted_src_comp_tasks.begin(), sorted_src_comp_tasks.end()); } else { for (CompTaskNode* src_node : sorted_src_comp_tasks) { - auto* identity_node = NewNode<BoxingIdentityCompTaskNode>(); - identity_node->Init(src_node, lbi); + auto* identity_node = NewNode<BoxingIdentityTaskNode>(); + identity_node->Init(src_node->machine_id(), src_node->thrd_id(), src_node->area_id(), lbi); Connect<TaskNode>(src_node, NewEdge(), identity_node); - src_nodes.push_back(identity_node); + in_nodes.push_back(identity_node); } } + std::vector<TaskNode*> out_nodes; + out_nodes.reserve(sorted_dst_comp_tasks.size()); + std::vector<std::vector<TaskNode*>> sorted_ctrl_tasks; const SbpParallel& src_sbp_parallel = Global<OpGraph>::Get()->GetSbpParallel(src_logical->SoleOp()->op_name(), lbi); const SbpParallel& dst_sbp_parallel = @@ -550,9 +551,22 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { const std::shared_ptr<const ParallelDesc>& dst_parallel_desc = dst_logical->parallel_desc(); const BlobDesc& blob_desc = Global<OpGraph>::Get()->GetLogicalBlobDesc(lbi); auto status = CHECK_JUST(sub_tsk_gph_builder_->Build( - sub_tsk_gph_builder_ctx_.get(), src_nodes, sorted_dst_comp_tasks, *src_parallel_desc, - *dst_parallel_desc, lbi, blob_desc, src_sbp_parallel, dst_sbp_parallel)); - boxing_logger_->Log(*status); + sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, + *src_parallel_desc, *dst_parallel_desc, lbi, blob_desc, src_sbp_parallel, dst_sbp_parallel, + *src_logical->out_blob_time_shape())); + boxing_logger_->Log(*status, src_logical->SoleOp()->op_name(), dst_logical->SoleOp()->op_name(), + *src_parallel_desc, *dst_parallel_desc, src_sbp_parallel, dst_sbp_parallel, + lbi, blob_desc); + sub_tsk_gph_builder_ctx_->ConnectAll121(out_nodes, sorted_dst_comp_tasks); + if (!sorted_ctrl_tasks.empty()) { + CHECK_EQ(sorted_ctrl_tasks.size(), sorted_dst_comp_tasks.size()); + FOR_RANGE(size_t, i, 0, sorted_dst_comp_tasks.size()) { + for (TaskNode* ctrl_node : sorted_ctrl_tasks.at(i)) { + Connect<TaskNode>(ctrl_node, NewEdge(), sorted_dst_comp_tasks.at(i)); + ctrl_node->BuildCtrlRegstDesc(sorted_dst_comp_tasks.at(i)); + } + } + } } } diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index cacfd89ecb0b2cfd79ee937dd88bf5b54284ea5f..f1f42da38976f1b3ac2e355e73edc090746f70ef 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -35,8 +35,8 @@ enum TaskType { kCollectiveBoxingGeneric = 58; kBoxingIdentity = 59; kDecodeH2D = 60; - kBoxingS2SAll2AllPack = 61; - kBoxingS2SAll2AllUnpack = 62; + kCollectiveBoxingPack = 61; + kCollectiveBoxingUnpack = 62; kSspVariableProxy = 63; kBoxingZeros = 64; }; diff --git a/oneflow/core/kernel/boxing_s2s_all2all_pack_kernel.cpp b/oneflow/core/kernel/collective_boxing_pack_kernel.cpp similarity index 55% rename from oneflow/core/kernel/boxing_s2s_all2all_pack_kernel.cpp rename to oneflow/core/kernel/collective_boxing_pack_kernel.cpp index 8cc69ebe7e948db6258b015dc4acdead910c0e1c..c2b823de330d676093070c10fa85e202f7a5f3b9 100644 --- a/oneflow/core/kernel/boxing_s2s_all2all_pack_kernel.cpp +++ b/oneflow/core/kernel/collective_boxing_pack_kernel.cpp @@ -19,11 +19,11 @@ limitations under the License. namespace oneflow { template<DeviceType device_type, typename T> -class BoxingS2SAll2AllPackKernel final : public KernelIf<device_type> { +class CollectiveBoxingPackKernel final : public KernelIf<device_type> { public: - OF_DISALLOW_COPY_AND_MOVE(BoxingS2SAll2AllPackKernel); - BoxingS2SAll2AllPackKernel() = default; - ~BoxingS2SAll2AllPackKernel() override = default; + OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackKernel); + CollectiveBoxingPackKernel() = default; + ~CollectiveBoxingPackKernel() override = default; private: bool IsStateless() const override { return false; } @@ -32,28 +32,28 @@ class BoxingS2SAll2AllPackKernel final : public KernelIf<device_type> { }; template<DeviceType device_type, typename T> -void BoxingS2SAll2AllPackKernel<device_type, T>::ForwardDataContent( +void CollectiveBoxingPackKernel<device_type, T>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* in = BnInOp2Blob("in"); Blob* out = BnInOp2Blob("out"); - const BoxingS2SAll2AllPackOpConf& pack_conf = this->op_conf().boxing_s2s_all2all_pack_conf(); - const int64_t dst_split_axis = pack_conf.dst_split_axis(); + const CollectiveBoxingPackOpConf& pack_conf = this->op_conf().collective_boxing_pack_conf(); const int64_t num_ranks = pack_conf.num_ranks(); - const bool need_transpose = (dst_split_axis != 0); + const Shape logical_shape(pack_conf.logical_shape()); + const bool need_transpose = !((pack_conf.dst_sbp_parallel().has_split_parallel() + && pack_conf.dst_sbp_parallel().split_parallel().axis() == 0) + || pack_conf.dst_sbp_parallel().has_broadcast_parallel() + || pack_conf.dst_sbp_parallel().has_partial_sum_parallel()); if (need_transpose) { - DimVector transpose_in_dim_vec; - const ShapeView& in_shape = in->shape(); - FOR_RANGE(int64_t, i, 0, in_shape.NumAxes()) { - if (i == dst_split_axis) { - transpose_in_dim_vec.push_back(num_ranks); - CHECK_EQ(in_shape.At(i) % num_ranks, 0); - transpose_in_dim_vec.push_back(in_shape.At(i) / num_ranks); - } else { - transpose_in_dim_vec.push_back(in_shape.At(i)); - } + const int64_t dst_split_axis = pack_conf.dst_sbp_parallel().split_parallel().axis(); + DimVector transpose_in_dim_vec = logical_shape.dim_vec(); + if (pack_conf.src_sbp_parallel().has_split_parallel()) { + const int64_t src_split_axis = pack_conf.src_sbp_parallel().split_parallel().axis(); + transpose_in_dim_vec[src_split_axis] = transpose_in_dim_vec.at(src_split_axis) / num_ranks; } + CHECK_EQ(transpose_in_dim_vec.at(dst_split_axis) % num_ranks, 0); + transpose_in_dim_vec[dst_split_axis] = transpose_in_dim_vec.at(dst_split_axis) / num_ranks; + transpose_in_dim_vec.insert(transpose_in_dim_vec.begin() + dst_split_axis, num_ranks); const Shape transpose_in_shape(transpose_in_dim_vec); - DimVector transpose_out_dim_vec; std::vector<int32_t> perm; perm.push_back(dst_split_axis); @@ -65,26 +65,24 @@ void BoxingS2SAll2AllPackKernel<device_type, T>::ForwardDataContent( } } const Shape transpose_out_shape(transpose_out_dim_vec); - NewKernelUtil<device_type>::Transpose( ctx.device_ctx, transpose_in_shape.NumAxes(), transpose_in_shape, transpose_out_shape, perm, transpose_in_shape.elem_cnt(), in->dptr<T>(), out->mut_dptr<T>()); } else { - CHECK_EQ(dst_split_axis, 0); out->CopyDataContentFrom(ctx.device_ctx, in); } } -#define REGISTER_BOXING_S2S_ALL2ALL_PACK_KERNEL(device_type_v, dtype_pair) \ +#define REGISTER_COLLECTIVE_BOXING_PACK_KERNEL(device_type_v, dtype_pair) \ REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE( \ - OperatorConf::kBoxingS2SAll2AllPackConf, device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \ - BoxingS2SAll2AllPackKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>) + OperatorConf::kCollectiveBoxingPackConf, device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \ + CollectiveBoxingPackKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>) -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BOXING_S2S_ALL2ALL_PACK_KERNEL, DEVICE_TYPE_SEQ, +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COLLECTIVE_BOXING_PACK_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) #if defined(WITH_CUDA) -REGISTER_BOXING_S2S_ALL2ALL_PACK_KERNEL(DeviceType::kGPU, (float16, DataType::kFloat16)) +REGISTER_COLLECTIVE_BOXING_PACK_KERNEL(DeviceType::kGPU, (float16, DataType::kFloat16)) #endif } // namespace oneflow diff --git a/oneflow/core/kernel/boxing_s2s_all2all_unpack_kernel.cpp b/oneflow/core/kernel/collective_boxing_unpack_kernel.cpp similarity index 62% rename from oneflow/core/kernel/boxing_s2s_all2all_unpack_kernel.cpp rename to oneflow/core/kernel/collective_boxing_unpack_kernel.cpp index dc3b54e6bd88c4da1e53f505dc98bab8a37c4697..006bdb7048d3d249ef36c5dc0038239a63bb35e9 100644 --- a/oneflow/core/kernel/boxing_s2s_all2all_unpack_kernel.cpp +++ b/oneflow/core/kernel/collective_boxing_unpack_kernel.cpp @@ -19,11 +19,11 @@ limitations under the License. namespace oneflow { template<DeviceType device_type, typename T> -class BoxingS2SAll2AllUnpackKernel final : public KernelIf<device_type> { +class CollectiveBoxingUnpackKernel final : public KernelIf<device_type> { public: - OF_DISALLOW_COPY_AND_MOVE(BoxingS2SAll2AllUnpackKernel); - BoxingS2SAll2AllUnpackKernel() = default; - ~BoxingS2SAll2AllUnpackKernel() override = default; + OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackKernel); + CollectiveBoxingUnpackKernel() = default; + ~CollectiveBoxingUnpackKernel() override = default; private: bool IsStateless() const override { return false; } @@ -32,26 +32,29 @@ class BoxingS2SAll2AllUnpackKernel final : public KernelIf<device_type> { }; template<DeviceType device_type, typename T> -void BoxingS2SAll2AllUnpackKernel<device_type, T>::ForwardDataContent( +void CollectiveBoxingUnpackKernel<device_type, T>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* in = BnInOp2Blob("in"); Blob* out = BnInOp2Blob("out"); - const BoxingS2SAll2AllUnpackOpConf& unpack_conf = - this->op_conf().boxing_s2s_all2all_unpack_conf(); - const int64_t src_split_axis = unpack_conf.src_split_axis(); - const int64_t dst_split_axis = unpack_conf.dst_split_axis(); + const CollectiveBoxingUnpackOpConf& unpack_conf = this->op_conf().collective_boxing_unpack_conf(); const int64_t num_ranks = unpack_conf.num_ranks(); const Shape logical_shape(unpack_conf.logical_shape()); - const bool need_transpose = (src_split_axis != 0); + const bool need_transpose = !((unpack_conf.src_sbp_parallel().has_split_parallel() + && unpack_conf.src_sbp_parallel().split_parallel().axis() == 0) + || unpack_conf.src_sbp_parallel().has_broadcast_parallel() + || unpack_conf.src_sbp_parallel().has_partial_sum_parallel()); if (need_transpose) { + const int64_t src_split_axis = unpack_conf.src_sbp_parallel().split_parallel().axis(); DimVector transpose_in_dim_vec = logical_shape.dim_vec(); CHECK_EQ(transpose_in_dim_vec.at(src_split_axis) % num_ranks, 0); - CHECK_EQ(transpose_in_dim_vec.at(dst_split_axis) % num_ranks, 0); transpose_in_dim_vec[src_split_axis] = transpose_in_dim_vec.at(src_split_axis) / num_ranks; - transpose_in_dim_vec[dst_split_axis] = transpose_in_dim_vec.at(dst_split_axis) / num_ranks; + if (unpack_conf.dst_sbp_parallel().has_split_parallel()) { + const int64_t dst_split_axis = unpack_conf.dst_sbp_parallel().split_parallel().axis(); + CHECK_EQ(transpose_in_dim_vec.at(dst_split_axis) % num_ranks, 0); + transpose_in_dim_vec[dst_split_axis] = transpose_in_dim_vec.at(dst_split_axis) / num_ranks; + } transpose_in_dim_vec.insert(transpose_in_dim_vec.begin(), num_ranks); const Shape transpose_in_shape(transpose_in_dim_vec); - DimVector transpose_out_dim_vec; std::vector<int32_t> perm; FOR_RANGE(int64_t, i, 1, transpose_in_shape.NumAxes()) { @@ -66,21 +69,20 @@ void BoxingS2SAll2AllUnpackKernel<device_type, T>::ForwardDataContent( ctx.device_ctx, transpose_in_shape.NumAxes(), transpose_in_shape, transpose_out_shape, perm, transpose_in_shape.elem_cnt(), in->dptr<T>(), out->mut_dptr<T>()); } else { - CHECK_EQ(src_split_axis, 0); out->CopyDataContentFrom(ctx.device_ctx, in); } } -#define REGISTER_BOXING_S2S_ALL2ALL_UNPACK_KERNEL(device_type_v, dtype_pair) \ +#define REGISTER_COLLECTIVE_BOXING_UNPACK_KERNEL(device_type_v, dtype_pair) \ REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE( \ - OperatorConf::kBoxingS2SAll2AllUnpackConf, device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \ - BoxingS2SAll2AllUnpackKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>) + OperatorConf::kCollectiveBoxingUnpackConf, device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \ + CollectiveBoxingUnpackKernel<device_type_v, OF_PP_PAIR_FIRST(dtype_pair)>) -OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_BOXING_S2S_ALL2ALL_UNPACK_KERNEL, DEVICE_TYPE_SEQ, +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_COLLECTIVE_BOXING_UNPACK_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) #if defined(WITH_CUDA) -REGISTER_BOXING_S2S_ALL2ALL_UNPACK_KERNEL(DeviceType::kGPU, (float16, DataType::kFloat16)) +REGISTER_COLLECTIVE_BOXING_UNPACK_KERNEL(DeviceType::kGPU, (float16, DataType::kFloat16)) #endif } // namespace oneflow diff --git a/oneflow/core/operator/boxing_s2s_all2all_pack_op.cpp b/oneflow/core/operator/collective_boxing_pack_op.cpp similarity index 74% rename from oneflow/core/operator/boxing_s2s_all2all_pack_op.cpp rename to oneflow/core/operator/collective_boxing_pack_op.cpp index 41984bd66371cad359d0dda94268196353a20590..5935eddf30e0ed7563c51b05c99e1f9fd5ba627c 100644 --- a/oneflow/core/operator/boxing_s2s_all2all_pack_op.cpp +++ b/oneflow/core/operator/collective_boxing_pack_op.cpp @@ -19,11 +19,11 @@ limitations under the License. namespace oneflow { -class BoxingS2SAll2AllPackOp : public Operator { +class CollectiveBoxingPackOp : public Operator { public: - OF_DISALLOW_COPY_AND_MOVE(BoxingS2SAll2AllPackOp); - BoxingS2SAll2AllPackOp() = default; - ~BoxingS2SAll2AllPackOp() override = default; + OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingPackOp); + CollectiveBoxingPackOp() = default; + ~CollectiveBoxingPackOp() override = default; void InitFromOpConf() override; @@ -40,20 +40,20 @@ class BoxingS2SAll2AllPackOp : public Operator { LogicalBlobId lbi4obn(const std::string& output_bn) const override; }; -void BoxingS2SAll2AllPackOp::InitFromOpConf() { +void CollectiveBoxingPackOp::InitFromOpConf() { EnrollInputBn("in", false); EnrollOutputBn("out", false); } -LogicalBlobId BoxingS2SAll2AllPackOp::lbi4ibn(const std::string& input_bn) const { - return this->op_conf().boxing_s2s_all2all_pack_conf().lbi(); +LogicalBlobId CollectiveBoxingPackOp::lbi4ibn(const std::string& input_bn) const { + return this->op_conf().collective_boxing_pack_conf().lbi(); } -LogicalBlobId BoxingS2SAll2AllPackOp::lbi4obn(const std::string& output_bn) const { - return this->op_conf().boxing_s2s_all2all_pack_conf().lbi(); +LogicalBlobId CollectiveBoxingPackOp::lbi4obn(const std::string& output_bn) const { + return this->op_conf().collective_boxing_pack_conf().lbi(); } -Maybe<void> BoxingS2SAll2AllPackOp::InferBlobDescs( +Maybe<void> CollectiveBoxingPackOp::InferBlobDescs( std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); @@ -63,6 +63,6 @@ Maybe<void> BoxingS2SAll2AllPackOp::InferBlobDescs( return Maybe<void>::Ok(); } -REGISTER_OP(OperatorConf::kBoxingS2SAll2AllPackConf, BoxingS2SAll2AllPackOp); +REGISTER_OP(OperatorConf::kCollectiveBoxingPackConf, CollectiveBoxingPackOp); } // namespace oneflow diff --git a/oneflow/core/operator/boxing_s2s_all2all_unpack_op.cpp b/oneflow/core/operator/collective_boxing_unpack_op.cpp similarity index 60% rename from oneflow/core/operator/boxing_s2s_all2all_unpack_op.cpp rename to oneflow/core/operator/collective_boxing_unpack_op.cpp index 7db186a602f75738ea34beef365ba47b14234aad..13e7153fe0465dfb1b9a8b4eb5d3880e6b6c609a 100644 --- a/oneflow/core/operator/boxing_s2s_all2all_unpack_op.cpp +++ b/oneflow/core/operator/collective_boxing_unpack_op.cpp @@ -19,11 +19,11 @@ limitations under the License. namespace oneflow { -class BoxingS2SAll2AllUnpackOp : public Operator { +class CollectiveBoxingUnpackOp : public Operator { public: - OF_DISALLOW_COPY_AND_MOVE(BoxingS2SAll2AllUnpackOp); - BoxingS2SAll2AllUnpackOp() = default; - ~BoxingS2SAll2AllUnpackOp() override = default; + OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingUnpackOp); + CollectiveBoxingUnpackOp() = default; + ~CollectiveBoxingUnpackOp() override = default; void InitFromOpConf() override; @@ -40,34 +40,37 @@ class BoxingS2SAll2AllUnpackOp : public Operator { LogicalBlobId lbi4obn(const std::string& output_bn) const override; }; -void BoxingS2SAll2AllUnpackOp::InitFromOpConf() { +void CollectiveBoxingUnpackOp::InitFromOpConf() { EnrollInputBn("in", false); EnrollOutputBn("out", false); } -LogicalBlobId BoxingS2SAll2AllUnpackOp::lbi4ibn(const std::string& input_bn) const { - return this->op_conf().boxing_s2s_all2all_unpack_conf().lbi(); +LogicalBlobId CollectiveBoxingUnpackOp::lbi4ibn(const std::string& input_bn) const { + return this->op_conf().collective_boxing_unpack_conf().lbi(); } -LogicalBlobId BoxingS2SAll2AllUnpackOp::lbi4obn(const std::string& output_bn) const { - return this->op_conf().boxing_s2s_all2all_unpack_conf().lbi(); +LogicalBlobId CollectiveBoxingUnpackOp::lbi4obn(const std::string& output_bn) const { + return this->op_conf().collective_boxing_unpack_conf().lbi(); } -Maybe<void> BoxingS2SAll2AllUnpackOp::InferBlobDescs( +Maybe<void> CollectiveBoxingUnpackOp::InferBlobDescs( std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx) const { - const BoxingS2SAll2AllUnpackOpConf& unpack_conf = - this->op_conf().boxing_s2s_all2all_unpack_conf(); - const int64_t dst_split_axis = unpack_conf.dst_split_axis(); + const CollectiveBoxingUnpackOpConf& unpack_conf = this->op_conf().collective_boxing_unpack_conf(); + const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); - *out_blob_desc = *GetBlobDesc4BnInOp("in"); + *out_blob_desc = *in_blob_desc; Shape out_shape(unpack_conf.logical_shape()); - out_shape.Set(dst_split_axis, out_shape.At(dst_split_axis) / unpack_conf.num_ranks()); + if (unpack_conf.dst_sbp_parallel().has_split_parallel()) { + const int64_t dst_split_axis = unpack_conf.dst_sbp_parallel().split_parallel().axis(); + out_shape.Set(dst_split_axis, out_shape.At(dst_split_axis) / unpack_conf.num_ranks()); + } + CHECK_EQ_OR_RETURN(out_shape.elem_cnt(), in_blob_desc->shape().elem_cnt()); out_blob_desc->mut_shape() = out_shape; return Maybe<void>::Ok(); } -REGISTER_OP(OperatorConf::kBoxingS2SAll2AllUnpackConf, BoxingS2SAll2AllUnpackOp); +REGISTER_OP(OperatorConf::kCollectiveBoxingUnpackConf, CollectiveBoxingUnpackOp); } // namespace oneflow diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index ed64d5686a2d0130c64b2de59b46b9891a60723a..d03e7ce66e42a0cb06a717c5459557b6f826ab2b 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -554,16 +554,18 @@ message BoxingIdentityOpConf { required LogicalBlobId lbi = 1; } -message BoxingS2SAll2AllPackOpConf { +message CollectiveBoxingPackOpConf { required LogicalBlobId lbi = 1; - required int64 dst_split_axis = 2; - required int64 num_ranks = 3; + required SbpParallel src_sbp_parallel = 2; + required SbpParallel dst_sbp_parallel = 3; + required int64 num_ranks = 4; + required ShapeProto logical_shape = 5; } -message BoxingS2SAll2AllUnpackOpConf { +message CollectiveBoxingUnpackOpConf { required LogicalBlobId lbi = 1; - required int64 src_split_axis = 2; - required int64 dst_split_axis = 3; + required SbpParallel src_sbp_parallel = 2; + required SbpParallel dst_sbp_parallel = 3; required int64 num_ranks = 4; required ShapeProto logical_shape = 5; } @@ -637,8 +639,8 @@ message OperatorConf { CollectiveBoxingGenericOpConf collective_boxing_generic_conf = 170; BoxingIdentityOpConf boxing_identity_conf = 171; TensorListSplitOpConf tensor_list_split_conf = 172; - BoxingS2SAll2AllPackOpConf boxing_s2s_all2all_pack_conf = 174; - BoxingS2SAll2AllUnpackOpConf boxing_s2s_all2all_unpack_conf = 175; + CollectiveBoxingPackOpConf collective_boxing_pack_conf = 174; + CollectiveBoxingUnpackOpConf collective_boxing_unpack_conf = 175; BoxingZerosOpConf boxing_zeros_conf = 176; UserOpConf user_conf = 199;