Skip to content
Snippets Groups Projects
Commit 80010aaa authored by chengtbf's avatar chengtbf Committed by Jinhui Yuan
Browse files

Dev regst manager (#920)

* refine regst manager

* refine regst mananger version 2

* remove memInfo

* refine code and reduce rows

* using protobuf message differencer

* explicit

* memory allocator free memory

* add regst desc proto list init for test code
parent 5fa0b3fa
No related branches found
No related tags found
No related merge requests found
Showing
with 184 additions and 113 deletions
......@@ -15,7 +15,7 @@ ActorMsg ActorMsg::BuildRegstMsgToConsumer(int64_t producer, int64_t consumer,
== Global<MachineCtx>::Get()->this_machine_id()) {
msg.regst_wrapper_.comm_net_token = nullptr;
} else {
msg.regst_wrapper_.comm_net_token = regst_raw_ptr->packed_blob()->comm_net_token();
msg.regst_wrapper_.comm_net_token = regst_raw_ptr->comm_net_token();
msg.regst_wrapper_.regst_status = regst_raw_ptr->status();
}
return msg;
......
......@@ -71,8 +71,7 @@ void CopyCommNetActor::Act() {
int64_t src_actor_id = readable_it->second.producer;
int64_t src_machine_id = Global<IDMgr>::Get()->MachineId4ActorId(src_actor_id);
// writeable
Blob* writeable_blob = GetCurSoleWriteableRegst()->packed_blob();
void* writeable_token = writeable_blob->comm_net_token();
void* writeable_token = GetCurSoleWriteableRegst()->comm_net_token();
// Async
void* read_id =
Global<CommNet>::Get()->Read(actor_read_id_, src_machine_id, readable_token, writeable_token);
......
......@@ -25,6 +25,12 @@ bool HasFieldInPbMessage(const PbMessage& msg, const std::string& field_name) {
return fd != nullptr;
}
const PbFd* GetPbFdFromPbMessage(const PbMessage& msg, const std::string& field_name) {
PROTOBUF_GET_FIELDDESC(msg, field_name);
CHECK_NOTNULL(fd);
return fd;
}
#define DEFINE_GET_VAL_FROM_PBMESSAGE(cpp_type, pb_type_name) \
template<> \
cpp_type GetValFromPbMessage<cpp_type>(const PbMessage& msg, const std::string& field_name) { \
......
......@@ -7,8 +7,10 @@
#include <google/protobuf/descriptor.h>
#include <google/protobuf/map.h>
#include <google/protobuf/message.h>
#include <google/protobuf/util/message_differencer.h>
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/persistence/persistent_out_stream.h"
namespace oneflow {
......@@ -22,6 +24,8 @@ template<typename T1, typename T2>
using PbMapPair = google::protobuf::MapPair<T1, T2>;
template<typename K, typename V>
using PbMap = google::protobuf::Map<K, V>;
using PbFd = google::protobuf::FieldDescriptor;
using PbMd = google::protobuf::util::MessageDifferencer;
#define PROTOBUF_BASIC_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(std::string, String) \
......@@ -34,7 +38,7 @@ using PbMap = google::protobuf::Map<K, V>;
#define PROTOBUF_GET_FIELDDESC(msg, field_name) \
auto d = const_cast<google::protobuf::Descriptor*>(msg.GetDescriptor()); \
auto fd = const_cast<google::protobuf::FieldDescriptor*>(d->FindFieldByName(field_name));
auto fd = const_cast<PbFd*>(d->FindFieldByName(field_name));
#define PROTOBUF_REFLECTION(msg, field_name) \
PROTOBUF_GET_FIELDDESC(msg, field_name) \
......@@ -51,6 +55,8 @@ bool HasFieldInPbMessage(const PbMessage&, const std::string& field_name);
// Get From PbMessage
const PbFd* GetPbFdFromPbMessage(const PbMessage&, const std::string& field_name);
template<typename T>
T GetValFromPbMessage(const PbMessage&, const std::string& field_name);
......@@ -128,10 +134,37 @@ const T* GetMsgPtrFromPbMessage(const PbMessage& msg, const std::string& field_n
}
}
inline bool operator<(const LogicalBlobId& lhs, const LogicalBlobId& rhs) {
if (lhs.op_name() != rhs.op_name()) { return lhs.op_name() < rhs.op_name(); }
if (lhs.blob_name() != rhs.blob_name()) { return lhs.blob_name() < rhs.blob_name(); }
if (lhs.b121_id() != rhs.b121_id()) { return lhs.b121_id() < rhs.b121_id(); }
if (lhs.clone_id() != rhs.clone_id()) { return lhs.clone_id() < rhs.clone_id(); }
if (lhs.is_packed_id() != rhs.is_packed_id()) { return lhs.is_packed_id() < rhs.is_packed_id(); }
return false;
}
inline bool operator==(const LogicalBlobId& lhs, const LogicalBlobId& rhs) {
PbMd message_diff;
return message_diff.Compare(lhs, rhs);
}
// Persistent
PersistentOutStream& operator<<(PersistentOutStream&, const PbMessage&);
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::LogicalBlobId> {
size_t operator()(const oneflow::LogicalBlobId& lbi) const {
return std::hash<std::string>()(lbi.op_name() + lbi.blob_name() + std::to_string(lbi.b121_id())
+ std::to_string(lbi.clone_id())
+ std::to_string(lbi.is_packed_id()));
}
};
} // namespace std
#endif // ONEFLOW_CORE_COMMON_PROTOBUF_H_
......@@ -23,7 +23,6 @@
#include <unordered_set>
#include <utility>
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/common/meta_util.hpp"
DECLARE_string(log_dir);
......@@ -196,34 +195,6 @@ void Erase(T& container, std::function<bool(const typename T::value_type&)> Need
Erase<T>(container, NeedErase, [](const typename T::value_type&) {});
}
inline bool operator<(const LogicalBlobId& lhs, const LogicalBlobId& rhs) {
if (lhs.op_name() != rhs.op_name()) { return lhs.op_name() < rhs.op_name(); }
if (lhs.blob_name() != rhs.blob_name()) { return lhs.blob_name() < rhs.blob_name(); }
if (lhs.b121_id() != rhs.b121_id()) { return lhs.b121_id() < rhs.b121_id(); }
if (lhs.clone_id() != rhs.clone_id()) { return lhs.clone_id() < rhs.clone_id(); }
if (lhs.is_packed_id() != rhs.is_packed_id()) { return lhs.is_packed_id() < rhs.is_packed_id(); }
return false;
}
inline bool operator==(const LogicalBlobId& lhs, const LogicalBlobId& rhs) {
return lhs.op_name() == rhs.op_name() && lhs.blob_name() == rhs.blob_name()
&& lhs.b121_id() == rhs.b121_id() && lhs.clone_id() == rhs.clone_id()
&& lhs.is_packed_id() == rhs.is_packed_id();
}
} // namespace oneflow
namespace std {
template<>
struct hash<oneflow::LogicalBlobId> {
size_t operator()(const oneflow::LogicalBlobId& lbi) const {
return std::hash<std::string>()(lbi.op_name() + lbi.blob_name() + std::to_string(lbi.b121_id())
+ std::to_string(lbi.clone_id())
+ std::to_string(lbi.is_packed_id()));
}
};
} // namespace std
#endif // ONEFLOW_CORE_COMMON_UTIL_H_
......@@ -97,7 +97,7 @@ void Runtime::NewAllGlobal(const Plan& plan, bool is_experiment_phase) {
}
Global<SnapshotMgr>::New(plan);
Global<MemoryAllocator>::New();
Global<RegstMgr>::New();
Global<RegstMgr>::New(plan);
Global<ActorMsgBus>::New();
Global<ThreadMgr>::New(plan);
}
......
......@@ -225,7 +225,7 @@ void KernelIfWithActivation<device_type, T>::GenActivationBlob(
std::unique_ptr<Blob>* activation_blob, void* buf_ptr,
const BlobDesc* activation_blob_desc) const {
activation_blob->reset(
NewBlob(nullptr, activation_blob_desc, static_cast<char*>(buf_ptr), nullptr, device_type));
NewBlob(nullptr, activation_blob_desc, static_cast<char*>(buf_ptr), device_type));
}
std::unique_ptr<const Kernel> ConstructKernel(const ParallelContext* parallel_ctx,
......
......@@ -242,7 +242,7 @@ size_t GetTmpSizeForReduceSum(DataType data_type, int64_t sum_elem_num) {
char* host_raw_dptr = nullptr; \
CudaCheck(cudaMallocHost(&host_raw_dptr, blob->TotalByteSize())); \
std::unique_ptr<Blob> host_blob; \
host_blob.reset(NewBlob(nullptr, &blob_desc, host_raw_dptr, nullptr, DeviceType::kGPU));
host_blob.reset(NewBlob(nullptr, &blob_desc, host_raw_dptr, DeviceType::kGPU));
// asynchronous copy to device
#define AFTER_CPU_INITIALIZE() \
......
......@@ -72,7 +72,7 @@ template<>
Blob* OpKernelTestUtil<DeviceType::kCPU>::CreateBlob(const BlobDesc* blob_desc, Regst* regst) {
void* mem_ptr = nullptr;
CudaCheck(cudaMallocHost(&mem_ptr, blob_desc->TotalByteSize()));
return NewBlob(regst, blob_desc, static_cast<char*>(mem_ptr), nullptr, DeviceType::kCPU);
return NewBlob(regst, blob_desc, static_cast<char*>(mem_ptr), DeviceType::kCPU);
}
template<DeviceType src_device_type, DeviceType dst_device_type>
......
......@@ -10,7 +10,7 @@ template<>
Blob* OpKernelTestUtil<DeviceType::kGPU>::CreateBlob(const BlobDesc* blob_desc, Regst* regst) {
void* mem_ptr = nullptr;
CudaCheck(cudaMalloc(&mem_ptr, blob_desc->TotalByteSize()));
return NewBlob(regst, blob_desc, static_cast<char*>(mem_ptr), nullptr, DeviceType::kGPU);
return NewBlob(regst, blob_desc, static_cast<char*>(mem_ptr), DeviceType::kGPU);
}
template<>
......
......@@ -4,11 +4,13 @@
namespace oneflow {
std::tuple<char*, void*, std::function<void()>> MemoryAllocator::Allocate(MemoryCase mem_case,
std::size_t size) {
MemoryAllocator::~MemoryAllocator() {
for (std::function<void()> deleter : deleters_) { deleter(); }
}
char* MemoryAllocator::Allocate(MemoryCase mem_case, std::size_t size) {
const int memset_val = 0;
char* dptr = nullptr;
void* comm_net_token = nullptr;
if (mem_case.has_host_mem()) {
if (mem_case.host_mem().used_by_device()) {
CudaCheck(cudaMallocHost(&dptr, size));
......@@ -16,38 +18,27 @@ std::tuple<char*, void*, std::function<void()>> MemoryAllocator::Allocate(Memory
dptr = reinterpret_cast<char*>(malloc(size));
CHECK_NOTNULL(dptr);
}
if (mem_case.host_mem().used_by_network()) {
comm_net_token = Global<CommNet>::Get()->RegisterMemory(dptr, size);
}
memset(dptr, memset_val, size);
} else if (mem_case.has_device_cuda_mem()) {
int32_t current_device_id;
CudaCheck(cudaGetDevice(&current_device_id));
CHECK_EQ(mem_case.device_cuda_mem().device_id(), current_device_id);
CudaCheck(cudaSetDevice(mem_case.device_cuda_mem().device_id()));
CudaCheck(cudaMalloc(&dptr, size));
CudaCheck(cudaMemset(dptr, memset_val, size));
} else {
UNIMPLEMENTED();
}
return std::make_tuple(
dptr, comm_net_token,
std::bind(&MemoryAllocator::Deallocate, this, dptr, comm_net_token, mem_case));
deleters_.push_back(std::bind(&MemoryAllocator::Deallocate, this, dptr, mem_case));
return dptr;
}
void MemoryAllocator::Deallocate(char* dptr, void* comm_net_token, MemoryCase mem_case) {
void MemoryAllocator::Deallocate(char* dptr, MemoryCase mem_case) {
if (mem_case.has_host_mem()) {
if (mem_case.host_mem().used_by_network()) {
Global<CommNet>::Get()->UnRegisterMemory(comm_net_token);
}
if (mem_case.host_mem().used_by_device()) {
CudaCheck(cudaFreeHost(dptr));
} else {
free(dptr);
}
} else if (mem_case.has_device_cuda_mem()) {
int32_t current_device_id = -1;
CudaCheck(cudaGetDevice(&current_device_id));
CHECK_EQ(mem_case.device_cuda_mem().device_id(), current_device_id);
CudaCheck(cudaSetDevice(mem_case.device_cuda_mem().device_id()));
CudaCheck(cudaFree(dptr));
} else {
UNIMPLEMENTED();
......
......@@ -9,13 +9,17 @@ namespace oneflow {
class MemoryAllocator final {
public:
OF_DISALLOW_COPY_AND_MOVE(MemoryAllocator);
~MemoryAllocator() = default;
~MemoryAllocator();
std::tuple<char*, void*, std::function<void()>> Allocate(MemoryCase mem_case, std::size_t size);
char* Allocate(MemoryCase mem_case, std::size_t size);
private:
friend class Global<MemoryAllocator>;
MemoryAllocator() = default;
void Deallocate(char* dptr, void*, MemoryCase mem_case);
void Deallocate(char* dptr, MemoryCase mem_case);
std::list<std::function<void()>> deleters_;
};
} // namespace oneflow
......
......@@ -7,7 +7,7 @@
namespace oneflow {
Blob::Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr, void* comm_net_token) {
Blob::Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr) {
mem_ptr_ = mem_ptr;
if (blob_desc->has_data_id_field()) {
data_id_ptr_ = mem_ptr;
......@@ -22,7 +22,6 @@ Blob::Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr, void* comm_ne
}
dptr_ = offset + RoundUp(blob_desc->ByteSizeOfColNumField(), kCudaAlignSize);
blob_desc_ = blob_desc;
comm_net_token_ = comm_net_token;
regst_ = regst;
}
......@@ -50,22 +49,21 @@ int32_t Blob::max_col_id() const { return regst_->max_col_id(); }
void Blob::set_max_col_id(int32_t val) { regst_->set_max_col_id(val); }
const MemoryCase& Blob::mem_case() const { return regst_->regst_desc()->mem_case(); }
#define MAKE_BLOB_ENTRY(data_type_pair, ndims, device_type) \
{GetHashKey(OF_PP_PAIR_SECOND(data_type_pair), ndims, device_type), \
[](Regst* regst, const BlobDesc* blob_desc, char* mem_ptr, void* comm_net_token) { \
return new BlobImpl<OF_PP_PAIR_FIRST(data_type_pair), ndims, device_type>( \
regst, blob_desc, mem_ptr, comm_net_token); \
#define MAKE_BLOB_ENTRY(data_type_pair, ndims, device_type) \
{GetHashKey(OF_PP_PAIR_SECOND(data_type_pair), ndims, device_type), \
[](Regst* regst, const BlobDesc* blob_desc, char* mem_ptr) { \
return new BlobImpl<OF_PP_PAIR_FIRST(data_type_pair), ndims, device_type>(regst, blob_desc, \
mem_ptr); \
}},
Blob* NewBlob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr, void* comm_net_token,
DeviceType device_type) {
static const HashMap<std::string, std::function<Blob*(Regst * regst, const BlobDesc* blob_desc,
char* mem_ptr, void* comm_net_token)>>
Blob* NewBlob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr, DeviceType device_type) {
static const HashMap<
std::string, std::function<Blob*(Regst * regst, const BlobDesc* blob_desc, char* mem_ptr)>>
creators = {OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_BLOB_ENTRY, ALL_DATA_TYPE_SEQ, DIM_SEQ,
DEVICE_TYPE_SEQ)};
std::string key = GetHashKey(blob_desc->data_type(),
static_cast<int32_t>(blob_desc->shape().NumAxes()), device_type);
return creators.at(key)(regst, blob_desc, mem_ptr, comm_net_token);
return creators.at(key)(regst, blob_desc, mem_ptr);
}
} // namespace oneflow
......@@ -55,8 +55,6 @@ class Blob : public BlobIf {
return static_cast<T*>(dptr_);
}
void* comm_net_token() const { return comm_net_token_; }
const BlobDesc& blob_desc() const { return *blob_desc_; }
const BlobDesc* blob_desc_ptr() const { return blob_desc_; }
const Shape& shape() const { return blob_desc_->shape(); }
......@@ -82,9 +80,7 @@ class Blob : public BlobIf {
const MemoryCase& mem_case() const;
protected:
Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr)
: Blob(regst, blob_desc, mem_ptr, nullptr) {}
Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr, void* comm_net_token);
Blob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr);
private:
template<typename T>
......@@ -99,13 +95,11 @@ class Blob : public BlobIf {
char* data_id_ptr_;
int32_t* col_num_ptr_;
void* dptr_;
void* comm_net_token_;
const BlobDesc* blob_desc_;
Regst* regst_;
};
Blob* NewBlob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr, void* comm_net_token,
DeviceType device_type);
Blob* NewBlob(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr, DeviceType device_type);
class RecordBlobIf : public BlobIf {
public:
......
......@@ -19,9 +19,7 @@ class BlobImpl final : public Blob {
public:
OF_DISALLOW_COPY_AND_MOVE(BlobImpl);
BlobImpl(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr)
: BlobImpl(regst, blob_desc, mem_ptr, nullptr) {}
BlobImpl(Regst* regst, const BlobDesc* blob_desc, char* mem_ptr, void* comm_net_token)
: Blob(regst, blob_desc, mem_ptr, comm_net_token) {
: Blob(regst, blob_desc, mem_ptr) {
CHECK_EQ(NDIMS, blob_desc_ptr()->shape().NumAxes());
for (int32_t d = 0; d < NDIMS; ++d) { dsizes_[d] = blob_desc_ptr()->shape().At(d); }
tensor_ =
......
#include "oneflow/core/register/register.h"
#include "oneflow/core/job/keyword.h"
#include "oneflow/core/comm_network/comm_network.h"
namespace oneflow {
......@@ -14,6 +15,11 @@ Regst::Regst() {
status_.col_id = 0;
status_.max_col_id = 0;
regst_desc_ = nullptr;
comm_net_token_ = nullptr;
}
Regst::~Regst() {
if (comm_net_token_ != nullptr) { Global<CommNet>::Get()->UnRegisterMemory(comm_net_token_); }
}
Blob* Regst::GetBlobByLbi(const LogicalBlobId& lbi) {
......
......@@ -18,7 +18,7 @@ struct RegstStatus {
class Regst final {
public:
OF_DISALLOW_COPY_AND_MOVE(Regst);
~Regst() { deleter_(); }
~Regst();
// Getters
const RegstStatus& status() const { return status_; }
......@@ -44,6 +44,8 @@ class Regst final {
bool IsMaxCol() const { return col_id() == max_col_id(); }
void* comm_net_token() const { return comm_net_token_; }
// Setters
void set_piece_id(int64_t val) { status_.piece_id = val; }
void set_model_version_id(int64_t val) { status_.model_version_id = val; }
......@@ -55,9 +57,9 @@ class Regst final {
friend class RegstMgr;
Regst();
void* comm_net_token_;
RegstStatus status_;
const RtRegstDesc* regst_desc_;
std::function<void()> deleter_;
HashMap<LogicalBlobId, std::unique_ptr<BlobIf>> lbi2blob_;
std::unique_ptr<BlobIf> packed_blob_;
};
......
#include "oneflow/core/register/register_manager.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/register/blob.h"
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/memory/memory_case.pb.h"
namespace std {
template<>
struct hash<oneflow::MemoryCase> {
size_t operator()(const oneflow::MemoryCase& val) const {
if (val.has_host_mem()) {
return val.host_mem().used_by_device() + 1024;
} else {
return val.device_cuda_mem().device_id();
}
}
};
} // namespace std
namespace oneflow {
inline bool operator==(const MemoryCase& lhs, const MemoryCase& rhs) {
PbMd message_diff;
message_diff.IgnoreField(GetPbFdFromPbMessage(lhs.host_mem(), "used_by_network"));
return message_diff.Compare(lhs, rhs);
}
RegstMgr::RegstMgr(const Plan& plan) {
std::list<const RegstDescProto*> regst_protos;
for (const TaskProto& task : plan.task()) {
if (task.machine_id() != Global<MachineCtx>::Get()->this_machine_id()) { continue; }
for (const auto& pair : task.produced_regst_desc()) { regst_protos.push_back(&pair.second); }
}
InitFromRegstProtoList(regst_protos);
}
RegstMgr::RegstMgr(const std::list<const RegstDescProto*>& regst_protos) {
InitFromRegstProtoList(regst_protos);
}
void RegstMgr::InitFromRegstProtoList(const std::list<const RegstDescProto*>& regst_protos) {
HashMap<MemoryCase, char*> mem_case2mem_ptr;
HashMap<MemoryCase, size_t> mem_case2mem_size;
for (const RegstDescProto* regst_desc_proto : regst_protos) {
CHECK(regst_desc_id2rt_regst_desc_
.emplace(regst_desc_proto->regst_desc_id(),
std::make_unique<const RtRegstDesc>(*regst_desc_proto))
.second);
mem_case2mem_size[regst_desc_proto->mem_case()] +=
regst_desc_id2rt_regst_desc_.at(regst_desc_proto->regst_desc_id())
->TotalByteSize4AllRegst();
}
for (const auto& pair : mem_case2mem_size) {
CHECK(
mem_case2mem_ptr
.emplace(pair.first, Global<MemoryAllocator>::Get()->Allocate(pair.first, pair.second))
.second);
}
for (const auto& pair : regst_desc_id2rt_regst_desc_) {
const RtRegstDesc* rt_regst_desc = pair.second.get();
const MemoryCase& mem_case = rt_regst_desc->mem_case();
CHECK(regst_desc_id2mem_ptr_.emplace(pair.first, mem_case2mem_ptr.at(mem_case)).second);
mem_case2mem_ptr.at(mem_case) += rt_regst_desc->TotalByteSize4AllRegst();
}
}
void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto, DeviceType device_type,
std::function<void(Regst*)> OneRegstDone) {
const RtRegstDesc* runtime_regst_desc = new RtRegstDesc(regst_desc_proto);
{
std::unique_lock<std::mutex> lck(rt_regst_descs_mtx_);
rt_regst_descs_.emplace_back(runtime_regst_desc);
}
const int64_t regst_desc_id = regst_desc_proto.regst_desc_id();
const RegstDescTypeProto& regst_desc_type = regst_desc_proto.regst_desc_type();
for (int64_t i = 0; i < regst_desc_proto.register_num(); ++i) {
const RtRegstDesc* rt_regst_desc = regst_desc_id2rt_regst_desc_.at(regst_desc_id).get();
char* mem_ptr = regst_desc_id2mem_ptr_.at(regst_desc_id);
std::vector<LogicalBlobId> lbis;
if (regst_desc_type.has_normal_regst_desc()) {
for (const LbiBlobDescPair& pair : regst_desc_type.normal_regst_desc().lbi2blob_desc()) {
lbis.push_back(pair.lbi());
}
CHECK(!lbis.empty());
}
for (int64_t i = 0; i < rt_regst_desc->register_num(); ++i) {
Regst* regst = new Regst;
regst->regst_desc_ = runtime_regst_desc;
regst->regst_desc_ = rt_regst_desc;
if (regst_desc_type.has_normal_regst_desc()) {
std::vector<LogicalBlobId> lbis;
for (const LbiBlobDescPair& pair : regst_desc_type.normal_regst_desc().lbi2blob_desc()) {
lbis.push_back(pair.lbi());
}
CHECK(!lbis.empty());
std::sort(lbis.begin(), lbis.end());
std::tuple<char*, void*, std::function<void()>> allocation_result =
Global<MemoryAllocator>::Get()->Allocate(
regst_desc_proto.mem_case(), runtime_regst_desc->packed_blob_desc()->TotalByteSize());
char* cur_pointer = std::get<0>(allocation_result);
char* cur_pointer = mem_ptr;
for (const LogicalBlobId& lbi : lbis) {
const BlobDesc* blob_desc = runtime_regst_desc->GetBlobDescFromLbi(lbi);
const BlobDesc* blob_desc = rt_regst_desc->GetBlobDescFromLbi(lbi);
std::unique_ptr<Blob> blob_ptr;
blob_ptr.reset(NewBlob(regst, blob_desc, cur_pointer, nullptr, device_type));
blob_ptr.reset(NewBlob(regst, blob_desc, cur_pointer, device_type));
CHECK(regst->lbi2blob_.emplace(lbi, std::move(blob_ptr)).second);
cur_pointer += blob_desc->TotalByteSize();
}
regst->packed_blob_.reset(NewBlob(regst, runtime_regst_desc->packed_blob_desc(),
std::get<0>(allocation_result),
std::get<1>(allocation_result), device_type));
regst->deleter_ = std::get<2>(allocation_result);
regst->packed_blob_.reset(
NewBlob(regst, rt_regst_desc->packed_blob_desc(), mem_ptr, device_type));
if (rt_regst_desc->mem_case().has_host_mem()
&& rt_regst_desc->mem_case().host_mem().used_by_network()) {
regst->comm_net_token_ = Global<CommNet>::Get()->RegisterMemory(
mem_ptr, rt_regst_desc->packed_blob_desc()->TotalByteSize());
}
mem_ptr += rt_regst_desc->packed_blob_desc()->TotalByteSize();
} else if (regst_desc_type.has_record_regst_desc()) {
const RecordTypeProto& record_type = regst_desc_type.record_regst_desc().record_type();
switch (record_type) {
case kOFRecord: regst->packed_blob_.reset(new RecordBlob<OFRecord>); break;
default: UNIMPLEMENTED();
}
regst->deleter_ = []() {};
} else if (regst_desc_type.has_delay_regst_desc()) {
regst->deleter_ = []() {};
// do nothing
} else {
UNIMPLEMENTED();
}
......
......@@ -13,6 +13,7 @@ namespace oneflow {
class RegstMgr final {
public:
OF_DISALLOW_COPY_AND_MOVE(RegstMgr);
RegstMgr() = delete;
~RegstMgr() = default;
void NewRegsts(const RegstDescProto& regst_desc_proto, DeviceType device_type,
......@@ -20,10 +21,13 @@ class RegstMgr final {
private:
friend class Global<RegstMgr>;
RegstMgr() = default;
std::mutex rt_regst_descs_mtx_;
std::list<std::unique_ptr<const RtRegstDesc>> rt_regst_descs_;
explicit RegstMgr(const Plan& plan);
explicit RegstMgr(const std::list<const RegstDescProto*>& regst_protos);
void InitFromRegstProtoList(const std::list<const RegstDescProto*>& regst_protos);
HashMap<int64_t, std::unique_ptr<const RtRegstDesc>> regst_desc_id2rt_regst_desc_;
HashMap<int64_t, char*> regst_desc_id2mem_ptr_;
};
} // namespace oneflow
......
......@@ -23,6 +23,9 @@ class RtRegstDesc {
const BlobDesc* GetBlobDescFromLbi(const LogicalBlobId& lbi) const;
const BlobDesc* packed_blob_desc() const { return &packed_blob_desc_; }
size_t TotalByteSize4AllRegst() const {
return packed_blob_desc_.TotalByteSize() * register_num_;
}
private:
int64_t regst_desc_id_;
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment