diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp
index e5639be88b7c6ea596a8a717ce86bbc1f1a9cd47..2462b4ded84fcd0ba0361f9432cce688539316ae 100644
--- a/oneflow/core/job/compiler.cpp
+++ b/oneflow/core/job/compiler.cpp
@@ -118,7 +118,8 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
if (task_node->IsMeaningLess()) { return; }
TaskProto task_proto;
task_node->ToProto(&task_proto);
- if (task_node->GetTaskType() == kNormalForward) {
+ if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat
+ || task_node->GetTaskType() == kAcc) {
CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);
}
plan->mutable_task()->Add(std::move(task_proto));
diff --git a/oneflow/core/job/inter_job_mem_sharing_util.cpp b/oneflow/core/job/inter_job_mem_sharing_util.cpp
index 3707cb3d426c81f315178e4ccbc650326cc88600..62a3b9952830a257e111dd083b451dcdb26b998d 100644
--- a/oneflow/core/job/inter_job_mem_sharing_util.cpp
+++ b/oneflow/core/job/inter_job_mem_sharing_util.cpp
@@ -33,7 +33,7 @@ void GetOpName2JobId2TaskProtos(
if (task->exec_sequence().exec_node_size() == 1) {
const KernelConf& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
std::string op_name =
- PlanUtil::GeOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name();
+ PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name();
if (op_names.find(op_name) != op_names.end()) {
CHECK(task->has_parallel_ctx());
(*op_name2job_id2task_protos)[op_name][task->job_id()].push_back(task);
diff --git a/oneflow/core/job/intra_job_mem_sharing_util.cpp b/oneflow/core/job/intra_job_mem_sharing_util.cpp
index 1443502d5ecd3ae24529fcc6a1393ecef8e73e7c..c8a54f56c342c0ea8faf49978754189491f4b519 100644
--- a/oneflow/core/job/intra_job_mem_sharing_util.cpp
+++ b/oneflow/core/job/intra_job_mem_sharing_util.cpp
@@ -198,8 +198,8 @@ void GenMemChainTasksAndRegsts(
std::string* op_name) -> bool {
if (task_proto->task_type() == TaskType::kNormalForward
&& task_proto->exec_sequence().exec_node_size() == 1) {
- *op_name = PlanUtil::GeOpAttribute(plan, task_proto->job_id(),
- task_proto->exec_sequence().exec_node(0).kernel_conf())
+ *op_name = PlanUtil::GetOpAttribute(plan, task_proto->job_id(),
+ task_proto->exec_sequence().exec_node(0).kernel_conf())
.op_conf()
.name();
return true;
diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp
index 2d47003ec6f05c4e83d39340e2dbbf56c6c3efe7..67fdc45fe9adecf5a70c834868ec3223e85123b6 100644
--- a/oneflow/core/job/oneflow.cpp
+++ b/oneflow/core/job/oneflow.cpp
@@ -231,8 +231,8 @@ const boxing::collective::RankDesc& GetRankDesc(const OperatorConf& conf) {
const boxing::collective::RankDesc& GetRankDesc(Plan* plan, const TaskProto& task_proto) {
CHECK_EQ(task_proto.exec_sequence().exec_node_size(), 1);
- return GetRankDesc(PlanUtil::GeOpAttribute(plan, task_proto.job_id(),
- task_proto.exec_sequence().exec_node(0).kernel_conf())
+ return GetRankDesc(PlanUtil::GetOpAttribute(plan, task_proto.job_id(),
+ task_proto.exec_sequence().exec_node(0).kernel_conf())
.op_conf());
}
@@ -432,8 +432,8 @@ RegstDescProto* GetSoleDataRegstDescProto(TaskProto* task) {
const OperatorConf& GetSoleOpConf(Plan* plan, const TaskProto& task) {
CHECK_EQ(task.exec_sequence().exec_node_size(), 1);
- return PlanUtil::GeOpAttribute(plan, task.job_id(),
- task.exec_sequence().exec_node(0).kernel_conf())
+ return PlanUtil::GetOpAttribute(plan, task.job_id(),
+ task.exec_sequence().exec_node(0).kernel_conf())
.op_conf();
}
@@ -441,7 +441,7 @@ void UpdateSoleObnRegstDescId(Plan* plan, TaskProto* task) {
CHECK_EQ(task->exec_sequence().exec_node_size(), 1);
auto* exec_node = task->mutable_exec_sequence()->mutable_exec_node(0);
const auto& obns =
- PlanUtil::GeOpAttribute(plan, task->job_id(), exec_node->kernel_conf()).output_bns();
+ PlanUtil::GetOpAttribute(plan, task->job_id(), exec_node->kernel_conf()).output_bns();
CHECK_EQ(obns.size(), 1);
int64_t regst_desc_id = GetSoleDataRegstDescProto(task)->regst_desc_id();
(*exec_node->mutable_bn_in_op2regst_desc_id())[obns.Get(0)] = regst_desc_id;
@@ -504,7 +504,7 @@ void LinkMainPlan(Plan* plan, Plan&& main_plan,
if (task->exec_sequence().exec_node_size() != 1) { return false; }
const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
OperatorConf::OpTypeCase op_type_case =
- PlanUtil::GeOpAttribute(plan, task->job_id(), kernel_conf).op_conf().op_type_case();
+ PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().op_type_case();
return op_type_case == OperatorConf::kSourceTickConf
|| op_type_case == OperatorConf::kSinkTickConf;
};
@@ -516,7 +516,7 @@ void LinkMainPlan(Plan* plan, Plan&& main_plan,
if (IsInterfaceTickTockTask(task) == false) { continue; }
const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
const auto& op_name =
- PlanUtil::GeOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name();
+ PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name();
CHECK(sole_tick_op_name2sole_task.emplace(op_name, task).second);
}
auto TaskProto4TaskId = PlanUtil::MakeGetterTaskProto4TaskId(*plan);
@@ -547,7 +547,7 @@ void LinkMainPlan(Plan* plan, Plan&& main_plan,
if (task.task_type() == TaskType::kSourceTick) {
CHECK(task.exec_sequence().exec_node_size() == 1);
const auto& kernel_conf = task.exec_sequence().exec_node(0).kernel_conf();
- const auto& op_conf = PlanUtil::GeOpAttribute(plan, task.job_id(), kernel_conf).op_conf();
+ const auto& op_conf = PlanUtil::GetOpAttribute(plan, task.job_id(), kernel_conf).op_conf();
CHECK(op_conf.has_source_tick_conf());
CHECK(source_tick_op_names.find(op_conf.name()) != source_tick_op_names.end());
return true;
@@ -883,7 +883,7 @@ Maybe<void> ConnectCriticalSectionEndToReentrantLockEnd(
CHECK_EQ_OR_RETURN(task->exec_sequence().exec_node_size(), 1);
const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
const auto& op_name =
- PlanUtil::GeOpAttribute(main_plan, task->job_id(), kernel_conf).op_conf().name();
+ PlanUtil::GetOpAttribute(main_plan, task->job_id(), kernel_conf).op_conf().name();
if (op_name == lock_back_edge.reentrant_lock_op_name) {
CHECK_ISNULL_OR_RETURN(reentrant_lock_task);
reentrant_lock_task = task;
@@ -950,7 +950,7 @@ void FinishGlobalCriticalSectionDesc(const Plan& plan, int64_t job_size) {
if (task.exec_sequence().exec_node_size() == 1) {
const auto& kernel_conf = task.exec_sequence().exec_node(0).kernel_conf();
const std::string& op_name =
- PlanUtil::GeOpAttribute(&plan, task.job_id(), kernel_conf).op_conf().name();
+ PlanUtil::GetOpAttribute(&plan, task.job_id(), kernel_conf).op_conf().name();
HashSet<int64_t>* mem_block_ids =
&(job_id2sole_op_name2mem_block_ids.at(task.job_id())[op_name]);
for (const auto& pair : task.produced_regst_desc()) {
diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp
index fe153e58a54fe0ef765395d2aadbacc1cad9518a..dd1dfe46028f6a0b9a2a030b9d4c608fee019fd3 100644
--- a/oneflow/core/job/plan_util.cpp
+++ b/oneflow/core/job/plan_util.cpp
@@ -345,7 +345,7 @@ void PlanUtil::ToDotFile(const Plan& plan, const std::string& filepath) {
std::string pass_tag = kNoPassTag;
for (const ExecNodeProto& exec_node : task_proto.exec_sequence().exec_node()) {
const auto& op_conf =
- GeOpAttribute(&plan, task_proto.job_id(), exec_node.kernel_conf()).op_conf();
+ GetOpAttribute(&plan, task_proto.job_id(), exec_node.kernel_conf()).op_conf();
op_name += op_conf.name();
if (op_conf.has_pass_tag()) { pass_tag = op_conf.pass_tag(); }
}
@@ -526,8 +526,8 @@ void PlanUtil::SetForceInplaceMemBlock(Plan* plan) {
}
}
-const oneflow::OpAttribute& PlanUtil::GeOpAttribute(const Plan* plan, int64_t job_id,
- const oneflow::KernelConf& kernel_conf) {
+const oneflow::OpAttribute& PlanUtil::GetOpAttribute(const Plan* plan, int64_t job_id,
+ const oneflow::KernelConf& kernel_conf) {
if (kernel_conf.has_op_attribute()) {
return kernel_conf.op_attribute();
} else if (kernel_conf.has_op_attribute_ref()) {
diff --git a/oneflow/core/job/plan_util.h b/oneflow/core/job/plan_util.h
index 0427fb00d46fd9bfddddacaa7ba1541bb769f6cd..cce71d9ac0062f1c757aa5e002299c70165b10ef 100644
--- a/oneflow/core/job/plan_util.h
+++ b/oneflow/core/job/plan_util.h
@@ -30,8 +30,8 @@ struct PlanUtil {
static void ToDotFile(const Plan& plan, const std::string& filepath);
static std::function<RegstDescProto*(int64_t)> MakeMutRegstDesc4Id(Plan* plan);
static void SetForceInplaceMemBlock(Plan* plan);
- static const oneflow::OpAttribute& GeOpAttribute(const Plan* plan, int64_t job_id,
- const oneflow::KernelConf& kernel_conf);
+ static const oneflow::OpAttribute& GetOpAttribute(const Plan* plan, int64_t job_id,
+ const oneflow::KernelConf& kernel_conf);
};
} // namespace oneflow