From b23d18f0dcad0969f05eda75b1304808dced864c Mon Sep 17 00:00:00 2001
From: Jinhui Yuan <yuan.ms2@gmail.com>
Date: Thu, 2 Aug 2018 12:12:00 +0800
Subject: [PATCH] fix empty bw_cudnn_buf (#1075)

---
 oneflow/core/kernel/conv_kernel.cu | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/oneflow/core/kernel/conv_kernel.cu b/oneflow/core/kernel/conv_kernel.cu
index 95a56dccf..3d1cca161 100644
--- a/oneflow/core/kernel/conv_kernel.cu
+++ b/oneflow/core/kernel/conv_kernel.cu
@@ -119,13 +119,15 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn(
     Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* weight_blob = BnInOp2Blob("weight");
   Blob* bw_cudnn_buf = BnInOp2Blob("bw_cudnn_buf");
+  void* bw_cudnn_buf_ptr = bw_cudnn_buf ? bw_cudnn_buf->mut_dptr() : nullptr;
+  size_t bw_cudnn_buf_size = bw_cudnn_buf ? bw_cudnn_buf->ByteSizeOfDataContentField() : 0;
   CudaCheck(cudnnConvolutionBackwardFilter(
       device_ctx->cudnn_handle(), OnePtr<T>::value, this->in_desc_->Get(), in_blob->dptr<T>(),
       this->out_desc_->Get(), out_diff_blob->dptr<T>(), this->conv_desc_->Get(),
       static_cast<cudnnConvolutionBwdFilterAlgo_t>(
           this->GetConvKernelConf().cudnn_bwd_filter_algo()),
-      bw_cudnn_buf->mut_dptr(), bw_cudnn_buf->ByteSizeOfDataContentField(), ZeroPtr<T>::value,
-      this->filter_desc_->Get(), weight_diff_blob->mut_dptr<T>()));
+      bw_cudnn_buf_ptr, bw_cudnn_buf_size, ZeroPtr<T>::value, this->filter_desc_->Get(),
+      weight_diff_blob->mut_dptr<T>()));
 
   if (in_diff_blob != nullptr) {
     CudaCheck(cudnnConvolutionBackwardData(
@@ -133,8 +135,8 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn(
         weight_blob->dptr<T>(), this->out_desc_->Get(), out_diff_blob->dptr<T>(),
         this->conv_desc_->Get(),
         static_cast<cudnnConvolutionBwdDataAlgo_t>(this->GetConvKernelConf().cudnn_bwd_data_algo()),
-        bw_cudnn_buf->mut_dptr(), bw_cudnn_buf->ByteSizeOfDataContentField(), ZeroPtr<T>::value,
-        this->in_desc_->Get(), in_diff_blob->mut_dptr<T>()));
+        bw_cudnn_buf_ptr, bw_cudnn_buf_size, ZeroPtr<T>::value, this->in_desc_->Get(),
+        in_diff_blob->mut_dptr<T>()));
   }
 }
 
-- 
GitLab