From d6a6079769e923954b5cd4ab74a7c793af7c1e90 Mon Sep 17 00:00:00 2001
From: Shenghang Tsai <jackalcooper@gmail.com>
Date: Fri, 7 May 2021 14:38:06 +0800
Subject: [PATCH] Delete ops outside of loops in job passes (#4815)

* Serialize proto in binary rather than text

* move del ops out from loop

* refine

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp | 4 +++-
 oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp    | 5 +++--
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp b/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp
index 92fa477be..49484d4f1 100644
--- a/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp
+++ b/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp
@@ -81,6 +81,7 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_
   });
 
   auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();
+  std::vector<OperatorConf> delete_ops;
   op_graph.ForEachNode([&](const OpNode* op_node) {
     const OperatorConf& op_conf = op_node->op().op_conf();
     if (!op_conf.has_user_conf()) { return; }
@@ -132,8 +133,9 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_
         }
       }
     }
-    job_builder->DelOps({op_conf});
+    delete_ops.push_back(op_conf);
   });
+  job_builder->DelOps(delete_ops);
   for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); }
   return Maybe<void>::Ok();
 }
diff --git a/oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp b/oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp
index e27fb19e8..c809fcc8b 100644
--- a/oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp
+++ b/oneflow/core/job_rewriter/fuse_cast_scale_pass.cpp
@@ -61,6 +61,7 @@ class FuseCastScalePass final : public JobPass {
 
 Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
   const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph);
+  std::vector<OperatorConf> delete_ops;
   op_graph.ForEachNode([&](const OpNode* op_node) {
     if (!IsUserOpWithTypeName(op_node->op().op_conf(), "cast")) { return; }
     if (!IsSafeToDelete(op_node)) { return; }
@@ -86,7 +87,6 @@ Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
     }
     if (op_node->parallel_desc().device_type() != DeviceType::kGPU) { return; }
     double scale = 1.0;
-    std::vector<OperatorConf> delete_ops;
     if (IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul")) {
       const user_op::UserOpConfWrapper scalar_mul_op_conf(sole_dst_node->op().op_conf());
       if (scalar_mul_op_conf.attr<bool>("has_int_operand")) {
@@ -111,9 +111,10 @@ Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
 
     OperatorConf new_op_conf = sole_dst_node->op().op_conf();
     *new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf();
-    job_builder->DelOps(delete_ops);
+
     job_builder->MutOpsOnlyOnce({new_op_conf});
   });
+  job_builder->DelOps(delete_ops);
   return Maybe<void>::Ok();
 }
 
-- 
GitLab