From b23d18f0dcad0969f05eda75b1304808dced864c Mon Sep 17 00:00:00 2001 From: Jinhui Yuan <yuan.ms2@gmail.com> Date: Thu, 2 Aug 2018 12:12:00 +0800 Subject: [PATCH] fix empty bw_cudnn_buf (#1075) --- oneflow/core/kernel/conv_kernel.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/oneflow/core/kernel/conv_kernel.cu b/oneflow/core/kernel/conv_kernel.cu index 95a56dccf..3d1cca161 100644 --- a/oneflow/core/kernel/conv_kernel.cu +++ b/oneflow/core/kernel/conv_kernel.cu @@ -119,13 +119,15 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn( Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* weight_blob = BnInOp2Blob("weight"); Blob* bw_cudnn_buf = BnInOp2Blob("bw_cudnn_buf"); + void* bw_cudnn_buf_ptr = bw_cudnn_buf ? bw_cudnn_buf->mut_dptr() : nullptr; + size_t bw_cudnn_buf_size = bw_cudnn_buf ? bw_cudnn_buf->ByteSizeOfDataContentField() : 0; CudaCheck(cudnnConvolutionBackwardFilter( device_ctx->cudnn_handle(), OnePtr<T>::value, this->in_desc_->Get(), in_blob->dptr<T>(), this->out_desc_->Get(), out_diff_blob->dptr<T>(), this->conv_desc_->Get(), static_cast<cudnnConvolutionBwdFilterAlgo_t>( this->GetConvKernelConf().cudnn_bwd_filter_algo()), - bw_cudnn_buf->mut_dptr(), bw_cudnn_buf->ByteSizeOfDataContentField(), ZeroPtr<T>::value, - this->filter_desc_->Get(), weight_diff_blob->mut_dptr<T>())); + bw_cudnn_buf_ptr, bw_cudnn_buf_size, ZeroPtr<T>::value, this->filter_desc_->Get(), + weight_diff_blob->mut_dptr<T>())); if (in_diff_blob != nullptr) { CudaCheck(cudnnConvolutionBackwardData( @@ -133,8 +135,8 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn( weight_blob->dptr<T>(), this->out_desc_->Get(), out_diff_blob->dptr<T>(), this->conv_desc_->Get(), static_cast<cudnnConvolutionBwdDataAlgo_t>(this->GetConvKernelConf().cudnn_bwd_data_algo()), - bw_cudnn_buf->mut_dptr(), bw_cudnn_buf->ByteSizeOfDataContentField(), ZeroPtr<T>::value, - this->in_desc_->Get(), in_diff_blob->mut_dptr<T>())); + bw_cudnn_buf_ptr, bw_cudnn_buf_size, ZeroPtr<T>::value, this->in_desc_->Get(), + in_diff_blob->mut_dptr<T>())); } } -- GitLab