Skip to content
Snippets Groups Projects
Unverified Commit b23d18f0 authored by Jinhui Yuan's avatar Jinhui Yuan Committed by GitHub
Browse files

fix empty bw_cudnn_buf (#1075)

parent bdceef38
No related branches found
No related tags found
No related merge requests found
...@@ -119,13 +119,15 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn( ...@@ -119,13 +119,15 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn(
Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* weight_blob = BnInOp2Blob("weight"); const Blob* weight_blob = BnInOp2Blob("weight");
Blob* bw_cudnn_buf = BnInOp2Blob("bw_cudnn_buf"); Blob* bw_cudnn_buf = BnInOp2Blob("bw_cudnn_buf");
void* bw_cudnn_buf_ptr = bw_cudnn_buf ? bw_cudnn_buf->mut_dptr() : nullptr;
size_t bw_cudnn_buf_size = bw_cudnn_buf ? bw_cudnn_buf->ByteSizeOfDataContentField() : 0;
CudaCheck(cudnnConvolutionBackwardFilter( CudaCheck(cudnnConvolutionBackwardFilter(
device_ctx->cudnn_handle(), OnePtr<T>::value, this->in_desc_->Get(), in_blob->dptr<T>(), device_ctx->cudnn_handle(), OnePtr<T>::value, this->in_desc_->Get(), in_blob->dptr<T>(),
this->out_desc_->Get(), out_diff_blob->dptr<T>(), this->conv_desc_->Get(), this->out_desc_->Get(), out_diff_blob->dptr<T>(), this->conv_desc_->Get(),
static_cast<cudnnConvolutionBwdFilterAlgo_t>( static_cast<cudnnConvolutionBwdFilterAlgo_t>(
this->GetConvKernelConf().cudnn_bwd_filter_algo()), this->GetConvKernelConf().cudnn_bwd_filter_algo()),
bw_cudnn_buf->mut_dptr(), bw_cudnn_buf->ByteSizeOfDataContentField(), ZeroPtr<T>::value, bw_cudnn_buf_ptr, bw_cudnn_buf_size, ZeroPtr<T>::value, this->filter_desc_->Get(),
this->filter_desc_->Get(), weight_diff_blob->mut_dptr<T>())); weight_diff_blob->mut_dptr<T>()));
if (in_diff_blob != nullptr) { if (in_diff_blob != nullptr) {
CudaCheck(cudnnConvolutionBackwardData( CudaCheck(cudnnConvolutionBackwardData(
...@@ -133,8 +135,8 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn( ...@@ -133,8 +135,8 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn(
weight_blob->dptr<T>(), this->out_desc_->Get(), out_diff_blob->dptr<T>(), weight_blob->dptr<T>(), this->out_desc_->Get(), out_diff_blob->dptr<T>(),
this->conv_desc_->Get(), this->conv_desc_->Get(),
static_cast<cudnnConvolutionBwdDataAlgo_t>(this->GetConvKernelConf().cudnn_bwd_data_algo()), static_cast<cudnnConvolutionBwdDataAlgo_t>(this->GetConvKernelConf().cudnn_bwd_data_algo()),
bw_cudnn_buf->mut_dptr(), bw_cudnn_buf->ByteSizeOfDataContentField(), ZeroPtr<T>::value, bw_cudnn_buf_ptr, bw_cudnn_buf_size, ZeroPtr<T>::value, this->in_desc_->Get(),
this->in_desc_->Get(), in_diff_blob->mut_dptr<T>())); in_diff_blob->mut_dptr<T>()));
} }
} }
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment