diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index fb9735365173a3efbe0cf1d42b2781fb726dbdfe..3966f6510d7cbc2cf0e1060161d7e68b6e25ec71 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -61,7 +61,7 @@ message OtherConf { optional bool use_ordered_allreduce_in_mdupdt = 115 [default = false]; optional bool use_synthetic_data = 116 [default = false]; optional bool enable_write_snapshot = 130 [default = true]; - optional bool enable_blob_mem_sharing = 140 [default = false]; + optional bool enable_blob_mem_sharing = 140 [default = true]; oneof JobType { TrainConf train_conf = 200; diff --git a/oneflow/core/register/blob_desc.cpp b/oneflow/core/register/blob_desc.cpp index 0013a3d5d332ba2062fbd2564244fd684f3d3fbd..44abc68e1bbd190a6be8e317cbc8495c573eb691 100644 --- a/oneflow/core/register/blob_desc.cpp +++ b/oneflow/core/register/blob_desc.cpp @@ -110,6 +110,7 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc( int32_t blob_desc_cnt = 0; std::unique_ptr<BlobDesc> ret(new BlobDesc()); const BlobDesc* last_blob_desc = nullptr; + HashMap<int32_t, size_t> blob_mem_id2size; int32_t last_blob_mem_id = -1; int64_t last_body_byte_size = 0; @@ -119,11 +120,15 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc( header_byte_size += rt_blob_desc.ByteSizeOfBlobHeader(); int64_t cur_body_byte_size = rt_blob_desc.ByteSizeOfBlobBody(); int32_t blob_mem_id = blob_desc->blob_mem_id(); - if (blob_mem_id == -1 || blob_mem_id != last_blob_mem_id) { + if (blob_mem_id == -1) { body_byte_size += cur_body_byte_size; - } - if (blob_mem_id != -1 && blob_mem_id == last_blob_mem_id) { - CHECK_EQ(cur_body_byte_size, last_body_byte_size); + } else { + auto size_it = blob_mem_id2size.find(blob_mem_id); + if (size_it == blob_mem_id2size.end()) { + CHECK(blob_mem_id2size.emplace(blob_mem_id, cur_body_byte_size).second); + } else { + CHECK_EQ(size_it->second, cur_body_byte_size); + } } data_type_set.insert(static_cast<int>(blob_desc->data_type())); if (max_col_num == -1) { @@ -136,6 +141,7 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc( last_blob_mem_id = blob_mem_id; last_body_byte_size = cur_body_byte_size; } + for (auto& pair : blob_mem_id2size) { body_byte_size += pair.second; } if (blob_desc_cnt == 0) { // do nothing } else if (blob_desc_cnt == 1) { diff --git a/oneflow/core/register/register_manager.cpp b/oneflow/core/register/register_manager.cpp index dc28547cccd03693f124749109891b81d2736248..c91888f843ee79d304edb3960d1ffe4cc5c0607a 100644 --- a/oneflow/core/register/register_manager.cpp +++ b/oneflow/core/register/register_manager.cpp @@ -127,19 +127,21 @@ void RegstMgr::NewBlobsInOneRegst(const std::vector<LbiBlobDescPair>& lbis, Regs cur_body_pointer = main_mem_ptr + packed_blob_desc->ByteSizeOfBlobHeader(); } int32_t last_blob_mem_id = -1; + size_t last_size = 0; for (const LbiBlobDescPair& lbi : lbis) { const RtBlobDesc* blob_desc = rt_regst_desc->GetRtBlobDescFromLbi(lbi.lbi()); + int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id(); + if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) { + cur_body_pointer += last_size; + } std::unique_ptr<Blob> blob_ptr( new Blob(regst, blob_desc, cur_header_pointer, cur_body_pointer)); InitOFRecordBlobIfNeed(blob_ptr.get()); CHECK(regst->lbi2blob_.emplace(lbi.lbi(), std::move(blob_ptr)).second); - cur_header_pointer += blob_desc->ByteSizeOfBlobHeader(); - int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id(); - if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) { - cur_body_pointer += blob_desc->ByteSizeOfBlobBody(); - } + last_blob_mem_id = cur_blob_mem_id; + last_size = blob_desc->ByteSizeOfBlobBody(); } }