From 2e8f33fc1dc366a1a26b9be2cae2cc12d8b36f2c Mon Sep 17 00:00:00 2001
From: Jinhui Yuan <yuan.ms2@gmail.com>
Date: Sun, 19 Aug 2018 01:11:29 +0800
Subject: [PATCH] fix blob mem sharing issues (#1134)

---
 oneflow/core/job/job_conf.proto            |  2 +-
 oneflow/core/register/blob_desc.cpp        | 14 ++++++++++----
 oneflow/core/register/register_manager.cpp | 12 +++++++-----
 3 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto
index fb9735365..3966f6510 100644
--- a/oneflow/core/job/job_conf.proto
+++ b/oneflow/core/job/job_conf.proto
@@ -61,7 +61,7 @@ message OtherConf {
   optional bool use_ordered_allreduce_in_mdupdt = 115 [default = false];
   optional bool use_synthetic_data =  116 [default = false];
   optional bool enable_write_snapshot = 130 [default = true];
-  optional bool enable_blob_mem_sharing = 140 [default = false];
+  optional bool enable_blob_mem_sharing = 140 [default = true];
 
   oneof JobType {
     TrainConf train_conf = 200;
diff --git a/oneflow/core/register/blob_desc.cpp b/oneflow/core/register/blob_desc.cpp
index 0013a3d5d..44abc68e1 100644
--- a/oneflow/core/register/blob_desc.cpp
+++ b/oneflow/core/register/blob_desc.cpp
@@ -110,6 +110,7 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
   int32_t blob_desc_cnt = 0;
   std::unique_ptr<BlobDesc> ret(new BlobDesc());
   const BlobDesc* last_blob_desc = nullptr;
+  HashMap<int32_t, size_t> blob_mem_id2size;
 
   int32_t last_blob_mem_id = -1;
   int64_t last_body_byte_size = 0;
@@ -119,11 +120,15 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
     header_byte_size += rt_blob_desc.ByteSizeOfBlobHeader();
     int64_t cur_body_byte_size = rt_blob_desc.ByteSizeOfBlobBody();
     int32_t blob_mem_id = blob_desc->blob_mem_id();
-    if (blob_mem_id == -1 || blob_mem_id != last_blob_mem_id) {
+    if (blob_mem_id == -1) {
       body_byte_size += cur_body_byte_size;
-    }
-    if (blob_mem_id != -1 && blob_mem_id == last_blob_mem_id) {
-      CHECK_EQ(cur_body_byte_size, last_body_byte_size);
+    } else {
+      auto size_it = blob_mem_id2size.find(blob_mem_id);
+      if (size_it == blob_mem_id2size.end()) {
+        CHECK(blob_mem_id2size.emplace(blob_mem_id, cur_body_byte_size).second);
+      } else {
+        CHECK_EQ(size_it->second, cur_body_byte_size);
+      }
     }
     data_type_set.insert(static_cast<int>(blob_desc->data_type()));
     if (max_col_num == -1) {
@@ -136,6 +141,7 @@ std::unique_ptr<BlobDesc> ComputePackedBlobDesc(
     last_blob_mem_id = blob_mem_id;
     last_body_byte_size = cur_body_byte_size;
   }
+  for (auto& pair : blob_mem_id2size) { body_byte_size += pair.second; }
   if (blob_desc_cnt == 0) {
     // do nothing
   } else if (blob_desc_cnt == 1) {
diff --git a/oneflow/core/register/register_manager.cpp b/oneflow/core/register/register_manager.cpp
index dc28547cc..c91888f84 100644
--- a/oneflow/core/register/register_manager.cpp
+++ b/oneflow/core/register/register_manager.cpp
@@ -127,19 +127,21 @@ void RegstMgr::NewBlobsInOneRegst(const std::vector<LbiBlobDescPair>& lbis, Regs
     cur_body_pointer = main_mem_ptr + packed_blob_desc->ByteSizeOfBlobHeader();
   }
   int32_t last_blob_mem_id = -1;
+  size_t last_size = 0;
   for (const LbiBlobDescPair& lbi : lbis) {
     const RtBlobDesc* blob_desc = rt_regst_desc->GetRtBlobDescFromLbi(lbi.lbi());
+    int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id();
+    if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) {
+      cur_body_pointer += last_size;
+    }
     std::unique_ptr<Blob> blob_ptr(
         new Blob(regst, blob_desc, cur_header_pointer, cur_body_pointer));
     InitOFRecordBlobIfNeed(blob_ptr.get());
     CHECK(regst->lbi2blob_.emplace(lbi.lbi(), std::move(blob_ptr)).second);
-
     cur_header_pointer += blob_desc->ByteSizeOfBlobHeader();
-    int32_t cur_blob_mem_id = lbi.blob_desc().header().blob_mem_id();
-    if (cur_blob_mem_id == -1 || cur_blob_mem_id != last_blob_mem_id) {
-      cur_body_pointer += blob_desc->ByteSizeOfBlobBody();
-    }
+
     last_blob_mem_id = cur_blob_mem_id;
+    last_size = blob_desc->ByteSizeOfBlobBody();
   }
 }
 
-- 
GitLab