diff --git a/oneflow/core/framework/nn_graph.cpp b/oneflow/core/framework/nn_graph.cpp index a18b4b6e82336a1a8dbd45fb48e702ffb0256361..59253959b7669f71ebe060d38f24101fe69c7d0e 100644 --- a/oneflow/core/framework/nn_graph.cpp +++ b/oneflow/core/framework/nn_graph.cpp @@ -66,7 +66,10 @@ Maybe<void> NNGraph::RegisterVariableOpNamesAndTensors( } else { var_blob = JUST(var->eager_blob_object())->mut_blob(); } - CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(variable_op_names.at(i), var_blob).second); + const std::string& var_name = variable_op_names.at(i); + CHECK_OR_RETURN(!var_name.empty()); + CHECK_OR_RETURN(variable_op_name2eager_blob_.emplace(var_name, var_blob).second); + CHECK_OR_RETURN(variable_op_names_.insert(var_name).second); } return Maybe<void>::Ok(); } @@ -85,6 +88,7 @@ Maybe<void> NNGraph::CompileAndInitRuntime() { double start = GetCurTime(); // TODO(chengcheng): new memory reused by chunk Compiler().Compile(&job_, &plan_, /* need_job_complete */ true); + PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(&plan_, variable_op_names_); LOG(INFO) << "\njob_id: " << job_ctx->job_id() << " , job_name: " << name_ << " , compile time: " << (GetCurTime() - start) / 1000000000.0 << " seconds.\n"; @@ -93,7 +97,7 @@ Maybe<void> NNGraph::CompileAndInitRuntime() { } // TODO(chengcheng): test collective boxing for multi-job. PlanUtil::GenCollectiveBoxingPlan(&job_, &plan_); - PlanUtil::SetForceInplaceMemBlock(&plan_); + // PlanUtil::SetForceInplaceMemBlock(&plan_); NOTE(chengcheng): only for ssp. PlanUtil::DumpCtrlRegstInfoToPlan(&plan_); } if (GlobalProcessCtx::WorldSize() > 1) { @@ -110,7 +114,7 @@ Maybe<void> NNGraph::CompileAndInitRuntime() { PlanUtil::PopulateOpAttibute(&plan_, plan_.job_id2op_attribute_ref_table()); NewRuntimeBuffers(); - runtime_.reset(new Runtime(plan_, GetMaxVal<size_t>(), false)); + runtime_.reset(new Runtime(plan_, GetMaxVal<size_t>(), false, variable_op_name2eager_blob_)); runtime_inited_ = true; return Maybe<void>::Ok(); } diff --git a/oneflow/core/framework/nn_graph.h b/oneflow/core/framework/nn_graph.h index 44d968a76123cfa55c409246bbc79a0296f56fbf..2d59a407588efc5e3daf05f6270dc89b15d5fba5 100644 --- a/oneflow/core/framework/nn_graph.h +++ b/oneflow/core/framework/nn_graph.h @@ -52,6 +52,7 @@ class NNGraph final : public NNGraphIf { std::vector<std::string> input_op_names_; std::vector<std::string> output_op_names_; HashMap<std::string, Blob*> variable_op_name2eager_blob_; + HashSet<std::string> variable_op_names_; Job job_; Plan plan_; // TODO(chengcheng): temp impl using runtime now, need reimplement for dynamic multi nn.Graph. diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 92df9f1c6a0c9bae77e8a74574f06025f2f84bc4..7a529f53180943fa247eb2824eb1c6714a843188 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -108,7 +108,6 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable); PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); - PlanUtil::GenMemBlockAndChunk4Plan(plan); Global<OpGraph>::Delete(); } diff --git a/oneflow/core/job/improver.cpp b/oneflow/core/job/improver.cpp index cc5d6bf3843c1669acd634ea704fde596d002df6..c4cb54caf9f1383109a5323fe904251700d6e7f1 100644 --- a/oneflow/core/job/improver.cpp +++ b/oneflow/core/job/improver.cpp @@ -286,7 +286,7 @@ void GenMemBlockAndChunk4Plan(Plan* plan) { mem_block.add_job_id(job_id); mem_block.set_machine_id(machine_id); *(mem_block.mutable_mem_case()) = - MemoryCaseUtil::GetHostPinnedMemoryCaseForRegstSeparatedHeader(regst_desc->mem_case()); + MemoryCaseUtil::GetHostMemoryCaseForRegstSeparatedHeader(regst_desc->mem_case()); mem_block.set_enable_reuse_mem(false); mem_block.set_mem_size(regst_separated_size); mem_block.set_thrd_id_hint(thrd_id); diff --git a/oneflow/core/job/inter_job_mem_sharing_util.cpp b/oneflow/core/job/inter_job_mem_sharing_util.cpp index 62a3b9952830a257e111dd083b451dcdb26b998d..e17794124e9b65a9b876b232cd2d40571e16dd10 100644 --- a/oneflow/core/job/inter_job_mem_sharing_util.cpp +++ b/oneflow/core/job/inter_job_mem_sharing_util.cpp @@ -275,7 +275,7 @@ void MergeSharedMemBlockR2L(RegstDescProto* lhs, RegstDescProto* rhs, int64_t merged_header_id = lhs->separated_header_mem_block_id(); int64_t erased_header_id = rhs->separated_header_mem_block_id(); MemoryCase header_mem_case = - MemoryCaseUtil::GetHostPinnedMemoryCaseForRegstSeparatedHeader(lhs->mem_case()); + MemoryCaseUtil::GetHostMemoryCaseForRegstSeparatedHeader(lhs->mem_case()); MemBlockProto* merged_header_block = CheckValidAndGetMemBlock(merged_header_id, separated_header_mem_size, header_mem_case); MemBlockProto* erased_header_block = diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index a6f727da48b4fd88d2c4cc15890e5e90a44d6e39..9545c03ea89eec9163c9f4470778f744eb99eb88 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -210,6 +210,7 @@ Maybe<void> CompileCurJobOnMaster(Job* job, Plan* plan, bool need_job_complete) if (GlobalProcessCtx::IsThisProcessMaster()) { double start = GetCurTime(); Compiler().Compile(job, plan, need_job_complete); + PlanUtil::GenMemBlockAndChunk4Plan(plan); LOG(INFO) << "\njob_id: " << job_desc.job_id() << " , job_name: " << job_desc.job_name() << " , compile time: " << (GetCurTime() - start) / 1000000000.0 << " seconds.\n"; @@ -1024,7 +1025,9 @@ Maybe<void> Oneflow::Init(const oneflow::JobSet& job_set) { LOG(ERROR) << "this is dry run, exiting"; exit(0); } - runtime_.reset(new Runtime(plan_, GetMaxVal<size_t>(), false)); + + HashMap<std::string, Blob*> variable_op_name2eager_blob; + runtime_.reset(new Runtime(plan_, GetMaxVal<size_t>(), false, variable_op_name2eager_blob)); OF_PROFILER_RANGE_POP(); // new Runtime return Maybe<void>::Ok(); } diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index d995d09dce2229c7e2344a7166ea99b9b12988a9..81be8c0f70fc8c944766842300c359a43156b0c1 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -63,10 +63,29 @@ void PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan) { } void PlanUtil::GenMemBlockAndChunk4Plan(Plan* plan) { + HashSet<std::string> variable_op_names; + PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan(plan, variable_op_names); +} + +void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan( + Plan* plan, const HashSet<std::string>& variable_op_names) { HashMap<int64_t, MemBlockProto> mem_block_id2mem_block; // mzuid = memory zone unique id HashMap<int64_t, ChunkProto> mzuid2chunk; + auto IsVariableRegst = [&](const TaskProto* task, std::string* name) -> bool { + if (variable_op_names.empty()) { return false; } + if (task->exec_sequence().exec_node_size() != 1) { return false; } + const auto& op_conf = + GetOpAttribute(plan, task->job_id(), task->exec_sequence().exec_node(0).kernel_conf()) + .op_conf(); + if (!op_conf.has_variable_conf()) { return false; } + const std::string& var_name = op_conf.name(); + if (variable_op_names.find(var_name) == variable_op_names.end()) { return false; } + *name = var_name; + return true; + }; + auto GenMemBlock4RegstIfNeed = [&](RegstDescProto* regst_desc, const TaskProto* task) { const int64_t job_id = task->job_id(); const int64_t machine_id = task->machine_id(); @@ -77,10 +96,25 @@ void PlanUtil::GenMemBlockAndChunk4Plan(Plan* plan) { CHECK_NE(mem_block_offset, -1); CHECK_EQ(regst_desc->separated_header_mem_block_id(), -1); + std::string var_name; + bool is_variable_regst = IsVariableRegst(task, &var_name); + if (is_variable_regst) { + CHECK(!var_name.empty()); + CHECK_EQ(regst_desc->register_num(), 1); + CHECK_EQ(regst_desc->min_register_num(), 1); + CHECK_EQ(regst_desc->max_register_num(), 1); + regst_desc->set_variable_op_name(var_name); + } + RtRegstDesc rt_regst_desc(*regst_desc); int64_t regst_main_size = rt_regst_desc.TotalMainByteSize4AllRegst(); int64_t regst_separated_size = rt_regst_desc.TotalSeparatedHeaderByteSize4AllRegst(); + if (is_variable_regst) { + CHECK_GT(regst_separated_size, + 0); // NOTE(chengcheng): variable regst header ALWAYS separated + } + if (mem_block_id2mem_block.find(mem_block_id) == mem_block_id2mem_block.end()) { MemBlockProto mem_block; mem_block.set_mem_block_id(mem_block_id); @@ -90,9 +124,14 @@ void PlanUtil::GenMemBlockAndChunk4Plan(Plan* plan) { mem_block.set_enable_reuse_mem(regst_desc->enable_reuse_mem()); mem_block.set_mem_size(regst_main_size + mem_block_offset); mem_block.set_thrd_id_hint(thrd_id); + if (is_variable_regst) { + mem_block.set_variable_op_name(var_name); + mem_block.set_is_separated_header(false); + } CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), mem_block).second); } else { MemBlockProto* mem_block = &(mem_block_id2mem_block.at(mem_block_id)); + CHECK(!mem_block->has_variable_op_name()); // variable regst mem block is unique. CHECK_EQ(mem_block->job_id(0), job_id); CHECK_EQ(mem_block->machine_id(), machine_id); CHECK(mem_block->mem_case() == regst_desc->mem_case()); @@ -108,10 +147,14 @@ void PlanUtil::GenMemBlockAndChunk4Plan(Plan* plan) { mem_block.add_job_id(job_id); mem_block.set_machine_id(machine_id); *(mem_block.mutable_mem_case()) = - MemoryCaseUtil::GetHostPinnedMemoryCaseForRegstSeparatedHeader(regst_desc->mem_case()); + MemoryCaseUtil::GetHostMemoryCaseForRegstSeparatedHeader(regst_desc->mem_case()); mem_block.set_enable_reuse_mem(false); mem_block.set_mem_size(regst_separated_size); mem_block.set_thrd_id_hint(thrd_id); + if (is_variable_regst) { + mem_block.set_variable_op_name(var_name); + mem_block.set_is_separated_header(true); + } CHECK(mem_block_id2mem_block.emplace(mem_block.mem_block_id(), mem_block).second); } }; @@ -211,7 +254,7 @@ void PlanUtil::CleanUselessMemBlockAndCheckValid(Plan* plan) { CHECK_EQ(header_mem_block.mem_size(), separated_header_mem_size); CHECK_EQ(task.machine_id(), header_mem_block.machine_id()); CHECK(header_mem_block.mem_case() - == MemoryCaseUtil::GetHostPinnedMemoryCaseForRegstSeparatedHeader(regst.mem_case())); + == MemoryCaseUtil::GetHostMemoryCaseForRegstSeparatedHeader(regst.mem_case())); CHECK(header_mem_block.enable_reuse_mem() == false); const auto& header_block_job_ids = mem_block_id2job_ids[header_block_id]; CHECK(header_block_job_ids.find(task.job_id()) != header_block_job_ids.end()); diff --git a/oneflow/core/job/plan_util.h b/oneflow/core/job/plan_util.h index 1e6d9021d05e4fae37cfa36099cfeb13ca3ff4e4..9675f0888fee05d0354478bd1a8c6b418abde3a9 100644 --- a/oneflow/core/job/plan_util.h +++ b/oneflow/core/job/plan_util.h @@ -18,6 +18,7 @@ limitations under the License. #include <functional> #include "oneflow/core/common/protobuf.h" +#include "oneflow/core/common/util.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/job.pb.h" @@ -28,6 +29,8 @@ struct PlanUtil { static std::function<const TaskProto*(int64_t)> MakeGetterTaskProto4TaskId(const Plan& plan); static void SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan); static void GenMemBlockAndChunk4Plan(Plan* plan); + static void GenMemBlockAndChunkWithVariableOpNames4Plan( + Plan* plan, const HashSet<std::string>& variable_op_names); static void CleanUselessMemBlockAndCheckValid(Plan* plan); static void ToDotFile(const Plan& plan, const std::string& filepath); static std::function<RegstDescProto*(int64_t)> MakeMutRegstDesc4Id(Plan* plan); diff --git a/oneflow/core/job/runtime.cpp b/oneflow/core/job/runtime.cpp index 49df84f26f42df79e92ccd3a5b8eda271f381a83..dcb90ff9162a55a857904c04db993ea9c0fa614e 100644 --- a/oneflow/core/job/runtime.cpp +++ b/oneflow/core/job/runtime.cpp @@ -63,8 +63,9 @@ bool HasNonCtrlConsumedRegstDescId(const TaskProto& task) { } // namespace -Runtime::Runtime(const Plan& plan, size_t total_piece_num, bool is_experiment_phase) { - NewAllGlobal(plan, total_piece_num, is_experiment_phase); +Runtime::Runtime(const Plan& plan, size_t total_piece_num, bool is_experiment_phase, + const HashMap<std::string, Blob*>& variable_op_name2eager_blob) { + NewAllGlobal(plan, total_piece_num, is_experiment_phase, variable_op_name2eager_blob); std::vector<const TaskProto*> source_tasks; std::vector<const TaskProto*> other_tasks; int64_t this_machine_task_num = 0; @@ -95,7 +96,8 @@ Runtime::~Runtime() { DeleteAllGlobal(); } -void Runtime::NewAllGlobal(const Plan& plan, size_t total_piece_num, bool is_experiment_phase) { +void Runtime::NewAllGlobal(const Plan& plan, size_t total_piece_num, bool is_experiment_phase, + const HashMap<std::string, Blob*>& variable_op_name2eager_blob) { Global<RuntimeCtx>::New(total_piece_num, is_experiment_phase); if (GlobalProcessCtx::IsThisProcessMaster() && Global<RuntimeCtx>::Get()->NeedCollectActEvent()) { Global<ActEventLogger>::New(is_experiment_phase); @@ -126,7 +128,8 @@ void Runtime::NewAllGlobal(const Plan& plan, size_t total_piece_num, bool is_exp } Global<boxing::collective::CollectiveBoxingExecutor>::New(plan); Global<MemoryAllocator>::New(); - Global<RegstMgr>::New(plan); + Global<RegstMgr>::New(); + Global<RegstMgr>::Get()->AddPlan(plan, variable_op_name2eager_blob); Global<ActorMsgBus>::New(); Global<ThreadMgr>::New(); Global<ThreadMgr>::Get()->AddPlan(plan); diff --git a/oneflow/core/job/runtime.h b/oneflow/core/job/runtime.h index 30d25632584245056b5c5c26669f52522e80dab5..c52c0b077f555b83276a7d29c418340358c9fb8d 100644 --- a/oneflow/core/job/runtime.h +++ b/oneflow/core/job/runtime.h @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/plan.pb.h" #include "oneflow/core/job/runtime_context.h" +#include "oneflow/core/register/blob.h" namespace oneflow { @@ -28,10 +29,13 @@ class Runtime final { Runtime() = delete; ~Runtime(); - Runtime(const Plan& plan, size_t total_piece_num, bool is_experiment_phase); + // TODO(chengcheng): refactor Runtime interface about variable_op_name2eager_blob + Runtime(const Plan& plan, size_t total_piece_num, bool is_experiment_phase, + const HashMap<std::string, Blob*>& variable_op_name2eager_blob); private: - void NewAllGlobal(const Plan& plan, size_t total_piece_num, bool is_experiment_phase); + void NewAllGlobal(const Plan& plan, size_t total_piece_num, bool is_experiment_phase, + const HashMap<std::string, Blob*>& variable_op_name2eager_blob); void DeleteAllGlobal(); }; diff --git a/oneflow/core/job/runtime_job_descs.cpp b/oneflow/core/job/runtime_job_descs.cpp index 0449468108ba3d39a7d8a05b6514e1aa41614aae..62acf962930e0c269c52bd8f85f7462d54cd99e0 100644 --- a/oneflow/core/job/runtime_job_descs.cpp +++ b/oneflow/core/job/runtime_job_descs.cpp @@ -24,4 +24,10 @@ RuntimeJobDescs::RuntimeJobDescs(const PbMap<int64_t, JobConfigProto>& proto) { } } +const JobDesc& RuntimeJobDescs::job_desc(int64_t job_id) const { + auto it = job_id2job_desc_.find(job_id); + CHECK(it != job_id2job_desc_.end()); + return *(it->second); +} + } // namespace oneflow diff --git a/oneflow/core/job/runtime_job_descs.h b/oneflow/core/job/runtime_job_descs.h index 2c93afce759cebe2484bb57751a484eee1e82e45..aab62c1463913bf4622fa73d4182dd3ec68e5c2e 100644 --- a/oneflow/core/job/runtime_job_descs.h +++ b/oneflow/core/job/runtime_job_descs.h @@ -24,7 +24,7 @@ namespace oneflow { class RuntimeJobDescs final { public: explicit RuntimeJobDescs(const PbMap<int64_t, JobConfigProto>& proto); - const JobDesc& job_desc(int64_t job_id) const { return *job_id2job_desc_.at(job_id); } + const JobDesc& job_desc(int64_t job_id) const; private: HashMap<int64_t, std::unique_ptr<JobDesc>> job_id2job_desc_; diff --git a/oneflow/core/memory/memory_block.proto b/oneflow/core/memory/memory_block.proto index 830366caa009b2febc395953ff8c83d460316d46..56bb0562526c597be32a19c4ce9f9ea696f7af55 100644 --- a/oneflow/core/memory/memory_block.proto +++ b/oneflow/core/memory/memory_block.proto @@ -14,6 +14,9 @@ message MemBlockProto { required int64 mem_size = 8; // NOTE(chengcheng): thrd id hint is used by packed separated block group order. optional int64 thrd_id_hint = 9 [default = -1]; + // NOTE(chengcheng): mark this block memory is shared with EagerParameter. + optional string variable_op_name = 10 [default = ""]; + optional bool is_separated_header = 11 [default = false]; } message ChunkProto { diff --git a/oneflow/core/memory/memory_case_util.cpp b/oneflow/core/memory/memory_case_util.cpp index 862b6b8d8cbc36030c4899e6ad538f6ca2fd7379..00885047fa3d3591253beb519c6b9b2b9742f3d3 100644 --- a/oneflow/core/memory/memory_case_util.cpp +++ b/oneflow/core/memory/memory_case_util.cpp @@ -37,12 +37,13 @@ bool MemoryCaseUtil::GetCommonMemoryCase(const MemoryCase& a, const MemoryCase& } } -MemoryCase MemoryCaseUtil::GetHostPinnedMemoryCaseForRegstSeparatedHeader( - const MemoryCase& mem_case) { - CHECK(mem_case.has_device_cuda_mem()); +MemoryCase MemoryCaseUtil::GetHostMemoryCaseForRegstSeparatedHeader(const MemoryCase& mem_case) { MemoryCase ret; - ret.mutable_host_mem()->mutable_cuda_pinned_mem()->set_device_id( - mem_case.device_cuda_mem().device_id()); + ret.mutable_host_mem(); + if (mem_case.has_device_cuda_mem()) { + ret.mutable_host_mem()->mutable_cuda_pinned_mem()->set_device_id( + mem_case.device_cuda_mem().device_id()); + } return ret; } diff --git a/oneflow/core/memory/memory_case_util.h b/oneflow/core/memory/memory_case_util.h index abb6885f548c5951ab1f30a2e602787cf1c4a1ae..62086884c528eb0794a649eed5436f75fef03cbb 100644 --- a/oneflow/core/memory/memory_case_util.h +++ b/oneflow/core/memory/memory_case_util.h @@ -42,7 +42,7 @@ inline bool operator==(const MemoryCase& lhs, const MemoryCase& rhs) { struct MemoryCaseUtil { static bool GetCommonMemoryCase(const MemoryCase& a, const MemoryCase& b, MemoryCase* common); - static MemoryCase GetHostPinnedMemoryCaseForRegstSeparatedHeader(const MemoryCase& mem_case); + static MemoryCase GetHostMemoryCaseForRegstSeparatedHeader(const MemoryCase& mem_case); static int64_t GenMemZoneUniqueId(int64_t machine_id, const MemoryCase& mem_case); diff --git a/oneflow/core/register/register_desc.proto b/oneflow/core/register/register_desc.proto index 9f5281a2f63eb9a1d86338dba58362f8f00e3558..38332cc7244cea156c1bf8fb6198571e7d1eace0 100644 --- a/oneflow/core/register/register_desc.proto +++ b/oneflow/core/register/register_desc.proto @@ -44,4 +44,6 @@ message RegstDescProto { int64 hint_inplace_consumed_regst_desc_id = 14 [default = -1]; int64 force_inplace_consumed_regst_desc_id = 15 [default = -1]; } + // NOTE(chengcheng): mark this regst memory is shared with EagerParameter. + optional string variable_op_name = 16 [default = ""]; } diff --git a/oneflow/core/register/register_manager.cpp b/oneflow/core/register/register_manager.cpp index f072fc2eadb6e432531848939badc410a0cf87aa..00ece25b2628eb2af501db96e796fd8ad091214f 100644 --- a/oneflow/core/register/register_manager.cpp +++ b/oneflow/core/register/register_manager.cpp @@ -39,9 +39,11 @@ struct PackedChunkInfo { } // namespace -RegstMgr::RegstMgr(const Plan& plan) { +void RegstMgr::AddPlan(const Plan& plan, + const HashMap<std::string, Blob*>& variable_op_name2eager_blob) { int64_t this_machine_id = GlobalProcessCtx::Rank(); + // TODO(chengcheng): create chunk mgr for reuse memory between plans. HashMap<int64_t, char*> chunk_id2ptr; for (const ChunkProto& chunk : plan.block_chunk_list().chunk()) { if (chunk.machine_id() != this_machine_id) { continue; } @@ -57,11 +59,36 @@ RegstMgr::RegstMgr(const Plan& plan) { if (mem_block.mem_size() == 0) { continue; } const int64_t mem_block_id = mem_block.mem_block_id(); CHECK(all_block_ids.insert(mem_block_id).second); + if (mem_block.has_chunk_id()) { CHECK(mem_block.has_chunk_offset()); CHECK(chunk_id2ptr.find(mem_block.chunk_id()) != chunk_id2ptr.end()); char* mem_block_ptr = chunk_id2ptr.at(mem_block.chunk_id()) + mem_block.chunk_offset(); CHECK(mem_block_id2ptr_.emplace(mem_block_id, mem_block_ptr).second); + CHECK(!mem_block.has_variable_op_name()); + } else if (mem_block.has_variable_op_name()) { + // NOTE(chengcheng): bind mem_block_ptr to variable blob header_ptr and body_ptr + CHECK(!mem_block.enable_reuse_mem()); + const std::string& var_name = mem_block.variable_op_name(); + CHECK(!var_name.empty()); + auto it = variable_op_name2eager_blob.find(var_name); + CHECK(it != variable_op_name2eager_blob.end()) + << " CANNOT find variable op name: " << var_name; + CHECK(mem_block.has_is_separated_header()); + Blob* var_blob = it->second; + CHECK(var_blob) << " variable op name: " << var_name << " in rank: " << this_machine_id + << " CANNNOT NULL."; + if (mem_block.is_separated_header()) { + CHECK_GE(var_blob->blob_desc().AlignedByteSizeOfBlobHeader(), mem_block.mem_size()); + CHECK_GE(mem_block.mem_size(), var_blob->blob_desc().ByteSizeOfBlobHeader()); + CHECK(mem_block_id2ptr_.emplace(mem_block_id, var_blob->mut_header_ptr()).second); + CHECK(mem_block.mem_case().has_host_mem()); + } else { + CHECK_GE(var_blob->blob_desc().AlignedByteSizeOfBlobBody(), mem_block.mem_size()); + CHECK_GE(mem_block.mem_size(), var_blob->blob_desc().ByteSizeOfBlobBody()); + CHECK(mem_block_id2ptr_.emplace(mem_block_id, var_blob->ForceMutDptr<char>()).second); + CHECK(mem_block.mem_case() == var_blob->mem_case()); + } } else { int64_t zone_id = MemoryCaseUtil::GenMemZoneId(mem_block.mem_case()); if (zone_id2packed_chunk.find(zone_id) == zone_id2packed_chunk.end()) { @@ -115,6 +142,11 @@ RegstMgr::RegstMgr(const Plan& plan) { } } +void RegstMgr::AddPlan(const Plan& plan) { + HashMap<std::string, Blob*> variable_op_name2eager_blob; + AddPlan(plan, variable_op_name2eager_blob); +} + void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto, std::function<void(Regst*)> OneRegstDone) { const int64_t regst_desc_id = regst_desc_proto.regst_desc_id(); diff --git a/oneflow/core/register/register_manager.h b/oneflow/core/register/register_manager.h index 6375b13b99cbe42804b3bb477b732bea71d21670..aab7f22d47f1dc198362a5f1188efa208a856333 100644 --- a/oneflow/core/register/register_manager.h +++ b/oneflow/core/register/register_manager.h @@ -32,9 +32,11 @@ namespace oneflow { class RegstMgr final { public: OF_DISALLOW_COPY_AND_MOVE(RegstMgr); - RegstMgr() = delete; + RegstMgr() = default; ~RegstMgr() = default; + void AddPlan(const Plan& plan, const HashMap<std::string, Blob*>& variable_op_name2eager_blob); + void AddPlan(const Plan& plan); void NewRegsts(const RegstDescProto& regst_desc_proto, std::function<void(Regst*)> OneRegstDone); const RtRegstDesc& RegstDesc4RegstDescId(int64_t regst_desc_id) const; bool HasRegstDescId(int64_t regst_desc_id) const; @@ -43,11 +45,9 @@ class RegstMgr final { Blob* Blob4LbiAndParallelId(const LogicalBlobId& lbi, const int64_t parallel_id); private: - friend class Global<RegstMgr>; - - explicit RegstMgr(const Plan& plan); void NewBlobsInOneRegst(const std::vector<LbiBlobDescPair>& lbis, Regst*, const RtRegstDesc*, char* main_mem_ptr, char* separated_header_mem_ptr); + HashMap<int64_t, std::unique_ptr<const RtRegstDesc>> regst_desc_id2rt_regst_desc_; HashMap<LogicalBlobId, HashMap<int64_t, Blob*>> lbi2parallel_id2blob_; HashMap<int64_t, char*> mem_block_id2ptr_; diff --git a/oneflow/core/register/runtime_register_desc.cpp b/oneflow/core/register/runtime_register_desc.cpp index 32ee774b90a6a47d99eec9c6afe3f734b578d696..91af766026a69e467a05b6d2ca1a29256d2e7614 100644 --- a/oneflow/core/register/runtime_register_desc.cpp +++ b/oneflow/core/register/runtime_register_desc.cpp @@ -44,6 +44,14 @@ RtRegstDesc::RtRegstDesc(const RegstDescProto& proto) { } else { sorted_blob_desc_vec_.push_back(std::make_unique<const BlobDesc>(BlobDesc(DataType::kChar))); } + + if ((proto.mem_case().has_device_cuda_mem()) + || (proto.has_variable_op_name() && !proto.variable_op_name().empty())) { + // NOTE(chengcheng): When this regst is shared with EagerBlobObject, header is ALWAYS separated. + has_separated_header_ = true; + } else { + has_separated_header_ = false; + } } int64_t RtRegstDesc::GetOrdinalForLbi(const LogicalBlobId& lbi) const { @@ -86,7 +94,7 @@ size_t RtRegstDesc::TotalMainByteSize4AllRegst() const { } size_t RtRegstDesc::MainByteSize4OneRegst() const { - if (mem_case_.has_device_cuda_mem()) { + if (has_separated_header_) { return GetSoleBlobDesc()->AlignedByteSizeOfBlobBody(); } else { return GetSoleBlobDesc()->AlignedTotalByteSize(); @@ -98,8 +106,9 @@ size_t RtRegstDesc::TotalSeparatedHeaderByteSize4AllRegst() const { } size_t RtRegstDesc::SeparatedHeaderByteSize4OneRegst() const { - if (mem_case_.has_device_cuda_mem()) { - return GetSoleBlobDesc()->ByteSizeOfBlobHeader(); + if (has_separated_header_) { + // NOTE(chengcheng): Header size need to be aligned for XRT memory allocate + return GetSoleBlobDesc()->AlignedByteSizeOfBlobHeader(); } else { return 0; } diff --git a/oneflow/core/register/runtime_register_desc.h b/oneflow/core/register/runtime_register_desc.h index 3867c58c44e1df4d4d81b1e922d9a313bb183114..6adcd1bbd9c021835537d53e080402162424e56a 100644 --- a/oneflow/core/register/runtime_register_desc.h +++ b/oneflow/core/register/runtime_register_desc.h @@ -65,6 +65,8 @@ class RtRegstDesc { std::unique_ptr<Shape> data_regst_time_shape_; std::vector<std::unique_ptr<const BlobDesc>> sorted_blob_desc_vec_; std::vector<LogicalBlobId> sorted_lbi_vec_; + + bool has_separated_header_; }; } // namespace oneflow diff --git a/python/oneflow/test/graph/test_graph_linear.py b/python/oneflow/test/graph/test_graph_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..000021207e889df700ab9d43ca4255563ea12df6 --- /dev/null +++ b/python/oneflow/test/graph/test_graph_linear.py @@ -0,0 +1,69 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import numpy as np + +import oneflow as flow +import oneflow.unittest + + +def _test_linear_graph(test_case, device): + linear = flow.nn.Linear(3, 8, False) + linear = linear.to(device) + input_arr = np.array( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=np.float32, + ) + np_weight = np.ones((3, 8)).astype(np.float32) + np_weight.fill(2.3) + x = flow.tensor(input_arr, device=device) + flow.nn.init.constant_(linear.weight, 2.3) + of_eager_out = linear(x) + np_out = np.matmul(input_arr, np_weight) + test_case.assertTrue(np.allclose(of_eager_out.numpy(), np_out, 1e-05, 1e-05)) + + class LinearGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.my_linear = linear + + def build(self, x): + return self.my_linear(x) + + linear_g = LinearGraph() + of_lazy_out = linear_g(x) + test_case.assertTrue(np.array_equal(of_lazy_out.numpy(), of_eager_out.numpy())) + + +class TestLinearGraph(oneflow.unittest.TestCase): + def test_linear_graph_gpu(test_case): + _test_linear_graph(test_case, flow.device("cuda")) + + def test_linear_graph_cpu(test_case): + _test_linear_graph(test_case, flow.device("cpu")) + + +if __name__ == "__main__": + unittest.main()