diff --git a/CMakeLists.txt b/CMakeLists.txt
index 96fd893d002b7d00ea2719a807ff768a2f690434..f9883c38cf37f188f5ec9a018ef9414385b5562d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -40,7 +40,7 @@ if (WIN32)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /D_ITERATOR_DEBUG_LEVEL=0")
else()
list(APPEND CUDA_NVCC_FLAGS -std=c++11 -w -Wno-deprecated-gpu-targets)
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare -Wno-unknown-pragmas -fopenmp")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wno-sign-compare")
if (RELEASE_VERSION)
list(APPEND CUDA_NVCC_FLAGS -O3)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
diff --git a/benchmark/alexnet/report.md b/benchmark/alexnet/report.md
index 946cd9cc1438e53ddbdc849202e5ab855760e06c..ac5029d8dd1c1d5b34f22873e297a700be08cad4 100644
--- a/benchmark/alexnet/report.md
+++ b/benchmark/alexnet/report.md
@@ -2,6 +2,6 @@ batch_size: 1024
gpu num | time (one batch)
:-------| :-------------
-1 | 541ms
-2 | 282ms
-4 | 207ms
+1 | 538ms
+2 | 285ms
+4 | 189ms
diff --git a/oneflow/core/kernel/boxing_kernel.cpp b/oneflow/core/kernel/boxing_kernel.cpp
index 42c8404ec97e79652d8e0b10cef92849a45b3d0f..8f2f8981cc2c3ca9262fe2b24d66997b9d056301 100644
--- a/oneflow/core/kernel/boxing_kernel.cpp
+++ b/oneflow/core/kernel/boxing_kernel.cpp
@@ -1,8 +1,8 @@
#include "oneflow/core/kernel/boxing_kernel.h"
-#include <omp.h>
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/common/balanced_splitter.h"
+#include "oneflow/core/thread/thread_manager.h"
namespace oneflow {
@@ -83,6 +83,45 @@ class DataContentDesc final {
int32_t axis_;
};
+void ConcatSplitPartDataContent(DeviceCtx* ctx, const DataContentDesc& in_desc,
+ const DataContentDesc& out_desc, int32_t part_id,
+ int32_t part_num) {
+ size_t one_elem_size = in_desc.OneElemSize();
+ BalancedSplitter bs(in_desc.TotalElemNum(), part_num);
+ Range range = bs.At(part_id);
+ int64_t in_idx = range.begin();
+ int64_t in_elem_num = 0;
+ char* in_ptr = nullptr;
+ int64_t out_idx = range.begin();
+ int64_t out_elem_num = 0;
+ char* out_ptr = nullptr;
+ while (in_elem_num > 0 || out_elem_num > 0 || in_idx < range.end() || out_idx < range.end()) {
+ if (in_elem_num == 0) {
+ std::tie(in_elem_num, in_ptr) = in_desc.CalcContinuousElemNumStartFrom(in_idx);
+ in_elem_num = std::min(in_elem_num, range.end() - in_idx);
+ if (in_elem_num == 0) { break; }
+ in_idx += in_elem_num;
+ }
+ if (out_elem_num == 0) {
+ std::tie(out_elem_num, out_ptr) = out_desc.CalcContinuousElemNumStartFrom(out_idx);
+ out_elem_num = std::min(out_elem_num, range.end() - out_idx);
+ if (out_elem_num == 0) { break; }
+ out_idx += out_elem_num;
+ }
+ int64_t copy_elem_num = std::min(in_elem_num, out_elem_num);
+ size_t copy_size = copy_elem_num * one_elem_size;
+ Memcpy<DeviceType::kCPU>(ctx, out_ptr, in_ptr, copy_size);
+ in_elem_num -= copy_elem_num;
+ out_elem_num -= copy_elem_num;
+ in_ptr += copy_size;
+ out_ptr += copy_size;
+ }
+ CHECK_EQ(in_elem_num, 0);
+ CHECK_EQ(out_elem_num, 0);
+ CHECK_EQ(in_idx, range.end());
+ CHECK_EQ(out_idx, range.end());
+}
+
void ConcatSplitDataContent(DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob,
const PbRpf<std::string>& concat_bns, int32_t concat_axis,
const PbRpf<std::string>& split_bns, int32_t split_axis) {
@@ -90,42 +129,21 @@ void ConcatSplitDataContent(DeviceCtx* ctx, std::function<Blob*(const std::strin
DataContentDesc out_desc(BnInOp2Blob, &split_bns, split_axis);
CHECK_EQ(in_desc.TotalElemNum(), out_desc.TotalElemNum());
CHECK_EQ(in_desc.OneElemSize(), out_desc.OneElemSize());
-#pragma omp parallel
- {
- size_t one_elem_size = in_desc.OneElemSize();
- BalancedSplitter bs(in_desc.TotalElemNum(), omp_get_num_threads());
- Range range = bs.At(omp_get_thread_num());
- int64_t in_idx = range.begin();
- int64_t in_elem_num = 0;
- char* in_ptr = nullptr;
- int64_t out_idx = range.begin();
- int64_t out_elem_num = 0;
- char* out_ptr = nullptr;
- while (in_elem_num > 0 || out_elem_num > 0 || in_idx < range.end() || out_idx < range.end()) {
- if (in_elem_num == 0) {
- std::tie(in_elem_num, in_ptr) = in_desc.CalcContinuousElemNumStartFrom(in_idx);
- in_elem_num = std::min(in_elem_num, range.end() - in_idx);
- if (in_elem_num == 0) { break; }
- in_idx += in_elem_num;
- }
- if (out_elem_num == 0) {
- std::tie(out_elem_num, out_ptr) = out_desc.CalcContinuousElemNumStartFrom(out_idx);
- out_elem_num = std::min(out_elem_num, range.end() - out_idx);
- if (out_elem_num == 0) { break; }
- out_idx += out_elem_num;
- }
- int64_t copy_elem_num = std::min(in_elem_num, out_elem_num);
- size_t copy_size = copy_elem_num * one_elem_size;
- Memcpy<DeviceType::kCPU>(ctx, out_ptr, in_ptr, copy_size);
- in_elem_num -= copy_elem_num;
- out_elem_num -= copy_elem_num;
- in_ptr += copy_size;
- out_ptr += copy_size;
+ static const size_t min_byte_one_part = 128;
+ int32_t part_num = in_desc.TotalElemNum() * in_desc.OneElemSize() / min_byte_one_part;
+ part_num = std::min(part_num, Global<ThreadMgr>::Get()->compute_thread_pool()->thread_num());
+ if (part_num >= 2) {
+ BlockingCounter bc(part_num);
+ FOR_RANGE(int32_t, part_id, 0, part_num) {
+ Global<ThreadMgr>::Get()->compute_thread_pool()->AddWork(
+ [&ctx, &in_desc, &out_desc, part_id, &part_num, &bc]() {
+ ConcatSplitPartDataContent(ctx, in_desc, out_desc, part_id, part_num);
+ bc.Decrease();
+ });
}
- CHECK_EQ(in_elem_num, 0);
- CHECK_EQ(out_elem_num, 0);
- CHECK_EQ(in_idx, range.end());
- CHECK_EQ(out_idx, range.end());
+ bc.WaitUntilCntEqualZero();
+ } else {
+ ConcatSplitPartDataContent(ctx, in_desc, out_desc, 0, 1);
}
}
diff --git a/oneflow/core/kernel/boxing_kernel.h b/oneflow/core/kernel/boxing_kernel.h
index 9d613d9d096e2984fe205e8ee46ea87a56ebe9b6..7d969cee0c37bbe1b1e3aa0e8850a52095750151 100644
--- a/oneflow/core/kernel/boxing_kernel.h
+++ b/oneflow/core/kernel/boxing_kernel.h
@@ -22,6 +22,7 @@ class BoxingKernel final : public KernelIf<DeviceType::kCPU> {
void ForwardColNum(const KernelCtx&, std::function<Blob*(const std::string&)>) const override;
void SetColId(const KernelCtx&, std::function<Blob*(const std::string&)>) const;
void SetMaxColId(const KernelCtx&, std::function<Blob*(const std::string&)>) const;
+
PbRpf<std::string> ibn_0_;
PbRpf<std::string> obn_0_;
};
diff --git a/oneflow/core/kernel/decode_ofrecord_kernel.cpp b/oneflow/core/kernel/decode_ofrecord_kernel.cpp
index 8d0193985ea1e7285baeca3983c3adac5f28210e..08f5108b0e230b5e2e91627049a4b8839f1edb0c 100644
--- a/oneflow/core/kernel/decode_ofrecord_kernel.cpp
+++ b/oneflow/core/kernel/decode_ofrecord_kernel.cpp
@@ -1,5 +1,6 @@
#include "oneflow/core/kernel/decode_ofrecord_kernel.h"
#include "oneflow/core/record/ofrecord_decoder.h"
+#include "oneflow/core/thread/thread_manager.h"
namespace oneflow {
diff --git a/oneflow/core/record/ofrecord_decoder.cpp b/oneflow/core/record/ofrecord_decoder.cpp
index 933d2240af96e7b847869e898936a1d12368ae4f..717de59b873d65c7f281693200658369303f7dd0 100644
--- a/oneflow/core/record/ofrecord_decoder.cpp
+++ b/oneflow/core/record/ofrecord_decoder.cpp
@@ -2,6 +2,7 @@
#include "oneflow/core/record/ofrecord_raw_decoder.h"
#include "oneflow/core/record/ofrecord_jpeg_decoder.h"
#include "oneflow/core/common/balanced_splitter.h"
+#include "oneflow/core/thread/thread_manager.h"
namespace oneflow {
@@ -130,30 +131,23 @@ void OFRecordDecoder<encode_case, T>::ReadDataContent(
Blob* out_blob, std::function<int32_t(void)> NextRandomInt) const {
int64_t one_col_elem_num = out_blob->shape().Count(1);
int32_t random_seed = NextRandomInt();
-#pragma omp parallel
- {
- BalancedSplitter bs(record_blob->record_num(), omp_get_num_threads());
- Range range = bs.At(omp_get_thread_num());
- if (range.size() > 0) {
- std::mt19937 gen(random_seed + omp_get_thread_num());
- std::uniform_int_distribution<int32_t> distribution;
- FOR_RANGE(int32_t, i, range.begin(), range.end()) {
- const OFRecord& record = record_blob->GetRecord(i);
- CHECK(record.feature().find(blob_conf.name()) != record.feature().end())
- << "Field " << blob_conf.name() << " not found";
- const Feature& feature = record.feature().at(blob_conf.name());
- T* out_dptr = out_blob->mut_dptr<T>() + i * one_col_elem_num;
- if (col_id < out_blob->col_num(i)) {
- ReadOneCol(ctx, feature, blob_conf, col_id, out_dptr, one_col_elem_num,
- [&]() { return distribution(gen); });
- FOR_RANGE(size_t, j, 0, blob_conf.preprocess_size()) {
- DoPreprocess<T>(blob_conf.preprocess(j), out_dptr, out_blob->shape());
- }
- } else {
- Memset<DeviceType::kCPU>(ctx, out_dptr, 0, one_col_elem_num * sizeof(T));
- }
- }
+ int32_t part_num = std::min(record_blob->record_num(),
+ Global<ThreadMgr>::Get()->compute_thread_pool()->thread_num());
+ if (part_num >= 2) {
+ BlockingCounter bc(part_num);
+ FOR_RANGE(int32_t, part_id, 0, part_num) {
+ Global<ThreadMgr>::Get()->compute_thread_pool()->AddWork(
+ [&ctx, &record_blob, &blob_conf, &col_id, &out_blob, &bc, part_id, &part_num,
+ &one_col_elem_num, &random_seed, this]() {
+ ReadPartDataContent(ctx, record_blob, blob_conf, col_id, out_blob, part_id, part_num,
+ one_col_elem_num, random_seed);
+ bc.Decrease();
+ });
}
+ bc.WaitUntilCntEqualZero();
+ } else {
+ ReadPartDataContent(ctx, record_blob, blob_conf, col_id, out_blob, 0, 1, one_col_elem_num,
+ random_seed);
}
int64_t left_row_num = out_blob->shape().At(0) - record_blob->record_num();
if (left_row_num > 0) {
@@ -163,6 +157,33 @@ void OFRecordDecoder<encode_case, T>::ReadDataContent(
}
}
+template<EncodeCase encode_case, typename T>
+void OFRecordDecoder<encode_case, T>::ReadPartDataContent(
+ DeviceCtx* ctx, RecordBlob<OFRecord>* record_blob, const BlobConf& blob_conf, int32_t col_id,
+ Blob* out_blob, int32_t part_id, int32_t part_num, int64_t one_col_elem_num,
+ int32_t random_seed) const {
+ BalancedSplitter bs(record_blob->record_num(), part_num);
+ Range range = bs.At(part_id);
+ std::mt19937 gen(random_seed + part_id);
+ std::uniform_int_distribution<int32_t> distribution;
+ FOR_RANGE(int32_t, i, range.begin(), range.end()) {
+ const OFRecord& record = record_blob->GetRecord(i);
+ CHECK(record.feature().find(blob_conf.name()) != record.feature().end())
+ << "Field " << blob_conf.name() << " not found";
+ const Feature& feature = record.feature().at(blob_conf.name());
+ T* out_dptr = out_blob->mut_dptr<T>() + i * one_col_elem_num;
+ if (col_id < out_blob->col_num(i)) {
+ ReadOneCol(ctx, feature, blob_conf, col_id, out_dptr, one_col_elem_num,
+ [&]() { return distribution(gen); });
+ FOR_RANGE(size_t, j, 0, blob_conf.preprocess_size()) {
+ DoPreprocess<T>(blob_conf.preprocess(j), out_dptr, out_blob->shape());
+ }
+ } else {
+ Memset<DeviceType::kCPU>(ctx, out_dptr, 0, one_col_elem_num * sizeof(T));
+ }
+ }
+}
+
OFRecordDecoderIf* GetOFRecordDecoder(EncodeCase encode_case, DataType data_type) {
static const HashMap<std::string, OFRecordDecoderIf*> obj = {
diff --git a/oneflow/core/record/ofrecord_decoder.h b/oneflow/core/record/ofrecord_decoder.h
index 54f0dbc49c1820ccb2e3718a55d97c51c3987551..c71b37034174e99c597c39992a1b1f0ae641542f 100644
--- a/oneflow/core/record/ofrecord_decoder.h
+++ b/oneflow/core/record/ofrecord_decoder.h
@@ -45,6 +45,9 @@ class OFRecordDecoder : public OFRecordDecoderIf {
void ReadDataId(DeviceCtx*, RecordBlob<OFRecord>*, Blob* out_blob) const;
void ReadDataContent(DeviceCtx*, RecordBlob<OFRecord>*, const BlobConf&, int32_t col_id,
Blob* out_blob, std::function<int32_t(void)> NextRandomInt) const;
+ void ReadPartDataContent(DeviceCtx*, RecordBlob<OFRecord>*, const BlobConf&, int32_t col_id,
+ Blob* out_blob, int32_t part_id, int32_t part_num,
+ int64_t one_col_elem_num, int32_t random_seed) const;
};
template<EncodeCase encode_case, typename T>
diff --git a/oneflow/core/thread/thread_manager.cpp b/oneflow/core/thread/thread_manager.cpp
index 3cb5e5421491ebf8a44a4c3b9edcb862e1d2827c..db4e40f65e90501920f452e1c8257b0e6239c63c 100644
--- a/oneflow/core/thread/thread_manager.cpp
+++ b/oneflow/core/thread/thread_manager.cpp
@@ -38,6 +38,7 @@ ThreadMgr::ThreadMgr(const Plan& plan) {
threads_.push_back(new CpuThread(thrd_id++, 0));
}
threads_.push_back(new CpuThread(thrd_id++, 0)); // comm_net
+ compute_thread_pool_.reset(new ThreadPool(job_desc->CpuDeviceNum()));
}
} // namespace oneflow
diff --git a/oneflow/core/thread/thread_manager.h b/oneflow/core/thread/thread_manager.h
index 1f5417692449606af973a607ae54d6097c8a98ed..7a1119b11d56d0f48e0f81ef21b2275ad1449a4c 100644
--- a/oneflow/core/thread/thread_manager.h
+++ b/oneflow/core/thread/thread_manager.h
@@ -5,6 +5,7 @@
#include "oneflow/core/common/channel.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/thread/thread.h"
+#include "oneflow/core/thread/thread_pool.h"
namespace oneflow {
@@ -16,11 +17,14 @@ class ThreadMgr final {
Thread* GetThrd(int64_t thrd_id);
+ ThreadPool* compute_thread_pool() { return compute_thread_pool_.get(); }
+
private:
friend class Global<ThreadMgr>;
ThreadMgr(const Plan& plan);
std::vector<Thread*> threads_;
+ std::unique_ptr<ThreadPool> compute_thread_pool_;
};
} // namespace oneflow
diff --git a/oneflow/core/thread/thread_pool.cpp b/oneflow/core/thread/thread_pool.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5a124f830ebaada0637daa4cf293d34fd623daf2
--- /dev/null
+++ b/oneflow/core/thread/thread_pool.cpp
@@ -0,0 +1,34 @@
+#include "oneflow/core/thread/thread_pool.h"
+
+namespace oneflow {
+
+ThreadPool::ThreadPool(int32_t thread_num)
+ : work_chans_(thread_num), threads_(thread_num), cur_chan_idx_(0) {
+ FOR_RANGE(int32_t, i, 0, thread_num) {
+ Channel<std::function<void()>>* chan = &(work_chans_.at(i));
+ threads_[i] = std::thread([chan]() {
+ std::function<void()> work;
+ while (chan->Receive(&work) == 0) { work(); }
+ });
+ }
+}
+
+ThreadPool::~ThreadPool() {
+ FOR_RANGE(int32_t, i, 0, work_chans_.size()) {
+ work_chans_.at(i).CloseSendEnd();
+ work_chans_.at(i).CloseReceiveEnd();
+ threads_.at(i).join();
+ }
+}
+
+void ThreadPool::AddWork(std::function<void()> work) {
+ if (work_chans_.size() > 1) {
+ std::unique_lock<std::mutex> lck(cur_chan_idx_mtx_);
+ work_chans_.at(cur_chan_idx_).Send(work);
+ cur_chan_idx_ = (cur_chan_idx_ + 1) % work_chans_.size();
+ } else {
+ work_chans_.at(cur_chan_idx_).Send(work);
+ }
+}
+
+} // namespace oneflow
diff --git a/oneflow/core/thread/thread_pool.h b/oneflow/core/thread/thread_pool.h
new file mode 100644
index 0000000000000000000000000000000000000000..b3adedf02e5e9fdfd7adc8af7cd92162aee0ec1e
--- /dev/null
+++ b/oneflow/core/thread/thread_pool.h
@@ -0,0 +1,29 @@
+#ifndef ONEFLOW_CORE_THREAD_THREAD_POOL_H_
+#define ONEFLOW_CORE_THREAD_THREAD_POOL_H_
+
+#include "oneflow/core/common/util.h"
+#include "oneflow/core/common/channel.h"
+
+namespace oneflow {
+
+class ThreadPool final {
+ public:
+ OF_DISALLOW_COPY_AND_MOVE(ThreadPool);
+ ThreadPool() = delete;
+ ThreadPool(int32_t thread_num);
+ ~ThreadPool();
+
+ int32_t thread_num() const { return threads_.size(); }
+ void AddWork(std::function<void()> work);
+
+ private:
+ std::vector<Channel<std::function<void()>>> work_chans_;
+ std::vector<std::thread> threads_;
+
+ std::mutex cur_chan_idx_mtx_;
+ int32_t cur_chan_idx_;
+};
+
+} // namespace oneflow
+
+#endif // ONEFLOW_CORE_THREAD_THREAD_POOL_H_