Skip to content
Snippets Groups Projects
Unverified Commit 5be84c50 authored by Jinhui Yuan's avatar Jinhui Yuan Committed by GitHub
Browse files

refine act_id order condition (#1088)

* refine act_id order condition

* strict act id check (excluding model regst)

* add TODO: figure out the ActNumForEachOutput of model regsts to MdSave area
parent 40a9b9a5
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,8 @@ void AccumulateCompActor::Init(const TaskProto& task_proto, int32_t max_acc_cnt,
next_piece_id_ = 0;
}
int64_t AccumulateCompActor::ActNumForEachOutput() const { return max_acc_cnt_; }
void AccumulateCompActor::Act() {
Regst* in_regst = GetNaiveSoleCurReadable();
Regst* out_regst = GetCurSoleWriteableRegst();
......
......@@ -13,6 +13,7 @@ class AccumulateCompActor : public CompActor {
protected:
void Init(const TaskProto&, int32_t max_acc_cnt, ColIdOrder order);
int64_t ActNumForEachOutput() const override;
private:
void Act() override;
......
......@@ -58,7 +58,7 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
writeable_produced_ctrl_regst_[regst->regst_desc_id()].push_back(regst.get());
produced_ctrl_regst2reading_cnt_[regst.get()] = 0;
}
produced_ctrl_regst2max_act_id_[pair.first] = act_id_;
produced_ctrl_regst2expected_act_id_[pair.first] = act_id_;
}
// non ctrl regst
......@@ -82,7 +82,7 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
writeable_produced_data_regst_[regst->regst_desc_id()].push_back(regst.get());
produced_data_regst2reading_cnt_[regst.get()] = 0;
}
produced_data_regst2max_act_id_[pair.first] = act_id_;
produced_data_regst2expected_act_id_[pair.first] = act_id_;
}
actual_writeable_produced_data_regst_desc_num_ = writeable_produced_data_regst_.size();
writeable_produced_data_regst_desc_cnt_ = actual_writeable_produced_data_regst_desc_num_;
......@@ -439,13 +439,11 @@ int Actor::ProcessWriteableCtrlRegstMsg(const ActorMsg& msg) {
auto writeable_it = writeable_produced_ctrl_regst_.find(regst->regst_desc_id());
CHECK(writeable_it != writeable_produced_ctrl_regst_.end());
if (writeable_it->second.empty()) { writeable_ctrl_regst_desc_cnt_ += 1; }
int64_t& max_act_id = produced_ctrl_regst2max_act_id_[regst->regst_desc_id()];
if (max_act_id >= 0) {
CHECK_GT(regst->act_id(), max_act_id);
max_act_id = regst->act_id();
} else if (regst->act_id() >= 0) {
max_act_id = regst->act_id();
int64_t& expected_act_id = produced_ctrl_regst2expected_act_id_[regst->regst_desc_id()];
if (expected_act_id >= 0 && CheckOutputActId(regst->regst_desc_id())) {
CHECK_EQ(regst->act_id(), expected_act_id);
}
expected_act_id = regst->act_id() + ActNumForEachOutput();
writeable_it->second.push_back(regst);
return 0;
}
......@@ -522,13 +520,11 @@ int Actor::TryUpdtStateAsProducedRegst(Regst* regst) {
auto writeable_it = writeable_produced_data_regst_.find(regst->regst_desc_id());
CHECK(writeable_it != writeable_produced_data_regst_.end());
if (writeable_it->second.empty()) { writeable_produced_data_regst_desc_cnt_ += 1; }
int64_t& max_act_id = produced_data_regst2max_act_id_[regst->regst_desc_id()];
if (max_act_id >= 0) {
CHECK_GT(regst->act_id(), max_act_id);
max_act_id = regst->act_id();
} else if (regst->act_id() >= 0) {
max_act_id = regst->act_id();
int64_t& expected_act_id = produced_data_regst2expected_act_id_[regst->regst_desc_id()];
if (expected_act_id >= 0 && CheckOutputActId(regst->regst_desc_id())) {
CHECK_EQ(regst->act_id(), expected_act_id);
}
expected_act_id = regst->act_id() + ActNumForEachOutput();
writeable_it->second.push_back(regst);
return 0;
}
......
......@@ -85,6 +85,11 @@ class Actor {
virtual bool IsCustomizedReadAlwaysUnReadyFromNow() { return false; }
bool IsWriteReady();
virtual void AsyncReturnAllCustomizedReadableRegst() {}
virtual int64_t ActNumForEachOutput() const { return 1; }
virtual bool CheckOutputActId(int64_t regst_desc_id) const {
return true; // TODO(jiyuan): figure out the ActNumForEachOutput of the model regsts to MdSave
// area
}
// Async Do on device_ctx_
void AsyncLaunchKernel(const KernelCtx&, std::function<Regst*(int64_t)> Regst4RegstDescId);
......@@ -149,7 +154,7 @@ class Actor {
// Status of Produced Registers
HashMap<int64_t, std::vector<std::unique_ptr<Regst>>> produced_data_regsts_;
HashMap<int64_t, std::deque<Regst*>> writeable_produced_data_regst_;
HashMap<int64_t, int64_t> produced_data_regst2max_act_id_;
HashMap<int64_t, int64_t> produced_data_regst2expected_act_id_;
HashMap<Regst*, int64_t> produced_data_regst2reading_cnt_;
int64_t actual_writeable_produced_data_regst_desc_num_;
int64_t writeable_produced_data_regst_desc_cnt_;
......@@ -163,7 +168,7 @@ class Actor {
// Status of Control Registers
HashMap<int64_t, std::vector<std::unique_ptr<Regst>>> produced_ctrl_regst_;
HashMap<int64_t, std::deque<Regst*>> writeable_produced_ctrl_regst_;
HashMap<int64_t, int64_t> produced_ctrl_regst2max_act_id_;
HashMap<int64_t, int64_t> produced_ctrl_regst2expected_act_id_;
HashMap<Regst*, int64_t> produced_ctrl_regst2reading_cnt_;
HashMap<int64_t, std::deque<Regst*>> consumed_ctrl_regst_;
int64_t total_reading_ctrl_cnt_;
......
......@@ -29,6 +29,8 @@ void InputWiseCompActor::Init(const TaskProto& task_proto) {
OF_SET_MSG_HANDLER(&InputWiseCompActor::HandlerNormal);
}
int64_t InputWiseCompActor::ActNumForEachOutput() const { return regst_desc_id2in_bn_id_.size(); }
void InputWiseCompActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) {
Regst* regst = msg.regst();
int regst_desc_id = regst->regst_desc_id();
......
......@@ -17,6 +17,7 @@ class InputWiseCompActor : public CompActor {
int64_t processed_regst_desc_id_cnt() const { return processed_regst_desc_id_cnt_; }
int64_t RegstDescNum() const { return readable_regsts_.size(); }
int64_t InBnId4RegstDescId(int64_t id) const { return regst_desc_id2in_bn_id_.at(id); }
int64_t ActNumForEachOutput() const override;
bool EnableInplace() const {
return GetDeviceType() == DeviceType::kGPU && Global<JobDesc>::Get()->enable_mem_sharing();
}
......
......@@ -34,6 +34,10 @@ void NormalForwardCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
}
}
bool NormalForwardCompActor::CheckOutputActId(int64_t regst_desc_id) const {
return regst_desc_id != forward_model_regst_desc_id_;
}
void NormalForwardCompActor::ForEachCurCustomizedReadableRegst(
std::function<void(const Regst*)> handler) const {
if (model_regst_desc_id_ != -1) { handler(model_regst_); }
......
......@@ -21,6 +21,7 @@ class NormalForwardCompActor final : public CompActor {
std::pair<bool, std::vector<std::string>> GetNaiveConsumedRegstDescName() override {
return {false, {"in"}};
}
bool CheckOutputActId(int64_t regst_desc_id) const override;
int HandlerInitModelAndConstBuf(const ActorMsg&);
void UpdateModelRegstPtr(Regst* regst);
......
......@@ -19,6 +19,10 @@ void NormalMdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
OF_SET_MSG_HANDLER(&NormalMdUpdtCompActor::HandlerInitModelAndConstModel);
}
bool NormalMdUpdtCompActor::CheckOutputActId(int64_t regst_desc_id) const {
return regst_desc_id != model_regst_desc_id_ && regst_desc_id != const_model_regst_desc_id_;
}
void NormalMdUpdtCompActor::Act() {
Regst* cur_model_regst = GetCurWriteableRegst(model_regst_desc_id_);
cur_model_regst->set_model_version_id(next_model_version_id_);
......
......@@ -17,6 +17,7 @@ class NormalMdUpdtCompActor final : public CompActor {
std::pair<bool, std::vector<std::string>> GetNaiveConsumedRegstDescName() override {
return {true, {}};
}
bool CheckOutputActId(int64_t regst_desc_id) const override;
void InitRegstBySendToFw(int64_t regst_desc_id);
int HandlerInitModelAndConstModel(const ActorMsg&);
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment