diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp index a30b7b13eea6d6f316e35c418849f999af94f9f2..f42ac80d91578dacf6e1187baee8a6c7b4a194aa 100644 --- a/oneflow/core/actor/actor.cpp +++ b/oneflow/core/actor/actor.cpp @@ -443,12 +443,12 @@ int64_t Actor::HandleRegstToConsumer(Regst* regst, std::function<bool(int64_t)> return real_consumer_cnt; } -bool Actor::IsReadReady() { +bool Actor::IsReadReady() const { return naive_consumed_rs_.IsCurSlotReady() && inplace_consumed_rs_.IsCurSlotReady() && IsCustomizedReadReady(); } -bool Actor::IsWriteReady() { +bool Actor::IsWriteReady() const { return naive_produced_rs_.IsCurSlotReady() && inplace_produced_rs_.IsCurSlotReady() && IsCustomizedWriteReady(); } diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h index 50d89d663458a30a5b9fdcd2db42865f71d0e2bd..bf05e4313db87753cf0c2e572e8c4800aee42d8e 100644 --- a/oneflow/core/actor/actor.h +++ b/oneflow/core/actor/actor.h @@ -155,8 +155,8 @@ class Actor { void TryLogActEvent(const std::function<void()>& Callback) const; // Ready - bool IsReadReady(); - bool IsWriteReady(); + bool IsReadReady() const; + bool IsWriteReady() const; // Naive, Inplace Or Customized void TakeOverInplaceConsumedAndProduced(const PbMap<std::string, RegstDescProto>& produced_ids); @@ -178,8 +178,8 @@ class Actor { virtual void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const {} virtual void NormalProcessCustomizedEordMsg(const ActorMsg&) {} virtual void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) { UNIMPLEMENTED(); } - virtual bool IsCustomizedReadReady() { return true; } - virtual bool IsCustomizedReadAlwaysUnReadyFromNow() { return false; } + virtual bool IsCustomizedReadReady() const { return true; } + virtual bool IsCustomizedReadAlwaysUnReadyFromNow() const { return false; } virtual std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName() { return std::make_pair(RegstNameType::kCustomized, HashSet<std::string>{}); @@ -189,7 +189,7 @@ class Actor { // Customized Produced virtual func virtual void UpdtStateAsCustomizedProducedRegst(Regst* regst) { UNIMPLEMENTED(); } - virtual bool IsCustomizedWriteReady() { return true; } + virtual bool IsCustomizedWriteReady() const { return true; } virtual std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedProducedRegstDescName() { return std::make_pair(RegstNameType::kCustomized, HashSet<std::string>{}); diff --git a/oneflow/core/actor/copy_comm_net_actor.cpp b/oneflow/core/actor/copy_comm_net_actor.cpp index c67df6c6d809f6d8556a5c1b571463b42dc3fab4..917abb8d876313739b90a12839a17831f0460c2e 100644 --- a/oneflow/core/actor/copy_comm_net_actor.cpp +++ b/oneflow/core/actor/copy_comm_net_actor.cpp @@ -86,11 +86,11 @@ void CopyCommNetActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { next_piece_id_ += 1; } -bool CopyCommNetActor::IsCustomizedReadReady() { +bool CopyCommNetActor::IsCustomizedReadReady() const { return piece_id2regst_ctx_.find(next_piece_id_) != piece_id2regst_ctx_.end(); } -bool CopyCommNetActor::IsCustomizedReadAlwaysUnReadyFromNow() { +bool CopyCommNetActor::IsCustomizedReadAlwaysUnReadyFromNow() const { return is_in_eord_ && piece_id2regst_ctx_.empty(); } diff --git a/oneflow/core/actor/copy_comm_net_actor.h b/oneflow/core/actor/copy_comm_net_actor.h index 094c88b8c39563f011ffaaf97e17e7f585660dc1..58cf354a73b71b9543578839d1566420283d7d8b 100644 --- a/oneflow/core/actor/copy_comm_net_actor.h +++ b/oneflow/core/actor/copy_comm_net_actor.h @@ -34,8 +34,8 @@ class CopyCommNetActor final : public Actor { void Act() override; void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; void AsyncSendCustomizedConsumedRegstMsgToProducer() override; - bool IsCustomizedReadReady() override; - bool IsCustomizedReadAlwaysUnReadyFromNow() override; + bool IsCustomizedReadReady() const override; + bool IsCustomizedReadAlwaysUnReadyFromNow() const override; void AsyncReturnAllCustomizedReadableRegst() override; bool is_in_eord_; diff --git a/oneflow/core/actor/decode_random_compute_actor.cpp b/oneflow/core/actor/decode_random_compute_actor.cpp index 7f297d65feaaf764e2f3ab1607448e2c4b3ca268..b4bedb9fb5952b9abf0fa7968d4cb74d0516b5b8 100644 --- a/oneflow/core/actor/decode_random_compute_actor.cpp +++ b/oneflow/core/actor/decode_random_compute_actor.cpp @@ -15,7 +15,7 @@ void DecodeRandomActor::Act() { AsyncLaunchKernel(GenDefaultKernelCtx()); } -bool DecodeRandomActor::IsCustomizedReadReady() { +bool DecodeRandomActor::IsCustomizedReadReady() const { return piece_id_ < Global<RuntimeCtx>::Get()->total_piece_num(); } diff --git a/oneflow/core/actor/decode_random_compute_actor.h b/oneflow/core/actor/decode_random_compute_actor.h index 848fb5c13801a247eb91e1d0f477d8575bf40651..09d88e390fc658b0d8623542f929ba2490a9814c 100644 --- a/oneflow/core/actor/decode_random_compute_actor.h +++ b/oneflow/core/actor/decode_random_compute_actor.h @@ -18,8 +18,8 @@ class DecodeRandomActor final : public CompActor { override { return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{}); } - bool IsCustomizedReadReady() override; - bool IsCustomizedReadAlwaysUnReadyFromNow() override { return !IsCustomizedReadReady(); } + bool IsCustomizedReadReady() const override; + bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return !IsCustomizedReadReady(); } int HandlerWaitToStart(const ActorMsg&); diff --git a/oneflow/core/actor/input_wise_compute_actor.cpp b/oneflow/core/actor/input_wise_compute_actor.cpp index c09159ba8e7bc3a279ded6366547e78d9cc52436..724c76142cc50ba0296b14ed86a04b8ebb96db1b 100644 --- a/oneflow/core/actor/input_wise_compute_actor.cpp +++ b/oneflow/core/actor/input_wise_compute_actor.cpp @@ -37,17 +37,8 @@ void InputWiseCompActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& CHECK_EQ(0, consumed_rs_.TryPushBackRegst(msg.regst())); } -bool InputWiseCompActor::IsCustomizedReadReady() { - CHECK_EQ(-1, cur_processed_regst_desc_id_); - consumed_rs_.ForChosenRegstDeq([this](int64_t) { return cur_processed_regst_desc_id_ == -1; }, - [this](const std::deque<Regst*>& reg_deq) { - if (reg_deq.empty()) { return; } - int64_t regst_desc_id = reg_deq.front()->regst_desc_id(); - if (regst_desc_id2is_processed_.at(regst_desc_id) == false) { - cur_processed_regst_desc_id_ = regst_desc_id; - } - }); - return cur_processed_regst_desc_id_ != -1; +bool InputWiseCompActor::IsCustomizedReadReady() const { + return -1 != GetCurProcessedRegstDescId(); } void InputWiseCompActor::ForEachCurCustomizedReadableRegst( @@ -56,6 +47,7 @@ void InputWiseCompActor::ForEachCurCustomizedReadableRegst( } void InputWiseCompActor::Act() { + cur_processed_regst_desc_id_ = GetCurProcessedRegstDescId(); Regst* cur_regst = consumed_rs_.Front(cur_processed_regst_desc_id_); CHECK(cur_regst); @@ -99,4 +91,18 @@ void InputWiseCompActor::AsyncReturnAllCustomizedReadableRegst() { bool InputWiseCompActor::ProducedCtrlRegstValid(int64_t regst_desc_id) const { return true; } +int64_t InputWiseCompActor::GetCurProcessedRegstDescId() const { + int64_t cur_processed_regst_desc_id = -1; + consumed_rs_.ForChosenRegstDeq( + [cur_processed_regst_desc_id](int64_t) { return cur_processed_regst_desc_id == -1; }, + [this, &cur_processed_regst_desc_id](const std::deque<Regst*>& reg_deq) { + if (reg_deq.empty()) { return; } + int64_t regst_desc_id = reg_deq.front()->regst_desc_id(); + if (regst_desc_id2is_processed_.at(regst_desc_id) == false) { + cur_processed_regst_desc_id = regst_desc_id; + } + }); + return cur_processed_regst_desc_id; +} + } // namespace oneflow diff --git a/oneflow/core/actor/input_wise_compute_actor.h b/oneflow/core/actor/input_wise_compute_actor.h index 0e51a68de1de87967378bc697b6a91f1d6cb3e05..01c18e256c9ba517bdd242a7f3a38d84652e84ae 100644 --- a/oneflow/core/actor/input_wise_compute_actor.h +++ b/oneflow/core/actor/input_wise_compute_actor.h @@ -28,9 +28,9 @@ class InputWiseCompActor : public CompActor { void Act() override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override; void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override; - bool IsCustomizedReadReady() override; + bool IsCustomizedReadReady() const override; void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} - bool IsCustomizedReadAlwaysUnReadyFromNow() override { + bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return ReceiveAllEordMsg() && consumed_rs_.available_regst_desc_cnt() == 0; } void AsyncReturnAllCustomizedReadableRegst() override; @@ -42,6 +42,7 @@ class InputWiseCompActor : public CompActor { void AsyncSendCustomizedConsumedRegstMsgToProducer() override; virtual void SetKernelCtxOther(void** other) { *other = nullptr; } + int64_t GetCurProcessedRegstDescId() const; RegstSlot consumed_rs_; HashMap<int64_t, bool> regst_desc_id2is_processed_; diff --git a/oneflow/core/actor/normal_backward_compute_actor.cpp b/oneflow/core/actor/normal_backward_compute_actor.cpp index 577136c906c8592ecc6d9a51eb44dd067cfaa7c7..d984a934bf07de5cfa6084e1d9dc28109cf464b1 100644 --- a/oneflow/core/actor/normal_backward_compute_actor.cpp +++ b/oneflow/core/actor/normal_backward_compute_actor.cpp @@ -81,7 +81,7 @@ void NormalBackwardCompActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { AsyncReturnModelRegstUntilLastPieceIdGreaterThan(piece_id); } -bool NormalBackwardCompActor::IsCustomizedReadReady() { +bool NormalBackwardCompActor::IsCustomizedReadReady() const { if (model_regst_desc_id_ != -1) { if (model_regst_queue_.empty()) { return false; } int64_t expected_model_vid = diff --git a/oneflow/core/actor/normal_backward_compute_actor.h b/oneflow/core/actor/normal_backward_compute_actor.h index 096ccaf50377d39872dfd1e1d13d57f31765b64b..4714c778a4dd1e332afcbf9de9a6e5b76ea441d5 100644 --- a/oneflow/core/actor/normal_backward_compute_actor.h +++ b/oneflow/core/actor/normal_backward_compute_actor.h @@ -18,7 +18,7 @@ class NormalBackwardCompActor final : public CompActor { void NormalProcessNaiveReadableDataRegstMsg(const std::deque<Regst*>&) override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override; void Act() override; - bool IsCustomizedReadReady() override; + bool IsCustomizedReadReady() const override; void AsyncReturnAllCustomizedReadableRegst() override; std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName() override { diff --git a/oneflow/core/actor/normal_forward_compute_actor.cpp b/oneflow/core/actor/normal_forward_compute_actor.cpp index 9ac65308ae4b2aed17f24dc37401d8bea02371ef..9edc79e89caf181c21eb6b6d63ea3bf513dee444 100644 --- a/oneflow/core/actor/normal_forward_compute_actor.cpp +++ b/oneflow/core/actor/normal_forward_compute_actor.cpp @@ -37,7 +37,7 @@ void NormalForwardCompActor::VirtualCompActorInit(const TaskProto& task_proto) { } } -bool NormalForwardCompActor::IsCustomizedWriteReady() { +bool NormalForwardCompActor::IsCustomizedWriteReady() const { if (const_buf_regst_desc_id_ != -1) { CHECK(send_const_buf_regst_); } return true; } @@ -123,7 +123,7 @@ void NormalForwardCompActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { cur_piece_id_ = -1; } -bool NormalForwardCompActor::IsCustomizedReadReady() { +bool NormalForwardCompActor::IsCustomizedReadReady() const { if (model_regst_desc_id_ != -1 && model_regst_ == nullptr) { return false; } if (const_model_regst_desc_id_ != -1 && const_model_regst_ == nullptr) { return false; } return true; diff --git a/oneflow/core/actor/normal_forward_compute_actor.h b/oneflow/core/actor/normal_forward_compute_actor.h index 27d10d22f3c377d3b391233e373f5fbb121f9e8c..bc7199c033be181add85fd25bae0cbe1f871f59b 100644 --- a/oneflow/core/actor/normal_forward_compute_actor.h +++ b/oneflow/core/actor/normal_forward_compute_actor.h @@ -16,7 +16,7 @@ class NormalForwardCompActor final : public CompActor { void ForEachCurCustomizedReadableRegst(std::function<void(const Regst*)>) const override; void NormalProcessCustomizedReadableRegstMsg(const ActorMsg&) override; void Act() override; - bool IsCustomizedReadReady() override; + bool IsCustomizedReadReady() const override; void AsyncReturnAllCustomizedReadableRegst() override; std::pair<RegstNameType, HashSet<std::string>> GetNaiveOrCustomizedConsumedRegstDescName() override { @@ -31,7 +31,7 @@ class NormalForwardCompActor final : public CompActor { void VirtualAsyncSendInplaceProducedRegstMsgToConsumer() override; void AsyncSendCustomizedConsumedRegstMsgToProducer() override; - bool IsCustomizedWriteReady() override; + bool IsCustomizedWriteReady() const override; void UpdtStateAsCustomizedProducedRegst(Regst* regst) override; bool CheckOutputActId(int64_t regst_desc_id) const override; diff --git a/oneflow/core/actor/normal_model_update_compute_actor.cpp b/oneflow/core/actor/normal_model_update_compute_actor.cpp index 9a74a6053d5a3aada4975f148687cc3f11a0e0c8..9f561bbdbcc36660f50d8300ac14a419651dc843 100644 --- a/oneflow/core/actor/normal_model_update_compute_actor.cpp +++ b/oneflow/core/actor/normal_model_update_compute_actor.cpp @@ -29,7 +29,7 @@ void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) { OF_SET_MSG_HANDLER(&NormalMdUpdtCompActor::HandlerInitModelAndConstModel); } -bool NormalMdUpdtCompActor::IsCustomizedWriteReady() { +bool NormalMdUpdtCompActor::IsCustomizedWriteReady() const { if (const_model_regst_desc_id_ != -1) { CHECK(send_const_model_regst_); } return true; } diff --git a/oneflow/core/actor/normal_model_update_compute_actor.h b/oneflow/core/actor/normal_model_update_compute_actor.h index 52291e0d0474251e62c73c16a291d0038f46b30d..008e95e79c4f170b1b6c1e09df0bdcdf2f537e4e 100644 --- a/oneflow/core/actor/normal_model_update_compute_actor.h +++ b/oneflow/core/actor/normal_model_update_compute_actor.h @@ -20,7 +20,7 @@ class NormalMdUpdtCompActor final : public CompActor { } void AsyncSendCustomizedProducedRegstMsgToConsumer() override {} void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; - bool IsCustomizedWriteReady() override; + bool IsCustomizedWriteReady() const override; void UpdtStateAsCustomizedProducedRegst(Regst* regst) override; void SendConstModelRegstToConsumer(); bool CheckOutputActId(int64_t regst_desc_id) const override; diff --git a/oneflow/core/actor/record_load_actor.cpp b/oneflow/core/actor/record_load_actor.cpp index 7c6171c1915ba9795a3ab916a34b04056ea61e65..643071e12d5475add5c28cb878e877656d6f3f75 100644 --- a/oneflow/core/actor/record_load_actor.cpp +++ b/oneflow/core/actor/record_load_actor.cpp @@ -27,7 +27,7 @@ void RecordLoadActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { if (record_load_status_.record_num > 0) { HandleProducedNaiveDataRegstToConsumer(); } } -bool RecordLoadActor::IsCustomizedReadReady() { +bool RecordLoadActor::IsCustomizedReadReady() const { return !is_eof_ && piece_id_ < Global<RuntimeCtx>::Get()->total_piece_num(); } diff --git a/oneflow/core/actor/record_load_actor.h b/oneflow/core/actor/record_load_actor.h index 77c980f48b3ad04c806618e435b95b73f5f80cf9..07982d233d56a616f78ac36923e9b2aa0a2e36b8 100644 --- a/oneflow/core/actor/record_load_actor.h +++ b/oneflow/core/actor/record_load_actor.h @@ -20,8 +20,8 @@ class RecordLoadActor final : public CompActor { return std::make_pair(RegstNameType::kNaive, HashSet<std::string>{}); } void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; - bool IsCustomizedReadReady() override; - bool IsCustomizedReadAlwaysUnReadyFromNow() override { return !IsCustomizedReadReady(); } + bool IsCustomizedReadReady() const override; + bool IsCustomizedReadAlwaysUnReadyFromNow() const override { return !IsCustomizedReadReady(); } int HandlerWaitToStart(const ActorMsg&);