From c37985d7a33f28caa1e72c7c6e4125f2f636101f Mon Sep 17 00:00:00 2001 From: Niu Chong <niuchong_dnm@163.com> Date: Wed, 13 Jun 2018 07:44:12 +0800 Subject: [PATCH] Conv GPU Kernel Without Cudnn (#916) * feat: add UseCudnnOnGpu in Operator and fix conv op * feat: add WithCudnn and WithoutCudnn in ConvKernel<kGPU, T> * feat: add CUDA NCDHWIm2ColGpu kernel and compile done * refine: rename Im2ColNCDHWGpu() to NCDHWIm2ColGpu() * fix: reverse update about int32_t to int64_t in BlocksNum4ThreadNum() * feat: add CUDA NCDHWCol2ImGpu kernel * refactor: extract InitSharedArrays() for device code * feat: add CUDA NDHWCIm2ColGpu kernel * feat: add CUDA NDHWCCol2ImGpu kernel and fix typos * fix: fix the bug of calc im_offset when NDHWCIm2ColGpu * refactor: extract Im2ColCalcKernelAndOutIndex() and Im2ColCalcImIndex() * fix: fix format and the missing shared_im[] parameter in Im2ColCalcImIndex() * refactor: merge NCDHWCol2ImGpu() and NDHWCCol2ImGpu() into Col2ImGpu() * refactor: merge NCDHWIm2ColGpu() and NDHWCIm2ColGpu() into Im2ColGpu() * feat: add class ConvKernelImplByIm2Col between ConvKernelIf and ConvKernel; compile done, to be run * fix: add explicit template instantiation for ConvKernelUtil * refine: remove unused class function declaration: KernelInitWithoutCudnn e.g. * fix(operator/conv_op.cpp): make sure UseCudnnOnGpu() == true when infer cudnn algo * refine(kernel/conv_kernel.cu): let the gpu kernel function be inside the anoymous namespace * refactor: add dim_num as the template paramter of Im2ColGpu() and Col2ImGpu() * refactor: add is_channel_first as the template paramter of Im2ColGpu() and Col2ImGpu() * refine(kernel/conv_kernel.cu): add #undef IM2COL_FUNC_CALL * refine(kernel/conv_kernel.cu): add dim_num as the template parameter of InitSharedMemory() * fix(kernel/conv_kernel.cu): fix the bug of use col_offset in Im2ColGpu() --- oneflow/core/kernel/conv_kernel.cpp | 122 ++++----- oneflow/core/kernel/conv_kernel.cu | 393 +++++++++++++++++++++++++++- oneflow/core/kernel/conv_kernel.h | 135 +++++++--- oneflow/core/kernel/kernel.h | 3 +- oneflow/core/operator/conv_op.cpp | 8 +- oneflow/core/operator/operator.cpp | 4 - oneflow/core/operator/operator.h | 3 +- 7 files changed, 557 insertions(+), 111 deletions(-) diff --git a/oneflow/core/kernel/conv_kernel.cpp b/oneflow/core/kernel/conv_kernel.cpp index 89b41a2a3..e78a961e5 100644 --- a/oneflow/core/kernel/conv_kernel.cpp +++ b/oneflow/core/kernel/conv_kernel.cpp @@ -3,20 +3,6 @@ namespace oneflow { -namespace { - -template<typename T> -const T* GetImgDptr(const Blob* blob, int64_t idx) { - return blob->dptr<T>() + blob->shape().Count(1) * idx; -} - -template<typename T> -T* GetImgMutDptr(Blob* blob, int64_t idx) { - return const_cast<T*>(GetImgDptr<T>(blob, idx)); -} - -} // namespace - template<DeviceType device_type, typename T> void ConvKernelIf<device_type, T>::ForwardDataContent( const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { @@ -40,7 +26,7 @@ void ConvKernelIf<device_type, T>::BackwardDataContent( template<DeviceType device_type, typename T> void ConvKernelIf<device_type, T>::InitConstBufBlobs( DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const { - if (this->template GetValFromCustomizedOpConf<bool>("use_bias") && !this->UseCudnn()) { + if (this->template GetValFromCustomizedOpConf<bool>("use_bias") && !this->UseCudnnOnGpu()) { InitializerConf bias_multiplier_initializer_conf; bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f); KernelUtil<device_type, T>::InitializeWithConf(ctx, bias_multiplier_initializer_conf, 0, @@ -98,20 +84,21 @@ const int32_t ConvKernelIf<device_type, T>::OpKernelDim() const { return this->GetConvKernelConf().dim(); } -template<typename T> -void ConvKernel<DeviceType::kCPU, T>::VirtualKernelInit(const ParallelContext* parallel_ctx) { +template<DeviceType device_type, typename T> +void ConvKernelImplByIm2Col<device_type, T>::VirtualKernelInit( + const ParallelContext* parallel_ctx) { const std::string& data_format = this->template GetValFromCustomizedOpConf<std::string>("data_format"); if (data_format == "channels_first") { - im2col_func_ = ConvKernelUtil<T>::NCDHWIm2Col; - col2im_func_ = ConvKernelUtil<T>::NCDHWCol2Im; - forward_func_ = KernelUtil<DeviceType::kCPU, T>::OFGemm; + im2col_func_ = ConvKernelUtil<device_type, T>::NCDHWIm2Col; + col2im_func_ = ConvKernelUtil<device_type, T>::NCDHWCol2Im; + forward_func_ = KernelUtil<device_type, T>::OFGemm; dhw_offset_ = 2; is_out_diff_need_trans_ = CblasNoTrans; } else { - im2col_func_ = ConvKernelUtil<T>::NDHWCIm2Col; - col2im_func_ = ConvKernelUtil<T>::NDHWCCol2Im; - forward_func_ = KernelUtil<DeviceType::kCPU, T>::OFGemmTrans; + im2col_func_ = ConvKernelUtil<device_type, T>::NDHWCIm2Col; + col2im_func_ = ConvKernelUtil<device_type, T>::NDHWCCol2Im; + forward_func_ = KernelUtil<device_type, T>::OFGemmTrans; dhw_offset_ = 1; is_out_diff_need_trans_ = CblasTrans; } @@ -123,13 +110,14 @@ void ConvKernel<DeviceType::kCPU, T>::VirtualKernelInit(const ParallelContext* p padding_before_ = this->GetConvKernelConf().pad_small_side().data(); } -template<typename T> -void ConvKernel<DeviceType::kCPU, T>::DoForwardDataContent( +template<DeviceType device_type, typename T> +void ConvKernelImplByIm2Col<device_type, T>::DoForwardDataContent( DeviceCtx* device_ctx, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { FOR_RANGE(int64_t, i, 0, in_shape_.At(0)) { - im2col_func_(device_ctx, GetImgDptr<T>(in_blob, i), in_shape_, weight_shape_, out_shape_, - strides_, dilation_rate_, padding_before_, static_cast<T*>(device_ctx->buf_ptr())); + im2col_func_(this->OpKernelDim(), device_ctx, GetImgDptr<T>(in_blob, i), in_shape_, + weight_shape_, out_shape_, strides_, dilation_rate_, padding_before_, + static_cast<T*>(device_ctx->buf_ptr())); // col_buf is device_ctx->buf_ptr() // channels first: out = weight * col_buf @@ -157,24 +145,25 @@ void ConvKernel<DeviceType::kCPU, T>::DoForwardDataContent( } } -template<typename T> -void ConvKernel<DeviceType::kCPU, T>::WeightBackward( +template<DeviceType device_type, typename T> +void ConvKernelImplByIm2Col<device_type, T>::WeightBackward( DeviceCtx* ctx, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob, Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* weight_blob = BnInOp2Blob("weight"); - Memset<DeviceType::kCPU>(ctx, weight_diff_blob->mut_dptr<T>(), 0, - weight_diff_blob->ByteSizeOfDataContentField()); + Memset<device_type>(ctx, weight_diff_blob->mut_dptr<T>(), 0, + weight_diff_blob->ByteSizeOfDataContentField()); if (in_diff_blob != nullptr) { - Memset<DeviceType::kCPU>(ctx, in_diff_blob->mut_dptr<T>(), 0, - in_diff_blob->ByteSizeOfDataContentField()); + Memset<device_type>(ctx, in_diff_blob->mut_dptr<T>(), 0, + in_diff_blob->ByteSizeOfDataContentField()); } FOR_RANGE(int64_t, i, 0, out_shape_.At(0)) { - im2col_func_(ctx, GetImgDptr<T>(in_blob, i), in_shape_, weight_shape_, out_shape_, strides_, - dilation_rate_, padding_before_, static_cast<T*>(ctx->buf_ptr())); + im2col_func_(this->OpKernelDim(), ctx, GetImgDptr<T>(in_blob, i), in_shape_, weight_shape_, + out_shape_, strides_, dilation_rate_, padding_before_, + static_cast<T*>(ctx->buf_ptr())); // channels first: weight' += out[i]' * col_buf(T) // channels last : weight' += out[i]'(T) * col_buf(T) - KernelUtil<DeviceType::kCPU, T>::OFGemm( + KernelUtil<device_type, T>::OFGemm( ctx, is_out_diff_need_trans_, CblasTrans, weight_shape_.At(0), // filter weight_shape_.Count(1), // ci * kd * kh * kw @@ -185,7 +174,7 @@ void ConvKernel<DeviceType::kCPU, T>::WeightBackward( if (in_diff_blob != nullptr) { // channels first: col_buf' = weight(T) * out[i]' // channels last : col_buf' = weight(T) * out[i]'(T) - KernelUtil<DeviceType::kCPU, T>::OFGemm( + KernelUtil<device_type, T>::OFGemm( ctx, CblasTrans, is_out_diff_need_trans_, weight_shape_.Count(1), // ci * kd * kh * kw out_shape_.Count(dhw_offset_, dhw_offset_ + 3), // od * oh * ow @@ -194,24 +183,25 @@ void ConvKernel<DeviceType::kCPU, T>::WeightBackward( static_cast<T>(0), static_cast<T*>(ctx->buf_ptr())); // in' = col2im(col_buf') - col2im_func_(ctx, static_cast<const T*>(ctx->buf_ptr()), in_shape_, weight_shape_, out_shape_, - strides_, dilation_rate_, padding_before_, GetImgMutDptr<T>(in_diff_blob, i)); + col2im_func_(this->OpKernelDim(), ctx, static_cast<const T*>(ctx->buf_ptr()), in_shape_, + weight_shape_, out_shape_, strides_, dilation_rate_, padding_before_, + GetImgMutDptr<T>(in_diff_blob, i)); } } } -template<typename T> -void ConvKernel<DeviceType::kCPU, T>::BiasBackward( +template<DeviceType device_type, typename T> +void ConvKernelImplByIm2Col<device_type, T>::BiasBackward( DeviceCtx* ctx, const Blob* out_diff_blob, Blob* bias_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier"); - Memset<DeviceType::kCPU>(ctx, bias_diff_blob->mut_dptr<T>(), 0, - bias_diff_blob->ByteSizeOfDataContentField()); + Memset<device_type>(ctx, bias_diff_blob->mut_dptr<T>(), 0, + bias_diff_blob->ByteSizeOfDataContentField()); FOR_RANGE(int64_t, i, 0, out_shape_.At(0)) { // channels first: bias' += out' * bias_mul // channels last: bias' += out'(T) * bias_mul - KernelUtil<DeviceType::kCPU, T>::OFGemm( + KernelUtil<device_type, T>::OFGemm( ctx, is_out_diff_need_trans_, CblasNoTrans, weight_shape_.At(0), // filter 1, // 1 @@ -364,8 +354,9 @@ void ColBufUtil<T>::operator()(ColBufWriter<T>* col_buf_writer, int64_t c, int64 } template<typename T> -void ConvKernelUtil<T>::DoNCDWHFunc(const Shape& weight_shape, ColBufUtil<T>& col_buf_util, - ColBufWriter<T>* col_buf_writer) { +void ConvKernelUtil<DeviceType::kCPU, T>::DoNCDWHFunc(const Shape& weight_shape, + ColBufUtil<T>& col_buf_util, + ColBufWriter<T>* col_buf_writer) { for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) { for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) { @@ -378,10 +369,10 @@ void ConvKernelUtil<T>::DoNCDWHFunc(const Shape& weight_shape, ColBufUtil<T>& co } template<typename T> -void ConvKernelUtil<T>::NCDHWIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape, - const Shape& weight_shape, const Shape& out_shape, - const int32_t* strides, const int32_t* dilation_rate, - const int32_t* padding_before, T* col_buf_ptr) { +void ConvKernelUtil<DeviceType::kCPU, T>::NCDHWIm2Col( + const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape, + const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) { ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before); Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1); @@ -389,11 +380,10 @@ void ConvKernelUtil<T>::NCDHWIm2Col(DeviceCtx* device_ctx, const T* in_dptr, con } template<typename T> -void ConvKernelUtil<T>::NCDHWCol2Im(DeviceCtx* device_ctx, const T* col_buf_ptr, - const Shape& in_shape, const Shape& weight_shape, - const Shape& out_shape, const int32_t* strides, - const int32_t* dilation_rate, const int32_t* padding_before, - T* in_diff_ptr) { +void ConvKernelUtil<DeviceType::kCPU, T>::NCDHWCol2Im( + const int dim_num, DeviceCtx* device_ctx, const T* col_buf_ptr, const Shape& in_shape, + const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before); Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1); @@ -401,8 +391,9 @@ void ConvKernelUtil<T>::NCDHWCol2Im(DeviceCtx* device_ctx, const T* col_buf_ptr, } template<typename T> -void ConvKernelUtil<T>::DoNDWHCFunc(const Shape& weight_shape, ColBufUtil<T>& col_buf_util, - ColBufWriter<T>* col_buf_writer) { +void ConvKernelUtil<DeviceType::kCPU, T>::DoNDWHCFunc(const Shape& weight_shape, + ColBufUtil<T>& col_buf_util, + ColBufWriter<T>* col_buf_writer) { for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) { for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) { for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) { @@ -415,10 +406,10 @@ void ConvKernelUtil<T>::DoNDWHCFunc(const Shape& weight_shape, ColBufUtil<T>& co } template<typename T> -void ConvKernelUtil<T>::NDHWCIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape, - const Shape& weight_shape, const Shape& out_shape, - const int32_t* strides, const int32_t* dilation_rate, - const int32_t* padding_before, T* col_buf_ptr) { +void ConvKernelUtil<DeviceType::kCPU, T>::NDHWCIm2Col( + const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape, + const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) { ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before); Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4), @@ -427,11 +418,10 @@ void ConvKernelUtil<T>::NDHWCIm2Col(DeviceCtx* device_ctx, const T* in_dptr, con } template<typename T> -void ConvKernelUtil<T>::NDHWCCol2Im(DeviceCtx* device_ctx, const T* col_buf_ptr, - const Shape& in_shape, const Shape& weight_shape, - const Shape& out_shape, const int32_t* strides, - const int32_t* dilation_rate, const int32_t* padding_before, - T* in_diff_ptr) { +void ConvKernelUtil<DeviceType::kCPU, T>::NDHWCCol2Im( + const int dim_num, DeviceCtx* device_ctx, const T* col_buf_ptr, const Shape& in_shape, + const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) { ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before); Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2), in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4), diff --git a/oneflow/core/kernel/conv_kernel.cu b/oneflow/core/kernel/conv_kernel.cu index e0243a126..2fc397ad6 100644 --- a/oneflow/core/kernel/conv_kernel.cu +++ b/oneflow/core/kernel/conv_kernel.cu @@ -5,6 +5,52 @@ namespace oneflow { template<typename T> void ConvKernel<DeviceType::kGPU, T>::VirtualKernelInit(const ParallelContext* parallel_ctx) { + if (this->UseCudnnOnGpu()) { + KernelInitWithCudnn(parallel_ctx); + } else { + ConvKernelImplByIm2Col<DeviceType::kGPU, T>::VirtualKernelInit(parallel_ctx); + } +} + +template<typename T> +void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContent( + DeviceCtx* device_ctx, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { + if (this->UseCudnnOnGpu()) { + DoForwardDataContentWithCudnn(device_ctx, in_blob, weight_blob, out_blob, BnInOp2Blob); + } else { + ConvKernelImplByIm2Col<DeviceType::kGPU, T>::DoForwardDataContent( + device_ctx, in_blob, weight_blob, out_blob, BnInOp2Blob); + } +} + +template<typename T> +void ConvKernel<DeviceType::kGPU, T>::WeightBackward( + DeviceCtx* device_ctx, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob, + Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { + if (this->UseCudnnOnGpu()) { + WeightBackwardWithCudnn(device_ctx, out_diff_blob, in_blob, weight_diff_blob, in_diff_blob, + BnInOp2Blob); + } else { + ConvKernelImplByIm2Col<DeviceType::kGPU, T>::WeightBackward( + device_ctx, out_diff_blob, in_blob, weight_diff_blob, in_diff_blob, BnInOp2Blob); + } +} + +template<typename T> +void ConvKernel<DeviceType::kGPU, T>::BiasBackward( + DeviceCtx* device_ctx, const Blob* out_diff_blob, Blob* bias_diff_blob, + std::function<Blob*(const std::string&)> BnInOp2Blob) const { + if (this->UseCudnnOnGpu()) { + BiasBackwardWithCudnn(device_ctx, out_diff_blob, bias_diff_blob, BnInOp2Blob); + } else { + ConvKernelImplByIm2Col<DeviceType::kGPU, T>::BiasBackward(device_ctx, out_diff_blob, + bias_diff_blob, BnInOp2Blob); + } +} + +template<typename T> +void ConvKernel<DeviceType::kGPU, T>::KernelInitWithCudnn(const ParallelContext* parallel_ctx) { Shape in_shape(this->GetConvKernelConf().in()); Shape out_shape(this->GetConvKernelConf().out()); Shape weight_shape(this->GetConvKernelConf().weight()); @@ -48,7 +94,7 @@ void ConvKernel<DeviceType::kGPU, T>::VirtualKernelInit(const ParallelContext* p } template<typename T> -void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContent( +void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContentWithCudnn( DeviceCtx* device_ctx, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { CudaCheck(cudnnConvolutionForward( @@ -67,7 +113,7 @@ void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContent( } template<typename T> -void ConvKernel<DeviceType::kGPU, T>::WeightBackward( +void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn( DeviceCtx* device_ctx, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob, Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { const Blob* weight_blob = BnInOp2Blob("weight"); @@ -91,7 +137,7 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackward( } template<typename T> -void ConvKernel<DeviceType::kGPU, T>::BiasBackward( +void ConvKernel<DeviceType::kGPU, T>::BiasBackwardWithCudnn( DeviceCtx* device_ctx, const Blob* out_diff_blob, Blob* bias_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const { CudaCheck(cudnnConvolutionBackwardBias(device_ctx->cudnn_handle(), OnePtr<T>::value, @@ -100,8 +146,349 @@ void ConvKernel<DeviceType::kGPU, T>::BiasBackward( bias_diff_blob->mut_dptr<T>())); } +namespace { + +template<int dim_num> +__device__ void InitSharedArrays(const int im_d, const int im_h, const int im_w, const int kernel_d, + const int kernel_h, const int kernel_w, const int out_d, + const int out_h, const int out_w, const int stride_d, + const int stride_h, const int stride_w, const int dilation_d, + const int dilation_h, const int dilation_w, const int pad_d, + const int pad_h, const int pad_w, int* shared_im, + int* shared_kernel, int* shared_out, int* shared_stride, + int* shared_dilation, int* shared_pad) { + if (threadIdx.x == 0) { + if (dim_num == 3) { + shared_im[0] = im_d; + shared_im[1] = im_h; + shared_im[2] = im_w; + shared_kernel[0] = kernel_d; + shared_kernel[1] = kernel_h; + shared_kernel[2] = kernel_w; + shared_out[0] = out_d; + shared_out[1] = out_h; + shared_out[2] = out_w; + shared_stride[0] = stride_d; + shared_stride[1] = stride_h; + shared_stride[2] = stride_w; + shared_dilation[0] = dilation_d; + shared_dilation[1] = dilation_h; + shared_dilation[2] = dilation_w; + shared_pad[0] = pad_d; + shared_pad[1] = pad_h; + shared_pad[2] = pad_w; + } else if (dim_num == 2) { + shared_im[0] = im_h; + shared_im[1] = im_w; + shared_kernel[0] = kernel_h; + shared_kernel[1] = kernel_w; + shared_out[0] = out_h; + shared_out[1] = out_w; + shared_stride[0] = stride_h; + shared_stride[1] = stride_w; + shared_dilation[0] = dilation_h; + shared_dilation[1] = dilation_w; + shared_pad[0] = pad_h; + shared_pad[1] = pad_w; + } else if (dim_num == 1) { + shared_im[0] = im_w; + shared_kernel[0] = kernel_w; + shared_out[0] = out_w; + shared_stride[0] = stride_w; + shared_dilation[0] = dilation_w; + shared_pad[0] = pad_w; + } + } + __syncthreads(); +} + +template<typename T, int dim_num, bool is_channel_first> +__global__ void Im2ColGpu(const int n, const T* im_dptr, const int channel, const int im_d, + const int im_h, const int im_w, const int kernel_d, const int kernel_h, + const int kernel_w, const int out_d, const int out_h, const int out_w, + const int stride_d, const int stride_h, const int stride_w, + const int dilation_rate_d, const int dilation_rate_h, + const int dilation_rate_w, const int padding_before_d, + const int padding_before_h, const int padding_before_w, T* col_buf_dptr) { + __shared__ int shared_im[dim_num]; + __shared__ int shared_kernel[dim_num]; + __shared__ int shared_out[dim_num]; + __shared__ int shared_stride[dim_num]; + __shared__ int shared_dilation[dim_num]; + __shared__ int shared_pad[dim_num]; + InitSharedArrays<dim_num>(im_d, im_h, im_w, kernel_d, kernel_h, kernel_w, out_d, out_h, out_w, + stride_d, stride_h, stride_w, dilation_rate_d, dilation_rate_h, + dilation_rate_w, padding_before_d, padding_before_h, padding_before_w, + shared_im, shared_kernel, shared_out, shared_stride, shared_dilation, + shared_pad); + + int out_size = 1; + for (int i = 0; i < dim_num; ++i) { out_size *= shared_out[i]; } + int kernel_index[dim_num]; + int out_index[dim_num]; + int channel_index; + int im_index[dim_num]; + CUDA_1D_KERNEL_LOOP(index, n) { + // total launch channel*od*oh*ow threads, + // each thread is responsible for a whole kernel size copy + // calc kernel_/out_/channel_index + channel_index = index / out_size; + int col_offset = index % out_size; // col_dim of col_buf: od*oh*ow + for (int i = dim_num - 1; i >= 0; --i) { + out_index[i] = col_offset % shared_out[i]; + col_offset /= shared_out[i]; + kernel_index[i] = 0; + } + + int col_buf_offset = 0; + if (is_channel_first) { col_buf_offset = channel_index; } + for (int i = 0; i < 3; ++i) { + col_buf_offset *= shared_kernel[i]; + // col_buf_offset += kernel_index[i]; commented for kernel_index[] == 0 + } + if (is_channel_first == false) { + col_buf_offset *= channel; + col_buf_offset += channel_index; + } + col_buf_offset *= out_size; + col_buf_offset += index % out_size; + + while (true) { + // calc im_index + bool is_im_index_valid = true; + for (int i = 0; i < dim_num; ++i) { + im_index[i] = + kernel_index[i] * shared_dilation[i] - shared_pad[i] + out_index[i] * shared_stride[i]; + if (im_index[i] < 0 || im_index[i] >= shared_im[i]) { + is_im_index_valid = false; + break; + } + } + + // write into col_buf + if (is_im_index_valid) { + // calc im_offset + int im_offset = 0; + if (is_channel_first) { im_offset = channel_index; } + for (int i = 0; i < dim_num; ++i) { + im_offset *= shared_im[i]; + im_offset += im_index[i]; + } + if (is_channel_first == false) { + im_offset *= channel; + im_offset += channel_index; + } + col_buf_dptr[col_buf_offset] = im_dptr[im_offset]; + } else { + col_buf_dptr[col_buf_offset] = 0; + } + col_buf_offset += out_size; + + // loop over all kernel index + bool is_loop_completed = true; + for (int i = dim_num - 1; i >= 0; --i) { + if (kernel_index[i] == shared_kernel[i] - 1) { + kernel_index[i] = 0; + } else { + kernel_index[i] += 1; + is_loop_completed = false; + break; + } + } + if (is_loop_completed) { break; } + } + } +} + +template<typename T, int dim_num, bool is_channel_first> +__global__ void Col2ImGpu(const int n, const T* col_buf_dptr, const int channel, const int im_d, + const int im_h, const int im_w, const int kernel_d, const int kernel_h, + const int kernel_w, const int out_d, const int out_h, const int out_w, + const int stride_d, const int stride_h, const int stride_w, + const int dilation_rate_d, const int dilation_rate_h, + const int dilation_rate_w, const int padding_before_d, + const int padding_before_h, const int padding_before_w, T* im_diff_dptr) { + __shared__ int shared_im[dim_num]; + __shared__ int shared_kernel[dim_num]; + __shared__ int shared_out[dim_num]; + __shared__ int shared_stride[dim_num]; + __shared__ int shared_dilation[dim_num]; + __shared__ int shared_pad[dim_num]; + InitSharedArrays<dim_num>(im_d, im_h, im_w, kernel_d, kernel_h, kernel_w, out_d, out_h, out_w, + stride_d, stride_h, stride_w, dilation_rate_d, dilation_rate_h, + dilation_rate_w, padding_before_d, padding_before_h, padding_before_w, + shared_im, shared_kernel, shared_out, shared_stride, shared_dilation, + shared_pad); + + int kernel_index[dim_num]; + int channel_index; + int im_index[dim_num]; + int out_begin[dim_num]; + int out_end[dim_num]; + int out_index[dim_num]; + CUDA_1D_KERNEL_LOOP(index, n) { + // calc im_/channel_index + int im_offset = index; + if (is_channel_first == false) { + channel_index = im_offset % channel; + im_offset /= channel; + } + for (int i = dim_num - 1; i >= 0; --i) { + im_index[i] = im_offset % shared_im[i] + shared_pad[i]; + im_offset /= shared_im[i]; + } + if (is_channel_first) { channel_index = im_offset; } + + // calc the out_range of this im element + bool is_in_dim_wrong = false; + for (int i = 0; i < dim_num; ++i) { + const int kernel_extent = shared_dilation[i] * (shared_kernel[i] - 1) + 1; + if (im_index[i] < kernel_extent) { + out_begin[i] = 0; + } else { + // original equation: ((im_index[i]-kernel_extent+1)+(stride[i]-1))/stride[i] + out_begin[i] = (im_index[i] - kernel_extent) / shared_stride[i] + 1; + } + out_end[i] = min(im_index[i] / shared_stride[i] + 1, shared_out[i]); + out_index[i] = out_begin[i]; + + if (out_begin[i] >= out_end[i]) { // for those im elements not chosen by kernel + is_in_dim_wrong = true; + break; + } + } + if (is_in_dim_wrong) { + im_diff_dptr[index] = 0; + continue; + } + + T val = 0; + while (true) { + bool is_skip = false; + // calc kernel_index + for (int i = 0; i < dim_num; ++i) { + kernel_index[i] = im_index[i] - out_index[i] * shared_stride[i]; + if (kernel_index[i] % shared_dilation[i] == 0) { + kernel_index[i] /= shared_dilation[i]; + } else { + is_skip = true; + break; + } + } + + // cal col_buf_offset + if (is_skip == false) { + int col_buf_offset = 0; + if (is_channel_first) { col_buf_offset = channel_index; } + for (int i = 0; i < dim_num; ++i) { + col_buf_offset *= shared_kernel[i]; + col_buf_offset += kernel_index[i]; + } + if (is_channel_first == false) { + col_buf_offset *= channel; + col_buf_offset += channel_index; + } + for (int i = 0; i < dim_num; ++i) { + col_buf_offset *= shared_out[i]; + col_buf_offset += out_index[i]; + } + val += col_buf_dptr[col_buf_offset]; + } + + // iter next out_index[] + bool is_iter_completed = true; + for (int i = dim_num - 1; i >= 0; --i) { + if (out_index[i] == out_end[i] - 1) { + out_index[i] = out_begin[i]; + } else { + out_index[i] += 1; + is_iter_completed = false; + break; + } + } + if (is_iter_completed) { break; } + } + im_diff_dptr[index] = val; + } +} + +} // namespace + +#define IM2COL_KERNEL_CALL(kernel_func_name, dim_num, is_channel_first, kernel_num, src_dptr, \ + dst_dptr) \ + kernel_func_name<T, dim_num, is_channel_first> \ + <<<BlocksNum4ThreadsNum(kernel_num), kCudaThreadsNumPerBlock, 0, \ + device_ctx->cuda_stream()>>>( \ + kernel_num, src_dptr, in_shape.At(1), in_shape.At(2), in_shape.At(3), in_shape.At(4), \ + weight_shape.At(2), weight_shape.At(3), weight_shape.At(4), out_shape.At(2), \ + out_shape.At(3), out_shape.At(4), strides[0], strides[1], strides[2], dilation_rate[0], \ + dilation_rate[1], dilation_rate[2], padding_before[0], padding_before[1], \ + padding_before[2], dst_dptr) + +template<typename T> +void ConvKernelUtil<DeviceType::kGPU, T>::NCDHWIm2Col( + const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape, + const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_dptr) { + int32_t kernels = weight_shape.At(1) * out_shape.Count(2); + switch (dim_num) { + case 1: IM2COL_KERNEL_CALL(Im2ColGpu, 1, true, kernels, in_dptr, col_buf_dptr); break; + case 2: IM2COL_KERNEL_CALL(Im2ColGpu, 2, true, kernels, in_dptr, col_buf_dptr); break; + case 3: IM2COL_KERNEL_CALL(Im2ColGpu, 3, true, kernels, in_dptr, col_buf_dptr); break; + default: UNIMPLEMENTED(); + } +} + +template<typename T> +void ConvKernelUtil<DeviceType::kGPU, T>::NDHWCIm2Col( + const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape, + const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_dptr) { + int32_t kernels = weight_shape.At(1) * out_shape.Count(2); + switch (dim_num) { + case 1: IM2COL_KERNEL_CALL(Im2ColGpu, 1, false, kernels, in_dptr, col_buf_dptr); break; + case 2: IM2COL_KERNEL_CALL(Im2ColGpu, 2, false, kernels, in_dptr, col_buf_dptr); break; + case 3: IM2COL_KERNEL_CALL(Im2ColGpu, 3, false, kernels, in_dptr, col_buf_dptr); break; + default: UNIMPLEMENTED(); + } +} + +template<typename T> +void ConvKernelUtil<DeviceType::kGPU, T>::NCDHWCol2Im( + const int dim_num, DeviceCtx* device_ctx, const T* col_buf_dptr, const Shape& in_shape, + const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_dptr) { + int32_t im_size = in_shape.Count(1); + switch (dim_num) { + case 1: IM2COL_KERNEL_CALL(Col2ImGpu, 1, true, im_size, col_buf_dptr, in_diff_dptr); break; + case 2: IM2COL_KERNEL_CALL(Col2ImGpu, 2, true, im_size, col_buf_dptr, in_diff_dptr); break; + case 3: IM2COL_KERNEL_CALL(Col2ImGpu, 3, true, im_size, col_buf_dptr, in_diff_dptr); break; + default: UNIMPLEMENTED(); + } +} + +template<typename T> +void ConvKernelUtil<DeviceType::kGPU, T>::NDHWCCol2Im( + const int dim_num, DeviceCtx* device_ctx, const T* col_buf_dptr, const Shape& in_shape, + const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_dptr) { + int32_t im_size = in_shape.Count(1); + switch (dim_num) { + case 1: IM2COL_KERNEL_CALL(Col2ImGpu, 1, false, im_size, col_buf_dptr, in_diff_dptr); break; + case 2: IM2COL_KERNEL_CALL(Col2ImGpu, 2, false, im_size, col_buf_dptr, in_diff_dptr); break; + case 3: IM2COL_KERNEL_CALL(Col2ImGpu, 3, false, im_size, col_buf_dptr, in_diff_dptr); break; + default: UNIMPLEMENTED(); + } +} + +#undef IM2COL_KERNEL_CALL + #define INSTANTIATE_CONV_KERNEL(type_cpp, type_proto) \ template class ConvKernel<DeviceType::kGPU, type_cpp>; OF_PP_FOR_EACH_TUPLE(INSTANTIATE_CONV_KERNEL, FLOATING_DATA_TYPE_SEQ) +#define INSTANTIATE_CONV_KERNEL_UTIL(type_cpp, type_proto) \ + template class ConvKernelUtil<DeviceType::kGPU, type_cpp>; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_CONV_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ) + } // namespace oneflow diff --git a/oneflow/core/kernel/conv_kernel.h b/oneflow/core/kernel/conv_kernel.h index 2aaf5ceb1..271290292 100644 --- a/oneflow/core/kernel/conv_kernel.h +++ b/oneflow/core/kernel/conv_kernel.h @@ -7,6 +7,20 @@ namespace oneflow { +namespace { + +template<typename T> +const T* GetImgDptr(const Blob* blob, int64_t idx) { + return blob->dptr<T>() + blob->shape().Count(1) * idx; +} + +template<typename T> +T* GetImgMutDptr(Blob* blob, int64_t idx) { + return const_cast<T*>(GetImgDptr<T>(blob, idx)); +} + +} // namespace + template<DeviceType device_type, typename T> class ConvKernelIf : public KernelIfWithActivation<device_type, T>, public KernelIfWithModel<device_type, T> { @@ -44,16 +58,18 @@ class ConvKernelIf : public KernelIfWithActivation<device_type, T>, }; template<typename T> -using Im2ColFunc = void (*)(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape, - const Shape& weight_shape, const Shape& out_shape, - const int32_t* strides, const int32_t* dilation_rate, - const int32_t* padding_before, T* col_buf); +using Im2ColFunc = void (*)(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, + const Shape& in_shape, const Shape& weight_shape, + const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, + T* col_buf); template<typename T> -using Col2ImFunc = void (*)(DeviceCtx* device_ctx, const T* col_buf, const Shape& in_shape, - const Shape& weight_shape, const Shape& out_shape, - const int32_t* strides, const int32_t* dilation_rate, - const int32_t* padding_before, T* in_diff_ptr); +using Col2ImFunc = void (*)(const int dim_num, DeviceCtx* device_ctx, const T* col_buf, + const Shape& in_shape, const Shape& weight_shape, + const Shape& out_shape, const int32_t* strides, + const int32_t* dilation_rate, const int32_t* padding_before, + T* in_diff_ptr); template<typename T> using GemmFunc = void (*)(DeviceCtx* ctx, enum CBLAS_TRANSPOSE, enum CBLAS_TRANSPOSE, const int m, @@ -61,16 +77,13 @@ using GemmFunc = void (*)(DeviceCtx* ctx, enum CBLAS_TRANSPOSE, enum CBLAS_TRANS const T beta, T* c); template<DeviceType device_type, typename T> -class ConvKernel; - -template<typename T> -class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelIf<DeviceType::kCPU, T> { +class ConvKernelImplByIm2Col : public ConvKernelIf<device_type, T> { public: - OF_DISALLOW_COPY_AND_MOVE(ConvKernel); - ConvKernel() = default; - ~ConvKernel() = default; + OF_DISALLOW_COPY_AND_MOVE(ConvKernelImplByIm2Col); + ConvKernelImplByIm2Col() = default; + ~ConvKernelImplByIm2Col() = default; - private: + protected: void VirtualKernelInit(const ParallelContext*) override; void DoForwardDataContent(DeviceCtx*, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob, @@ -80,6 +93,8 @@ class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelIf<DeviceType::kC std::function<Blob*(const std::string&)> BnInOp2Blob) const override; void BiasBackward(DeviceCtx*, const Blob* out_diff_blob, Blob* bias_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + + private: Im2ColFunc<T> im2col_func_; Col2ImFunc<T> col2im_func_; GemmFunc<T> forward_func_; @@ -93,8 +108,21 @@ class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelIf<DeviceType::kC Shape weight_shape_; }; +template<DeviceType device_type, typename T> +class ConvKernel; + +template<typename T> +class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelImplByIm2Col<DeviceType::kCPU, T> { + public: + OF_DISALLOW_COPY_AND_MOVE(ConvKernel); + ConvKernel() = default; + ~ConvKernel() = default; + + private: +}; + template<typename T> -class ConvKernel<DeviceType::kGPU, T> final : public ConvKernelIf<DeviceType::kGPU, T> { +class ConvKernel<DeviceType::kGPU, T> final : public ConvKernelImplByIm2Col<DeviceType::kGPU, T> { public: OF_DISALLOW_COPY_AND_MOVE(ConvKernel); ConvKernel() = default; @@ -102,14 +130,26 @@ class ConvKernel<DeviceType::kGPU, T> final : public ConvKernelIf<DeviceType::kG private: void VirtualKernelInit(const ParallelContext*) override; + void KernelInitWithCudnn(const ParallelContext*); + void DoForwardDataContent(DeviceCtx*, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void DoForwardDataContentWithCudnn(DeviceCtx*, const Blob* in_blob, const Blob* weight_blob, + Blob* out_blob, + std::function<Blob*(const std::string&)> BnInOp2Blob) const; + void WeightBackward(DeviceCtx*, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob, Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void WeightBackwardWithCudnn(DeviceCtx*, const Blob* out_diff_blob, const Blob* in_blob, + Blob* weight_diff_blob, Blob* in_diff_blob, + std::function<Blob*(const std::string&)> BnInOp2Blob) const; + void BiasBackward(DeviceCtx*, const Blob* out_diff_blob, Blob* bias_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const override; + void BiasBackwardWithCudnn(DeviceCtx*, const Blob* out_diff_blob, Blob* bias_diff_blob, + std::function<Blob*(const std::string&)> BnInOp2Blob) const; std::unique_ptr<CudnnTensorDesc> in_desc_; std::unique_ptr<CudnnTensorDesc> out_desc_; @@ -194,26 +234,31 @@ class ColBufUtil final { DHWValidFunc<T> dhw_valid_func_; }; +template<DeviceType device_type, typename T> +class ConvKernelUtil; + template<typename T> -struct ConvKernelUtil final { +struct ConvKernelUtil<DeviceType::kCPU, T> final { public: - static void NCDHWIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape, - const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, - const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf); + static void NCDHWIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, + const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* col_buf); - static void NDHWCIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape, - const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, - const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf); + static void NDHWCIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, + const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* col_buf); - static void NCDHWCol2Im(DeviceCtx* device_ctx, const T* col_buf, const Shape& in_shape, - const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, - const int32_t* dilation_rate, const int32_t* padding_before, - T* in_diff_ptr); + static void NCDHWCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf, + const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* in_diff_ptr); - static void NDHWCCol2Im(DeviceCtx* device_ctx, const T* col_buf, const Shape& in_shape, - const Shape& weight_shape, const Shape& out_shape, const int32_t* strides, - const int32_t* dilation_rate, const int32_t* padding_before, - T* in_diff_ptr); + static void NDHWCCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf, + const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* in_diff_ptr); private: static void DoNCDWHFunc(const Shape& weight_shape, ColBufUtil<T>& conv_util, @@ -223,6 +268,32 @@ struct ConvKernelUtil final { ColBufWriter<T>* col_buf_writer); }; +template<typename T> +struct ConvKernelUtil<DeviceType::kGPU, T> final { + public: + static void NCDHWIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, + const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* col_buf); + + static void NDHWCIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, + const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* col_buf); + + static void NCDHWCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf, + const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* in_diff_ptr); + + static void NDHWCCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf, + const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* in_diff_ptr); + + private: +}; + } // namespace oneflow #endif // ONEFLOW_CORE_KERNEL_CONV_KERNEL_H_ diff --git a/oneflow/core/kernel/kernel.h b/oneflow/core/kernel/kernel.h index 001b45b57..4424847af 100644 --- a/oneflow/core/kernel/kernel.h +++ b/oneflow/core/kernel/kernel.h @@ -142,7 +142,8 @@ class KernelIf : public Kernel { const PbRpf<std::string>& from_bns, const PbRpf<std::string>& to_bns, void (Blob::*Copy)(DeviceCtx*, const Blob*)) const; - bool UseCudnn() const { return device_type == DeviceType::kGPU && op_conf().use_cudnn_on_gpu(); } + bool UseCudnn() const { return device_type == DeviceType::kGPU && UseCudnnOnGpu(); } + bool UseCudnnOnGpu() const { return op_conf().use_cudnn_on_gpu(); } }; template<DeviceType device_type, typename ModelType> diff --git a/oneflow/core/operator/conv_op.cpp b/oneflow/core/operator/conv_op.cpp index 6f0baca02..20003955d 100644 --- a/oneflow/core/operator/conv_op.cpp +++ b/oneflow/core/operator/conv_op.cpp @@ -120,7 +120,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string)> G if (GetValFromCustomizedConf<bool>("use_bias")) { // bias and bias_multiplier GetBlobDesc4BnInOp("bias")->mut_shape() = Shape({filters, 1}); - if (!UseCudnn()) { + if (!UseCudnnOnGpu()) { std::vector<int64_t> bias_mul_shape(NDims + 1, 1); for (size_t i = 0; i != NDims; ++i) { bias_mul_shape[i + 1] = out_shape[dhw_offset + i]; } GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape(bias_mul_shape); @@ -130,7 +130,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string)> G ConvOpCtx* conv_op_ctx = new ConvOpCtx(); EnrollOpCtx(conv_op_ctx); - if (!UseCudnn()) { + if (device_type() == DeviceType::kCPU || !UseCudnnOnGpu()) { // col_buf size_t col_buf_elem_cnt = 1; for (size_t i = 0; i != NDims + 1; ++i) { col_buf_elem_cnt *= weight_shape[i + 1]; } @@ -139,7 +139,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string)> G } #ifdef WITH_CUDA - if (device_type() == DeviceType::kGPU) { + if (device_type() == DeviceType::kGPU && UseCudnnOnGpu()) { // cudnn_buf InferCudnnAlgo(GetBlobDesc4BnInOp, &(conv_op_ctx->cudnn_conv_algo_ctx)); *buf_size = std::max({conv_op_ctx->cudnn_conv_algo_ctx.fwd_ws_size, @@ -235,7 +235,7 @@ void ConvOp<NDims>::VirtualGenKernelConf( const ParallelContext* parallel_ctx, KernelConf* kernel_conf, const OpContext* op_ctx) const { ConvKernelConf* conv_conf = kernel_conf->mutable_conv_conf(); conv_conf->set_dim(NDims); - if (!UseCudnn()) { + if (!UseCudnnOnGpu()) { GenKernelConfWithoutCudnn(GetBlobDesc4BnInOp, conv_conf); } else { GenKernelConfWithCudnn(GetBlobDesc4BnInOp, kernel_conf, conv_conf, op_ctx); diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 79501498d..91069547c 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -41,10 +41,6 @@ LogicalBlobId* Operator::MutBnInOp2Lbi(const std::string& bn_in_op) { } } -bool Operator::UseCudnn() const { - return device_type() == DeviceType::kGPU && op_conf().use_cudnn_on_gpu(); -} - const std::string& Operator::SoleIbn() const { CHECK_EQ(input_bns().size(), 1); return input_bns().Get(0); diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index 6572b1acc..e08bf058a 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -47,7 +47,8 @@ class Operator { // Getters const std::string& op_name() const { return op_conf().name(); } DeviceType device_type() const { return op_attribute_.op_conf().device_type(); } - bool UseCudnn() const; + bool UseCudnn() const { return device_type() == DeviceType::kGPU && UseCudnnOnGpu(); } + bool UseCudnnOnGpu() const { return op_conf().use_cudnn_on_gpu(); } const OperatorConf& op_conf() const { return op_attribute_.op_conf(); } virtual const PbMessage& GetCustomizedConf() const { UNIMPLEMENTED(); } -- GitLab