Skip to content
Snippets Groups Projects
Commit ab899c67 authored by Xinqi's avatar Xinqi
Browse files

add reentrant_lock_op in main_job

parent 81f7aae7
No related branches found
No related tags found
No related merge requests found
......@@ -32,9 +32,12 @@ const std::vector<int64_t>& CriticalSectionDesc::CriticalSectionIds4JobId(int64_
return job_id2critical_section_ids_.at(job_id);
}
const HashSet<int64_t>& CriticalSectionDesc::GetIntersectingCriticalSectionIds(int64_t idx) const {
void CriticalSectionDesc::DumpCriticalSectionId2IntersectinIds(PbRpf<IdList>* id2id_list) const {
CHECK(inited_);
return critical_section_id2intersecting_ids_.at(idx);
FOR_RANGE(int64_t, i, 0, critical_sections_.size()) {
*id2id_list->Add()->mutable_id() = {critical_section_id2intersecting_ids_.at(i).begin(),
critical_section_id2intersecting_ids_.at(i).end()};
}
}
void CriticalSectionDesc::UpdateJobId2CriticalSectionIds() {
......
......@@ -3,6 +3,7 @@
#include "oneflow/core/common/util.h"
#include "oneflow/core/job/critical_section.pb.h"
#include "oneflow/core/common/protobuf.h"
namespace oneflow {
......@@ -17,7 +18,7 @@ class CriticalSectionDesc final {
const CriticalSection& GetCriticalSection(int64_t) const;
CriticalSection* MutCriticalSection(int64_t) const;
const std::vector<int64_t>& CriticalSectionIds4JobId(int64_t) const;
const HashSet<int64_t>& GetIntersectingCriticalSectionIds(int64_t) const;
void DumpCriticalSectionId2IntersectinIds(PbRpf<IdList>* id2id_list) const;
private:
friend class Global<CriticalSectionDesc>;
......
......@@ -182,13 +182,13 @@ void WithJobSetLevelGlobalObjs(
Global<std::vector<std::unique_ptr<JobDesc>>>::Get()->emplace_back(
new JobDesc(job_set.job_conf(i), i));
}
Global<BufferMgr<int32_t>>::New();
Global<BufferMgr<int32_t>>::Get()->NewChannel(kChannelNameGlobalWaitJobId,
Global<BufferMgr<int64_t>>::New();
Global<BufferMgr<int64_t>>::Get()->NewChannel(kChannelNameGlobalWaitJobId,
job_set.job_conf_size());
Handler(job_set.job_conf());
Global<BufferMgr<int32_t>>::Delete();
Global<BufferMgr<int64_t>>::Delete();
Global<std::vector<std::unique_ptr<JobDesc>>>::Delete();
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
Global<CriticalSectionDesc>::Delete();
......@@ -407,6 +407,9 @@ HashSet<std::string> GetArgOpNames(const std::vector<Job>& jobs) {
HashSet<std::string> arg_op_names;
for (const Job& job : jobs) {
for (const auto& arg_op_name : job.arg_op_name()) { arg_op_names.insert(arg_op_name); }
for (const OperatorConf& op_conf : job.net().op()) {
if (op_conf.has_variable_conf()) { arg_op_names.insert(op_conf.name()); }
}
}
return arg_op_names;
}
......@@ -419,9 +422,7 @@ HashMap<std::string, HashSet<int64_t>> GetInterfaceOpName2JobIds(const std::vect
const auto& job = jobs.at(i);
for (const auto& op : job.net().op()) {
if (IsInterfaceOpConf(op)) {
if (op.has_variable_conf() == false) {
CHECK(arg_op_names.find(op.name()) != arg_op_names.end());
}
CHECK(arg_op_names.find(op.name()) != arg_op_names.end());
CHECK(interface_op_name2job_ids[op.name()].emplace(i).second);
} else {
CHECK(unique_op_name_check.find(op.name()) == unique_op_name_check.end());
......@@ -481,7 +482,7 @@ std::vector<TaskProto*> SortSameOpNameTaskProtos(const std::string& op_name, Pla
return task_protos;
}
RegstDescProto* GetSoleDataRegst(TaskProto* task_proto) {
RegstDescProto* GetSoleProducedDataRegst(TaskProto* task_proto) {
RegstDescProto* ret = nullptr;
for (auto& pair : *task_proto->mutable_produced_regst_desc()) {
RegstDescProto* regst_desc = &pair.second;
......@@ -508,9 +509,9 @@ void BindInterfaceMemBlockId(const std::vector<Job>& jobs, std::vector<Plan>* su
FOR_RANGE(int32_t, i, 0, first_vec.size()) {
CHECK_EQ(task_protos.at(i)->machine_id(), first_vec.at(i)->machine_id());
CHECK_EQ(task_protos.at(i)->thrd_id(), first_vec.at(i)->thrd_id());
const RegstDescProto& first_regst_desc = *GetSoleDataRegst(first_vec.at(i));
const RegstDescProto& first_regst_desc = *GetSoleProducedDataRegst(first_vec.at(i));
CHECK_EQ(first_regst_desc.mem_shared_offset(), 0);
RegstDescProto* regst_desc = GetSoleDataRegst(task_protos.at(i));
RegstDescProto* regst_desc = GetSoleProducedDataRegst(task_protos.at(i));
CHECK_EQ(regst_desc->mem_shared_offset(), 0);
regst_desc->set_mem_shared_id(first_regst_desc.mem_shared_id());
}
......@@ -519,7 +520,9 @@ void BindInterfaceMemBlockId(const std::vector<Job>& jobs, std::vector<Plan>* su
}
void MakeMainJob(const std::vector<Job>& jobs, Job* main_job,
std::vector<std::string>* identity_tick_op_names) {
std::vector<std::string>* identity_tick_op_names,
LogicalBlobId* critical_section_sink_lbi) {
CHECK(Global<MachineCtx>::Get()->IsThisMachineMaster());
std::vector<OperatorConf> op_confs;
OperatorConf wait_and_send_ids_op_conf;
{
......@@ -527,17 +530,29 @@ void MakeMainJob(const std::vector<Job>& jobs, Job* main_job,
auto* wait_and_send_ids_conf = wait_and_send_ids_op_conf.mutable_wait_and_send_ids_conf();
wait_and_send_ids_conf->set_out("out");
wait_and_send_ids_conf->set_wait_channel_name(kChannelNameGlobalWaitJobId);
wait_and_send_ids_conf->set_data_type(DataType::kInt32);
FOR_RANGE(int64_t, i, 0, Global<std::vector<std::unique_ptr<JobDesc>>>::Get()->size()) {
const auto& cs_idx = Global<CriticalSectionDesc>::Get()->CriticalSectionIds4JobId(i);
*wait_and_send_ids_conf->add_id_list()->mutable_id() = {cs_idx.begin(), cs_idx.end()};
}
}
op_confs.push_back(wait_and_send_ids_op_conf);
OperatorConf reentrant_lock_op_conf;
{
reentrant_lock_op_conf.set_name(std::string("System-Main-ReentrantLock_") + NewUniqueId());
auto* reentrant_lock_conf = reentrant_lock_op_conf.mutable_reentrant_lock_conf();
reentrant_lock_conf->set_start(wait_and_send_ids_op_conf.name() + "/out");
// ibn "end" is set after plan generated because we don't like cycle in job
reentrant_lock_conf->set_out("out");
Global<CriticalSectionDesc>::Get()->DumpCriticalSectionId2IntersectinIds(
reentrant_lock_conf->mutable_lock_id2intersecting_lock_ids());
}
op_confs.push_back(reentrant_lock_op_conf);
OperatorConf cs_case_op_conf;
{
cs_case_op_conf.set_name(std::string("System-Main-Case_") + NewUniqueId());
auto* cs_case_conf = cs_case_op_conf.mutable_case_conf();
cs_case_conf->set_in(wait_and_send_ids_op_conf.name() + "/out");
cs_case_conf->set_in(reentrant_lock_op_conf.name() + "/out");
FOR_RANGE(int64_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
cs_case_conf->add_out(GenRepeatedBn("out", i));
}
......@@ -560,9 +575,13 @@ void MakeMainJob(const std::vector<Job>& jobs, Job* main_job,
cs_esac_conf->add_in(identity_tick_op_name + "/out");
}
cs_esac_conf->set_out("out");
cs_esac_conf->set_data_type(DataType::kInt32);
}
op_confs.push_back(cs_esac_op_conf);
critical_section_sink_lbi->set_op_name(cs_esac_op_conf.name());
critical_section_sink_lbi->set_blob_name("out");
ParallelConf parallel_conf;
parallel_conf.set_policy(kDataParallel);
parallel_conf.add_device_name("0:cpu:0");
......@@ -574,11 +593,50 @@ void MakeMainJob(const std::vector<Job>& jobs, Job* main_job,
main_job->mutable_other()->set_default_data_type(DataType::kInt32);
}
void CompileMainJob(Job* main_job, int32_t job_id, Plan* main_plan) {
void ConnectCriticalSectionEndToReentrantLockEnd(Plan* main_plan,
const LogicalBlobId& critical_section_sink_lbi) {
TaskProto* reentrant_lock_task = nullptr;
TaskProto* cs_sink_task = nullptr;
FOR_RANGE(int64_t, i, 0, main_plan->task_size()) {
auto* task = main_plan->mutable_task(i);
CHECK_EQ(task->exec_sequence().exec_node_size(), 1);
if (task->task_type() == TaskType::kReentrantLock) {
CHECK_ISNULL(reentrant_lock_task);
reentrant_lock_task = task;
} else {
const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
if (critical_section_sink_lbi.op_name() == kernel_conf.op_attribute().op_conf().name()) {
CHECK_ISNULL(cs_sink_task);
cs_sink_task = task;
}
}
}
CHECK_NOTNULL(reentrant_lock_task);
CHECK_NOTNULL(cs_sink_task);
RegstDescProto* cs_end_regst = GetSoleProducedDataRegst(cs_sink_task);
cs_end_regst->add_consumer_task_id(reentrant_lock_task->task_id());
reentrant_lock_task->mutable_consumed_regst_desc_id()->at("in").add_regst_desc_id(
cs_end_regst->regst_desc_id());
auto* reentrant_exec_node = reentrant_lock_task->mutable_exec_sequence()->mutable_exec_node(0);
(*reentrant_exec_node->mutable_bn_in_op2regst_desc_id())["end"] = cs_end_regst->regst_desc_id();
auto* op_attribute = reentrant_exec_node->mutable_kernel_conf()->mutable_op_attribute();
op_attribute->add_input_bns("end");
(*op_attribute->mutable_bn_in_op2lbi())["end"] = critical_section_sink_lbi;
auto* reentrant_lock_conf = op_attribute->mutable_op_conf()->mutable_reentrant_lock_conf();
reentrant_lock_conf->set_end(GenLogicalBlobName(critical_section_sink_lbi));
}
void CompileMainJob(Job* main_job, const LogicalBlobId& critical_section_sink_lbi, int32_t job_id,
Plan* main_plan) {
CHECK(Global<MachineCtx>::Get()->IsThisMachineMaster());
JobConf job_conf = ConvertJob2JobConf(*main_job);
Global<JobDesc>::New(job_conf, job_id);
CompileCurJobOnMaster(main_job, main_plan, false);
Global<JobDesc>::Delete();
ConnectCriticalSectionEndToReentrantLockEnd(main_plan, critical_section_sink_lbi);
}
void AddGlobalJobDesc(const Job& job, int32_t job_id) {
......@@ -619,12 +677,13 @@ void FinishGlobalCriticalSectionDesc(const std::vector<Plan>& plans) {
}
}
}
HashSet<int64_t> input_output_mem_block_ids;
HashMap<int64_t, HashSet<int64_t>> job_id2input_output_mem_block_ids;
auto* critical_section_desc = Global<CriticalSectionDesc>::Get();
// set mem_block_id for InputOutputCriticalSection
FOR_RANGE(int64_t, i, 0, critical_section_desc->CriticalSectionNum()) {
auto* critical_section = critical_section_desc->MutCriticalSection(i);
int64_t job_id = critical_section->job_id();
auto* input_output_mem_block_ids = &job_id2input_output_mem_block_ids[job_id];
if (critical_section->has_input_output_critical_section()) {
HashSet<int64_t> mem_block_ids;
for (const auto& op_name :
......@@ -633,7 +692,7 @@ void FinishGlobalCriticalSectionDesc(const std::vector<Plan>& plans) {
mem_block_ids.insert(cur_mem_block_ids.begin(), cur_mem_block_ids.end());
}
*critical_section->mutable_mem_block_id() = {mem_block_ids.begin(), mem_block_ids.end()};
input_output_mem_block_ids.insert(mem_block_ids.begin(), mem_block_ids.end());
input_output_mem_block_ids->insert(mem_block_ids.begin(), mem_block_ids.end());
} else {
CHECK(critical_section->has_total_job_critical_section());
}
......@@ -642,11 +701,22 @@ void FinishGlobalCriticalSectionDesc(const std::vector<Plan>& plans) {
// set mem_block_id for TotalJobCriticalSection
FOR_RANGE(int64_t, i, 0, critical_section_desc->CriticalSectionNum()) {
auto* critical_section = critical_section_desc->MutCriticalSection(i);
int64_t job_id = critical_section->job_id();
const auto& input_output_mem_block_ids = job_id2input_output_mem_block_ids.at(job_id);
if (critical_section->has_total_job_critical_section()) {
int64_t job_id = critical_section->job_id();
CHECK(unique_job_id_check.emplace(job_id).second);
auto* mem_block_ids = &job_id2mem_block_ids.at(job_id);
mem_block_ids->erase(input_output_mem_block_ids.begin(), input_output_mem_block_ids.end());
{
auto it = mem_block_ids->begin();
while (it != mem_block_ids->end()) {
if (input_output_mem_block_ids.find(*it) == input_output_mem_block_ids.end()) {
++it;
} else {
it = mem_block_ids->erase(it);
}
}
}
*critical_section->mutable_mem_block_id() = {mem_block_ids->begin(), mem_block_ids->end()};
}
}
......@@ -671,9 +741,10 @@ void CompileAndMergePlanOnMaster(const PbRpf<JobConf>& job_confs, Plan* plan) {
std::vector<std::string> identity_tick_op_names;
{
Job main_job;
MakeMainJob(jobs, &main_job, &identity_tick_op_names);
LogicalBlobId critical_section_sink_lbi;
MakeMainJob(jobs, &main_job, &identity_tick_op_names, &critical_section_sink_lbi);
AddGlobalJobDesc(main_job, sub_plans.size());
CompileMainJob(&main_job, sub_plans.size(), &main_plan);
CompileMainJob(&main_job, critical_section_sink_lbi, sub_plans.size(), &main_plan);
}
LinkMainPlan(plan, main_plan, identity_tick_op_names);
TeePersistentLogStream::Create("merged_plan")->Write(*plan);
......
......@@ -79,9 +79,9 @@ Runtime::Runtime(const Plan& plan, size_t total_piece_num, bool is_experiment_ph
// if is this machine master
FOR_RANGE(int64_t, i, 0,
Global<std::vector<std::unique_ptr<JobDesc>>>::Get()->at(0)->TotalBatchNum()) {
Global<BufferMgr<int32_t>>::Get()->Get(kChannelNameGlobalWaitJobId)->Send(0);
Global<BufferMgr<int64_t>>::Get()->Get(kChannelNameGlobalWaitJobId)->Send(0);
}
Global<BufferMgr<int32_t>>::Get()->Get(kChannelNameGlobalWaitJobId)->Close();
Global<BufferMgr<int64_t>>::Get()->Get(kChannelNameGlobalWaitJobId)->Close();
runtime_ctx->WaitUntilCntEqualZero("running_actor_cnt");
OF_BARRIER();
DeleteAllGlobal();
......
......@@ -180,14 +180,17 @@ void ForEachInputOutputCriticalSectionOpNodes(
const OpGraph& op_graph,
const std::function<void(const HashSet<const OpNode*>&, const std::vector<std::string>&)>&
Handler) {
HashSet<std::string> arg_op_names;
for (const auto& name : Global<JobDesc>::Get()->arg_op_name()) { arg_op_names.insert(name); }
HashMap<OperatorConf::OpTypeCase, HashSet<const OpNode*>> op_type_case2op_nodes;
for (const std::string& op_name : Global<JobDesc>::Get()->arg_op_name()) {
const OpNode* op_node = op_graph.OpNode4OpName(op_name);
op_graph.ForEachNode([&](OpNode* op_node) {
const auto& op_name = op_node->op().op_name();
const auto& op_conf = op_node->op().op_conf();
if (IsInterfaceOpConf(op_conf)) {
if (IsInterfaceOpConf(op_conf)
&& (op_conf.has_variable_conf() || arg_op_names.find(op_name) != arg_op_names.end())) {
CHECK(op_type_case2op_nodes[op_conf.op_type_case()].emplace(op_node).second);
}
}
});
for (OperatorConf::OpTypeCase op_type_case :
{OperatorConf::kVariableConf, OperatorConf::kInputConf}) {
if (op_type_case2op_nodes[op_type_case].empty()) { continue; }
......
......@@ -11,14 +11,17 @@ void ReentrantLockStatus::Init(const KernelConf& kernel_conf) {
total_acquired_lock_num_ = 0;
lock_id2queued_request_act_id_.resize(conf.lock_id2intersecting_lock_ids_size());
lock_id2acquired_num_.resize(conf.lock_id2intersecting_lock_ids_size());
lock_id2intersecting_lock_ids_ = conf.lock_id2intersecting_lock_ids();
for (const IdList ids : conf.lock_id2intersecting_lock_ids()) {
lock_id2intersecting_lock_ids_.push_back(
std::vector<int64_t>(ids.id().begin(), ids.id().end()));
}
}
bool ReentrantLockStatus::TryAcquireLock(int64_t lock_id) {
CHECK_EQ(lock_id2queued_request_act_id_.at(lock_id).empty(), false);
int64_t act_id = lock_id2queued_request_act_id_.at(lock_id).front();
bool blocked = false;
for (int64_t intersect_lock_id : lock_id2intersecting_lock_ids_.Get(lock_id).id()) {
for (int64_t intersect_lock_id : lock_id2intersecting_lock_ids_.at(lock_id)) {
if (lock_id2acquired_num_.at(intersect_lock_id) > 0
|| (lock_id2queued_request_act_id_.at(intersect_lock_id).empty() == false
&& lock_id2queued_request_act_id_.at(intersect_lock_id).front() < act_id)) {
......@@ -48,7 +51,7 @@ int64_t ReentrantLockStatus::ReleaseLock(int64_t lock_id) {
if (lock_id2acquired_num_.at(lock_id) == 0) {
int64_t min_act_id = cur_act_id();
int64_t min_lock_id = -1;
for (int64_t intersect_lock_id : lock_id2intersecting_lock_ids_.Get(lock_id).id()) {
for (int64_t intersect_lock_id : lock_id2intersecting_lock_ids_.at(lock_id)) {
CHECK_EQ(lock_id2acquired_num_.at(intersect_lock_id), 0);
int64_t act_id = lock_id2queued_request_act_id_.at(intersect_lock_id).front();
if (act_id < min_act_id) {
......
......@@ -46,7 +46,7 @@ class ReentrantLockStatus final {
size_t total_acquired_lock_num_;
std::vector<std::queue<int64_t>> lock_id2queued_request_act_id_;
std::vector<size_t> lock_id2acquired_num_;
PbRpf<IdList> lock_id2intersecting_lock_ids_;
std::vector<std::vector<int64_t>> lock_id2intersecting_lock_ids_;
};
template<typename T>
......
......@@ -2,22 +2,24 @@
namespace oneflow {
void WaitAndSendIdsKernel::ForwardDataContent(
template<typename T>
void WaitAndSendIdsKernel<T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
CHECK(ctx.other);
auto* status = static_cast<WaitAndSendIdsStatus*>(ctx.other);
const auto& conf = op_conf().wait_and_send_ids_conf();
const auto& conf = this->op_conf().wait_and_send_ids_conf();
if (status->out_idx_ >= status->out_num_) {
status->channel_status_ =
Global<BufferMgr<int32_t>>::Get()->Get(conf.wait_channel_name())->Receive(&status->in_id_);
Global<BufferMgr<int64_t>>::Get()->Get(conf.wait_channel_name())->Receive(&status->in_id_);
if (status->channel_status_ == kChannelStatusErrorClosed) { return; }
status->out_idx_ = 0;
status->out_num_ = conf.id_list(status->in_id_).id_size();
}
*BnInOp2Blob("out")->mut_dptr<int32_t>() = conf.id_list(status->in_id_).id(status->out_idx_);
*BnInOp2Blob("out")->mut_dptr<T>() = conf.id_list(status->in_id_).id(status->out_idx_);
++status->out_idx_;
}
REGISTER_KERNEL(OperatorConf::kWaitAndSendIdsConf, WaitAndSendIdsKernel);
ADD_CPU_DEFAULT_KERNEL_CREATOR(OperatorConf::kWaitAndSendIdsConf, WaitAndSendIdsKernel,
INT_DATA_TYPE_SEQ);
} // namespace oneflow
......@@ -8,11 +8,12 @@ namespace oneflow {
struct WaitAndSendIdsStatus final {
ChannelStatus channel_status_;
int32_t in_id_;
int32_t out_idx_;
int32_t out_num_;
int64_t in_id_;
int64_t out_idx_;
size_t out_num_;
};
template<typename T>
class WaitAndSendIdsKernel final : public KernelIf<DeviceType::kCPU> {
public:
OF_DISALLOW_COPY_AND_MOVE(WaitAndSendIdsKernel);
......
......@@ -1278,12 +1278,13 @@ message IdList {
message WaitAndSendIdsOpConf {
required string out = 1;
required string wait_channel_name = 2;
repeated IdList id_list = 3;
repeated IdList id_list = 3;
required DataType data_type = 4 [default = kInt32];
}
message ReentrantLockOpConf {
required string start = 1;
required string end = 2;
optional string end = 2;
required string out = 3;
repeated IdList lock_id2intersecting_lock_ids = 4;
}
......
......@@ -6,7 +6,7 @@ namespace oneflow {
void ReentrantLockOp::InitFromOpConf() {
EnrollInputBn("start", false);
EnrollInputBn("end", false);
if (op_conf().reentrant_lock_conf().has_end()) { EnrollInputBn("end", false); }
EnrollOutputBn("out", false);
}
......@@ -16,7 +16,7 @@ void ReentrantLockOp::InferBlobDescs(
CHECK_EQ(parallel_ctx->parallel_num(), 1);
BlobDesc* out = GetBlobDesc4BnInOp("out");
out->mut_shape() = Shape({1});
const DataType data_type = op_conf().esac_conf().data_type();
const DataType data_type = GetBlobDesc4BnInOp("out")->data_type();
CHECK(IsIntegralDataType(data_type));
out->set_data_type(data_type);
}
......
......@@ -22,7 +22,7 @@ void WaitAndSendIdsOp::InferBlobDescs(
const ParallelContext* parallel_ctx) const {
CHECK_EQ(parallel_ctx->parallel_num(), 1);
GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1});
GetBlobDesc4BnInOp("out")->set_data_type(DataType::kInt32);
GetBlobDesc4BnInOp("out")->set_data_type(op_conf().wait_and_send_ids_conf().data_type());
}
void WaitAndSendIdsOp::InferHasBatchDim(
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment