Skip to content
Snippets Groups Projects
Unverified Commit d6a60797 authored by Shenghang Tsai's avatar Shenghang Tsai Committed by GitHub
Browse files

Delete ops outside of loops in job passes (#4815)


* Serialize proto in binary rather than text

* move del ops out from loop

* refine

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 4491fef4
No related branches found
No related tags found
No related merge requests found
......@@ -81,6 +81,7 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_
});
auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable();
std::vector<OperatorConf> delete_ops;
op_graph.ForEachNode([&](const OpNode* op_node) {
const OperatorConf& op_conf = op_node->op().op_conf();
if (!op_conf.has_user_conf()) { return; }
......@@ -132,8 +133,9 @@ Maybe<void> FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_
}
}
}
job_builder->DelOps({op_conf});
delete_ops.push_back(op_conf);
});
job_builder->DelOps(delete_ops);
for (const auto& pair : op_name2op_conf) { job_builder->MutOpsOnlyOnce({pair.second}); }
return Maybe<void>::Ok();
}
......
......@@ -61,6 +61,7 @@ class FuseCastScalePass final : public JobPass {
Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const {
const auto IsSafeToDelete = MakePredicatorIsSafeToDelete(op_graph);
std::vector<OperatorConf> delete_ops;
op_graph.ForEachNode([&](const OpNode* op_node) {
if (!IsUserOpWithTypeName(op_node->op().op_conf(), "cast")) { return; }
if (!IsSafeToDelete(op_node)) { return; }
......@@ -86,7 +87,6 @@ Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
}
if (op_node->parallel_desc().device_type() != DeviceType::kGPU) { return; }
double scale = 1.0;
std::vector<OperatorConf> delete_ops;
if (IsUserOpWithTypeName(sole_dst_node->op().op_conf(), "scalar_mul")) {
const user_op::UserOpConfWrapper scalar_mul_op_conf(sole_dst_node->op().op_conf());
if (scalar_mul_op_conf.attr<bool>("has_int_operand")) {
......@@ -111,9 +111,10 @@ Maybe<void> FuseCastScalePass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
OperatorConf new_op_conf = sole_dst_node->op().op_conf();
*new_op_conf.mutable_user_conf() = fused_op_builder.Build().op_conf().user_conf();
job_builder->DelOps(delete_ops);
job_builder->MutOpsOnlyOnce({new_op_conf});
});
job_builder->DelOps(delete_ops);
return Maybe<void>::Ok();
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment