diff --git a/oneflow/core/job/improver.cpp b/oneflow/core/job/improver.cpp index 5995112c24400b0ccf49f7b89a18f704ffd9fd6d..45446972546cf63bf858859a2f31c80000901feb 100644 --- a/oneflow/core/job/improver.cpp +++ b/oneflow/core/job/improver.cpp @@ -420,8 +420,10 @@ void GenMemBlockAndChunk4Plan(Plan* plan) { // mzuid = memory zone unique id HashMap<int64_t, ChunkProto> mzuid2chunk; - auto GenMemBlock4RegstIfNeed = [&](RegstDescProto* regst_desc, int64_t job_id, - int64_t machine_id) { + auto GenMemBlock4RegstIfNeed = [&](RegstDescProto* regst_desc, const TaskProto* task) { + const int64_t job_id = task->job_id(); + const int64_t machine_id = task->machine_id(); + const int64_t thrd_id = task->thrd_id(); int64_t mem_block_id = regst_desc->mem_block_id(); int64_t mem_block_offset = regst_desc->mem_block_offset(); CHECK_NE(mem_block_id, -1); @@ -440,6 +442,7 @@ void GenMemBlockAndChunk4Plan(Plan* plan) { *(mem_block.mutable_mem_case()) = regst_desc->mem_case(); mem_block.set_enable_reuse_mem(regst_desc->enable_reuse_mem()); mem_block.set_mem_size(regst_main_size + mem_block_offset); + mem_block.set_thrd_id_hint(thrd_id); CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), mem_block).second); } else { MemBlockProto* mem_block = &(mem_block_id2mem_block.at(mem_block_id)); @@ -461,6 +464,7 @@ void GenMemBlockAndChunk4Plan(Plan* plan) { MemoryCaseUtil::GetHostPinnedMemoryCaseForRegstSeparatedHeader(regst_desc->mem_case()); mem_block.set_enable_reuse_mem(false); mem_block.set_mem_size(regst_separated_size); + mem_block.set_thrd_id_hint(thrd_id); CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), mem_block).second); } }; @@ -490,7 +494,7 @@ void GenMemBlockAndChunk4Plan(Plan* plan) { for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); for (auto& pair : *task->mutable_produced_regst_desc()) { - GenMemBlock4RegstIfNeed(&pair.second, task->job_id(), task->machine_id()); + GenMemBlock4RegstIfNeed(&pair.second, task); } } diff --git a/oneflow/core/memory/memory_block.proto b/oneflow/core/memory/memory_block.proto index 3d7f61a8c275b3793ef84feb211dfe637c6ac07c..830366caa009b2febc395953ff8c83d460316d46 100644 --- a/oneflow/core/memory/memory_block.proto +++ b/oneflow/core/memory/memory_block.proto @@ -12,6 +12,8 @@ message MemBlockProto { optional int64 chunk_id = 6 [default = -1]; optional int64 chunk_offset = 7 [default = -1]; required int64 mem_size = 8; + // NOTE(chengcheng): thrd id hint is used by packed separated block group order. + optional int64 thrd_id_hint = 9 [default = -1]; } message ChunkProto { diff --git a/oneflow/core/memory/memory_case_util.cpp b/oneflow/core/memory/memory_case_util.cpp index 624dbaab6d3d0180448c24c6dbd12b0c620b0baf..9492fb23429114603db0cc518009cf064b5b3275 100644 --- a/oneflow/core/memory/memory_case_util.cpp +++ b/oneflow/core/memory/memory_case_util.cpp @@ -49,16 +49,26 @@ MemoryCase MemoryCaseUtil::GetHostPinnedMemoryCaseForRegstSeparatedHeader( return ret; } -int64_t MemoryCaseUtil::GenMemZoneUniqueId(int64_t machine_id, const MemoryCase& mem_case) { - int64_t mem_zone_id = 1024; +int64_t MemoryCaseUtil::GenMemZoneId(const MemoryCase& mem_case) { + // [0, 127] = GPU device mem + // [128] = CPU host mem + // [129, 256] = CPU host mem used by CUDA with device id + // [257, ...] Other Device + if (mem_case.has_device_cuda_mem()) { + return mem_case.device_cuda_mem().device_id(); // GPU device mem + } if (mem_case.has_host_mem()) { if (mem_case.host_mem().has_cuda_pinned_mem()) { - mem_zone_id = 1025 + mem_case.host_mem().cuda_pinned_mem().device_id(); + return 129 + mem_case.host_mem().cuda_pinned_mem().device_id(); // Host mem used by GPU } - } else { - mem_zone_id = mem_case.device_cuda_mem().device_id(); + return 128; // CPU host mem } - return (machine_id << 32) | mem_zone_id; + UNIMPLEMENTED(); + return -1; +} + +int64_t MemoryCaseUtil::GenMemZoneUniqueId(int64_t machine_id, const MemoryCase& mem_case) { + return (machine_id << 32) | (MemoryCaseUtil::GenMemZoneId(mem_case)); } bool MemoryCaseUtil::IsHostUnPinnedMemoryCase(const MemoryCase& mem_case) { @@ -66,4 +76,8 @@ bool MemoryCaseUtil::IsHostUnPinnedMemoryCase(const MemoryCase& mem_case) { && !mem_case.host_mem().used_by_network(); } +int64_t MemoryCaseUtil::MergeThrdMemZoneId(int64_t thrd_id, const MemoryCase& mem_case) { + return (thrd_id << 21) | (MemoryCaseUtil::GenMemZoneId(mem_case)); +} + } // namespace oneflow diff --git a/oneflow/core/memory/memory_case_util.h b/oneflow/core/memory/memory_case_util.h index 66b4b974e1edf9c1cd3a09dd20d1bbee0bf1fee3..fbef6f2afa3c7f9544d7e8ad0b09d0af4ecb9992 100644 --- a/oneflow/core/memory/memory_case_util.h +++ b/oneflow/core/memory/memory_case_util.h @@ -45,6 +45,10 @@ struct MemoryCaseUtil { static int64_t GenMemZoneUniqueId(int64_t machine_id, const MemoryCase& mem_case); + static int64_t GenMemZoneId(const MemoryCase& mem_case); + + static int64_t MergeThrdMemZoneId(int64_t thrd_id, const MemoryCase& mem_case); + static bool IsHostUnPinnedMemoryCase(const MemoryCase& mem_case); }; diff --git a/oneflow/core/register/register_manager.cpp b/oneflow/core/register/register_manager.cpp index a9716f7815907478354c727b5ff6810b639adf18..418222732dc716249c3b4342f753372d2800da33 100644 --- a/oneflow/core/register/register_manager.cpp +++ b/oneflow/core/register/register_manager.cpp @@ -33,10 +33,21 @@ void CheckBlobInRegstNotDisabled(const RegstDescProto& regst_desc) { == false); } +struct PackedChunkInfo { + MemoryCase mem_case; + int64_t size; + std::vector<const MemBlockProto*> blocks; + PackedChunkInfo(const MemoryCase& mem) { + mem_case = mem; + size = 0; + } +}; + } // namespace RegstMgr::RegstMgr(const Plan& plan) { int64_t this_machine_id = Global<MachineCtx>::Get()->this_machine_id(); + HashMap<int64_t, char*> chunk_id2ptr; for (const ChunkProto& chunk : plan.block_chunk_list().chunk()) { if (chunk.machine_id() != this_machine_id) { continue; } @@ -44,20 +55,56 @@ RegstMgr::RegstMgr(const Plan& plan) { char* chunk_ptr = Global<MemoryAllocator>::Get()->Allocate(chunk.mem_case(), chunk.mem_size()); CHECK(chunk_id2ptr.emplace(chunk.chunk_id(), chunk_ptr).second); } + + HashSet<int64_t> all_block_ids; + HashMap<int64_t, PackedChunkInfo> zone_id2packed_chunk; for (const MemBlockProto& mem_block : plan.block_chunk_list().mem_block()) { if (mem_block.machine_id() != this_machine_id) { continue; } if (mem_block.mem_size() == 0) { continue; } - char* mem_block_ptr = nullptr; + const int64_t mem_block_id = mem_block.mem_block_id(); + CHECK(all_block_ids.insert(mem_block_id).second); if (mem_block.has_chunk_id()) { CHECK(mem_block.has_chunk_offset()); CHECK(chunk_id2ptr.find(mem_block.chunk_id()) != chunk_id2ptr.end()); - mem_block_ptr = chunk_id2ptr.at(mem_block.chunk_id()) + mem_block.chunk_offset(); + char* mem_block_ptr = chunk_id2ptr.at(mem_block.chunk_id()) + mem_block.chunk_offset(); + CHECK(mem_block_id2ptr_.emplace(mem_block_id, mem_block_ptr).second); } else { - mem_block_ptr = - Global<MemoryAllocator>::Get()->Allocate(mem_block.mem_case(), mem_block.mem_size()); + int64_t zone_id = MemoryCaseUtil::GenMemZoneId(mem_block.mem_case()); + if (zone_id2packed_chunk.find(zone_id) == zone_id2packed_chunk.end()) { + zone_id2packed_chunk.emplace(zone_id, PackedChunkInfo(mem_block.mem_case())); + } + PackedChunkInfo* packed_chunk = &(zone_id2packed_chunk.at(zone_id)); + packed_chunk->blocks.push_back(&mem_block); + packed_chunk->size += mem_block.mem_size(); + CHECK(packed_chunk->mem_case == mem_block.mem_case()); } - CHECK(mem_block_id2ptr_.emplace(mem_block.mem_block_id(), mem_block_ptr).second); } + + for (auto& pair : zone_id2packed_chunk) { + PackedChunkInfo* packed_chunk = &pair.second; + char* ptr = + Global<MemoryAllocator>::Get()->Allocate(packed_chunk->mem_case, packed_chunk->size); + // sort blocks as thrd id + std::vector<const MemBlockProto*>* blocks = &(packed_chunk->blocks); + std::sort(blocks->begin(), blocks->end(), + [](const MemBlockProto* lhs, const MemBlockProto* rhs) { + if (lhs->thrd_id_hint() == rhs->thrd_id_hint()) { + return lhs->mem_block_id() < rhs->mem_block_id(); + } + return lhs->thrd_id_hint() < rhs->thrd_id_hint(); + }); + int64_t offset = 0; + for (const MemBlockProto* block : packed_chunk->blocks) { + CHECK(mem_block_id2ptr_.emplace(block->mem_block_id(), ptr + offset).second); + offset += block->mem_size(); + } + CHECK_EQ(offset, packed_chunk->size); + } + + for (int64_t mem_block_id : all_block_ids) { + CHECK(mem_block_id2ptr_.find(mem_block_id) != mem_block_id2ptr_.end()); + } + for (const TaskProto& task : plan.task()) { if (task.machine_id() != this_machine_id) { continue; } for (const auto& pair : task.produced_regst_desc()) {