diff --git a/oneflow/core/kernel/conv_kernel.cu b/oneflow/core/kernel/conv_kernel.cu index 95a56dccf0401c35c385034dfd74ba2352e74b21..3d1cca16158bc8a4ac585ac400e0a6a33db7547d 100644 --- a/oneflow/core/kernel/conv_kernel.cu +++ b/oneflow/core/kernel/conv_kernel.cu @@ -119,13 +119,15 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn( Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* weight_blob = BnInOp2Blob("weight"); Blob* bw_cudnn_buf = BnInOp2Blob("bw_cudnn_buf"); + void* bw_cudnn_buf_ptr = bw_cudnn_buf ? bw_cudnn_buf->mut_dptr() : nullptr; + size_t bw_cudnn_buf_size = bw_cudnn_buf ? bw_cudnn_buf->ByteSizeOfDataContentField() : 0; CudaCheck(cudnnConvolutionBackwardFilter( device_ctx->cudnn_handle(), OnePtr<T>::value, this->in_desc_->Get(), in_blob->dptr<T>(), this->out_desc_->Get(), out_diff_blob->dptr<T>(), this->conv_desc_->Get(), static_cast<cudnnConvolutionBwdFilterAlgo_t>( this->GetConvKernelConf().cudnn_bwd_filter_algo()), - bw_cudnn_buf->mut_dptr(), bw_cudnn_buf->ByteSizeOfDataContentField(), ZeroPtr<T>::value, - this->filter_desc_->Get(), weight_diff_blob->mut_dptr<T>())); + bw_cudnn_buf_ptr, bw_cudnn_buf_size, ZeroPtr<T>::value, this->filter_desc_->Get(), + weight_diff_blob->mut_dptr<T>())); if (in_diff_blob != nullptr) { CudaCheck(cudnnConvolutionBackwardData( @@ -133,8 +135,8 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn( weight_blob->dptr<T>(), this->out_desc_->Get(), out_diff_blob->dptr<T>(), this->conv_desc_->Get(), static_cast<cudnnConvolutionBwdDataAlgo_t>(this->GetConvKernelConf().cudnn_bwd_data_algo()), - bw_cudnn_buf->mut_dptr(), bw_cudnn_buf->ByteSizeOfDataContentField(), ZeroPtr<T>::value, - this->in_desc_->Get(), in_diff_blob->mut_dptr<T>())); + bw_cudnn_buf_ptr, bw_cudnn_buf_size, ZeroPtr<T>::value, this->in_desc_->Get(), + in_diff_blob->mut_dptr<T>())); } }