diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp index 5bc210f9330af9d70cd7f0913aba44d4118f809d..2ba8be25eb3b16b6d12c91a16a23976d4bcfc986 100644 --- a/oneflow/core/actor/actor.cpp +++ b/oneflow/core/actor/actor.cpp @@ -103,7 +103,8 @@ void Actor::InitDeviceCtx(const ThreadCtx& thread_ctx) { cuda_handle = thread_ctx.g_cuda_stream.get(); } else { CHECK(Global<IDMgr>::Get()->IsIndependentLocalWorkStreamId(GetLocalWorkStreamId())); - cuda_handle = &cuda_handle_; + cuda_handle_.reset(new CudaStreamHandle(thread_ctx.cb_event_chan)); + cuda_handle = cuda_handle_.get(); } device_ctx_.reset(new CudaDeviceCtx(thread_ctx.buf_ptr, thread_ctx.buf_size, cuda_handle)); break; diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h index d920dec9c1a3ef986f7ceb87a4064a015537f8fa..17e89dad3332d18339762d2e955e4fa42eeb1eb4 100644 --- a/oneflow/core/actor/actor.h +++ b/oneflow/core/actor/actor.h @@ -134,7 +134,7 @@ class Actor { MsgHandler msg_handler_; std::unique_ptr<DeviceCtx> device_ctx_; HashSet<int64_t> eord_regst_desc_ids_; - CudaStreamHandle cuda_handle_; + std::unique_ptr<CudaStreamHandle> cuda_handle_; // Status of Produced Registers HashMap<int64_t, std::deque<Regst*>> writeable_produced_regst_; diff --git a/oneflow/core/device/cuda_device_context.cpp b/oneflow/core/device/cuda_device_context.cpp deleted file mode 100644 index 67854bf93184a328f7e545b260a8f40bfb152b04..0000000000000000000000000000000000000000 --- a/oneflow/core/device/cuda_device_context.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "oneflow/core/device/cuda_device_context.h" - -namespace oneflow { - -#ifdef WITH_CUDA - -namespace { - -void CUDART_CB CudaCallBackHandle(cudaStream_t, cudaError_t status, void* void_ptr) { - CudaCheck(status); - auto callback_ptr = static_cast<std::function<void()>*>(void_ptr); - (*callback_ptr)(); - delete callback_ptr; -} - -} // namespace - -void CudaDeviceCtx::AddCallBack(std::function<void()> callback_stack) const { - auto callback_heap = new std::function<void()>(callback_stack); - CudaCheck(cudaStreamAddCallback(cuda_stream(), &CudaCallBackHandle, callback_heap, 0)); -} - -#endif // WITH_CUDA - -} // namespace oneflow diff --git a/oneflow/core/device/cuda_device_context.h b/oneflow/core/device/cuda_device_context.h index fff6e8eec2f0f027cfd0f838289c40f89395262c..2685ff38c2a55359595aa8872f90133ad3744063 100644 --- a/oneflow/core/device/cuda_device_context.h +++ b/oneflow/core/device/cuda_device_context.h @@ -24,7 +24,9 @@ class CudaDeviceCtx final : public DeviceCtx { const cudnnHandle_t& cudnn_handle() const { return *(cuda_handler_->cudnn_handle()); } const Eigen::GpuDevice& eigen_gpu_device() const { return *(cuda_handler_->eigen_gpu_device()); } - void AddCallBack(std::function<void()> callback) const override; + void AddCallBack(std::function<void()> callback) const override { + cuda_handler_->AddCallBack(callback); + } private: CudaStreamHandle* cuda_handler_; diff --git a/oneflow/core/device/cuda_stream_handle.cpp b/oneflow/core/device/cuda_stream_handle.cpp index 1283a5a37d950e25ba7c7762363093d651f0fd45..80c127fb2ef2faba94728811e355be683397faf8 100644 --- a/oneflow/core/device/cuda_stream_handle.cpp +++ b/oneflow/core/device/cuda_stream_handle.cpp @@ -49,6 +49,15 @@ const Eigen::GpuDevice* CudaStreamHandle::eigen_gpu_device() { return eigen_gpu_device_.get(); } +void CudaStreamHandle::AddCallBack(std::function<void()> callback) { + CudaCBEvent cb_event; + cb_event.callback = callback; + CudaCheck( + cudaEventCreateWithFlags(&(cb_event.event), cudaEventBlockingSync | cudaEventDisableTiming)); + CudaCheck(cudaEventRecord(cb_event.event, *cuda_stream())); + cb_event_chan_->Send(cb_event); +} + CudaStreamHandle::~CudaStreamHandle() { if (cudnn_handle_) { CudaCheck(cudnnDestroy(*cudnn_handle_)); } if (cublas_pmh_handle_) { CudaCheck(cublasDestroy(*cublas_pmh_handle_)); } diff --git a/oneflow/core/device/cuda_stream_handle.h b/oneflow/core/device/cuda_stream_handle.h index 2baf855310a57c070e974f9c5f16159096217e00..731391da65538e7d1faa960ab957b2b2c16d9ba4 100644 --- a/oneflow/core/device/cuda_stream_handle.h +++ b/oneflow/core/device/cuda_stream_handle.h @@ -1,6 +1,7 @@ #ifndef ONEFLOW_CORE_DEVICE_CUDA_STREAM_HANDLE_H_ #define ONEFLOW_CORE_DEVICE_CUDA_STREAM_HANDLE_H_ +#include "oneflow/core/common/channel.h" #include "oneflow/core/device/cuda_util.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -8,10 +9,16 @@ namespace oneflow { #ifdef WITH_CUDA +struct CudaCBEvent { + std::function<void()> callback; + cudaEvent_t event; +}; + class CudaStreamHandle final { public: OF_DISALLOW_COPY_AND_MOVE(CudaStreamHandle); - CudaStreamHandle() = default; + CudaStreamHandle() = delete; + CudaStreamHandle(Channel<CudaCBEvent>* cb_event_chan) : cb_event_chan_(cb_event_chan) {} const cudaStream_t* cuda_stream(); const cublasHandle_t* cublas_pmh_handle(); @@ -19,9 +26,12 @@ class CudaStreamHandle final { const cudnnHandle_t* cudnn_handle(); const Eigen::GpuDevice* eigen_gpu_device(); + void AddCallBack(std::function<void()> callback); + ~CudaStreamHandle(); private: + Channel<CudaCBEvent>* cb_event_chan_; std::unique_ptr<cudaStream_t> cuda_stream_; std::unique_ptr<cublasHandle_t> cublas_pmh_handle_; std::unique_ptr<cublasHandle_t> cublas_pmd_handle_; diff --git a/oneflow/core/kernel/opkernel_test_case.cpp b/oneflow/core/kernel/opkernel_test_case.cpp index e445f885eab7270c0d426fd8a3899afafc5a05f1..8e2b61e65fd6170554a79bfa192ce8396ad0023d 100644 --- a/oneflow/core/kernel/opkernel_test_case.cpp +++ b/oneflow/core/kernel/opkernel_test_case.cpp @@ -120,7 +120,7 @@ void OpKernelTestUtil<DeviceType::kCPU>::BuildKernelCtx(KernelCtx* ctx) { template<> void OpKernelTestUtil<DeviceType::kGPU>::BuildKernelCtx(KernelCtx* ctx) { - if (!Global<CudaStreamHandle>::Get()) { Global<CudaStreamHandle>::New(); } + if (!Global<CudaStreamHandle>::Get()) { Global<CudaStreamHandle>::New(nullptr); } CudaStreamHandle* cuda_handle = Global<CudaStreamHandle>::Get(); ctx->device_ctx = new CudaDeviceCtx(nullptr, 0, cuda_handle); } diff --git a/oneflow/core/operator/conv_op.cpp b/oneflow/core/operator/conv_op.cpp index 5a61a350e9dd1bd0790681bf3e13da7f823a6f0c..6f0baca0228fc43731c6ee11f83aeb9251493750 100644 --- a/oneflow/core/operator/conv_op.cpp +++ b/oneflow/core/operator/conv_op.cpp @@ -268,7 +268,7 @@ template<int32_t NDims> void ConvOp<NDims>::InferCudnnAlgo( std::function<const BlobDesc*(const std::string)> GetBlobDesc4BnInOp, CudnnConvAlgoCtx* conv_ctx) const { - CudaStreamHandle cuda_handle; + CudaStreamHandle cuda_handle(nullptr); const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); const BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); diff --git a/oneflow/core/thread/cpu_thread.cpp b/oneflow/core/thread/cpu_thread.cpp index aa6edc3a7a2f0d63fcc16932a802fd8e11f7f303..02e645be0b8cf46d4a2b241aa20449f6b844cf0d 100644 --- a/oneflow/core/thread/cpu_thread.cpp +++ b/oneflow/core/thread/cpu_thread.cpp @@ -14,6 +14,7 @@ CpuThread::CpuThread(int64_t thrd_id, size_t buf_size) { ThreadCtx ctx; ctx.buf_ptr = buf_ptr; ctx.buf_size = buf_size; + ctx.cb_event_chan = nullptr; PollMsgChannel(ctx); } if (buf_ptr) { free(buf_ptr); } diff --git a/oneflow/core/thread/gpu_thread.cpp b/oneflow/core/thread/gpu_thread.cpp index de9fd1fce47ac7988196d9a742459dbce5d27df8..cb289d38639ed51e086a81bee6efe7beec33a19e 100644 --- a/oneflow/core/thread/gpu_thread.cpp +++ b/oneflow/core/thread/gpu_thread.cpp @@ -15,11 +15,26 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id, size_t buf_size) { ThreadCtx ctx; ctx.buf_ptr = buf_ptr; ctx.buf_size = buf_size; - ctx.g_cuda_stream.reset(new CudaStreamHandle); + ctx.g_cuda_stream.reset(new CudaStreamHandle(&cb_event_chan_)); + ctx.cb_event_chan = &cb_event_chan_; PollMsgChannel(ctx); } if (buf_ptr) { CudaCheck(cudaFree(buf_ptr)); } }); + cb_event_poller_ = std::thread([this]() { + CudaCBEvent cb_event; + while (cb_event_chan_.Receive(&cb_event) == 0) { + CudaCheck(cudaEventSynchronize(cb_event.event)); + cb_event.callback(); + CudaCheck(cudaEventDestroy(cb_event.event)); + } + }); +} + +GpuThread::~GpuThread() { + cb_event_chan_.CloseSendEnd(); + cb_event_chan_.CloseReceiveEnd(); + cb_event_poller_.join(); } #endif diff --git a/oneflow/core/thread/gpu_thread.h b/oneflow/core/thread/gpu_thread.h index a920862dd427bacf3445aecde1d814152a1366e3..e98ab7285b48a113d8395e33b1b3ec46461e4dff 100644 --- a/oneflow/core/thread/gpu_thread.h +++ b/oneflow/core/thread/gpu_thread.h @@ -11,11 +11,13 @@ class GpuThread final : public Thread { public: OF_DISALLOW_COPY_AND_MOVE(GpuThread); GpuThread() = delete; - ~GpuThread() = default; + ~GpuThread(); GpuThread(int64_t thrd_id, int64_t dev_id, size_t buf_size); private: + Channel<CudaCBEvent> cb_event_chan_; + std::thread cb_event_poller_; }; #endif diff --git a/oneflow/core/thread/thread_context.h b/oneflow/core/thread/thread_context.h index d9bd82c12c9e042d6c79f13bc1925718a3d0bf8c..f5edc491ad30b1407796ede8245f6d794b50968f 100644 --- a/oneflow/core/thread/thread_context.h +++ b/oneflow/core/thread/thread_context.h @@ -10,6 +10,7 @@ struct ThreadCtx { size_t buf_size; #ifdef WITH_CUDA std::unique_ptr<CudaStreamHandle> g_cuda_stream; + Channel<CudaCBEvent>* cb_event_chan; #endif };