From 6556bdac5d211b6c49714cba74d640c3ea2a8b8e Mon Sep 17 00:00:00 2001 From: willzhang4a58 <willzhang4a58@gmail.com> Date: Wed, 30 May 2018 18:59:40 +0800 Subject: [PATCH] max_part_num --- benchmark/alexnet/report.md | 4 ++-- oneflow/core/job/runtime.cpp | 16 +++++++++------- oneflow/core/kernel/decode_ofrecord_kernel.cpp | 12 ++++++++---- oneflow/core/kernel/decode_ofrecord_kernel.h | 1 + oneflow/core/record/ofrecord_decoder.cpp | 15 ++++++++------- oneflow/core/record/ofrecord_decoder.h | 14 ++++++++------ 6 files changed, 36 insertions(+), 26 deletions(-) diff --git a/benchmark/alexnet/report.md b/benchmark/alexnet/report.md index ac5029d8d..8a8e3af26 100644 --- a/benchmark/alexnet/report.md +++ b/benchmark/alexnet/report.md @@ -2,6 +2,6 @@ batch_size: 1024 gpu num | time (one batch) :-------| :------------- -1 | 538ms +1 | 549ms 2 | 285ms -4 | 189ms +4 | 171ms diff --git a/oneflow/core/job/runtime.cpp b/oneflow/core/job/runtime.cpp index 24d26d53b..a9216639b 100644 --- a/oneflow/core/job/runtime.cpp +++ b/oneflow/core/job/runtime.cpp @@ -53,7 +53,7 @@ Runtime::Runtime(const Plan& plan, bool is_experiment_phase) { LOG(INFO) << "All actor on this machine are constructed"; OF_BARRIER(); LOG(INFO) << "All actor on all machine are constructed"; - Global<CommNet>::Get()->RegisterMemoryDone(); + if (Global<CommNet>::Get()) { Global<CommNet>::Get()->RegisterMemoryDone(); } runtime_ctx->NewCounter("model_init_cnt", mdupdt_tasks.size()); SendCmdMsg(mdupdt_tasks, ActorCmd::kInitModel); runtime_ctx->WaitUntilCntEqualZero("model_init_cnt"); @@ -82,17 +82,19 @@ void Runtime::NewAllGlobal(const Plan& plan, bool is_experiment_phase) { } } Global<RuntimeCtx>::New(piece_num, is_experiment_phase); + if (job_desc->TotalMachineNum() > 1) { #ifdef PLATFORM_POSIX - if (job_desc->use_rdma()) { + if (job_desc->use_rdma()) { #ifdef WITH_RDMA - IBVerbsCommNet::Init(plan); + IBVerbsCommNet::Init(plan); #else - LOG(FATAL) << "RDMA components not found"; + LOG(FATAL) << "RDMA components not found"; #endif - } else { - EpollCommNet::Init(plan); - } + } else { + EpollCommNet::Init(plan); + } #endif + } Global<SnapshotMgr>::New(plan); Global<MemoryAllocator>::New(); Global<RegstMgr>::New(); diff --git a/oneflow/core/kernel/decode_ofrecord_kernel.cpp b/oneflow/core/kernel/decode_ofrecord_kernel.cpp index 08f5108b0..0b8ae68b2 100644 --- a/oneflow/core/kernel/decode_ofrecord_kernel.cpp +++ b/oneflow/core/kernel/decode_ofrecord_kernel.cpp @@ -4,9 +4,10 @@ namespace oneflow { -void DecodeOFRecordKernel::VirtualKernelInit(const ParallelContext*) { +void DecodeOFRecordKernel::VirtualKernelInit(const ParallelContext* parallel_ctx) { random_seed_gen_.reset(new std::mt19937(kernel_conf().decode_ofrecord_conf().random_seed())); distribution_.reset(new std::uniform_int_distribution<int32_t>(0, 1024 * 1024)); + parallel_num_ = parallel_ctx->parallel_num(); } int32_t DecodeOFRecordKernel::NextRandomInt() const { return (*distribution_)(*random_seed_gen_); } @@ -24,9 +25,12 @@ void DecodeOFRecordKernel::Forward(const KernelCtx& ctx, const BlobConf& blob_conf = decode_conf.blob(i); OFRecordDecoderIf* decoder = GetOFRecordDecoder(blob_conf.encode_case().encode_case(), blob_conf.data_type()); - int32_t max_col_id = - decoder->DecodeOneCol(ctx.device_ctx, record_blob, blob_conf, status->cur_col_id_, out_blob, - std::bind(&DecodeOFRecordKernel::NextRandomInt, this)); + int32_t compute_thread_num = Global<ThreadMgr>::Get()->compute_thread_pool()->thread_num(); + int32_t max_col_id = decoder->DecodeOneCol( + ctx.device_ctx, + compute_thread_num / parallel_num_ + (compute_thread_num % parallel_num_ == 0 ? 0 : 1), + record_blob, blob_conf, status->cur_col_id_, out_blob, + std::bind(&DecodeOFRecordKernel::NextRandomInt, this)); if (status->max_col_id_ == -1) { status->max_col_id_ = max_col_id; diff --git a/oneflow/core/kernel/decode_ofrecord_kernel.h b/oneflow/core/kernel/decode_ofrecord_kernel.h index c39f0d4f2..74da69bd3 100644 --- a/oneflow/core/kernel/decode_ofrecord_kernel.h +++ b/oneflow/core/kernel/decode_ofrecord_kernel.h @@ -26,6 +26,7 @@ class DecodeOFRecordKernel final : public KernelIf<DeviceType::kCPU> { std::unique_ptr<std::mt19937> random_seed_gen_; std::unique_ptr<std::uniform_int_distribution<int32_t>> distribution_; + int64_t parallel_num_; }; } // namespace oneflow diff --git a/oneflow/core/record/ofrecord_decoder.cpp b/oneflow/core/record/ofrecord_decoder.cpp index 717de59b8..0189f9942 100644 --- a/oneflow/core/record/ofrecord_decoder.cpp +++ b/oneflow/core/record/ofrecord_decoder.cpp @@ -73,14 +73,15 @@ void DoPreprocess(const PreprocessConf& conf, T* dptr, const Shape& shape) { template<EncodeCase encode_case, typename T> int32_t OFRecordDecoder<encode_case, T>::DecodeOneCol( - DeviceCtx* ctx, RecordBlob<OFRecord>* record_blob, const BlobConf& blob_conf, int32_t col_id, - Blob* out_blob, std::function<int32_t(void)> NextRandomInt) const { + DeviceCtx* ctx, int32_t max_part_num, RecordBlob<OFRecord>* record_blob, + const BlobConf& blob_conf, int32_t col_id, Blob* out_blob, + std::function<int32_t(void)> NextRandomInt) const { int32_t max_col_id = 0; if (out_blob->has_col_num_field()) { max_col_id = ReadColNum(ctx, record_blob, blob_conf.name(), out_blob) - 1; } if (out_blob->has_data_id_field()) { ReadDataId(ctx, record_blob, out_blob); } - ReadDataContent(ctx, record_blob, blob_conf, col_id, out_blob, NextRandomInt); + ReadDataContent(ctx, max_part_num, record_blob, blob_conf, col_id, out_blob, NextRandomInt); return max_col_id; } @@ -127,12 +128,12 @@ void OFRecordDecoder<encode_case, T>::ReadDataId(DeviceCtx* ctx, RecordBlob<OFRe template<EncodeCase encode_case, typename T> void OFRecordDecoder<encode_case, T>::ReadDataContent( - DeviceCtx* ctx, RecordBlob<OFRecord>* record_blob, const BlobConf& blob_conf, int32_t col_id, - Blob* out_blob, std::function<int32_t(void)> NextRandomInt) const { + DeviceCtx* ctx, int32_t max_part_num, RecordBlob<OFRecord>* record_blob, + const BlobConf& blob_conf, int32_t col_id, Blob* out_blob, + std::function<int32_t(void)> NextRandomInt) const { int64_t one_col_elem_num = out_blob->shape().Count(1); int32_t random_seed = NextRandomInt(); - int32_t part_num = std::min(record_blob->record_num(), - Global<ThreadMgr>::Get()->compute_thread_pool()->thread_num()); + int32_t part_num = std::min(record_blob->record_num(), max_part_num); if (part_num >= 2) { BlockingCounter bc(part_num); FOR_RANGE(int32_t, part_id, 0, part_num) { diff --git a/oneflow/core/record/ofrecord_decoder.h b/oneflow/core/record/ofrecord_decoder.h index c71b37034..caefe96e0 100644 --- a/oneflow/core/record/ofrecord_decoder.h +++ b/oneflow/core/record/ofrecord_decoder.h @@ -12,8 +12,8 @@ class OFRecordDecoderIf { OF_DISALLOW_COPY_AND_MOVE(OFRecordDecoderIf); virtual ~OFRecordDecoderIf() = default; - virtual int32_t DecodeOneCol(DeviceCtx*, RecordBlob<OFRecord>*, const BlobConf&, - int32_t cur_col_id, Blob* out_blob, + virtual int32_t DecodeOneCol(DeviceCtx*, int32_t max_part_num, RecordBlob<OFRecord>*, + const BlobConf&, int32_t cur_col_id, Blob* out_blob, std::function<int32_t(void)> NextRandomInt) const = 0; protected: @@ -28,8 +28,9 @@ class OFRecordDecoder : public OFRecordDecoderIf { OF_DISALLOW_COPY_AND_MOVE(OFRecordDecoder); virtual ~OFRecordDecoder() = default; - int32_t DecodeOneCol(DeviceCtx*, RecordBlob<OFRecord>*, const BlobConf&, int32_t cur_col_id, - Blob* out_blob, std::function<int32_t(void)> NextRandomInt) const override; + int32_t DecodeOneCol(DeviceCtx*, int32_t max_part_num, RecordBlob<OFRecord>*, const BlobConf&, + int32_t cur_col_id, Blob* out_blob, + std::function<int32_t(void)> NextRandomInt) const override; protected: OFRecordDecoder() = default; @@ -43,8 +44,9 @@ class OFRecordDecoder : public OFRecordDecoderIf { int32_t ReadColNum(DeviceCtx*, RecordBlob<OFRecord>*, const std::string& name, Blob* out_blob) const; void ReadDataId(DeviceCtx*, RecordBlob<OFRecord>*, Blob* out_blob) const; - void ReadDataContent(DeviceCtx*, RecordBlob<OFRecord>*, const BlobConf&, int32_t col_id, - Blob* out_blob, std::function<int32_t(void)> NextRandomInt) const; + void ReadDataContent(DeviceCtx*, int32_t max_part_num, RecordBlob<OFRecord>*, const BlobConf&, + int32_t col_id, Blob* out_blob, + std::function<int32_t(void)> NextRandomInt) const; void ReadPartDataContent(DeviceCtx*, RecordBlob<OFRecord>*, const BlobConf&, int32_t col_id, Blob* out_blob, int32_t part_id, int32_t part_num, int64_t one_col_elem_num, int32_t random_seed) const; -- GitLab