Skip to content
Snippets Groups Projects
Commit 3cc42016 authored by Juncheng's avatar Juncheng Committed by Niu Chong
Browse files

actor::IsReadReady/WriteReady add const qualifier (#1892)

parent 77131a3f
No related branches found
No related tags found
No related merge requests found
Showing
with 45 additions and 38 deletions
......@@ -443,12 +443,12 @@ int64_t Actor::HandleRegstToConsumer(Regst* regst, std::function<bool(int64_t)>
return real_consumer_cnt;
}
bool Actor::IsReadReady() {
bool Actor::IsReadReady() const {
return naive_consumed_rs_.IsCurSlotReady() && inplace_consumed_rs_.IsCurSlotReady()
&& IsCustomizedReadReady();
}
bool Actor::IsWriteReady() {
bool Actor::IsWriteReady() const {
return naive_produced_rs_.IsCurSlotReady() && inplace_produced_rs_.IsCurSlotReady()
&& IsCustomizedWriteReady();
}
......
......@@ -155,8 +155,8 @@ class Actor {
void TryLogActEvent(const std::function<void()>& Callback) const;
// Ready
bool IsReadReady();
bool IsWriteReady();
bool IsReadReady() const;
bool IsWriteReady() const;
// Naive, Inplace Or Customized
void TakeOverInplaceConsumedAndProduced(const PbMap<std::string, RegstDescProto>& produced_ids);
......@@ -178,8 +178,8 @@ class Actor {
virtual void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const {}
virtual void NormalProcessCustomizedEordMsg(const ActorMsg&) {}
virtual void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) { UNIMPLEMENTED(); }
virtual bool IsCustomizedReadReady() { return true; }
virtual bool IsCustomizedReadAlwaysUnReadyFromNow() { return false; }
virtual bool IsCustomizedReadReady() const { return true; }
virtual bool IsCustomizedReadAlwaysUnReadyFromNow() const { return false; }
virtual std::pair<RegstNameType, HashSet<std::string>>
GetNaiveOrCustomizedConsumedRegstDescName() {
return std::make_pair(RegstNameType::kCustomized, HashSet<std::string>{});
......@@ -189,7 +189,7 @@ class Actor {
// Customized Produced virtual func
virtual void UpdtStateAsCustomizedProducedRegst(Regst* regst) { UNIMPLEMENTED(); }
virtual bool IsCustomizedWriteReady() { return true; }
virtual bool IsCustomizedWriteReady() const { return true; }
virtual std::pair<RegstNameType, HashSet<std::string>>
GetNaiveOrCustomizedProducedRegstDescName() {
return std::make_pair(RegstNameType::kCustomized, HashSet<std::string>{});
......
......@@ -86,11 +86,11 @@ void CopyCommNetActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {
next_piece_id_ += 1;
}
bool CopyCommNetActor::IsCustomizedReadReady() {
bool CopyCommNetActor::IsCustomizedReadReady() const {
return piece_id2regst_ctx_.find(next_piece_id_) != piece_id2regst_ctx_.end();
}
bool CopyCommNetActor::IsCustomizedReadAlwaysUnReadyFromNow() {
bool CopyCommNetActor::IsCustomizedReadAlwaysUnReadyFromNow() const {
return is_in_eord_ && piece_id2regst_ctx_.empty();
}
......
......@@ -34,8 +34,8 @@ class CopyCommNetActor final : public Actor {
void Act() override;
void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;
void AsyncSendCustomizedConsumedRegstMsgToProducer() override;
bool IsCustomizedReadReady() override;
bool IsCustomizedReadAlwaysUnReadyFromNow() override;
bool IsCustomizedReadReady() const override;
bool IsCustomizedReadAlwaysUnReadyFromNow() const override;
void AsyncReturnAllCustomizedReadableRegst() override;
bool is_in_eord_;
......
......@@ -15,7 +15,7 @@ void DecodeRandomActor::Act() {
AsyncLaunchKernel(GenDefaultKernelCtx());
}
bool DecodeRandomActor::IsCustomizedReadReady() {
bool DecodeRandomActor::IsCustomizedReadReady() const {
return piece_id_ < Global<RuntimeCtx>::Get()->total_piece_num();
}
......
......@@ -18,8 +18,8 @@ class DecodeRandomActor final : public CompActor {
override {
return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});
}
bool IsCustomizedReadReady() override;
bool IsCustomizedReadAlwaysUnReadyFromNow() override { return !IsCustomizedReadReady(); }
bool IsCustomizedReadReady() const override;
bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return !IsCustomizedReadReady(); }
int HandlerWaitToStart(const ActorMsg&);
......
......@@ -37,17 +37,8 @@ void InputWiseCompActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg&
CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst()));
}
bool InputWiseCompActor::IsCustomizedReadReady() {
CHECK_EQ(-1, cur_processed_regst_desc_id_);
consumed_rs_.ForChosenRegstDeq([this](int64_t) { return cur_processed_regst_desc_id_ == -1; },
[this](const std::deque<Regst*>& reg_deq) {
if (reg_deq.empty()) { return; }
int64_t regst_desc_id = reg_deq.front()->regst_desc_id();
if (regst_desc_id2is_processed_.at(regst_desc_id) == false) {
cur_processed_regst_desc_id_ = regst_desc_id;
}
});
return cur_processed_regst_desc_id_ != -1;
bool InputWiseCompActor::IsCustomizedReadReady() const {
return -1 != GetCurProcessedRegstDescId();
}
void InputWiseCompActor::ForEachCurCustomizedReadableRegst(
......@@ -56,6 +47,7 @@ void InputWiseCompActor::ForEachCurCustomizedReadableRegst(
}
void InputWiseCompActor::Act() {
cur_processed_regst_desc_id_ = GetCurProcessedRegstDescId();
Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_);
CHECK(cur_regst);
......@@ -99,4 +91,18 @@ void InputWiseCompActor::AsyncReturnAllCustomizedReadableRegst() {
bool InputWiseCompActor::ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; }
int64_t InputWiseCompActor::GetCurProcessedRegstDescId() const {
int64_t cur_processed_regst_desc_id = -1;
consumed_rs_.ForChosenRegstDeq(
[cur_processed_regst_desc_id](int64_t) { return cur_processed_regst_desc_id == -1; },
[this, &cur_processed_regst_desc_id](const std::deque<Regst*>& reg_deq) {
if (reg_deq.empty()) { return; }
int64_t regst_desc_id = reg_deq.front()->regst_desc_id();
if (regst_desc_id2is_processed_.at(regst_desc_id) == false) {
cur_processed_regst_desc_id = regst_desc_id;
}
});
return cur_processed_regst_desc_id;
}
} // namespace oneflow
......@@ -28,9 +28,9 @@ class InputWiseCompActor : public CompActor {
void Act() override;
void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override;
void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override;
bool IsCustomizedReadReady() override;
bool IsCustomizedReadReady() const override;
void NormalProcessCustomizedEordMsg(const ActorMsg&) override {}
bool IsCustomizedReadAlwaysUnReadyFromNow() override {
bool IsCustomizedReadAlwaysUnReadyFromNow() const override {
return ReceiveAllEordMsg() && consumed_rs_.available_regst_desc_cnt() == 0;
}
void AsyncReturnAllCustomizedReadableRegst() override;
......@@ -42,6 +42,7 @@ class InputWiseCompActor : public CompActor {
void AsyncSendCustomizedConsumedRegstMsgToProducer() override;
virtual void SetKernelCtxOther(void** other) { *other = nullptr; }
int64_t GetCurProcessedRegstDescId() const;
RegstSlot consumed_rs_;
HashMap<int64_t, bool> regst_desc_id2is_processed_;
......
......@@ -81,7 +81,7 @@ void NormalBackwardCompActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {
AsyncReturnModelRegstUntilLastPieceIdGreaterThan(piece_id);
}
bool NormalBackwardCompActor::IsCustomizedReadReady() {
bool NormalBackwardCompActor::IsCustomizedReadReady() const {
if (model_regst_desc_id_ != -1) {
if (model_regst_queue_.empty()) { return false; }
int64_t expected_model_vid =
......
......@@ -18,7 +18,7 @@ class NormalBackwardCompActor final : public CompActor {
void NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>&) override;
void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override;
void Act() override;
bool IsCustomizedReadReady() override;
bool IsCustomizedReadReady() const override;
void AsyncReturnAllCustomizedReadableRegst() override;
std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()
override {
......
......@@ -37,7 +37,7 @@ void NormalForwardCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
}
}
bool NormalForwardCompActor::IsCustomizedWriteReady() {
bool NormalForwardCompActor::IsCustomizedWriteReady() const {
if (const_buf_regst_desc_id_ != -1) { CHECK(send_const_buf_regst_); }
return true;
}
......@@ -123,7 +123,7 @@ void NormalForwardCompActor::AsyncSendCustomizedConsumedRegstMsgToProducer() {
cur_piece_id_ = -1;
}
bool NormalForwardCompActor::IsCustomizedReadReady() {
bool NormalForwardCompActor::IsCustomizedReadReady() const {
if (model_regst_desc_id_ != -1 && model_regst_ == nullptr) { return false; }
if (const_model_regst_desc_id_ != -1 && const_model_regst_ == nullptr) { return false; }
return true;
......
......@@ -16,7 +16,7 @@ class NormalForwardCompActor final : public CompActor {
void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override;
void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override;
void Act() override;
bool IsCustomizedReadReady() override;
bool IsCustomizedReadReady() const override;
void AsyncReturnAllCustomizedReadableRegst() override;
std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName()
override {
......@@ -31,7 +31,7 @@ class NormalForwardCompActor final : public CompActor {
void VirtualAsyncSendInplaceProducedRegstMsgToConsumer() override;
void AsyncSendCustomizedConsumedRegstMsgToProducer() override;
bool IsCustomizedWriteReady() override;
bool IsCustomizedWriteReady() const override;
void UpdtStateAsCustomizedProducedRegst(Regst* regst) override;
bool CheckOutputActId(int64_t regst_desc_id) const override;
......
......@@ -29,7 +29,7 @@ void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
OF_SET_MSG_HANDLER(&NormalMdUpdtCompActor::HandlerInitModelAndConstModel);
}
bool NormalMdUpdtCompActor::IsCustomizedWriteReady() {
bool NormalMdUpdtCompActor::IsCustomizedWriteReady() const {
if (const_model_regst_desc_id_ != -1) { CHECK(send_const_model_regst_); }
return true;
}
......
......@@ -20,7 +20,7 @@ class NormalMdUpdtCompActor final : public CompActor {
}
void AsyncSendCustomizedProducedRegstMsgToConsumer() override {}
void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;
bool IsCustomizedWriteReady() override;
bool IsCustomizedWriteReady() const override;
void UpdtStateAsCustomizedProducedRegst(Regst* regst) override;
void SendConstModelRegstToConsumer();
bool CheckOutputActId(int64_t regst_desc_id) const override;
......
......@@ -27,7 +27,7 @@ void RecordLoadActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {
if (record_load_status_.record_num > 0) { HandleProducedNaiveDataRegstToConsumer(); }
}
bool RecordLoadActor::IsCustomizedReadReady() {
bool RecordLoadActor::IsCustomizedReadReady() const {
return !is_eof_ && piece_id_ < Global<RuntimeCtx>::Get()->total_piece_num();
}
......
......@@ -20,8 +20,8 @@ class RecordLoadActor final : public CompActor {
return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{});
}
void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override;
bool IsCustomizedReadReady() override;
bool IsCustomizedReadAlwaysUnReadyFromNow() override { return !IsCustomizedReadReady(); }
bool IsCustomizedReadReady() const override;
bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return !IsCustomizedReadReady(); }
int HandlerWaitToStart(const ActorMsg&);
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment