From d6a6079769e923954b5cd4ab74a7c793af7c1e90 Mon Sep 17 00:00:00 2001 From: Shenghang Tsai <jackalcooper@gmail.com> Date: Fri, 7 May 2021 14:38:06 +0800 Subject: [PATCH] Delete ops outside of loops in job passes (#4815) * Serialize proto in binary rather than text * move del ops out from loop * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp | 4 +++- oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp b/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp index 92fa477be..49484d4f1 100644 --- a/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp +++ b/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp @@ -81,6 +81,7 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_ }); auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); + std::vector<OperatorConf> delete_ops; op_graph.ForEachNode([&](const OpNode* op_node) { const OperatorConf& op_conf = op_node->op().op_conf(); if (!op_conf.has_user_conf()) { return; } @@ -132,8 +133,9 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_ } } } - job_builder->DelOps({op_conf}); + delete_ops.push_back(op_conf); }); + job_builder->DelOps(delete_ops); for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); } return Maybe<void>::Ok(); } diff --git a/oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp b/oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp index e27fb19e8..c809fcc8b 100644 --- a/oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp +++ b/oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp @@ -61,6 +61,7 @@ class FuseCastScalePass final : public JobPass { Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph); + std::vector<OperatorConf> delete_ops; op_graph.ForEachNode([&](const OpNode* op_node) { if (!IsUserOpWithTypeName(op_node->op().op_conf(), "cast")) { return; } if (!IsSafeToDelete(op_node)) { return; } @@ -86,7 +87,6 @@ Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_bu } if (op_node->parallel_desc().device_type() != DeviceType::kGPU) { return; } double scale = 1.0; - std::vector<OperatorConf> delete_ops; if (IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul")) { const user_op::UserOpConfWrapper scalar_mul_op_conf(sole_dst_node->op().op_conf()); if (scalar_mul_op_conf.attr<bool>("has_int_operand")) { @@ -111,9 +111,10 @@ Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_bu OperatorConf new_op_conf = sole_dst_node->op().op_conf(); *new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf(); - job_builder->DelOps(delete_ops); + job_builder->MutOpsOnlyOnce({new_op_conf}); }); + job_builder->DelOps(delete_ops); return Maybe<void>::Ok(); } -- GitLab