Skip to content
Snippets Groups Projects
Unverified Commit 2e8f33fc authored by Jinhui Yuan's avatar Jinhui Yuan Committed by GitHub
Browse files

fix blob mem sharing issues (#1134)

parent 980b0dc3
No related branches found
No related tags found
No related merge requests found
......@@ -61,7 +61,7 @@ message OtherConf {
optional bool use_ordered_allreduce_in_mdupdt = 115 [default = false];
optional bool use_synthetic_data = 116 [default = false];
optional bool enable_write_snapshot = 130 [default = true];
optional bool enable_blob_mem_sharing = 140 [default = false];
optional bool enable_blob_mem_sharing = 140 [default = true];
oneof JobType {
TrainConf train_conf = 200;
......
......@@ -110,6 +110,7 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
int32_t blob_desc_cnt = 0;
std::unique_ptr<BlobDesc> ret(new BlobDesc());
const BlobDesc* last_blob_desc = nullptr;
HashMap<int32_t, size_t> blob_mem_id2size;
int32_t last_blob_mem_id = -1;
int64_t last_body_byte_size = 0;
......@@ -119,11 +120,15 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
header_byte_size += rt_blob_desc.ByteSizeOfBlobHeader();
int64_t cur_body_byte_size = rt_blob_desc.ByteSizeOfBlobBody();
int32_t blob_mem_id = blob_desc->blob_mem_id();
if (blob_mem_id == -1 || blob_mem_id != last_blob_mem_id) {
if (blob_mem_id == -1) {
body_byte_size += cur_body_byte_size;
}
if (blob_mem_id != -1 && blob_mem_id == last_blob_mem_id) {
CHECK_EQ(cur_body_byte_size, last_body_byte_size);
} else {
auto size_it = blob_mem_id2size.find(blob_mem_id);
if (size_it == blob_mem_id2size.end()) {
CHECK(blob_mem_id2size.emplace(blob_mem_id, cur_body_byte_size).second);
} else {
CHECK_EQ(size_it->second, cur_body_byte_size);
}
}
data_type_set.insert(static_cast<int>(blob_desc->data_type()));
if (max_col_num == -1) {
......@@ -136,6 +141,7 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
last_blob_mem_id = blob_mem_id;
last_body_byte_size = cur_body_byte_size;
}
for (auto& pair : blob_mem_id2size) { body_byte_size += pair.second; }
if (blob_desc_cnt == 0) {
// do nothing
} else if (blob_desc_cnt == 1) {
......
......@@ -127,19 +127,21 @@ void RegstMgr::NewBlobsInOneRegst(const std::vector<LbiBlobDescPair>& lbis, Regs
cur_body_pointer = main_mem_ptr + packed_blob_desc->ByteSizeOfBlobHeader();
}
int32_t last_blob_mem_id = -1;
size_t last_size = 0;
for (const LbiBlobDescPair& lbi : lbis) {
const RtBlobDesc* blob_desc = rt_regst_desc->GetRtBlobDescFromLbi(lbi.lbi());
int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id();
if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) {
cur_body_pointer += last_size;
}
std::unique_ptr<Blob> blob_ptr(
new Blob(regst, blob_desc, cur_header_pointer, cur_body_pointer));
InitOFRecordBlobIfNeed(blob_ptr.get());
CHECK(regst->lbi2blob_.emplace(lbi.lbi(), std::move(blob_ptr)).second);
cur_header_pointer += blob_desc->ByteSizeOfBlobHeader();
int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id();
if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) {
cur_body_pointer += blob_desc->ByteSizeOfBlobBody();
}
last_blob_mem_id = cur_blob_mem_id;
last_size = blob_desc->ByteSizeOfBlobBody();
}
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment