diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp
index 5bc210f9330af9d70cd7f0913aba44d4118f809d..2ba8be25eb3b16b6d12c91a16a23976d4bcfc986 100644
--- a/oneflow/core/actor/actor.cpp
+++ b/oneflow/core/actor/actor.cpp
@@ -103,7 +103,8 @@ void Actor::InitDeviceCtx(const ThreadCtx& thread_ctx) {
         cuda_handle = thread_ctx.g_cuda_stream.get();
       } else {
         CHECK(Global<IDMgr>::Get()->IsIndependentLocalWorkStreamId(GetLocalWorkStreamId()));
-        cuda_handle = &cuda_handle_;
+        cuda_handle_.reset(new CudaStreamHandle(thread_ctx.cb_event_chan));
+        cuda_handle = cuda_handle_.get();
       }
       device_ctx_.reset(new CudaDeviceCtx(thread_ctx.buf_ptr, thread_ctx.buf_size, cuda_handle));
       break;
diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h
index d920dec9c1a3ef986f7ceb87a4064a015537f8fa..17e89dad3332d18339762d2e955e4fa42eeb1eb4 100644
--- a/oneflow/core/actor/actor.h
+++ b/oneflow/core/actor/actor.h
@@ -134,7 +134,7 @@ class Actor {
   MsgHandler msg_handler_;
   std::unique_ptr<DeviceCtx> device_ctx_;
   HashSet<int64_t> eord_regst_desc_ids_;
-  CudaStreamHandle cuda_handle_;
+  std::unique_ptr<CudaStreamHandle> cuda_handle_;
 
   // Status of Produced Registers
   HashMap<int64_t, std::deque<Regst*>> writeable_produced_regst_;
diff --git a/oneflow/core/device/cuda_device_context.cpp b/oneflow/core/device/cuda_device_context.cpp
deleted file mode 100644
index 67854bf93184a328f7e545b260a8f40bfb152b04..0000000000000000000000000000000000000000
--- a/oneflow/core/device/cuda_device_context.cpp
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "oneflow/core/device/cuda_device_context.h"
-
-namespace oneflow {
-
-#ifdef WITH_CUDA
-
-namespace {
-
-void CUDART_CB CudaCallBackHandle(cudaStream_t, cudaError_t status, void* void_ptr) {
-  CudaCheck(status);
-  auto callback_ptr = static_cast<std::function<void()>*>(void_ptr);
-  (*callback_ptr)();
-  delete callback_ptr;
-}
-
-}  // namespace
-
-void CudaDeviceCtx::AddCallBack(std::function<void()> callback_stack) const {
-  auto callback_heap = new std::function<void()>(callback_stack);
-  CudaCheck(cudaStreamAddCallback(cuda_stream(), &CudaCallBackHandle, callback_heap, 0));
-}
-
-#endif  // WITH_CUDA
-
-}  // namespace oneflow
diff --git a/oneflow/core/device/cuda_device_context.h b/oneflow/core/device/cuda_device_context.h
index fff6e8eec2f0f027cfd0f838289c40f89395262c..2685ff38c2a55359595aa8872f90133ad3744063 100644
--- a/oneflow/core/device/cuda_device_context.h
+++ b/oneflow/core/device/cuda_device_context.h
@@ -24,7 +24,9 @@ class CudaDeviceCtx final : public DeviceCtx {
   const cudnnHandle_t& cudnn_handle() const { return *(cuda_handler_->cudnn_handle()); }
   const Eigen::GpuDevice& eigen_gpu_device() const { return *(cuda_handler_->eigen_gpu_device()); }
 
-  void AddCallBack(std::function<void()> callback) const override;
+  void AddCallBack(std::function<void()> callback) const override {
+    cuda_handler_->AddCallBack(callback);
+  }
 
  private:
   CudaStreamHandle* cuda_handler_;
diff --git a/oneflow/core/device/cuda_stream_handle.cpp b/oneflow/core/device/cuda_stream_handle.cpp
index 1283a5a37d950e25ba7c7762363093d651f0fd45..80c127fb2ef2faba94728811e355be683397faf8 100644
--- a/oneflow/core/device/cuda_stream_handle.cpp
+++ b/oneflow/core/device/cuda_stream_handle.cpp
@@ -49,6 +49,15 @@ const Eigen::GpuDevice* CudaStreamHandle::eigen_gpu_device() {
   return eigen_gpu_device_.get();
 }
 
+void CudaStreamHandle::AddCallBack(std::function<void()> callback) {
+  CudaCBEvent cb_event;
+  cb_event.callback = callback;
+  CudaCheck(
+      cudaEventCreateWithFlags(&(cb_event.event), cudaEventBlockingSync | cudaEventDisableTiming));
+  CudaCheck(cudaEventRecord(cb_event.event, *cuda_stream()));
+  cb_event_chan_->Send(cb_event);
+}
+
 CudaStreamHandle::~CudaStreamHandle() {
   if (cudnn_handle_) { CudaCheck(cudnnDestroy(*cudnn_handle_)); }
   if (cublas_pmh_handle_) { CudaCheck(cublasDestroy(*cublas_pmh_handle_)); }
diff --git a/oneflow/core/device/cuda_stream_handle.h b/oneflow/core/device/cuda_stream_handle.h
index 2baf855310a57c070e974f9c5f16159096217e00..731391da65538e7d1faa960ab957b2b2c16d9ba4 100644
--- a/oneflow/core/device/cuda_stream_handle.h
+++ b/oneflow/core/device/cuda_stream_handle.h
@@ -1,6 +1,7 @@
 #ifndef ONEFLOW_CORE_DEVICE_CUDA_STREAM_HANDLE_H_
 #define ONEFLOW_CORE_DEVICE_CUDA_STREAM_HANDLE_H_
 
+#include "oneflow/core/common/channel.h"
 #include "oneflow/core/device/cuda_util.h"
 #include "unsupported/Eigen/CXX11/Tensor"
 
@@ -8,10 +9,16 @@ namespace oneflow {
 
 #ifdef WITH_CUDA
 
+struct CudaCBEvent {
+  std::function<void()> callback;
+  cudaEvent_t event;
+};
+
 class CudaStreamHandle final {
  public:
   OF_DISALLOW_COPY_AND_MOVE(CudaStreamHandle);
-  CudaStreamHandle() = default;
+  CudaStreamHandle() = delete;
+  CudaStreamHandle(Channel<CudaCBEvent>* cb_event_chan) : cb_event_chan_(cb_event_chan) {}
 
   const cudaStream_t* cuda_stream();
   const cublasHandle_t* cublas_pmh_handle();
@@ -19,9 +26,12 @@ class CudaStreamHandle final {
   const cudnnHandle_t* cudnn_handle();
   const Eigen::GpuDevice* eigen_gpu_device();
 
+  void AddCallBack(std::function<void()> callback);
+
   ~CudaStreamHandle();
 
  private:
+  Channel<CudaCBEvent>* cb_event_chan_;
   std::unique_ptr<cudaStream_t> cuda_stream_;
   std::unique_ptr<cublasHandle_t> cublas_pmh_handle_;
   std::unique_ptr<cublasHandle_t> cublas_pmd_handle_;
diff --git a/oneflow/core/kernel/opkernel_test_case.cpp b/oneflow/core/kernel/opkernel_test_case.cpp
index e445f885eab7270c0d426fd8a3899afafc5a05f1..8e2b61e65fd6170554a79bfa192ce8396ad0023d 100644
--- a/oneflow/core/kernel/opkernel_test_case.cpp
+++ b/oneflow/core/kernel/opkernel_test_case.cpp
@@ -120,7 +120,7 @@ void OpKernelTestUtil<DeviceType::kCPU>::BuildKernelCtx(KernelCtx* ctx) {
 
 template<>
 void OpKernelTestUtil<DeviceType::kGPU>::BuildKernelCtx(KernelCtx* ctx) {
-  if (!Global<CudaStreamHandle>::Get()) { Global<CudaStreamHandle>::New(); }
+  if (!Global<CudaStreamHandle>::Get()) { Global<CudaStreamHandle>::New(nullptr); }
   CudaStreamHandle* cuda_handle = Global<CudaStreamHandle>::Get();
   ctx->device_ctx = new CudaDeviceCtx(nullptr, 0, cuda_handle);
 }
diff --git a/oneflow/core/operator/conv_op.cpp b/oneflow/core/operator/conv_op.cpp
index 5a61a350e9dd1bd0790681bf3e13da7f823a6f0c..6f0baca0228fc43731c6ee11f83aeb9251493750 100644
--- a/oneflow/core/operator/conv_op.cpp
+++ b/oneflow/core/operator/conv_op.cpp
@@ -268,7 +268,7 @@ template<int32_t NDims>
 void ConvOp<NDims>::InferCudnnAlgo(
     std::function<const BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
     CudnnConvAlgoCtx* conv_ctx) const {
-  CudaStreamHandle cuda_handle;
+  CudaStreamHandle cuda_handle(nullptr);
 
   const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
   const BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
diff --git a/oneflow/core/thread/cpu_thread.cpp b/oneflow/core/thread/cpu_thread.cpp
index aa6edc3a7a2f0d63fcc16932a802fd8e11f7f303..02e645be0b8cf46d4a2b241aa20449f6b844cf0d 100644
--- a/oneflow/core/thread/cpu_thread.cpp
+++ b/oneflow/core/thread/cpu_thread.cpp
@@ -14,6 +14,7 @@ CpuThread::CpuThread(int64_t thrd_id, size_t buf_size) {
       ThreadCtx ctx;
       ctx.buf_ptr = buf_ptr;
       ctx.buf_size = buf_size;
+      ctx.cb_event_chan = nullptr;
       PollMsgChannel(ctx);
     }
     if (buf_ptr) { free(buf_ptr); }
diff --git a/oneflow/core/thread/gpu_thread.cpp b/oneflow/core/thread/gpu_thread.cpp
index de9fd1fce47ac7988196d9a742459dbce5d27df8..cb289d38639ed51e086a81bee6efe7beec33a19e 100644
--- a/oneflow/core/thread/gpu_thread.cpp
+++ b/oneflow/core/thread/gpu_thread.cpp
@@ -15,11 +15,26 @@ GpuThread::GpuThread(int64_t thrd_id, int64_t dev_id, size_t buf_size) {
       ThreadCtx ctx;
       ctx.buf_ptr = buf_ptr;
       ctx.buf_size = buf_size;
-      ctx.g_cuda_stream.reset(new CudaStreamHandle);
+      ctx.g_cuda_stream.reset(new CudaStreamHandle(&cb_event_chan_));
+      ctx.cb_event_chan = &cb_event_chan_;
       PollMsgChannel(ctx);
     }
     if (buf_ptr) { CudaCheck(cudaFree(buf_ptr)); }
   });
+  cb_event_poller_ = std::thread([this]() {
+    CudaCBEvent cb_event;
+    while (cb_event_chan_.Receive(&cb_event) == 0) {
+      CudaCheck(cudaEventSynchronize(cb_event.event));
+      cb_event.callback();
+      CudaCheck(cudaEventDestroy(cb_event.event));
+    }
+  });
+}
+
+GpuThread::~GpuThread() {
+  cb_event_chan_.CloseSendEnd();
+  cb_event_chan_.CloseReceiveEnd();
+  cb_event_poller_.join();
 }
 
 #endif
diff --git a/oneflow/core/thread/gpu_thread.h b/oneflow/core/thread/gpu_thread.h
index a920862dd427bacf3445aecde1d814152a1366e3..e98ab7285b48a113d8395e33b1b3ec46461e4dff 100644
--- a/oneflow/core/thread/gpu_thread.h
+++ b/oneflow/core/thread/gpu_thread.h
@@ -11,11 +11,13 @@ class GpuThread final : public Thread {
  public:
   OF_DISALLOW_COPY_AND_MOVE(GpuThread);
   GpuThread() = delete;
-  ~GpuThread() = default;
+  ~GpuThread();
 
   GpuThread(int64_t thrd_id, int64_t dev_id, size_t buf_size);
 
  private:
+  Channel<CudaCBEvent> cb_event_chan_;
+  std::thread cb_event_poller_;
 };
 
 #endif
diff --git a/oneflow/core/thread/thread_context.h b/oneflow/core/thread/thread_context.h
index d9bd82c12c9e042d6c79f13bc1925718a3d0bf8c..f5edc491ad30b1407796ede8245f6d794b50968f 100644
--- a/oneflow/core/thread/thread_context.h
+++ b/oneflow/core/thread/thread_context.h
@@ -10,6 +10,7 @@ struct ThreadCtx {
   size_t buf_size;
 #ifdef WITH_CUDA
   std::unique_ptr<CudaStreamHandle> g_cuda_stream;
+  Channel<CudaCBEvent>* cb_event_chan;
 #endif
 };