diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp
index 3eeb429e825af84e8008e037836962a9f1a04889..9b321f39fe22302be693d96f89700a4bcbd71e58 100644
--- a/oneflow/core/actor/actor.cpp
+++ b/oneflow/core/actor/actor.cpp
@@ -236,6 +236,10 @@ void Actor::ForEachCurNaiveReadableDataRegst(std::function<void(const Regst*)> f
});
}
+bool Actor::ReceiveEordMsg(int64_t regst_desc_id) const {
+ return eord_regst_desc_ids_.find(regst_desc_id) != eord_regst_desc_ids_.end();
+}
+
int Actor::HandlerNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kEordMsg) {
remaining_eord_cnt_ -= 1;
diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h
index a741506c5b3227f6202621b63a33a9ba3b635954..1842b716d4e63e5d07a7772deef0a0787133f6dd 100644
--- a/oneflow/core/actor/actor.h
+++ b/oneflow/core/actor/actor.h
@@ -51,6 +51,7 @@ class Actor {
Actor() = default;
const ParallelContext* parallel_ctx() const { return parallel_ctx_.get(); }
bool ReceiveAllEordMsg() const { return remaining_eord_cnt_ == 0; }
+ bool ReceiveEordMsg(int64_t regst_desc_id) const;
DeviceType GetDeviceType() const;
virtual void VirtualActorInit(const TaskProto&) {}
int64_t Name2SoleRegstDescId(const std::string& name) const;
diff --git a/oneflow/core/actor/reentrant_lock_compute_actor.cpp b/oneflow/core/actor/reentrant_lock_compute_actor.cpp
index 4f5fef21de54f127938ef2c9f4454ebb296e48f2..30de238eb5a5f672466936774de50bb73cdc51a2 100644
--- a/oneflow/core/actor/reentrant_lock_compute_actor.cpp
+++ b/oneflow/core/actor/reentrant_lock_compute_actor.cpp
@@ -7,8 +7,9 @@ void ReentrantLockCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
const auto& kernel_conf = task_proto.exec_sequence().exec_node().Get(0).kernel_conf();
const auto& ibns = kernel_conf.op_attribute().input_bns();
for (const auto& ibn : ibns) {
- CHECK(regst_desc_id2ibn_.emplace(exec_kernel_vec().at(0).bn_in_op2regst_desc_id.at(ibn), ibn)
- .second);
+ int64_t regst_desc_id = exec_kernel_vec().at(0).bn_in_op2regst_desc_id.at(ibn);
+ if (ibn == "start") { eord_regst_desc_id_ = regst_desc_id; }
+ CHECK(regst_desc_id2ibn_.emplace(regst_desc_id, ibn).second);
}
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
for (const int64_t regst_desc_id : pair.second.regst_desc_id()) {
@@ -49,7 +50,8 @@ void ReentrantLockCompActor::Act() {
}
bool ReentrantLockCompActor::IsCustomizedReadAlwaysUnReadyFromNow() const {
- return ReceiveAllEordMsg() && reentrant_lock_status_.total_queued_request_lock_num() == 0
+ return ReceiveEordMsg(eord_regst_desc_id_)
+ && reentrant_lock_status_.total_queued_request_lock_num() == 0
&& reentrant_lock_status_.total_acquired_lock_num() == 0;
}
diff --git a/oneflow/core/actor/reentrant_lock_compute_actor.h b/oneflow/core/actor/reentrant_lock_compute_actor.h
index cde0e875c5829b99028346473351112f33108899..c1c7b818f3142e0fae9b8d5bd6d8bb01c7a6a994 100644
--- a/oneflow/core/actor/reentrant_lock_compute_actor.h
+++ b/oneflow/core/actor/reentrant_lock_compute_actor.h
@@ -37,6 +37,7 @@ class ReentrantLockCompActor final : public CompActor {
int64_t cur_processed_regst_desc_id_;
HashMap<int64_t, std::string> regst_desc_id2ibn_;
ReentrantLockStatus reentrant_lock_status_;
+ int64_t eord_regst_desc_id_;
};
} // namespace oneflow