diff --git a/oneflow/core/kernel/conv_kernel.cpp b/oneflow/core/kernel/conv_kernel.cpp
index 89b41a2a3e89707a37fd77bac8f0dd30e3437d40..e78a961e5bd053cfe2716128f9f51bb52448e2f7 100644
--- a/oneflow/core/kernel/conv_kernel.cpp
+++ b/oneflow/core/kernel/conv_kernel.cpp
@@ -3,20 +3,6 @@
namespace oneflow {
-namespace {
-
-template<typename T>
-const T* GetImgDptr(const Blob* blob, int64_t idx) {
- return blob->dptr<T>() + blob->shape().Count(1) * idx;
-}
-
-template<typename T>
-T* GetImgMutDptr(Blob* blob, int64_t idx) {
- return const_cast<T*>(GetImgDptr<T>(blob, idx));
-}
-
-} // namespace
-
template<DeviceType device_type, typename T>
void ConvKernelIf<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
@@ -40,7 +26,7 @@ void ConvKernelIf<device_type, T>::BackwardDataContent(
template<DeviceType device_type, typename T>
void ConvKernelIf<device_type, T>::InitConstBufBlobs(
DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
- if (this->template GetValFromCustomizedOpConf<bool>("use_bias") && !this->UseCudnn()) {
+ if (this->template GetValFromCustomizedOpConf<bool>("use_bias") && !this->UseCudnnOnGpu()) {
InitializerConf bias_multiplier_initializer_conf;
bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f);
KernelUtil<device_type, T>::InitializeWithConf(ctx, bias_multiplier_initializer_conf, 0,
@@ -98,20 +84,21 @@ const int32_t ConvKernelIf<device_type, T>::OpKernelDim() const {
return this->GetConvKernelConf().dim();
}
-template<typename T>
-void ConvKernel<DeviceType::kCPU, T>::VirtualKernelInit(const ParallelContext* parallel_ctx) {
+template<DeviceType device_type, typename T>
+void ConvKernelImplByIm2Col<device_type, T>::VirtualKernelInit(
+ const ParallelContext* parallel_ctx) {
const std::string& data_format =
this->template GetValFromCustomizedOpConf<std::string>("data_format");
if (data_format == "channels_first") {
- im2col_func_ = ConvKernelUtil<T>::NCDHWIm2Col;
- col2im_func_ = ConvKernelUtil<T>::NCDHWCol2Im;
- forward_func_ = KernelUtil<DeviceType::kCPU, T>::OFGemm;
+ im2col_func_ = ConvKernelUtil<device_type, T>::NCDHWIm2Col;
+ col2im_func_ = ConvKernelUtil<device_type, T>::NCDHWCol2Im;
+ forward_func_ = KernelUtil<device_type, T>::OFGemm;
dhw_offset_ = 2;
is_out_diff_need_trans_ = CblasNoTrans;
} else {
- im2col_func_ = ConvKernelUtil<T>::NDHWCIm2Col;
- col2im_func_ = ConvKernelUtil<T>::NDHWCCol2Im;
- forward_func_ = KernelUtil<DeviceType::kCPU, T>::OFGemmTrans;
+ im2col_func_ = ConvKernelUtil<device_type, T>::NDHWCIm2Col;
+ col2im_func_ = ConvKernelUtil<device_type, T>::NDHWCCol2Im;
+ forward_func_ = KernelUtil<device_type, T>::OFGemmTrans;
dhw_offset_ = 1;
is_out_diff_need_trans_ = CblasTrans;
}
@@ -123,13 +110,14 @@ void ConvKernel<DeviceType::kCPU, T>::VirtualKernelInit(const ParallelContext* p
padding_before_ = this->GetConvKernelConf().pad_small_side().data();
}
-template<typename T>
-void ConvKernel<DeviceType::kCPU, T>::DoForwardDataContent(
+template<DeviceType device_type, typename T>
+void ConvKernelImplByIm2Col<device_type, T>::DoForwardDataContent(
DeviceCtx* device_ctx, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
FOR_RANGE(int64_t, i, 0, in_shape_.At(0)) {
- im2col_func_(device_ctx, GetImgDptr<T>(in_blob, i), in_shape_, weight_shape_, out_shape_,
- strides_, dilation_rate_, padding_before_, static_cast<T*>(device_ctx->buf_ptr()));
+ im2col_func_(this->OpKernelDim(), device_ctx, GetImgDptr<T>(in_blob, i), in_shape_,
+ weight_shape_, out_shape_, strides_, dilation_rate_, padding_before_,
+ static_cast<T*>(device_ctx->buf_ptr()));
// col_buf is device_ctx->buf_ptr()
// channels first: out = weight * col_buf
@@ -157,24 +145,25 @@ void ConvKernel<DeviceType::kCPU, T>::DoForwardDataContent(
}
}
-template<typename T>
-void ConvKernel<DeviceType::kCPU, T>::WeightBackward(
+template<DeviceType device_type, typename T>
+void ConvKernelImplByIm2Col<device_type, T>::WeightBackward(
DeviceCtx* ctx, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob,
Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* weight_blob = BnInOp2Blob("weight");
- Memset<DeviceType::kCPU>(ctx, weight_diff_blob->mut_dptr<T>(), 0,
- weight_diff_blob->ByteSizeOfDataContentField());
+ Memset<device_type>(ctx, weight_diff_blob->mut_dptr<T>(), 0,
+ weight_diff_blob->ByteSizeOfDataContentField());
if (in_diff_blob != nullptr) {
- Memset<DeviceType::kCPU>(ctx, in_diff_blob->mut_dptr<T>(), 0,
- in_diff_blob->ByteSizeOfDataContentField());
+ Memset<device_type>(ctx, in_diff_blob->mut_dptr<T>(), 0,
+ in_diff_blob->ByteSizeOfDataContentField());
}
FOR_RANGE(int64_t, i, 0, out_shape_.At(0)) {
- im2col_func_(ctx, GetImgDptr<T>(in_blob, i), in_shape_, weight_shape_, out_shape_, strides_,
- dilation_rate_, padding_before_, static_cast<T*>(ctx->buf_ptr()));
+ im2col_func_(this->OpKernelDim(), ctx, GetImgDptr<T>(in_blob, i), in_shape_, weight_shape_,
+ out_shape_, strides_, dilation_rate_, padding_before_,
+ static_cast<T*>(ctx->buf_ptr()));
// channels first: weight' += out[i]' * col_buf(T)
// channels last : weight' += out[i]'(T) * col_buf(T)
- KernelUtil<DeviceType::kCPU, T>::OFGemm(
+ KernelUtil<device_type, T>::OFGemm(
ctx, is_out_diff_need_trans_, CblasTrans,
weight_shape_.At(0), // filter
weight_shape_.Count(1), // ci * kd * kh * kw
@@ -185,7 +174,7 @@ void ConvKernel<DeviceType::kCPU, T>::WeightBackward(
if (in_diff_blob != nullptr) {
// channels first: col_buf' = weight(T) * out[i]'
// channels last : col_buf' = weight(T) * out[i]'(T)
- KernelUtil<DeviceType::kCPU, T>::OFGemm(
+ KernelUtil<device_type, T>::OFGemm(
ctx, CblasTrans, is_out_diff_need_trans_,
weight_shape_.Count(1), // ci * kd * kh * kw
out_shape_.Count(dhw_offset_, dhw_offset_ + 3), // od * oh * ow
@@ -194,24 +183,25 @@ void ConvKernel<DeviceType::kCPU, T>::WeightBackward(
static_cast<T>(0), static_cast<T*>(ctx->buf_ptr()));
// in' = col2im(col_buf')
- col2im_func_(ctx, static_cast<const T*>(ctx->buf_ptr()), in_shape_, weight_shape_, out_shape_,
- strides_, dilation_rate_, padding_before_, GetImgMutDptr<T>(in_diff_blob, i));
+ col2im_func_(this->OpKernelDim(), ctx, static_cast<const T*>(ctx->buf_ptr()), in_shape_,
+ weight_shape_, out_shape_, strides_, dilation_rate_, padding_before_,
+ GetImgMutDptr<T>(in_diff_blob, i));
}
}
}
-template<typename T>
-void ConvKernel<DeviceType::kCPU, T>::BiasBackward(
+template<DeviceType device_type, typename T>
+void ConvKernelImplByIm2Col<device_type, T>::BiasBackward(
DeviceCtx* ctx, const Blob* out_diff_blob, Blob* bias_diff_blob,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
- Memset<DeviceType::kCPU>(ctx, bias_diff_blob->mut_dptr<T>(), 0,
- bias_diff_blob->ByteSizeOfDataContentField());
+ Memset<device_type>(ctx, bias_diff_blob->mut_dptr<T>(), 0,
+ bias_diff_blob->ByteSizeOfDataContentField());
FOR_RANGE(int64_t, i, 0, out_shape_.At(0)) {
// channels first: bias' += out' * bias_mul
// channels last: bias' += out'(T) * bias_mul
- KernelUtil<DeviceType::kCPU, T>::OFGemm(
+ KernelUtil<device_type, T>::OFGemm(
ctx, is_out_diff_need_trans_, CblasNoTrans,
weight_shape_.At(0), // filter
1, // 1
@@ -364,8 +354,9 @@ void ColBufUtil<T>::operator()(ColBufWriter<T>* col_buf_writer, int64_t c, int64
}
template<typename T>
-void ConvKernelUtil<T>::DoNCDWHFunc(const Shape& weight_shape, ColBufUtil<T>& col_buf_util,
- ColBufWriter<T>* col_buf_writer) {
+void ConvKernelUtil<DeviceType::kCPU, T>::DoNCDWHFunc(const Shape& weight_shape,
+ ColBufUtil<T>& col_buf_util,
+ ColBufWriter<T>* col_buf_writer) {
for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) {
for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) {
for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) {
@@ -378,10 +369,10 @@ void ConvKernelUtil<T>::DoNCDWHFunc(const Shape& weight_shape, ColBufUtil<T>& co
}
template<typename T>
-void ConvKernelUtil<T>::NCDHWIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
- const Shape& weight_shape, const Shape& out_shape,
- const int32_t* strides, const int32_t* dilation_rate,
- const int32_t* padding_before, T* col_buf_ptr) {
+void ConvKernelUtil<DeviceType::kCPU, T>::NCDHWIm2Col(
+ const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
+ const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) {
ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before);
Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(3),
in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);
@@ -389,11 +380,10 @@ void ConvKernelUtil<T>::NCDHWIm2Col(DeviceCtx* device_ctx, const T* in_dptr, con
}
template<typename T>
-void ConvKernelUtil<T>::NCDHWCol2Im(DeviceCtx* device_ctx, const T* col_buf_ptr,
- const Shape& in_shape, const Shape& weight_shape,
- const Shape& out_shape, const int32_t* strides,
- const int32_t* dilation_rate, const int32_t* padding_before,
- T* in_diff_ptr) {
+void ConvKernelUtil<DeviceType::kCPU, T>::NCDHWCol2Im(
+ const int dim_num, DeviceCtx* device_ctx, const T* col_buf_ptr, const Shape& in_shape,
+ const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) {
ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before);
Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3),
in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);
@@ -401,8 +391,9 @@ void ConvKernelUtil<T>::NCDHWCol2Im(DeviceCtx* device_ctx, const T* col_buf_ptr,
}
template<typename T>
-void ConvKernelUtil<T>::DoNDWHCFunc(const Shape& weight_shape, ColBufUtil<T>& col_buf_util,
- ColBufWriter<T>* col_buf_writer) {
+void ConvKernelUtil<DeviceType::kCPU, T>::DoNDWHCFunc(const Shape& weight_shape,
+ ColBufUtil<T>& col_buf_util,
+ ColBufWriter<T>* col_buf_writer) {
for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) {
for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) {
for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) {
@@ -415,10 +406,10 @@ void ConvKernelUtil<T>::DoNDWHCFunc(const Shape& weight_shape, ColBufUtil<T>& co
}
template<typename T>
-void ConvKernelUtil<T>::NDHWCIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
- const Shape& weight_shape, const Shape& out_shape,
- const int32_t* strides, const int32_t* dilation_rate,
- const int32_t* padding_before, T* col_buf_ptr) {
+void ConvKernelUtil<DeviceType::kCPU, T>::NDHWCIm2Col(
+ const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
+ const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) {
ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before);
Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(2),
in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),
@@ -427,11 +418,10 @@ void ConvKernelUtil<T>::NDHWCIm2Col(DeviceCtx* device_ctx, const T* in_dptr, con
}
template<typename T>
-void ConvKernelUtil<T>::NDHWCCol2Im(DeviceCtx* device_ctx, const T* col_buf_ptr,
- const Shape& in_shape, const Shape& weight_shape,
- const Shape& out_shape, const int32_t* strides,
- const int32_t* dilation_rate, const int32_t* padding_before,
- T* in_diff_ptr) {
+void ConvKernelUtil<DeviceType::kCPU, T>::NDHWCCol2Im(
+ const int dim_num, DeviceCtx* device_ctx, const T* col_buf_ptr, const Shape& in_shape,
+ const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) {
ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before);
Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2),
in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),
diff --git a/oneflow/core/kernel/conv_kernel.cu b/oneflow/core/kernel/conv_kernel.cu
index e0243a126965a57e4284dd9f36868ef95fc270fb..2fc397ad6652929121a8a8f8479cbfe3a9968690 100644
--- a/oneflow/core/kernel/conv_kernel.cu
+++ b/oneflow/core/kernel/conv_kernel.cu
@@ -5,6 +5,52 @@ namespace oneflow {
template<typename T>
void ConvKernel<DeviceType::kGPU, T>::VirtualKernelInit(const ParallelContext* parallel_ctx) {
+ if (this->UseCudnnOnGpu()) {
+ KernelInitWithCudnn(parallel_ctx);
+ } else {
+ ConvKernelImplByIm2Col<DeviceType::kGPU, T>::VirtualKernelInit(parallel_ctx);
+ }
+}
+
+template<typename T>
+void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContent(
+ DeviceCtx* device_ctx, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob,
+ std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+ if (this->UseCudnnOnGpu()) {
+ DoForwardDataContentWithCudnn(device_ctx, in_blob, weight_blob, out_blob, BnInOp2Blob);
+ } else {
+ ConvKernelImplByIm2Col<DeviceType::kGPU, T>::DoForwardDataContent(
+ device_ctx, in_blob, weight_blob, out_blob, BnInOp2Blob);
+ }
+}
+
+template<typename T>
+void ConvKernel<DeviceType::kGPU, T>::WeightBackward(
+ DeviceCtx* device_ctx, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob,
+ Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+ if (this->UseCudnnOnGpu()) {
+ WeightBackwardWithCudnn(device_ctx, out_diff_blob, in_blob, weight_diff_blob, in_diff_blob,
+ BnInOp2Blob);
+ } else {
+ ConvKernelImplByIm2Col<DeviceType::kGPU, T>::WeightBackward(
+ device_ctx, out_diff_blob, in_blob, weight_diff_blob, in_diff_blob, BnInOp2Blob);
+ }
+}
+
+template<typename T>
+void ConvKernel<DeviceType::kGPU, T>::BiasBackward(
+ DeviceCtx* device_ctx, const Blob* out_diff_blob, Blob* bias_diff_blob,
+ std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+ if (this->UseCudnnOnGpu()) {
+ BiasBackwardWithCudnn(device_ctx, out_diff_blob, bias_diff_blob, BnInOp2Blob);
+ } else {
+ ConvKernelImplByIm2Col<DeviceType::kGPU, T>::BiasBackward(device_ctx, out_diff_blob,
+ bias_diff_blob, BnInOp2Blob);
+ }
+}
+
+template<typename T>
+void ConvKernel<DeviceType::kGPU, T>::KernelInitWithCudnn(const ParallelContext* parallel_ctx) {
Shape in_shape(this->GetConvKernelConf().in());
Shape out_shape(this->GetConvKernelConf().out());
Shape weight_shape(this->GetConvKernelConf().weight());
@@ -48,7 +94,7 @@ void ConvKernel<DeviceType::kGPU, T>::VirtualKernelInit(const ParallelContext* p
}
template<typename T>
-void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContent(
+void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContentWithCudnn(
DeviceCtx* device_ctx, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
CudaCheck(cudnnConvolutionForward(
@@ -67,7 +113,7 @@ void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContent(
}
template<typename T>
-void ConvKernel<DeviceType::kGPU, T>::WeightBackward(
+void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn(
DeviceCtx* device_ctx, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob,
Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* weight_blob = BnInOp2Blob("weight");
@@ -91,7 +137,7 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackward(
}
template<typename T>
-void ConvKernel<DeviceType::kGPU, T>::BiasBackward(
+void ConvKernel<DeviceType::kGPU, T>::BiasBackwardWithCudnn(
DeviceCtx* device_ctx, const Blob* out_diff_blob, Blob* bias_diff_blob,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
CudaCheck(cudnnConvolutionBackwardBias(device_ctx->cudnn_handle(), OnePtr<T>::value,
@@ -100,8 +146,349 @@ void ConvKernel<DeviceType::kGPU, T>::BiasBackward(
bias_diff_blob->mut_dptr<T>()));
}
+namespace {
+
+template<int dim_num>
+__device__ void InitSharedArrays(const int im_d, const int im_h, const int im_w, const int kernel_d,
+ const int kernel_h, const int kernel_w, const int out_d,
+ const int out_h, const int out_w, const int stride_d,
+ const int stride_h, const int stride_w, const int dilation_d,
+ const int dilation_h, const int dilation_w, const int pad_d,
+ const int pad_h, const int pad_w, int* shared_im,
+ int* shared_kernel, int* shared_out, int* shared_stride,
+ int* shared_dilation, int* shared_pad) {
+ if (threadIdx.x == 0) {
+ if (dim_num == 3) {
+ shared_im[0] = im_d;
+ shared_im[1] = im_h;
+ shared_im[2] = im_w;
+ shared_kernel[0] = kernel_d;
+ shared_kernel[1] = kernel_h;
+ shared_kernel[2] = kernel_w;
+ shared_out[0] = out_d;
+ shared_out[1] = out_h;
+ shared_out[2] = out_w;
+ shared_stride[0] = stride_d;
+ shared_stride[1] = stride_h;
+ shared_stride[2] = stride_w;
+ shared_dilation[0] = dilation_d;
+ shared_dilation[1] = dilation_h;
+ shared_dilation[2] = dilation_w;
+ shared_pad[0] = pad_d;
+ shared_pad[1] = pad_h;
+ shared_pad[2] = pad_w;
+ } else if (dim_num == 2) {
+ shared_im[0] = im_h;
+ shared_im[1] = im_w;
+ shared_kernel[0] = kernel_h;
+ shared_kernel[1] = kernel_w;
+ shared_out[0] = out_h;
+ shared_out[1] = out_w;
+ shared_stride[0] = stride_h;
+ shared_stride[1] = stride_w;
+ shared_dilation[0] = dilation_h;
+ shared_dilation[1] = dilation_w;
+ shared_pad[0] = pad_h;
+ shared_pad[1] = pad_w;
+ } else if (dim_num == 1) {
+ shared_im[0] = im_w;
+ shared_kernel[0] = kernel_w;
+ shared_out[0] = out_w;
+ shared_stride[0] = stride_w;
+ shared_dilation[0] = dilation_w;
+ shared_pad[0] = pad_w;
+ }
+ }
+ __syncthreads();
+}
+
+template<typename T, int dim_num, bool is_channel_first>
+__global__ void Im2ColGpu(const int n, const T* im_dptr, const int channel, const int im_d,
+ const int im_h, const int im_w, const int kernel_d, const int kernel_h,
+ const int kernel_w, const int out_d, const int out_h, const int out_w,
+ const int stride_d, const int stride_h, const int stride_w,
+ const int dilation_rate_d, const int dilation_rate_h,
+ const int dilation_rate_w, const int padding_before_d,
+ const int padding_before_h, const int padding_before_w, T* col_buf_dptr) {
+ __shared__ int shared_im[dim_num];
+ __shared__ int shared_kernel[dim_num];
+ __shared__ int shared_out[dim_num];
+ __shared__ int shared_stride[dim_num];
+ __shared__ int shared_dilation[dim_num];
+ __shared__ int shared_pad[dim_num];
+ InitSharedArrays<dim_num>(im_d, im_h, im_w, kernel_d, kernel_h, kernel_w, out_d, out_h, out_w,
+ stride_d, stride_h, stride_w, dilation_rate_d, dilation_rate_h,
+ dilation_rate_w, padding_before_d, padding_before_h, padding_before_w,
+ shared_im, shared_kernel, shared_out, shared_stride, shared_dilation,
+ shared_pad);
+
+ int out_size = 1;
+ for (int i = 0; i < dim_num; ++i) { out_size *= shared_out[i]; }
+ int kernel_index[dim_num];
+ int out_index[dim_num];
+ int channel_index;
+ int im_index[dim_num];
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ // total launch channel*od*oh*ow threads,
+ // each thread is responsible for a whole kernel size copy
+ // calc kernel_/out_/channel_index
+ channel_index = index / out_size;
+ int col_offset = index % out_size; // col_dim of col_buf: od*oh*ow
+ for (int i = dim_num - 1; i >= 0; --i) {
+ out_index[i] = col_offset % shared_out[i];
+ col_offset /= shared_out[i];
+ kernel_index[i] = 0;
+ }
+
+ int col_buf_offset = 0;
+ if (is_channel_first) { col_buf_offset = channel_index; }
+ for (int i = 0; i < 3; ++i) {
+ col_buf_offset *= shared_kernel[i];
+ // col_buf_offset += kernel_index[i]; commented for kernel_index[] == 0
+ }
+ if (is_channel_first == false) {
+ col_buf_offset *= channel;
+ col_buf_offset += channel_index;
+ }
+ col_buf_offset *= out_size;
+ col_buf_offset += index % out_size;
+
+ while (true) {
+ // calc im_index
+ bool is_im_index_valid = true;
+ for (int i = 0; i < dim_num; ++i) {
+ im_index[i] =
+ kernel_index[i] * shared_dilation[i] - shared_pad[i] + out_index[i] * shared_stride[i];
+ if (im_index[i] < 0 || im_index[i] >= shared_im[i]) {
+ is_im_index_valid = false;
+ break;
+ }
+ }
+
+ // write into col_buf
+ if (is_im_index_valid) {
+ // calc im_offset
+ int im_offset = 0;
+ if (is_channel_first) { im_offset = channel_index; }
+ for (int i = 0; i < dim_num; ++i) {
+ im_offset *= shared_im[i];
+ im_offset += im_index[i];
+ }
+ if (is_channel_first == false) {
+ im_offset *= channel;
+ im_offset += channel_index;
+ }
+ col_buf_dptr[col_buf_offset] = im_dptr[im_offset];
+ } else {
+ col_buf_dptr[col_buf_offset] = 0;
+ }
+ col_buf_offset += out_size;
+
+ // loop over all kernel index
+ bool is_loop_completed = true;
+ for (int i = dim_num - 1; i >= 0; --i) {
+ if (kernel_index[i] == shared_kernel[i] - 1) {
+ kernel_index[i] = 0;
+ } else {
+ kernel_index[i] += 1;
+ is_loop_completed = false;
+ break;
+ }
+ }
+ if (is_loop_completed) { break; }
+ }
+ }
+}
+
+template<typename T, int dim_num, bool is_channel_first>
+__global__ void Col2ImGpu(const int n, const T* col_buf_dptr, const int channel, const int im_d,
+ const int im_h, const int im_w, const int kernel_d, const int kernel_h,
+ const int kernel_w, const int out_d, const int out_h, const int out_w,
+ const int stride_d, const int stride_h, const int stride_w,
+ const int dilation_rate_d, const int dilation_rate_h,
+ const int dilation_rate_w, const int padding_before_d,
+ const int padding_before_h, const int padding_before_w, T* im_diff_dptr) {
+ __shared__ int shared_im[dim_num];
+ __shared__ int shared_kernel[dim_num];
+ __shared__ int shared_out[dim_num];
+ __shared__ int shared_stride[dim_num];
+ __shared__ int shared_dilation[dim_num];
+ __shared__ int shared_pad[dim_num];
+ InitSharedArrays<dim_num>(im_d, im_h, im_w, kernel_d, kernel_h, kernel_w, out_d, out_h, out_w,
+ stride_d, stride_h, stride_w, dilation_rate_d, dilation_rate_h,
+ dilation_rate_w, padding_before_d, padding_before_h, padding_before_w,
+ shared_im, shared_kernel, shared_out, shared_stride, shared_dilation,
+ shared_pad);
+
+ int kernel_index[dim_num];
+ int channel_index;
+ int im_index[dim_num];
+ int out_begin[dim_num];
+ int out_end[dim_num];
+ int out_index[dim_num];
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ // calc im_/channel_index
+ int im_offset = index;
+ if (is_channel_first == false) {
+ channel_index = im_offset % channel;
+ im_offset /= channel;
+ }
+ for (int i = dim_num - 1; i >= 0; --i) {
+ im_index[i] = im_offset % shared_im[i] + shared_pad[i];
+ im_offset /= shared_im[i];
+ }
+ if (is_channel_first) { channel_index = im_offset; }
+
+ // calc the out_range of this im element
+ bool is_in_dim_wrong = false;
+ for (int i = 0; i < dim_num; ++i) {
+ const int kernel_extent = shared_dilation[i] * (shared_kernel[i] - 1) + 1;
+ if (im_index[i] < kernel_extent) {
+ out_begin[i] = 0;
+ } else {
+ // original equation: ((im_index[i]-kernel_extent+1)+(stride[i]-1))/stride[i]
+ out_begin[i] = (im_index[i] - kernel_extent) / shared_stride[i] + 1;
+ }
+ out_end[i] = min(im_index[i] / shared_stride[i] + 1, shared_out[i]);
+ out_index[i] = out_begin[i];
+
+ if (out_begin[i] >= out_end[i]) { // for those im elements not chosen by kernel
+ is_in_dim_wrong = true;
+ break;
+ }
+ }
+ if (is_in_dim_wrong) {
+ im_diff_dptr[index] = 0;
+ continue;
+ }
+
+ T val = 0;
+ while (true) {
+ bool is_skip = false;
+ // calc kernel_index
+ for (int i = 0; i < dim_num; ++i) {
+ kernel_index[i] = im_index[i] - out_index[i] * shared_stride[i];
+ if (kernel_index[i] % shared_dilation[i] == 0) {
+ kernel_index[i] /= shared_dilation[i];
+ } else {
+ is_skip = true;
+ break;
+ }
+ }
+
+ // cal col_buf_offset
+ if (is_skip == false) {
+ int col_buf_offset = 0;
+ if (is_channel_first) { col_buf_offset = channel_index; }
+ for (int i = 0; i < dim_num; ++i) {
+ col_buf_offset *= shared_kernel[i];
+ col_buf_offset += kernel_index[i];
+ }
+ if (is_channel_first == false) {
+ col_buf_offset *= channel;
+ col_buf_offset += channel_index;
+ }
+ for (int i = 0; i < dim_num; ++i) {
+ col_buf_offset *= shared_out[i];
+ col_buf_offset += out_index[i];
+ }
+ val += col_buf_dptr[col_buf_offset];
+ }
+
+ // iter next out_index[]
+ bool is_iter_completed = true;
+ for (int i = dim_num - 1; i >= 0; --i) {
+ if (out_index[i] == out_end[i] - 1) {
+ out_index[i] = out_begin[i];
+ } else {
+ out_index[i] += 1;
+ is_iter_completed = false;
+ break;
+ }
+ }
+ if (is_iter_completed) { break; }
+ }
+ im_diff_dptr[index] = val;
+ }
+}
+
+} // namespace
+
+#define IM2COL_KERNEL_CALL(kernel_func_name, dim_num, is_channel_first, kernel_num, src_dptr, \
+ dst_dptr) \
+ kernel_func_name<T, dim_num, is_channel_first> \
+ <<<BlocksNum4ThreadsNum(kernel_num), kCudaThreadsNumPerBlock, 0, \
+ device_ctx->cuda_stream()>>>( \
+ kernel_num, src_dptr, in_shape.At(1), in_shape.At(2), in_shape.At(3), in_shape.At(4), \
+ weight_shape.At(2), weight_shape.At(3), weight_shape.At(4), out_shape.At(2), \
+ out_shape.At(3), out_shape.At(4), strides[0], strides[1], strides[2], dilation_rate[0], \
+ dilation_rate[1], dilation_rate[2], padding_before[0], padding_before[1], \
+ padding_before[2], dst_dptr)
+
+template<typename T>
+void ConvKernelUtil<DeviceType::kGPU, T>::NCDHWIm2Col(
+ const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
+ const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_dptr) {
+ int32_t kernels = weight_shape.At(1) * out_shape.Count(2);
+ switch (dim_num) {
+ case 1: IM2COL_KERNEL_CALL(Im2ColGpu, 1, true, kernels, in_dptr, col_buf_dptr); break;
+ case 2: IM2COL_KERNEL_CALL(Im2ColGpu, 2, true, kernels, in_dptr, col_buf_dptr); break;
+ case 3: IM2COL_KERNEL_CALL(Im2ColGpu, 3, true, kernels, in_dptr, col_buf_dptr); break;
+ default: UNIMPLEMENTED();
+ }
+}
+
+template<typename T>
+void ConvKernelUtil<DeviceType::kGPU, T>::NDHWCIm2Col(
+ const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
+ const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_dptr) {
+ int32_t kernels = weight_shape.At(1) * out_shape.Count(2);
+ switch (dim_num) {
+ case 1: IM2COL_KERNEL_CALL(Im2ColGpu, 1, false, kernels, in_dptr, col_buf_dptr); break;
+ case 2: IM2COL_KERNEL_CALL(Im2ColGpu, 2, false, kernels, in_dptr, col_buf_dptr); break;
+ case 3: IM2COL_KERNEL_CALL(Im2ColGpu, 3, false, kernels, in_dptr, col_buf_dptr); break;
+ default: UNIMPLEMENTED();
+ }
+}
+
+template<typename T>
+void ConvKernelUtil<DeviceType::kGPU, T>::NCDHWCol2Im(
+ const int dim_num, DeviceCtx* device_ctx, const T* col_buf_dptr, const Shape& in_shape,
+ const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_dptr) {
+ int32_t im_size = in_shape.Count(1);
+ switch (dim_num) {
+ case 1: IM2COL_KERNEL_CALL(Col2ImGpu, 1, true, im_size, col_buf_dptr, in_diff_dptr); break;
+ case 2: IM2COL_KERNEL_CALL(Col2ImGpu, 2, true, im_size, col_buf_dptr, in_diff_dptr); break;
+ case 3: IM2COL_KERNEL_CALL(Col2ImGpu, 3, true, im_size, col_buf_dptr, in_diff_dptr); break;
+ default: UNIMPLEMENTED();
+ }
+}
+
+template<typename T>
+void ConvKernelUtil<DeviceType::kGPU, T>::NDHWCCol2Im(
+ const int dim_num, DeviceCtx* device_ctx, const T* col_buf_dptr, const Shape& in_shape,
+ const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_dptr) {
+ int32_t im_size = in_shape.Count(1);
+ switch (dim_num) {
+ case 1: IM2COL_KERNEL_CALL(Col2ImGpu, 1, false, im_size, col_buf_dptr, in_diff_dptr); break;
+ case 2: IM2COL_KERNEL_CALL(Col2ImGpu, 2, false, im_size, col_buf_dptr, in_diff_dptr); break;
+ case 3: IM2COL_KERNEL_CALL(Col2ImGpu, 3, false, im_size, col_buf_dptr, in_diff_dptr); break;
+ default: UNIMPLEMENTED();
+ }
+}
+
+#undef IM2COL_KERNEL_CALL
+
#define INSTANTIATE_CONV_KERNEL(type_cpp, type_proto) \
template class ConvKernel<DeviceType::kGPU, type_cpp>;
OF_PP_FOR_EACH_TUPLE(INSTANTIATE_CONV_KERNEL, FLOATING_DATA_TYPE_SEQ)
+#define INSTANTIATE_CONV_KERNEL_UTIL(type_cpp, type_proto) \
+ template class ConvKernelUtil<DeviceType::kGPU, type_cpp>;
+OF_PP_FOR_EACH_TUPLE(INSTANTIATE_CONV_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ)
+
} // namespace oneflow
diff --git a/oneflow/core/kernel/conv_kernel.h b/oneflow/core/kernel/conv_kernel.h
index 2aaf5ceb1872c7fa30efc28fc940cdf161716516..2712902920a084cf11529bdb34f7a896d54f9f5b 100644
--- a/oneflow/core/kernel/conv_kernel.h
+++ b/oneflow/core/kernel/conv_kernel.h
@@ -7,6 +7,20 @@
namespace oneflow {
+namespace {
+
+template<typename T>
+const T* GetImgDptr(const Blob* blob, int64_t idx) {
+ return blob->dptr<T>() + blob->shape().Count(1) * idx;
+}
+
+template<typename T>
+T* GetImgMutDptr(Blob* blob, int64_t idx) {
+ return const_cast<T*>(GetImgDptr<T>(blob, idx));
+}
+
+} // namespace
+
template<DeviceType device_type, typename T>
class ConvKernelIf : public KernelIfWithActivation<device_type, T>,
public KernelIfWithModel<device_type, T> {
@@ -44,16 +58,18 @@ class ConvKernelIf : public KernelIfWithActivation<device_type, T>,
};
template<typename T>
-using Im2ColFunc = void (*)(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
- const Shape& weight_shape, const Shape& out_shape,
- const int32_t* strides, const int32_t* dilation_rate,
- const int32_t* padding_before, T* col_buf);
+using Im2ColFunc = void (*)(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+ const Shape& in_shape, const Shape& weight_shape,
+ const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before,
+ T* col_buf);
template<typename T>
-using Col2ImFunc = void (*)(DeviceCtx* device_ctx, const T* col_buf, const Shape& in_shape,
- const Shape& weight_shape, const Shape& out_shape,
- const int32_t* strides, const int32_t* dilation_rate,
- const int32_t* padding_before, T* in_diff_ptr);
+using Col2ImFunc = void (*)(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+ const Shape& in_shape, const Shape& weight_shape,
+ const Shape& out_shape, const int32_t* strides,
+ const int32_t* dilation_rate, const int32_t* padding_before,
+ T* in_diff_ptr);
template<typename T>
using GemmFunc = void (*)(DeviceCtx* ctx, enum CBLAS_TRANSPOSE, enum CBLAS_TRANSPOSE, const int m,
@@ -61,16 +77,13 @@ using GemmFunc = void (*)(DeviceCtx* ctx, enum CBLAS_TRANSPOSE, enum CBLAS_TRANS
const T beta, T* c);
template<DeviceType device_type, typename T>
-class ConvKernel;
-
-template<typename T>
-class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelIf<DeviceType::kCPU, T> {
+class ConvKernelImplByIm2Col : public ConvKernelIf<device_type, T> {
public:
- OF_DISALLOW_COPY_AND_MOVE(ConvKernel);
- ConvKernel() = default;
- ~ConvKernel() = default;
+ OF_DISALLOW_COPY_AND_MOVE(ConvKernelImplByIm2Col);
+ ConvKernelImplByIm2Col() = default;
+ ~ConvKernelImplByIm2Col() = default;
- private:
+ protected:
void VirtualKernelInit(const ParallelContext*) override;
void DoForwardDataContent(DeviceCtx*, const Blob* in_blob, const Blob* weight_blob,
Blob* out_blob,
@@ -80,6 +93,8 @@ class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelIf<DeviceType::kC
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void BiasBackward(DeviceCtx*, const Blob* out_diff_blob, Blob* bias_diff_blob,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+
+ private:
Im2ColFunc<T> im2col_func_;
Col2ImFunc<T> col2im_func_;
GemmFunc<T> forward_func_;
@@ -93,8 +108,21 @@ class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelIf<DeviceType::kC
Shape weight_shape_;
};
+template<DeviceType device_type, typename T>
+class ConvKernel;
+
+template<typename T>
+class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelImplByIm2Col<DeviceType::kCPU, T> {
+ public:
+ OF_DISALLOW_COPY_AND_MOVE(ConvKernel);
+ ConvKernel() = default;
+ ~ConvKernel() = default;
+
+ private:
+};
+
template<typename T>
-class ConvKernel<DeviceType::kGPU, T> final : public ConvKernelIf<DeviceType::kGPU, T> {
+class ConvKernel<DeviceType::kGPU, T> final : public ConvKernelImplByIm2Col<DeviceType::kGPU, T> {
public:
OF_DISALLOW_COPY_AND_MOVE(ConvKernel);
ConvKernel() = default;
@@ -102,14 +130,26 @@ class ConvKernel<DeviceType::kGPU, T> final : public ConvKernelIf<DeviceType::kG
private:
void VirtualKernelInit(const ParallelContext*) override;
+ void KernelInitWithCudnn(const ParallelContext*);
+
void DoForwardDataContent(DeviceCtx*, const Blob* in_blob, const Blob* weight_blob,
Blob* out_blob,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+ void DoForwardDataContentWithCudnn(DeviceCtx*, const Blob* in_blob, const Blob* weight_blob,
+ Blob* out_blob,
+ std::function<Blob*(const std::string&)> BnInOp2Blob) const;
+
void WeightBackward(DeviceCtx*, const Blob* out_diff_blob, const Blob* in_blob,
Blob* weight_diff_blob, Blob* in_diff_blob,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+ void WeightBackwardWithCudnn(DeviceCtx*, const Blob* out_diff_blob, const Blob* in_blob,
+ Blob* weight_diff_blob, Blob* in_diff_blob,
+ std::function<Blob*(const std::string&)> BnInOp2Blob) const;
+
void BiasBackward(DeviceCtx*, const Blob* out_diff_blob, Blob* bias_diff_blob,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+ void BiasBackwardWithCudnn(DeviceCtx*, const Blob* out_diff_blob, Blob* bias_diff_blob,
+ std::function<Blob*(const std::string&)> BnInOp2Blob) const;
std::unique_ptr<CudnnTensorDesc> in_desc_;
std::unique_ptr<CudnnTensorDesc> out_desc_;
@@ -194,26 +234,31 @@ class ColBufUtil final {
DHWValidFunc<T> dhw_valid_func_;
};
+template<DeviceType device_type, typename T>
+class ConvKernelUtil;
+
template<typename T>
-struct ConvKernelUtil final {
+struct ConvKernelUtil<DeviceType::kCPU, T> final {
public:
- static void NCDHWIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
- const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
- const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf);
+ static void NCDHWIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+ const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+ const int32_t* strides, const int32_t* dilation_rate,
+ const int32_t* padding_before, T* col_buf);
- static void NDHWCIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
- const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
- const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf);
+ static void NDHWCIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+ const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+ const int32_t* strides, const int32_t* dilation_rate,
+ const int32_t* padding_before, T* col_buf);
- static void NCDHWCol2Im(DeviceCtx* device_ctx, const T* col_buf, const Shape& in_shape,
- const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
- const int32_t* dilation_rate, const int32_t* padding_before,
- T* in_diff_ptr);
+ static void NCDHWCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+ const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+ const int32_t* strides, const int32_t* dilation_rate,
+ const int32_t* padding_before, T* in_diff_ptr);
- static void NDHWCCol2Im(DeviceCtx* device_ctx, const T* col_buf, const Shape& in_shape,
- const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
- const int32_t* dilation_rate, const int32_t* padding_before,
- T* in_diff_ptr);
+ static void NDHWCCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+ const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+ const int32_t* strides, const int32_t* dilation_rate,
+ const int32_t* padding_before, T* in_diff_ptr);
private:
static void DoNCDWHFunc(const Shape& weight_shape, ColBufUtil<T>& conv_util,
@@ -223,6 +268,32 @@ struct ConvKernelUtil final {
ColBufWriter<T>* col_buf_writer);
};
+template<typename T>
+struct ConvKernelUtil<DeviceType::kGPU, T> final {
+ public:
+ static void NCDHWIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+ const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+ const int32_t* strides, const int32_t* dilation_rate,
+ const int32_t* padding_before, T* col_buf);
+
+ static void NDHWCIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+ const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+ const int32_t* strides, const int32_t* dilation_rate,
+ const int32_t* padding_before, T* col_buf);
+
+ static void NCDHWCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+ const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+ const int32_t* strides, const int32_t* dilation_rate,
+ const int32_t* padding_before, T* in_diff_ptr);
+
+ static void NDHWCCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+ const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+ const int32_t* strides, const int32_t* dilation_rate,
+ const int32_t* padding_before, T* in_diff_ptr);
+
+ private:
+};
+
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_CONV_KERNEL_H_
diff --git a/oneflow/core/kernel/kernel.h b/oneflow/core/kernel/kernel.h
index 001b45b5766b273b7a861072371b1bf6c4d15757..4424847afd238d66c668c68b1b84d2423c3049fb 100644
--- a/oneflow/core/kernel/kernel.h
+++ b/oneflow/core/kernel/kernel.h
@@ -142,7 +142,8 @@ class KernelIf : public Kernel {
const PbRpf<std::string>& from_bns, const PbRpf<std::string>& to_bns,
void (Blob::*Copy)(DeviceCtx*, const Blob*)) const;
- bool UseCudnn() const { return device_type == DeviceType::kGPU && op_conf().use_cudnn_on_gpu(); }
+ bool UseCudnn() const { return device_type == DeviceType::kGPU && UseCudnnOnGpu(); }
+ bool UseCudnnOnGpu() const { return op_conf().use_cudnn_on_gpu(); }
};
template<DeviceType device_type, typename ModelType>
diff --git a/oneflow/core/operator/conv_op.cpp b/oneflow/core/operator/conv_op.cpp
index 6f0baca0228fc43731c6ee11f83aeb9251493750..20003955da311fc0095bdb6ba6bf2fbcc10e9729 100644
--- a/oneflow/core/operator/conv_op.cpp
+++ b/oneflow/core/operator/conv_op.cpp
@@ -120,7 +120,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string)> G
if (GetValFromCustomizedConf<bool>("use_bias")) {
// bias and bias_multiplier
GetBlobDesc4BnInOp("bias")->mut_shape() = Shape({filters, 1});
- if (!UseCudnn()) {
+ if (!UseCudnnOnGpu()) {
std::vector<int64_t> bias_mul_shape(NDims + 1, 1);
for (size_t i = 0; i != NDims; ++i) { bias_mul_shape[i + 1] = out_shape[dhw_offset + i]; }
GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape(bias_mul_shape);
@@ -130,7 +130,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string)> G
ConvOpCtx* conv_op_ctx = new ConvOpCtx();
EnrollOpCtx(conv_op_ctx);
- if (!UseCudnn()) {
+ if (device_type() == DeviceType::kCPU || !UseCudnnOnGpu()) {
// col_buf
size_t col_buf_elem_cnt = 1;
for (size_t i = 0; i != NDims + 1; ++i) { col_buf_elem_cnt *= weight_shape[i + 1]; }
@@ -139,7 +139,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string)> G
}
#ifdef WITH_CUDA
- if (device_type() == DeviceType::kGPU) {
+ if (device_type() == DeviceType::kGPU && UseCudnnOnGpu()) {
// cudnn_buf
InferCudnnAlgo(GetBlobDesc4BnInOp, &(conv_op_ctx->cudnn_conv_algo_ctx));
*buf_size = std::max({conv_op_ctx->cudnn_conv_algo_ctx.fwd_ws_size,
@@ -235,7 +235,7 @@ void ConvOp<NDims>::VirtualGenKernelConf(
const ParallelContext* parallel_ctx, KernelConf* kernel_conf, const OpContext* op_ctx) const {
ConvKernelConf* conv_conf = kernel_conf->mutable_conv_conf();
conv_conf->set_dim(NDims);
- if (!UseCudnn()) {
+ if (!UseCudnnOnGpu()) {
GenKernelConfWithoutCudnn(GetBlobDesc4BnInOp, conv_conf);
} else {
GenKernelConfWithCudnn(GetBlobDesc4BnInOp, kernel_conf, conv_conf, op_ctx);
diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp
index 79501498dd03d06c356f6c8e7e955cc07a458978..91069547c6139400bbd286e57b564dd39cf301bc 100644
--- a/oneflow/core/operator/operator.cpp
+++ b/oneflow/core/operator/operator.cpp
@@ -41,10 +41,6 @@ LogicalBlobId* Operator::MutBnInOp2Lbi(const std::string& bn_in_op) {
}
}
-bool Operator::UseCudnn() const {
- return device_type() == DeviceType::kGPU && op_conf().use_cudnn_on_gpu();
-}
-
const std::string& Operator::SoleIbn() const {
CHECK_EQ(input_bns().size(), 1);
return input_bns().Get(0);
diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h
index 6572b1acc7ab1cadbfc1b9f2f838e77d81579a5a..e08bf058a0b5ff511e5231e40d3cd26d9daa503e 100644
--- a/oneflow/core/operator/operator.h
+++ b/oneflow/core/operator/operator.h
@@ -47,7 +47,8 @@ class Operator {
// Getters
const std::string& op_name() const { return op_conf().name(); }
DeviceType device_type() const { return op_attribute_.op_conf().device_type(); }
- bool UseCudnn() const;
+ bool UseCudnn() const { return device_type() == DeviceType::kGPU && UseCudnnOnGpu(); }
+ bool UseCudnnOnGpu() const { return op_conf().use_cudnn_on_gpu(); }
const OperatorConf& op_conf() const { return op_attribute_.op_conf(); }
virtual const PbMessage& GetCustomizedConf() const { UNIMPLEMENTED(); }