From c37985d7a33f28caa1e72c7c6e4125f2f636101f Mon Sep 17 00:00:00 2001
From: Niu Chong <niuchong_dnm@163.com>
Date: Wed, 13 Jun 2018 07:44:12 +0800
Subject: [PATCH] Conv GPU Kernel Without Cudnn (#916)

* feat: add UseCudnnOnGpu in Operator and fix conv op

* feat: add WithCudnn and WithoutCudnn in ConvKernel<kGPU, T>

* feat: add CUDA NCDHWIm2ColGpu kernel and compile done

* refine: rename Im2ColNCDHWGpu() to NCDHWIm2ColGpu()

* fix: reverse update about int32_t to int64_t in BlocksNum4ThreadNum()

* feat: add CUDA NCDHWCol2ImGpu kernel

* refactor: extract InitSharedArrays() for device code

* feat: add CUDA NDHWCIm2ColGpu kernel

* feat: add CUDA NDHWCCol2ImGpu kernel and fix typos

* fix: fix the bug of calc im_offset when NDHWCIm2ColGpu

* refactor: extract Im2ColCalcKernelAndOutIndex() and Im2ColCalcImIndex()

* fix: fix format and the missing shared_im[] parameter in Im2ColCalcImIndex()

* refactor: merge NCDHWCol2ImGpu() and NDHWCCol2ImGpu() into Col2ImGpu()

* refactor: merge NCDHWIm2ColGpu() and NDHWCIm2ColGpu() into Im2ColGpu()

* feat: add class ConvKernelImplByIm2Col between ConvKernelIf and ConvKernel; compile done, to be run

* fix: add explicit template instantiation for ConvKernelUtil

* refine: remove unused class function declaration: KernelInitWithoutCudnn e.g.

* fix(operator/conv_op.cpp): make sure UseCudnnOnGpu() == true when infer cudnn algo

* refine(kernel/conv_kernel.cu): let the gpu kernel function be inside the anoymous namespace

* refactor: add dim_num as the template paramter of Im2ColGpu() and Col2ImGpu()

* refactor: add is_channel_first as the template paramter of Im2ColGpu() and Col2ImGpu()

* refine(kernel/conv_kernel.cu): add  #undef IM2COL_FUNC_CALL

* refine(kernel/conv_kernel.cu): add dim_num as the template parameter of InitSharedMemory()

* fix(kernel/conv_kernel.cu): fix the bug of use col_offset in Im2ColGpu()
---
 oneflow/core/kernel/conv_kernel.cpp | 122 ++++-----
 oneflow/core/kernel/conv_kernel.cu  | 393 +++++++++++++++++++++++++++-
 oneflow/core/kernel/conv_kernel.h   | 135 +++++++---
 oneflow/core/kernel/kernel.h        |   3 +-
 oneflow/core/operator/conv_op.cpp   |   8 +-
 oneflow/core/operator/operator.cpp  |   4 -
 oneflow/core/operator/operator.h    |   3 +-
 7 files changed, 557 insertions(+), 111 deletions(-)

diff --git a/oneflow/core/kernel/conv_kernel.cpp b/oneflow/core/kernel/conv_kernel.cpp
index 89b41a2a3..e78a961e5 100644
--- a/oneflow/core/kernel/conv_kernel.cpp
+++ b/oneflow/core/kernel/conv_kernel.cpp
@@ -3,20 +3,6 @@
 
 namespace oneflow {
 
-namespace {
-
-template<typename T>
-const T* GetImgDptr(const Blob* blob, int64_t idx) {
-  return blob->dptr<T>() + blob->shape().Count(1) * idx;
-}
-
-template<typename T>
-T* GetImgMutDptr(Blob* blob, int64_t idx) {
-  return const_cast<T*>(GetImgDptr<T>(blob, idx));
-}
-
-}  // namespace
-
 template<DeviceType device_type, typename T>
 void ConvKernelIf<device_type, T>::ForwardDataContent(
     const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
@@ -40,7 +26,7 @@ void ConvKernelIf<device_type, T>::BackwardDataContent(
 template<DeviceType device_type, typename T>
 void ConvKernelIf<device_type, T>::InitConstBufBlobs(
     DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
-  if (this->template GetValFromCustomizedOpConf<bool>("use_bias") && !this->UseCudnn()) {
+  if (this->template GetValFromCustomizedOpConf<bool>("use_bias") && !this->UseCudnnOnGpu()) {
     InitializerConf bias_multiplier_initializer_conf;
     bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f);
     KernelUtil<device_type, T>::InitializeWithConf(ctx, bias_multiplier_initializer_conf, 0,
@@ -98,20 +84,21 @@ const int32_t ConvKernelIf<device_type, T>::OpKernelDim() const {
   return this->GetConvKernelConf().dim();
 }
 
-template<typename T>
-void ConvKernel<DeviceType::kCPU, T>::VirtualKernelInit(const ParallelContext* parallel_ctx) {
+template<DeviceType device_type, typename T>
+void ConvKernelImplByIm2Col<device_type, T>::VirtualKernelInit(
+    const ParallelContext* parallel_ctx) {
   const std::string& data_format =
       this->template GetValFromCustomizedOpConf<std::string>("data_format");
   if (data_format == "channels_first") {
-    im2col_func_ = ConvKernelUtil<T>::NCDHWIm2Col;
-    col2im_func_ = ConvKernelUtil<T>::NCDHWCol2Im;
-    forward_func_ = KernelUtil<DeviceType::kCPU, T>::OFGemm;
+    im2col_func_ = ConvKernelUtil<device_type, T>::NCDHWIm2Col;
+    col2im_func_ = ConvKernelUtil<device_type, T>::NCDHWCol2Im;
+    forward_func_ = KernelUtil<device_type, T>::OFGemm;
     dhw_offset_ = 2;
     is_out_diff_need_trans_ = CblasNoTrans;
   } else {
-    im2col_func_ = ConvKernelUtil<T>::NDHWCIm2Col;
-    col2im_func_ = ConvKernelUtil<T>::NDHWCCol2Im;
-    forward_func_ = KernelUtil<DeviceType::kCPU, T>::OFGemmTrans;
+    im2col_func_ = ConvKernelUtil<device_type, T>::NDHWCIm2Col;
+    col2im_func_ = ConvKernelUtil<device_type, T>::NDHWCCol2Im;
+    forward_func_ = KernelUtil<device_type, T>::OFGemmTrans;
     dhw_offset_ = 1;
     is_out_diff_need_trans_ = CblasTrans;
   }
@@ -123,13 +110,14 @@ void ConvKernel<DeviceType::kCPU, T>::VirtualKernelInit(const ParallelContext* p
   padding_before_ = this->GetConvKernelConf().pad_small_side().data();
 }
 
-template<typename T>
-void ConvKernel<DeviceType::kCPU, T>::DoForwardDataContent(
+template<DeviceType device_type, typename T>
+void ConvKernelImplByIm2Col<device_type, T>::DoForwardDataContent(
     DeviceCtx* device_ctx, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob,
     std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   FOR_RANGE(int64_t, i, 0, in_shape_.At(0)) {
-    im2col_func_(device_ctx, GetImgDptr<T>(in_blob, i), in_shape_, weight_shape_, out_shape_,
-                 strides_, dilation_rate_, padding_before_, static_cast<T*>(device_ctx->buf_ptr()));
+    im2col_func_(this->OpKernelDim(), device_ctx, GetImgDptr<T>(in_blob, i), in_shape_,
+                 weight_shape_, out_shape_, strides_, dilation_rate_, padding_before_,
+                 static_cast<T*>(device_ctx->buf_ptr()));
 
     // col_buf is device_ctx->buf_ptr()
     // channels first: out = weight * col_buf
@@ -157,24 +145,25 @@ void ConvKernel<DeviceType::kCPU, T>::DoForwardDataContent(
   }
 }
 
-template<typename T>
-void ConvKernel<DeviceType::kCPU, T>::WeightBackward(
+template<DeviceType device_type, typename T>
+void ConvKernelImplByIm2Col<device_type, T>::WeightBackward(
     DeviceCtx* ctx, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob,
     Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* weight_blob = BnInOp2Blob("weight");
-  Memset<DeviceType::kCPU>(ctx, weight_diff_blob->mut_dptr<T>(), 0,
-                           weight_diff_blob->ByteSizeOfDataContentField());
+  Memset<device_type>(ctx, weight_diff_blob->mut_dptr<T>(), 0,
+                      weight_diff_blob->ByteSizeOfDataContentField());
   if (in_diff_blob != nullptr) {
-    Memset<DeviceType::kCPU>(ctx, in_diff_blob->mut_dptr<T>(), 0,
-                             in_diff_blob->ByteSizeOfDataContentField());
+    Memset<device_type>(ctx, in_diff_blob->mut_dptr<T>(), 0,
+                        in_diff_blob->ByteSizeOfDataContentField());
   }
   FOR_RANGE(int64_t, i, 0, out_shape_.At(0)) {
-    im2col_func_(ctx, GetImgDptr<T>(in_blob, i), in_shape_, weight_shape_, out_shape_, strides_,
-                 dilation_rate_, padding_before_, static_cast<T*>(ctx->buf_ptr()));
+    im2col_func_(this->OpKernelDim(), ctx, GetImgDptr<T>(in_blob, i), in_shape_, weight_shape_,
+                 out_shape_, strides_, dilation_rate_, padding_before_,
+                 static_cast<T*>(ctx->buf_ptr()));
 
     // channels first:  weight' += out[i]' * col_buf(T)
     // channels last :  weight' += out[i]'(T) * col_buf(T)
-    KernelUtil<DeviceType::kCPU, T>::OFGemm(
+    KernelUtil<device_type, T>::OFGemm(
         ctx, is_out_diff_need_trans_, CblasTrans,
         weight_shape_.At(0),                             //  filter
         weight_shape_.Count(1),                          //  ci * kd * kh * kw
@@ -185,7 +174,7 @@ void ConvKernel<DeviceType::kCPU, T>::WeightBackward(
     if (in_diff_blob != nullptr) {
       // channels first:  col_buf' = weight(T) * out[i]'
       // channels last :  col_buf' = weight(T) * out[i]'(T)
-      KernelUtil<DeviceType::kCPU, T>::OFGemm(
+      KernelUtil<device_type, T>::OFGemm(
           ctx, CblasTrans, is_out_diff_need_trans_,
           weight_shape_.Count(1),                          //  ci * kd * kh * kw
           out_shape_.Count(dhw_offset_, dhw_offset_ + 3),  //  od * oh * ow
@@ -194,24 +183,25 @@ void ConvKernel<DeviceType::kCPU, T>::WeightBackward(
           static_cast<T>(0), static_cast<T*>(ctx->buf_ptr()));
 
       // in' = col2im(col_buf')
-      col2im_func_(ctx, static_cast<const T*>(ctx->buf_ptr()), in_shape_, weight_shape_, out_shape_,
-                   strides_, dilation_rate_, padding_before_, GetImgMutDptr<T>(in_diff_blob, i));
+      col2im_func_(this->OpKernelDim(), ctx, static_cast<const T*>(ctx->buf_ptr()), in_shape_,
+                   weight_shape_, out_shape_, strides_, dilation_rate_, padding_before_,
+                   GetImgMutDptr<T>(in_diff_blob, i));
     }
   }
 }
 
-template<typename T>
-void ConvKernel<DeviceType::kCPU, T>::BiasBackward(
+template<DeviceType device_type, typename T>
+void ConvKernelImplByIm2Col<device_type, T>::BiasBackward(
     DeviceCtx* ctx, const Blob* out_diff_blob, Blob* bias_diff_blob,
     std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
-  Memset<DeviceType::kCPU>(ctx, bias_diff_blob->mut_dptr<T>(), 0,
-                           bias_diff_blob->ByteSizeOfDataContentField());
+  Memset<device_type>(ctx, bias_diff_blob->mut_dptr<T>(), 0,
+                      bias_diff_blob->ByteSizeOfDataContentField());
 
   FOR_RANGE(int64_t, i, 0, out_shape_.At(0)) {
     // channels first:  bias' += out' * bias_mul
     // channels last:   bias' += out'(T) * bias_mul
-    KernelUtil<DeviceType::kCPU, T>::OFGemm(
+    KernelUtil<device_type, T>::OFGemm(
         ctx, is_out_diff_need_trans_, CblasNoTrans,
         weight_shape_.At(0),                             //  filter
         1,                                               //  1
@@ -364,8 +354,9 @@ void ColBufUtil<T>::operator()(ColBufWriter<T>* col_buf_writer, int64_t c, int64
 }
 
 template<typename T>
-void ConvKernelUtil<T>::DoNCDWHFunc(const Shape& weight_shape, ColBufUtil<T>& col_buf_util,
-                                    ColBufWriter<T>* col_buf_writer) {
+void ConvKernelUtil<DeviceType::kCPU, T>::DoNCDWHFunc(const Shape& weight_shape,
+                                                      ColBufUtil<T>& col_buf_util,
+                                                      ColBufWriter<T>* col_buf_writer) {
   for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) {
     for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) {
       for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) {
@@ -378,10 +369,10 @@ void ConvKernelUtil<T>::DoNCDWHFunc(const Shape& weight_shape, ColBufUtil<T>& co
 }
 
 template<typename T>
-void ConvKernelUtil<T>::NCDHWIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
-                                    const Shape& weight_shape, const Shape& out_shape,
-                                    const int32_t* strides, const int32_t* dilation_rate,
-                                    const int32_t* padding_before, T* col_buf_ptr) {
+void ConvKernelUtil<DeviceType::kCPU, T>::NCDHWIm2Col(
+    const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
+    const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+    const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) {
   ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before);
   Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(3),
                                  in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);
@@ -389,11 +380,10 @@ void ConvKernelUtil<T>::NCDHWIm2Col(DeviceCtx* device_ctx, const T* in_dptr, con
 }
 
 template<typename T>
-void ConvKernelUtil<T>::NCDHWCol2Im(DeviceCtx* device_ctx, const T* col_buf_ptr,
-                                    const Shape& in_shape, const Shape& weight_shape,
-                                    const Shape& out_shape, const int32_t* strides,
-                                    const int32_t* dilation_rate, const int32_t* padding_before,
-                                    T* in_diff_ptr) {
+void ConvKernelUtil<DeviceType::kCPU, T>::NCDHWCol2Im(
+    const int dim_num, DeviceCtx* device_ctx, const T* col_buf_ptr, const Shape& in_shape,
+    const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+    const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) {
   ColBufUtil<T> col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before);
   Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3),
                                  in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1);
@@ -401,8 +391,9 @@ void ConvKernelUtil<T>::NCDHWCol2Im(DeviceCtx* device_ctx, const T* col_buf_ptr,
 }
 
 template<typename T>
-void ConvKernelUtil<T>::DoNDWHCFunc(const Shape& weight_shape, ColBufUtil<T>& col_buf_util,
-                                    ColBufWriter<T>* col_buf_writer) {
+void ConvKernelUtil<DeviceType::kCPU, T>::DoNDWHCFunc(const Shape& weight_shape,
+                                                      ColBufUtil<T>& col_buf_util,
+                                                      ColBufWriter<T>* col_buf_writer) {
   for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) {
     for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) {
       for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) {
@@ -415,10 +406,10 @@ void ConvKernelUtil<T>::DoNDWHCFunc(const Shape& weight_shape, ColBufUtil<T>& co
 }
 
 template<typename T>
-void ConvKernelUtil<T>::NDHWCIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
-                                    const Shape& weight_shape, const Shape& out_shape,
-                                    const int32_t* strides, const int32_t* dilation_rate,
-                                    const int32_t* padding_before, T* col_buf_ptr) {
+void ConvKernelUtil<DeviceType::kCPU, T>::NDHWCIm2Col(
+    const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
+    const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+    const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_ptr) {
   ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before);
   Im2ColWriter<T> col_buf_writer(in_dptr, col_buf_ptr, in_shape.Count(2), in_shape.Count(2),
                                  in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),
@@ -427,11 +418,10 @@ void ConvKernelUtil<T>::NDHWCIm2Col(DeviceCtx* device_ctx, const T* in_dptr, con
 }
 
 template<typename T>
-void ConvKernelUtil<T>::NDHWCCol2Im(DeviceCtx* device_ctx, const T* col_buf_ptr,
-                                    const Shape& in_shape, const Shape& weight_shape,
-                                    const Shape& out_shape, const int32_t* strides,
-                                    const int32_t* dilation_rate, const int32_t* padding_before,
-                                    T* in_diff_ptr) {
+void ConvKernelUtil<DeviceType::kCPU, T>::NDHWCCol2Im(
+    const int dim_num, DeviceCtx* device_ctx, const T* col_buf_ptr, const Shape& in_shape,
+    const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+    const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_ptr) {
   ColBufUtil<T> col_buf_util(in_shape, out_shape, 1, strides, dilation_rate, padding_before);
   Col2ImWriter<T> col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2),
                                  in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4),
diff --git a/oneflow/core/kernel/conv_kernel.cu b/oneflow/core/kernel/conv_kernel.cu
index e0243a126..2fc397ad6 100644
--- a/oneflow/core/kernel/conv_kernel.cu
+++ b/oneflow/core/kernel/conv_kernel.cu
@@ -5,6 +5,52 @@ namespace oneflow {
 
 template<typename T>
 void ConvKernel<DeviceType::kGPU, T>::VirtualKernelInit(const ParallelContext* parallel_ctx) {
+  if (this->UseCudnnOnGpu()) {
+    KernelInitWithCudnn(parallel_ctx);
+  } else {
+    ConvKernelImplByIm2Col<DeviceType::kGPU, T>::VirtualKernelInit(parallel_ctx);
+  }
+}
+
+template<typename T>
+void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContent(
+    DeviceCtx* device_ctx, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob,
+    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  if (this->UseCudnnOnGpu()) {
+    DoForwardDataContentWithCudnn(device_ctx, in_blob, weight_blob, out_blob, BnInOp2Blob);
+  } else {
+    ConvKernelImplByIm2Col<DeviceType::kGPU, T>::DoForwardDataContent(
+        device_ctx, in_blob, weight_blob, out_blob, BnInOp2Blob);
+  }
+}
+
+template<typename T>
+void ConvKernel<DeviceType::kGPU, T>::WeightBackward(
+    DeviceCtx* device_ctx, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob,
+    Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  if (this->UseCudnnOnGpu()) {
+    WeightBackwardWithCudnn(device_ctx, out_diff_blob, in_blob, weight_diff_blob, in_diff_blob,
+                            BnInOp2Blob);
+  } else {
+    ConvKernelImplByIm2Col<DeviceType::kGPU, T>::WeightBackward(
+        device_ctx, out_diff_blob, in_blob, weight_diff_blob, in_diff_blob, BnInOp2Blob);
+  }
+}
+
+template<typename T>
+void ConvKernel<DeviceType::kGPU, T>::BiasBackward(
+    DeviceCtx* device_ctx, const Blob* out_diff_blob, Blob* bias_diff_blob,
+    std::function<Blob*(const std::string&)> BnInOp2Blob) const {
+  if (this->UseCudnnOnGpu()) {
+    BiasBackwardWithCudnn(device_ctx, out_diff_blob, bias_diff_blob, BnInOp2Blob);
+  } else {
+    ConvKernelImplByIm2Col<DeviceType::kGPU, T>::BiasBackward(device_ctx, out_diff_blob,
+                                                              bias_diff_blob, BnInOp2Blob);
+  }
+}
+
+template<typename T>
+void ConvKernel<DeviceType::kGPU, T>::KernelInitWithCudnn(const ParallelContext* parallel_ctx) {
   Shape in_shape(this->GetConvKernelConf().in());
   Shape out_shape(this->GetConvKernelConf().out());
   Shape weight_shape(this->GetConvKernelConf().weight());
@@ -48,7 +94,7 @@ void ConvKernel<DeviceType::kGPU, T>::VirtualKernelInit(const ParallelContext* p
 }
 
 template<typename T>
-void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContent(
+void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContentWithCudnn(
     DeviceCtx* device_ctx, const Blob* in_blob, const Blob* weight_blob, Blob* out_blob,
     std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   CudaCheck(cudnnConvolutionForward(
@@ -67,7 +113,7 @@ void ConvKernel<DeviceType::kGPU, T>::DoForwardDataContent(
 }
 
 template<typename T>
-void ConvKernel<DeviceType::kGPU, T>::WeightBackward(
+void ConvKernel<DeviceType::kGPU, T>::WeightBackwardWithCudnn(
     DeviceCtx* device_ctx, const Blob* out_diff_blob, const Blob* in_blob, Blob* weight_diff_blob,
     Blob* in_diff_blob, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   const Blob* weight_blob = BnInOp2Blob("weight");
@@ -91,7 +137,7 @@ void ConvKernel<DeviceType::kGPU, T>::WeightBackward(
 }
 
 template<typename T>
-void ConvKernel<DeviceType::kGPU, T>::BiasBackward(
+void ConvKernel<DeviceType::kGPU, T>::BiasBackwardWithCudnn(
     DeviceCtx* device_ctx, const Blob* out_diff_blob, Blob* bias_diff_blob,
     std::function<Blob*(const std::string&)> BnInOp2Blob) const {
   CudaCheck(cudnnConvolutionBackwardBias(device_ctx->cudnn_handle(), OnePtr<T>::value,
@@ -100,8 +146,349 @@ void ConvKernel<DeviceType::kGPU, T>::BiasBackward(
                                          bias_diff_blob->mut_dptr<T>()));
 }
 
+namespace {
+
+template<int dim_num>
+__device__ void InitSharedArrays(const int im_d, const int im_h, const int im_w, const int kernel_d,
+                                 const int kernel_h, const int kernel_w, const int out_d,
+                                 const int out_h, const int out_w, const int stride_d,
+                                 const int stride_h, const int stride_w, const int dilation_d,
+                                 const int dilation_h, const int dilation_w, const int pad_d,
+                                 const int pad_h, const int pad_w, int* shared_im,
+                                 int* shared_kernel, int* shared_out, int* shared_stride,
+                                 int* shared_dilation, int* shared_pad) {
+  if (threadIdx.x == 0) {
+    if (dim_num == 3) {
+      shared_im[0] = im_d;
+      shared_im[1] = im_h;
+      shared_im[2] = im_w;
+      shared_kernel[0] = kernel_d;
+      shared_kernel[1] = kernel_h;
+      shared_kernel[2] = kernel_w;
+      shared_out[0] = out_d;
+      shared_out[1] = out_h;
+      shared_out[2] = out_w;
+      shared_stride[0] = stride_d;
+      shared_stride[1] = stride_h;
+      shared_stride[2] = stride_w;
+      shared_dilation[0] = dilation_d;
+      shared_dilation[1] = dilation_h;
+      shared_dilation[2] = dilation_w;
+      shared_pad[0] = pad_d;
+      shared_pad[1] = pad_h;
+      shared_pad[2] = pad_w;
+    } else if (dim_num == 2) {
+      shared_im[0] = im_h;
+      shared_im[1] = im_w;
+      shared_kernel[0] = kernel_h;
+      shared_kernel[1] = kernel_w;
+      shared_out[0] = out_h;
+      shared_out[1] = out_w;
+      shared_stride[0] = stride_h;
+      shared_stride[1] = stride_w;
+      shared_dilation[0] = dilation_h;
+      shared_dilation[1] = dilation_w;
+      shared_pad[0] = pad_h;
+      shared_pad[1] = pad_w;
+    } else if (dim_num == 1) {
+      shared_im[0] = im_w;
+      shared_kernel[0] = kernel_w;
+      shared_out[0] = out_w;
+      shared_stride[0] = stride_w;
+      shared_dilation[0] = dilation_w;
+      shared_pad[0] = pad_w;
+    }
+  }
+  __syncthreads();
+}
+
+template<typename T, int dim_num, bool is_channel_first>
+__global__ void Im2ColGpu(const int n, const T* im_dptr, const int channel, const int im_d,
+                          const int im_h, const int im_w, const int kernel_d, const int kernel_h,
+                          const int kernel_w, const int out_d, const int out_h, const int out_w,
+                          const int stride_d, const int stride_h, const int stride_w,
+                          const int dilation_rate_d, const int dilation_rate_h,
+                          const int dilation_rate_w, const int padding_before_d,
+                          const int padding_before_h, const int padding_before_w, T* col_buf_dptr) {
+  __shared__ int shared_im[dim_num];
+  __shared__ int shared_kernel[dim_num];
+  __shared__ int shared_out[dim_num];
+  __shared__ int shared_stride[dim_num];
+  __shared__ int shared_dilation[dim_num];
+  __shared__ int shared_pad[dim_num];
+  InitSharedArrays<dim_num>(im_d, im_h, im_w, kernel_d, kernel_h, kernel_w, out_d, out_h, out_w,
+                            stride_d, stride_h, stride_w, dilation_rate_d, dilation_rate_h,
+                            dilation_rate_w, padding_before_d, padding_before_h, padding_before_w,
+                            shared_im, shared_kernel, shared_out, shared_stride, shared_dilation,
+                            shared_pad);
+
+  int out_size = 1;
+  for (int i = 0; i < dim_num; ++i) { out_size *= shared_out[i]; }
+  int kernel_index[dim_num];
+  int out_index[dim_num];
+  int channel_index;
+  int im_index[dim_num];
+  CUDA_1D_KERNEL_LOOP(index, n) {
+    // total launch channel*od*oh*ow threads,
+    // each thread is responsible for a whole kernel size copy
+    // calc kernel_/out_/channel_index
+    channel_index = index / out_size;
+    int col_offset = index % out_size;  // col_dim of col_buf: od*oh*ow
+    for (int i = dim_num - 1; i >= 0; --i) {
+      out_index[i] = col_offset % shared_out[i];
+      col_offset /= shared_out[i];
+      kernel_index[i] = 0;
+    }
+
+    int col_buf_offset = 0;
+    if (is_channel_first) { col_buf_offset = channel_index; }
+    for (int i = 0; i < 3; ++i) {
+      col_buf_offset *= shared_kernel[i];
+      // col_buf_offset += kernel_index[i]; commented for kernel_index[] == 0
+    }
+    if (is_channel_first == false) {
+      col_buf_offset *= channel;
+      col_buf_offset += channel_index;
+    }
+    col_buf_offset *= out_size;
+    col_buf_offset += index % out_size;
+
+    while (true) {
+      // calc im_index
+      bool is_im_index_valid = true;
+      for (int i = 0; i < dim_num; ++i) {
+        im_index[i] =
+            kernel_index[i] * shared_dilation[i] - shared_pad[i] + out_index[i] * shared_stride[i];
+        if (im_index[i] < 0 || im_index[i] >= shared_im[i]) {
+          is_im_index_valid = false;
+          break;
+        }
+      }
+
+      // write into col_buf
+      if (is_im_index_valid) {
+        // calc im_offset
+        int im_offset = 0;
+        if (is_channel_first) { im_offset = channel_index; }
+        for (int i = 0; i < dim_num; ++i) {
+          im_offset *= shared_im[i];
+          im_offset += im_index[i];
+        }
+        if (is_channel_first == false) {
+          im_offset *= channel;
+          im_offset += channel_index;
+        }
+        col_buf_dptr[col_buf_offset] = im_dptr[im_offset];
+      } else {
+        col_buf_dptr[col_buf_offset] = 0;
+      }
+      col_buf_offset += out_size;
+
+      // loop over all kernel index
+      bool is_loop_completed = true;
+      for (int i = dim_num - 1; i >= 0; --i) {
+        if (kernel_index[i] == shared_kernel[i] - 1) {
+          kernel_index[i] = 0;
+        } else {
+          kernel_index[i] += 1;
+          is_loop_completed = false;
+          break;
+        }
+      }
+      if (is_loop_completed) { break; }
+    }
+  }
+}
+
+template<typename T, int dim_num, bool is_channel_first>
+__global__ void Col2ImGpu(const int n, const T* col_buf_dptr, const int channel, const int im_d,
+                          const int im_h, const int im_w, const int kernel_d, const int kernel_h,
+                          const int kernel_w, const int out_d, const int out_h, const int out_w,
+                          const int stride_d, const int stride_h, const int stride_w,
+                          const int dilation_rate_d, const int dilation_rate_h,
+                          const int dilation_rate_w, const int padding_before_d,
+                          const int padding_before_h, const int padding_before_w, T* im_diff_dptr) {
+  __shared__ int shared_im[dim_num];
+  __shared__ int shared_kernel[dim_num];
+  __shared__ int shared_out[dim_num];
+  __shared__ int shared_stride[dim_num];
+  __shared__ int shared_dilation[dim_num];
+  __shared__ int shared_pad[dim_num];
+  InitSharedArrays<dim_num>(im_d, im_h, im_w, kernel_d, kernel_h, kernel_w, out_d, out_h, out_w,
+                            stride_d, stride_h, stride_w, dilation_rate_d, dilation_rate_h,
+                            dilation_rate_w, padding_before_d, padding_before_h, padding_before_w,
+                            shared_im, shared_kernel, shared_out, shared_stride, shared_dilation,
+                            shared_pad);
+
+  int kernel_index[dim_num];
+  int channel_index;
+  int im_index[dim_num];
+  int out_begin[dim_num];
+  int out_end[dim_num];
+  int out_index[dim_num];
+  CUDA_1D_KERNEL_LOOP(index, n) {
+    // calc im_/channel_index
+    int im_offset = index;
+    if (is_channel_first == false) {
+      channel_index = im_offset % channel;
+      im_offset /= channel;
+    }
+    for (int i = dim_num - 1; i >= 0; --i) {
+      im_index[i] = im_offset % shared_im[i] + shared_pad[i];
+      im_offset /= shared_im[i];
+    }
+    if (is_channel_first) { channel_index = im_offset; }
+
+    // calc the out_range of this im element
+    bool is_in_dim_wrong = false;
+    for (int i = 0; i < dim_num; ++i) {
+      const int kernel_extent = shared_dilation[i] * (shared_kernel[i] - 1) + 1;
+      if (im_index[i] < kernel_extent) {
+        out_begin[i] = 0;
+      } else {
+        // original equation: ((im_index[i]-kernel_extent+1)+(stride[i]-1))/stride[i]
+        out_begin[i] = (im_index[i] - kernel_extent) / shared_stride[i] + 1;
+      }
+      out_end[i] = min(im_index[i] / shared_stride[i] + 1, shared_out[i]);
+      out_index[i] = out_begin[i];
+
+      if (out_begin[i] >= out_end[i]) {  // for those im elements not chosen by kernel
+        is_in_dim_wrong = true;
+        break;
+      }
+    }
+    if (is_in_dim_wrong) {
+      im_diff_dptr[index] = 0;
+      continue;
+    }
+
+    T val = 0;
+    while (true) {
+      bool is_skip = false;
+      // calc kernel_index
+      for (int i = 0; i < dim_num; ++i) {
+        kernel_index[i] = im_index[i] - out_index[i] * shared_stride[i];
+        if (kernel_index[i] % shared_dilation[i] == 0) {
+          kernel_index[i] /= shared_dilation[i];
+        } else {
+          is_skip = true;
+          break;
+        }
+      }
+
+      // cal col_buf_offset
+      if (is_skip == false) {
+        int col_buf_offset = 0;
+        if (is_channel_first) { col_buf_offset = channel_index; }
+        for (int i = 0; i < dim_num; ++i) {
+          col_buf_offset *= shared_kernel[i];
+          col_buf_offset += kernel_index[i];
+        }
+        if (is_channel_first == false) {
+          col_buf_offset *= channel;
+          col_buf_offset += channel_index;
+        }
+        for (int i = 0; i < dim_num; ++i) {
+          col_buf_offset *= shared_out[i];
+          col_buf_offset += out_index[i];
+        }
+        val += col_buf_dptr[col_buf_offset];
+      }
+
+      // iter next out_index[]
+      bool is_iter_completed = true;
+      for (int i = dim_num - 1; i >= 0; --i) {
+        if (out_index[i] == out_end[i] - 1) {
+          out_index[i] = out_begin[i];
+        } else {
+          out_index[i] += 1;
+          is_iter_completed = false;
+          break;
+        }
+      }
+      if (is_iter_completed) { break; }
+    }
+    im_diff_dptr[index] = val;
+  }
+}
+
+}  // namespace
+
+#define IM2COL_KERNEL_CALL(kernel_func_name, dim_num, is_channel_first, kernel_num, src_dptr,     \
+                           dst_dptr)                                                              \
+  kernel_func_name<T, dim_num, is_channel_first>                                                  \
+      <<<BlocksNum4ThreadsNum(kernel_num), kCudaThreadsNumPerBlock, 0,                            \
+         device_ctx->cuda_stream()>>>(                                                            \
+          kernel_num, src_dptr, in_shape.At(1), in_shape.At(2), in_shape.At(3), in_shape.At(4),   \
+          weight_shape.At(2), weight_shape.At(3), weight_shape.At(4), out_shape.At(2),            \
+          out_shape.At(3), out_shape.At(4), strides[0], strides[1], strides[2], dilation_rate[0], \
+          dilation_rate[1], dilation_rate[2], padding_before[0], padding_before[1],               \
+          padding_before[2], dst_dptr)
+
+template<typename T>
+void ConvKernelUtil<DeviceType::kGPU, T>::NCDHWIm2Col(
+    const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
+    const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+    const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_dptr) {
+  int32_t kernels = weight_shape.At(1) * out_shape.Count(2);
+  switch (dim_num) {
+    case 1: IM2COL_KERNEL_CALL(Im2ColGpu, 1, true, kernels, in_dptr, col_buf_dptr); break;
+    case 2: IM2COL_KERNEL_CALL(Im2ColGpu, 2, true, kernels, in_dptr, col_buf_dptr); break;
+    case 3: IM2COL_KERNEL_CALL(Im2ColGpu, 3, true, kernels, in_dptr, col_buf_dptr); break;
+    default: UNIMPLEMENTED();
+  }
+}
+
+template<typename T>
+void ConvKernelUtil<DeviceType::kGPU, T>::NDHWCIm2Col(
+    const int dim_num, DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
+    const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+    const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf_dptr) {
+  int32_t kernels = weight_shape.At(1) * out_shape.Count(2);
+  switch (dim_num) {
+    case 1: IM2COL_KERNEL_CALL(Im2ColGpu, 1, false, kernels, in_dptr, col_buf_dptr); break;
+    case 2: IM2COL_KERNEL_CALL(Im2ColGpu, 2, false, kernels, in_dptr, col_buf_dptr); break;
+    case 3: IM2COL_KERNEL_CALL(Im2ColGpu, 3, false, kernels, in_dptr, col_buf_dptr); break;
+    default: UNIMPLEMENTED();
+  }
+}
+
+template<typename T>
+void ConvKernelUtil<DeviceType::kGPU, T>::NCDHWCol2Im(
+    const int dim_num, DeviceCtx* device_ctx, const T* col_buf_dptr, const Shape& in_shape,
+    const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+    const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_dptr) {
+  int32_t im_size = in_shape.Count(1);
+  switch (dim_num) {
+    case 1: IM2COL_KERNEL_CALL(Col2ImGpu, 1, true, im_size, col_buf_dptr, in_diff_dptr); break;
+    case 2: IM2COL_KERNEL_CALL(Col2ImGpu, 2, true, im_size, col_buf_dptr, in_diff_dptr); break;
+    case 3: IM2COL_KERNEL_CALL(Col2ImGpu, 3, true, im_size, col_buf_dptr, in_diff_dptr); break;
+    default: UNIMPLEMENTED();
+  }
+}
+
+template<typename T>
+void ConvKernelUtil<DeviceType::kGPU, T>::NDHWCCol2Im(
+    const int dim_num, DeviceCtx* device_ctx, const T* col_buf_dptr, const Shape& in_shape,
+    const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
+    const int32_t* dilation_rate, const int32_t* padding_before, T* in_diff_dptr) {
+  int32_t im_size = in_shape.Count(1);
+  switch (dim_num) {
+    case 1: IM2COL_KERNEL_CALL(Col2ImGpu, 1, false, im_size, col_buf_dptr, in_diff_dptr); break;
+    case 2: IM2COL_KERNEL_CALL(Col2ImGpu, 2, false, im_size, col_buf_dptr, in_diff_dptr); break;
+    case 3: IM2COL_KERNEL_CALL(Col2ImGpu, 3, false, im_size, col_buf_dptr, in_diff_dptr); break;
+    default: UNIMPLEMENTED();
+  }
+}
+
+#undef IM2COL_KERNEL_CALL
+
 #define INSTANTIATE_CONV_KERNEL(type_cpp, type_proto) \
   template class ConvKernel<DeviceType::kGPU, type_cpp>;
 OF_PP_FOR_EACH_TUPLE(INSTANTIATE_CONV_KERNEL, FLOATING_DATA_TYPE_SEQ)
 
+#define INSTANTIATE_CONV_KERNEL_UTIL(type_cpp, type_proto) \
+  template class ConvKernelUtil<DeviceType::kGPU, type_cpp>;
+OF_PP_FOR_EACH_TUPLE(INSTANTIATE_CONV_KERNEL_UTIL, FLOATING_DATA_TYPE_SEQ)
+
 }  // namespace oneflow
diff --git a/oneflow/core/kernel/conv_kernel.h b/oneflow/core/kernel/conv_kernel.h
index 2aaf5ceb1..271290292 100644
--- a/oneflow/core/kernel/conv_kernel.h
+++ b/oneflow/core/kernel/conv_kernel.h
@@ -7,6 +7,20 @@
 
 namespace oneflow {
 
+namespace {
+
+template<typename T>
+const T* GetImgDptr(const Blob* blob, int64_t idx) {
+  return blob->dptr<T>() + blob->shape().Count(1) * idx;
+}
+
+template<typename T>
+T* GetImgMutDptr(Blob* blob, int64_t idx) {
+  return const_cast<T*>(GetImgDptr<T>(blob, idx));
+}
+
+}  // namespace
+
 template<DeviceType device_type, typename T>
 class ConvKernelIf : public KernelIfWithActivation<device_type, T>,
                      public KernelIfWithModel<device_type, T> {
@@ -44,16 +58,18 @@ class ConvKernelIf : public KernelIfWithActivation<device_type, T>,
 };
 
 template<typename T>
-using Im2ColFunc = void (*)(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
-                            const Shape& weight_shape, const Shape& out_shape,
-                            const int32_t* strides, const int32_t* dilation_rate,
-                            const int32_t* padding_before, T* col_buf);
+using Im2ColFunc = void (*)(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+                            const Shape& in_shape, const Shape& weight_shape,
+                            const Shape& out_shape, const int32_t* strides,
+                            const int32_t* dilation_rate, const int32_t* padding_before,
+                            T* col_buf);
 
 template<typename T>
-using Col2ImFunc = void (*)(DeviceCtx* device_ctx, const T* col_buf, const Shape& in_shape,
-                            const Shape& weight_shape, const Shape& out_shape,
-                            const int32_t* strides, const int32_t* dilation_rate,
-                            const int32_t* padding_before, T* in_diff_ptr);
+using Col2ImFunc = void (*)(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+                            const Shape& in_shape, const Shape& weight_shape,
+                            const Shape& out_shape, const int32_t* strides,
+                            const int32_t* dilation_rate, const int32_t* padding_before,
+                            T* in_diff_ptr);
 
 template<typename T>
 using GemmFunc = void (*)(DeviceCtx* ctx, enum CBLAS_TRANSPOSE, enum CBLAS_TRANSPOSE, const int m,
@@ -61,16 +77,13 @@ using GemmFunc = void (*)(DeviceCtx* ctx, enum CBLAS_TRANSPOSE, enum CBLAS_TRANS
                           const T beta, T* c);
 
 template<DeviceType device_type, typename T>
-class ConvKernel;
-
-template<typename T>
-class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelIf<DeviceType::kCPU, T> {
+class ConvKernelImplByIm2Col : public ConvKernelIf<device_type, T> {
  public:
-  OF_DISALLOW_COPY_AND_MOVE(ConvKernel);
-  ConvKernel() = default;
-  ~ConvKernel() = default;
+  OF_DISALLOW_COPY_AND_MOVE(ConvKernelImplByIm2Col);
+  ConvKernelImplByIm2Col() = default;
+  ~ConvKernelImplByIm2Col() = default;
 
- private:
+ protected:
   void VirtualKernelInit(const ParallelContext*) override;
   void DoForwardDataContent(DeviceCtx*, const Blob* in_blob, const Blob* weight_blob,
                             Blob* out_blob,
@@ -80,6 +93,8 @@ class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelIf<DeviceType::kC
                       std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
   void BiasBackward(DeviceCtx*, const Blob* out_diff_blob, Blob* bias_diff_blob,
                     std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+
+ private:
   Im2ColFunc<T> im2col_func_;
   Col2ImFunc<T> col2im_func_;
   GemmFunc<T> forward_func_;
@@ -93,8 +108,21 @@ class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelIf<DeviceType::kC
   Shape weight_shape_;
 };
 
+template<DeviceType device_type, typename T>
+class ConvKernel;
+
+template<typename T>
+class ConvKernel<DeviceType::kCPU, T> final : public ConvKernelImplByIm2Col<DeviceType::kCPU, T> {
+ public:
+  OF_DISALLOW_COPY_AND_MOVE(ConvKernel);
+  ConvKernel() = default;
+  ~ConvKernel() = default;
+
+ private:
+};
+
 template<typename T>
-class ConvKernel<DeviceType::kGPU, T> final : public ConvKernelIf<DeviceType::kGPU, T> {
+class ConvKernel<DeviceType::kGPU, T> final : public ConvKernelImplByIm2Col<DeviceType::kGPU, T> {
  public:
   OF_DISALLOW_COPY_AND_MOVE(ConvKernel);
   ConvKernel() = default;
@@ -102,14 +130,26 @@ class ConvKernel<DeviceType::kGPU, T> final : public ConvKernelIf<DeviceType::kG
 
  private:
   void VirtualKernelInit(const ParallelContext*) override;
+  void KernelInitWithCudnn(const ParallelContext*);
+
   void DoForwardDataContent(DeviceCtx*, const Blob* in_blob, const Blob* weight_blob,
                             Blob* out_blob,
                             std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void DoForwardDataContentWithCudnn(DeviceCtx*, const Blob* in_blob, const Blob* weight_blob,
+                                     Blob* out_blob,
+                                     std::function<Blob*(const std::string&)> BnInOp2Blob) const;
+
   void WeightBackward(DeviceCtx*, const Blob* out_diff_blob, const Blob* in_blob,
                       Blob* weight_diff_blob, Blob* in_diff_blob,
                       std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void WeightBackwardWithCudnn(DeviceCtx*, const Blob* out_diff_blob, const Blob* in_blob,
+                               Blob* weight_diff_blob, Blob* in_diff_blob,
+                               std::function<Blob*(const std::string&)> BnInOp2Blob) const;
+
   void BiasBackward(DeviceCtx*, const Blob* out_diff_blob, Blob* bias_diff_blob,
                     std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
+  void BiasBackwardWithCudnn(DeviceCtx*, const Blob* out_diff_blob, Blob* bias_diff_blob,
+                             std::function<Blob*(const std::string&)> BnInOp2Blob) const;
 
   std::unique_ptr<CudnnTensorDesc> in_desc_;
   std::unique_ptr<CudnnTensorDesc> out_desc_;
@@ -194,26 +234,31 @@ class ColBufUtil final {
   DHWValidFunc<T> dhw_valid_func_;
 };
 
+template<DeviceType device_type, typename T>
+class ConvKernelUtil;
+
 template<typename T>
-struct ConvKernelUtil final {
+struct ConvKernelUtil<DeviceType::kCPU, T> final {
  public:
-  static void NCDHWIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
-                          const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
-                          const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf);
+  static void NCDHWIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+                          const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+                          const int32_t* strides, const int32_t* dilation_rate,
+                          const int32_t* padding_before, T* col_buf);
 
-  static void NDHWCIm2Col(DeviceCtx* device_ctx, const T* in_dptr, const Shape& in_shape,
-                          const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
-                          const int32_t* dilation_rate, const int32_t* padding_before, T* col_buf);
+  static void NDHWCIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+                          const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+                          const int32_t* strides, const int32_t* dilation_rate,
+                          const int32_t* padding_before, T* col_buf);
 
-  static void NCDHWCol2Im(DeviceCtx* device_ctx, const T* col_buf, const Shape& in_shape,
-                          const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
-                          const int32_t* dilation_rate, const int32_t* padding_before,
-                          T* in_diff_ptr);
+  static void NCDHWCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+                          const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+                          const int32_t* strides, const int32_t* dilation_rate,
+                          const int32_t* padding_before, T* in_diff_ptr);
 
-  static void NDHWCCol2Im(DeviceCtx* device_ctx, const T* col_buf, const Shape& in_shape,
-                          const Shape& weight_shape, const Shape& out_shape, const int32_t* strides,
-                          const int32_t* dilation_rate, const int32_t* padding_before,
-                          T* in_diff_ptr);
+  static void NDHWCCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+                          const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+                          const int32_t* strides, const int32_t* dilation_rate,
+                          const int32_t* padding_before, T* in_diff_ptr);
 
  private:
   static void DoNCDWHFunc(const Shape& weight_shape, ColBufUtil<T>& conv_util,
@@ -223,6 +268,32 @@ struct ConvKernelUtil final {
                           ColBufWriter<T>* col_buf_writer);
 };
 
+template<typename T>
+struct ConvKernelUtil<DeviceType::kGPU, T> final {
+ public:
+  static void NCDHWIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+                          const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+                          const int32_t* strides, const int32_t* dilation_rate,
+                          const int32_t* padding_before, T* col_buf);
+
+  static void NDHWCIm2Col(const int dim_num, DeviceCtx* device_ctx, const T* in_dptr,
+                          const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+                          const int32_t* strides, const int32_t* dilation_rate,
+                          const int32_t* padding_before, T* col_buf);
+
+  static void NCDHWCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+                          const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+                          const int32_t* strides, const int32_t* dilation_rate,
+                          const int32_t* padding_before, T* in_diff_ptr);
+
+  static void NDHWCCol2Im(const int dim_num, DeviceCtx* device_ctx, const T* col_buf,
+                          const Shape& in_shape, const Shape& weight_shape, const Shape& out_shape,
+                          const int32_t* strides, const int32_t* dilation_rate,
+                          const int32_t* padding_before, T* in_diff_ptr);
+
+ private:
+};
+
 }  // namespace oneflow
 
 #endif  // ONEFLOW_CORE_KERNEL_CONV_KERNEL_H_
diff --git a/oneflow/core/kernel/kernel.h b/oneflow/core/kernel/kernel.h
index 001b45b57..4424847af 100644
--- a/oneflow/core/kernel/kernel.h
+++ b/oneflow/core/kernel/kernel.h
@@ -142,7 +142,8 @@ class KernelIf : public Kernel {
                  const PbRpf<std::string>& from_bns, const PbRpf<std::string>& to_bns,
                  void (Blob::*Copy)(DeviceCtx*, const Blob*)) const;
 
-  bool UseCudnn() const { return device_type == DeviceType::kGPU && op_conf().use_cudnn_on_gpu(); }
+  bool UseCudnn() const { return device_type == DeviceType::kGPU && UseCudnnOnGpu(); }
+  bool UseCudnnOnGpu() const { return op_conf().use_cudnn_on_gpu(); }
 };
 
 template<DeviceType device_type, typename ModelType>
diff --git a/oneflow/core/operator/conv_op.cpp b/oneflow/core/operator/conv_op.cpp
index 6f0baca02..20003955d 100644
--- a/oneflow/core/operator/conv_op.cpp
+++ b/oneflow/core/operator/conv_op.cpp
@@ -120,7 +120,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string)> G
   if (GetValFromCustomizedConf<bool>("use_bias")) {
     // bias and bias_multiplier
     GetBlobDesc4BnInOp("bias")->mut_shape() = Shape({filters, 1});
-    if (!UseCudnn()) {
+    if (!UseCudnnOnGpu()) {
       std::vector<int64_t> bias_mul_shape(NDims + 1, 1);
       for (size_t i = 0; i != NDims; ++i) { bias_mul_shape[i + 1] = out_shape[dhw_offset + i]; }
       GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape(bias_mul_shape);
@@ -130,7 +130,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string)> G
   ConvOpCtx* conv_op_ctx = new ConvOpCtx();
   EnrollOpCtx(conv_op_ctx);
 
-  if (!UseCudnn()) {
+  if (device_type() == DeviceType::kCPU || !UseCudnnOnGpu()) {
     // col_buf
     size_t col_buf_elem_cnt = 1;
     for (size_t i = 0; i != NDims + 1; ++i) { col_buf_elem_cnt *= weight_shape[i + 1]; }
@@ -139,7 +139,7 @@ void ConvOp<NDims>::InferBlobDescs(std::function<BlobDesc*(const std::string)> G
   }
 
 #ifdef WITH_CUDA
-  if (device_type() == DeviceType::kGPU) {
+  if (device_type() == DeviceType::kGPU && UseCudnnOnGpu()) {
     // cudnn_buf
     InferCudnnAlgo(GetBlobDesc4BnInOp, &(conv_op_ctx->cudnn_conv_algo_ctx));
     *buf_size = std::max({conv_op_ctx->cudnn_conv_algo_ctx.fwd_ws_size,
@@ -235,7 +235,7 @@ void ConvOp<NDims>::VirtualGenKernelConf(
     const ParallelContext* parallel_ctx, KernelConf* kernel_conf, const OpContext* op_ctx) const {
   ConvKernelConf* conv_conf = kernel_conf->mutable_conv_conf();
   conv_conf->set_dim(NDims);
-  if (!UseCudnn()) {
+  if (!UseCudnnOnGpu()) {
     GenKernelConfWithoutCudnn(GetBlobDesc4BnInOp, conv_conf);
   } else {
     GenKernelConfWithCudnn(GetBlobDesc4BnInOp, kernel_conf, conv_conf, op_ctx);
diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp
index 79501498d..91069547c 100644
--- a/oneflow/core/operator/operator.cpp
+++ b/oneflow/core/operator/operator.cpp
@@ -41,10 +41,6 @@ LogicalBlobId* Operator::MutBnInOp2Lbi(const std::string& bn_in_op) {
   }
 }
 
-bool Operator::UseCudnn() const {
-  return device_type() == DeviceType::kGPU && op_conf().use_cudnn_on_gpu();
-}
-
 const std::string& Operator::SoleIbn() const {
   CHECK_EQ(input_bns().size(), 1);
   return input_bns().Get(0);
diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h
index 6572b1acc..e08bf058a 100644
--- a/oneflow/core/operator/operator.h
+++ b/oneflow/core/operator/operator.h
@@ -47,7 +47,8 @@ class Operator {
   // Getters
   const std::string& op_name() const { return op_conf().name(); }
   DeviceType device_type() const { return op_attribute_.op_conf().device_type(); }
-  bool UseCudnn() const;
+  bool UseCudnn() const { return device_type() == DeviceType::kGPU && UseCudnnOnGpu(); }
+  bool UseCudnnOnGpu() const { return op_conf().use_cudnn_on_gpu(); }
   const OperatorConf& op_conf() const { return op_attribute_.op_conf(); }
   virtual const PbMessage& GetCustomizedConf() const { UNIMPLEMENTED(); }
 
-- 
GitLab