Skip to content
Snippets Groups Projects
Commit 5e7a5660 authored by Xinqi's avatar Xinqi
Browse files

ReentrantLockComputeActor::eord_regst_desc_id_

parent 839f157d
No related branches found
No related tags found
No related merge requests found
......@@ -236,6 +236,10 @@ void Actor::ForEachCurNaiveReadableDataRegst(std::function<void(const Regst*)> f
});
}
bool Actor::ReceiveEordMsg(int64_t regst_desc_id) const {
return eord_regst_desc_ids_.find(regst_desc_id) != eord_regst_desc_ids_.end();
}
int Actor::HandlerNormal(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kEordMsg) {
remaining_eord_cnt_ -= 1;
......
......@@ -51,6 +51,7 @@ class Actor {
Actor() = default;
const ParallelContext* parallel_ctx() const { return parallel_ctx_.get(); }
bool ReceiveAllEordMsg() const { return remaining_eord_cnt_ == 0; }
bool ReceiveEordMsg(int64_t regst_desc_id) const;
DeviceType GetDeviceType() const;
virtual void VirtualActorInit(const TaskProto&) {}
int64_t Name2SoleRegstDescId(const std::string& name) const;
......
......@@ -7,8 +7,9 @@ void ReentrantLockCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
const auto& kernel_conf = task_proto.exec_sequence().exec_node().Get(0).kernel_conf();
const auto& ibns = kernel_conf.op_attribute().input_bns();
for (const auto& ibn : ibns) {
CHECK(regst_desc_id2ibn_.emplace(exec_kernel_vec().at(0).bn_in_op2regst_desc_id.at(ibn), ibn)
.second);
int64_t regst_desc_id = exec_kernel_vec().at(0).bn_in_op2regst_desc_id.at(ibn);
if (ibn == "start") { eord_regst_desc_id_ = regst_desc_id; }
CHECK(regst_desc_id2ibn_.emplace(regst_desc_id, ibn).second);
}
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
for (const int64_t regst_desc_id : pair.second.regst_desc_id()) {
......@@ -49,7 +50,8 @@ void ReentrantLockCompActor::Act() {
}
bool ReentrantLockCompActor::IsCustomizedReadAlwaysUnReadyFromNow() const {
return ReceiveAllEordMsg() && reentrant_lock_status_.total_queued_request_lock_num() == 0
return ReceiveEordMsg(eord_regst_desc_id_)
&& reentrant_lock_status_.total_queued_request_lock_num() == 0
&& reentrant_lock_status_.total_acquired_lock_num() == 0;
}
......
......@@ -37,6 +37,7 @@ class ReentrantLockCompActor final : public CompActor {
int64_t cur_processed_regst_desc_id_;
HashMap<int64_t, std::string> regst_desc_id2ibn_;
ReentrantLockStatus reentrant_lock_status_;
int64_t eord_regst_desc_id_;
};
} // namespace oneflow
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment