diff --git a/oneflow/api/python/functional/python_arg.cpp b/oneflow/api/python/functional/python_arg.cpp
index c572a2885fda464bbf8eca2ca11acac1e02343ab..57d4280f678121eca8c8af36bf2fcb26e4a34f5f 100644
--- a/oneflow/api/python/functional/python_arg.cpp
+++ b/oneflow/api/python/functional/python_arg.cpp
@@ -70,6 +70,11 @@ Maybe<std::shared_ptr<one::Tensor>> PythonArg::ObjectAs<std::shared_ptr<one::Ten
   return detail::cast<std::shared_ptr<one::Tensor>>(Borrow());
 }
 
+template<>
+Maybe<one::Tensor> PythonArg::ObjectAs<one::Tensor>() const {
+  return *JUST(detail::cast<std::shared_ptr<one::Tensor>>(Borrow()));
+}
+
 template<>
 Maybe<std::shared_ptr<one::TensorTuple>> PythonArg::ObjectAs<std::shared_ptr<one::TensorTuple>>()
     const {
diff --git a/oneflow/api/python/functional/python_arg.h b/oneflow/api/python/functional/python_arg.h
index 519f07e07a7eb20150bf863af83cc856f834d2fc..1ac8d231ebc787b371af7d855650d8dda8d307e6 100644
--- a/oneflow/api/python/functional/python_arg.h
+++ b/oneflow/api/python/functional/python_arg.h
@@ -61,6 +61,21 @@ class PythonArg {
 
   virtual ~PythonArg() = default;
 
+  template<typename T>
+  friend class ObjectAsHelper;
+
+  template<typename T>
+  struct ObjectAsHelper {
+    Maybe<T> operator()(const PythonArg* self) { return self->ObjectAs<T>(); }
+  };
+  template<typename T>
+  struct ObjectAsHelper<Optional<T>> {
+    Maybe<Optional<T>> operator()(const PythonArg* self) {
+      if (self->object_ == Py_None) { return std::make_shared<Optional<T>>(); }
+      return std::make_shared<Optional<T>>(JUST(self->ObjectAs<T>()));
+    }
+  };
+
   template<typename T>
   operator T() const {
     if (active_tag_ == HAS_IMMEDIATE) {
@@ -70,7 +85,7 @@ class PythonArg {
       return *reinterpret_cast<const T*>(immediate_->Ptr());
     }
     CHECK_EQ_OR_THROW(active_tag_, HAS_OBJECT);
-    return this->ObjectAs<oneflow::detail::remove_cvref_t<T>>().GetOrThrow();
+    return ObjectAsHelper<oneflow::detail::remove_cvref_t<T>>()(this).GetOrThrow();
   }
 
  private:
diff --git a/oneflow/core/autograd/gradient_funcs/conv.cpp b/oneflow/core/autograd/gradient_funcs/conv.cpp
index d87ecd8b928744d1de3653835f82bb3f517d302c..8080a93161878ea6b88cc4429aa030ee05b7056a 100644
--- a/oneflow/core/autograd/gradient_funcs/conv.cpp
+++ b/oneflow/core/autograd/gradient_funcs/conv.cpp
@@ -13,103 +13,93 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */
+#include "oneflow/core/framework/attr_map.h"
 #include "oneflow/core/framework/op_expr_grad_function.h"
 #include "oneflow/core/framework/op_builder.h"
 #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
 #include "oneflow/core/framework/op_expr.h"
-#include "oneflow/core/framework/op_expr_helper.h"
-#include "oneflow/core/framework/user_op_conf_trait.h"
+#include "oneflow/core/functional/functional.h"
 
 namespace oneflow {
 namespace one {
 
-namespace {
-
-struct ConvInterpState : public OpExprInterpState {
-  bool weight_requires_grad = true;
-  bool input_requires_grad = true;
+struct ConvolutionNdInterpState : public OpExprInterpState {
+  bool input_requires_grad = false;
+  bool weight_requires_grad = false;
+  size_t input_index;
+  size_t weight_index;
+
+  std::string data_format;
+  std::vector<int32_t> padding_before;
+  std::vector<int32_t> kernel_size;
+  std::vector<int32_t> strides;
+  std::vector<int32_t> dilation_rate;
+  int32_t groups;
 };
 
-class ConvNdGrad : public OpExprGradFunction<ConvInterpState> {
+class ConvolutionNd : public OpExprGradFunction<ConvolutionNdInterpState> {
  public:
   Maybe<void> Init(const OpExpr& op) override;
-  Maybe<void> Capture(ConvInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
-                      const AttrMap& attrs) const override;
-  Maybe<void> Apply(const ConvInterpState* ctx, const TensorTuple& out_grads,
+  Maybe<void> Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override;
+  Maybe<void> Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads,
                     TensorTuple* in_grads) const override;
 
  private:
-  std::shared_ptr<user_op::UserOpConfTrait> op_trait_;
-  std::shared_ptr<std::string> data_format_;
-  std::shared_ptr<std::vector<int32_t>> padding_before_;
-  std::shared_ptr<std::vector<int32_t>> kernel_size_;
-  std::shared_ptr<std::vector<int32_t>> strides_;
-  std::shared_ptr<std::vector<int32_t>> dilation_rate_;
-  int32_t groups_;
-
-  std::shared_ptr<OpExpr> data_grad_op_;
-  std::shared_ptr<OpExpr> weight_grad_op_;
+  AttrMap base_attrs_;
 };
 
-Maybe<void> ConvNdGrad::Init(const OpExpr& op) {
+Maybe<void> ConvolutionNd::Init(const OpExpr& op) {
   const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
   CHECK_NOTNULL_OR_RETURN(fw_op_expr);
-  const std::string& op_name = fw_op_expr->op_name();
-  op_trait_ = std::make_shared<user_op::UserOpConfTrait>(op_name, fw_op_expr->proto());
-
-  data_format_ = JUST(op_trait_->GetAttr<std::string>("data_format"));
-  padding_before_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("padding_before"));
-  kernel_size_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("kernel_size"));
-  strides_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("strides"));
-  dilation_rate_ = JUST(op_trait_->GetAttr<std::vector<int32_t>>("dilation_rate"));
-  groups_ = JUST(op_trait_->GetAttr<int32_t>("groups"));
-  int32_t ndims = kernel_size_->size();
-  CHECK_EQ_OR_RETURN(ndims, strides_->size());
-  CHECK_EQ_OR_RETURN(ndims, dilation_rate_->size());
-  data_grad_op_ = JUST(op_expr_helper::ConvNdDataGradOp(*kernel_size_, *strides_, *padding_before_,
-                                                        *dilation_rate_, groups_, *data_format_));
-
-  weight_grad_op_ = JUST(op_expr_helper::ConvNdFilterGradOp(
-      *kernel_size_, *strides_, *padding_before_, *dilation_rate_, groups_, *data_format_));
-
+  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
   return Maybe<void>::Ok();
 }
 
-Maybe<void> ConvNdGrad::Capture(ConvInterpState* ctx, const TensorTuple& inputs,
-                                const TensorTuple& outputs, const AttrMap& attrs) const {
+Maybe<void> ConvolutionNd::Capture(ConvolutionNdInterpState* ctx, const TensorTuple& inputs,
+                                   const TensorTuple& outputs, const AttrMap& attrs) const {
+  CHECK_EQ_OR_RETURN(inputs.size(), 2);
   ctx->input_requires_grad = inputs.at(0)->requires_grad();
   ctx->weight_requires_grad = inputs.at(1)->requires_grad();
-  ctx->SaveTensorForBackward(inputs.at(0));  // x
+  if (!ctx->input_requires_grad && !ctx->weight_requires_grad) { return Maybe<void>::Ok(); }
   if (ctx->input_requires_grad) {
-    ctx->SaveTensorForBackward(inputs.at(1));  // weight
+    ctx->weight_index = ctx->SaveTensorForBackward(inputs.at(1));  // weight
   }
+  ctx->input_index = ctx->SaveTensorForBackward(inputs.at(0));  // input
+
+  ComposedAttrMap composed_attrs(attrs, base_attrs_);
+  ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
+  ctx->padding_before = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("padding_before"));
+  ctx->kernel_size = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("kernel_size"));
+  ctx->strides = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("strides"));
+  ctx->dilation_rate = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("dilation_rate"));
+  ctx->groups = JUST(composed_attrs.GetAttr<int32_t>("groups"));
   return Maybe<void>::Ok();
 }
 
-Maybe<void> ConvNdGrad::Apply(const ConvInterpState* ctx, const TensorTuple& out_grads,
-                              TensorTuple* in_grads) const {
-  CHECK_EQ_OR_RETURN(out_grads.size(), 1);
-  const auto& dy = out_grads.at(0);
-
+Maybe<void> ConvolutionNd::Apply(const ConvolutionNdInterpState* ctx, const TensorTuple& out_grads,
+                                 TensorTuple* in_grads) const {
   in_grads->resize(2);
+  size_t num_spatial_dims = ctx->kernel_size.size();
   if (ctx->input_requires_grad) {
-    const auto& x = ctx->SavedTensors().at(0);
-    const auto& weight = ctx->SavedTensors().at(1);
-    in_grads->at(0) =
-        JUST(OpInterpUtil::Dispatch<Tensor>(*data_grad_op_, {dy, weight, x}, /*attrs=*/{}));
+    const auto& weight = ctx->SavedTensors().at(ctx->weight_index);
+    const auto& input = ctx->SavedTensors().at(ctx->input_index);
+    in_grads->at(0) = JUST(functional::ConvDataGrad(
+        out_grads.at(0), weight, input, num_spatial_dims, ctx->kernel_size, ctx->strides,
+        ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
   }
   if (ctx->weight_requires_grad) {
-    const auto& x = ctx->SavedTensors().at(0);
-    in_grads->at(1) = JUST(OpInterpUtil::Dispatch<Tensor>(*weight_grad_op_, {dy, x}, /*attrs=*/{}));
+    const auto& input = ctx->SavedTensors().at(ctx->input_index);
+    in_grads->at(1) = JUST(functional::ConvFilterGrad(
+        out_grads.at(0), input, num_spatial_dims, ctx->kernel_size, ctx->strides,
+        ctx->padding_before, ctx->dilation_rate, ctx->groups, ctx->data_format));
   }
   return Maybe<void>::Ok();
 }
 
-}  // namespace
-
-REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvNdGrad);
-REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvNdGrad);
-REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvNdGrad);
+REGISTER_OP_EXPR_GRAD_FUNCTION("conv1d", ConvolutionNd);
+REGISTER_OP_EXPR_GRAD_FUNCTION("conv2d", ConvolutionNd);
+REGISTER_OP_EXPR_GRAD_FUNCTION("conv3d", ConvolutionNd);
 
 }  // namespace one
 }  // namespace oneflow
diff --git a/oneflow/core/common/optional.h b/oneflow/core/common/optional.h
new file mode 100644
index 0000000000000000000000000000000000000000..866b6f5c339da118c798455c3d6757baa95454d2
--- /dev/null
+++ b/oneflow/core/common/optional.h
@@ -0,0 +1,202 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#ifndef ONEFLOW_CORE_COMMON_OPTIONAL_H_
+#define ONEFLOW_CORE_COMMON_OPTIONAL_H_
+
+#include "oneflow/core/common/type_traits.h"
+#include "oneflow/core/common/maybe.h"
+
+namespace oneflow {
+
+namespace internal {
+
+template<typename T, typename U = void>
+class Storage;
+
+template<typename T>
+class Storage<T, typename std::enable_if<IsScalarType<T>::value>::type> {
+ public:
+  Storage() = default;
+
+  template<typename... Args,
+           typename std::enable_if<std::is_constructible<T, Args...>::value, int>::type = 0>
+  Storage(Args&&... args) {
+    new (&value_) T(std::forward<Args>(args)...);
+  }
+
+  Storage& operator=(const T& value) {
+    value_ = value;
+    return *this;
+  }
+  Storage& operator=(T&& value) {
+    value_ = std::move(value);
+    return *this;
+  }
+  Storage& operator=(const Storage<T>& rhs) {
+    value_ = rhs.value_;
+    return *this;
+  }
+  Storage& operator=(Storage<T>&& rhs) {
+    value_ = std::move(rhs.value_);
+    return *this;
+  }
+
+  Maybe<T> value() const { return value_; }
+
+ private:
+  T value_;
+};
+
+template<typename T>
+class Storage<T, typename std::enable_if<!IsScalarType<T>::value>::type> {
+ public:
+  Storage() = default;
+
+  template<typename... Args,
+           typename std::enable_if<std::is_constructible<T, Args...>::value, int>::type = 0>
+  Storage(Args&&... args) {
+    value_ = std::make_shared<T>(std::forward<Args>(args)...);
+  }
+
+  Storage(const std::shared_ptr<T>& value) : value_(value) {}
+
+  Storage& operator=(const T& value) {
+    if (value_) {
+      *value_ = value;
+    } else {
+      value_ = std::make_shared<T>(value);
+    }
+    return *this;
+  }
+  Storage& operator=(T&& value) {
+    if (value_) {
+      *value_ = std::move(value);
+    } else {
+      value_ = std::make_shared<T>(value);
+    }
+    return *this;
+  }
+  Storage& operator=(const Storage<T>& rhs) {
+    value_ = rhs.value_;
+    return *this;
+  }
+  Storage& operator=(Storage<T>&& rhs) {
+    value_ = std::move(rhs.value_);
+    return *this;
+  }
+
+  Maybe<T> value() const { return value_; }
+
+ private:
+  std::shared_ptr<T> value_;
+};
+
+}  // namespace internal
+
+template<typename T>
+class Optional {
+ public:
+  Optional() : init_(false) {}
+
+  template<typename... Args,
+           typename std::enable_if<std::is_constructible<internal::Storage<T>, Args...>::value,
+                                   int>::type = 0>
+  Optional(Args&&... args) : init_(true), storage_(std::forward<Args>(args)...) {}
+
+  ~Optional() = default;
+
+  Optional(const Optional<T>& rhs) : init_(rhs.init_) {
+    if (init_) { storage_ = rhs.storage_; }
+  }
+
+  Optional(Optional<T>&& rhs) : init_(rhs.init_) {
+    if (init_) { storage_ = std::move(rhs.storage_); }
+  }
+
+  Optional& operator=(const T& val) {
+    init_ = true;
+    storage_ = val;
+    return *this;
+  }
+
+  Optional& operator=(T&& val) {
+    init_ = true;
+    storage_ = std::move(val);
+    return *this;
+  }
+
+  Optional& operator=(const Optional<T>& rhs) {
+    init_ = rhs.init_;
+    if (init_) { storage_ = rhs.storage_; }
+    return *this;
+  }
+
+  Optional& operator=(Optional<T>&& rhs) {
+    init_ = rhs.init_;
+    if (init_) { storage_ = std::move(rhs.storage_); }
+    return *this;
+  }
+
+  Maybe<T> value() const {
+    CHECK_OR_RETURN(has_value()) << "Optional has no value.";
+    return storage_.value();
+  }
+
+  bool has_value() const { return init_; }
+  operator bool() const { return has_value(); }
+
+ private:
+  bool init_;
+  internal::Storage<T> storage_;
+};
+
+template<typename T>
+class Optional<T&> {
+ public:
+  Optional() : value_ptr_(nullptr) {}
+
+  Optional(T& val) : value_ptr_(&val) {}
+
+  ~Optional() = default;
+
+  Optional& operator=(T& val) {
+    value_ptr_ = &val;
+    return *this;
+  }
+
+  Optional& operator=(const Optional<T&>& rhs) {
+    value_ptr_ = rhs.value_ptr_;
+    return *this;
+  }
+
+  Maybe<T&> value() const {
+    CHECK_OR_RETURN(has_value()) << "Optional has no value.";
+    return *value_ptr_;
+  }
+
+  void Clear() { value_ptr_ = nullptr; }
+
+  bool has_value() const { return value_ptr_ != nullptr; }
+  operator bool() const { return has_value(); }
+
+ private:
+  T* value_ptr_;
+};
+
+}  // namespace oneflow
+
+#endif  // ONEFLOW_CORE_COMMON_OPTIONAL_H_
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index 24884e18f2d95f746e50669dda70490d3d18a4e2..2175ce0598bfe6e7443e02b89f3f875361065c48 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -238,6 +238,31 @@
   signature: "Tensor BiasAdd(Tensor x, Tensor bias, *, Int32 axis=1)"
   bind_python: True
 
+- name: "conv2d"
+  signature:
+    "Tensor Conv2D(Tensor x, Tensor weight, *, Tensor bias=None, Int32List stride, 
+                   Int32List padding, Int32List dilation, Int32 groups=1)"
+  bind_python: True
+
+- name: "conv_data_grad"
+  signature:
+    "Tensor ConvDataGrad(Tensor dy, Tensor weight, Tensor x, *, Int32 num_spatial_dims,
+                         Int32List kernel_size, Int32List strides, Int32List padding_before,
+                         Int32List dilation_rate, Int32 groups=1,
+                         String data_format=\"channels_first\")"
+  bind_python: False
+
+- name: "conv_filter_grad"
+  signature:
+    "Tensor ConvFilterGrad(Tensor dy, Tensor x, *, Int32 num_spatial_dims, Int32List kernel_size,
+                           Int32List strides, Int32List padding_before, Int32List dilation_rate,
+                           Int32 groups=1, String data_format=\"channels_first\")"
+  bind_python: False
+
+- name: "conv_bias_grad"
+  signature: "Tensor ConvBiasGrad(Tensor dy, *, Int32 num_spatial_dims, String data_format=\"channels_first\")"
+  bind_python: False
+
 - name: "expand"
   signature: "Tensor Expand(Tensor x, *, Int32List in_shape, Int32List out_shape, Int32List stride)"
   bind_python: True
diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp
index 3fb8358a724c28ede4b62b1f482fc10f5ec8a7cb..4c3606a0e9048c2d63811f9fec246e065d0bfda7 100644
--- a/oneflow/core/functional/impl/nn_functor.cpp
+++ b/oneflow/core/functional/impl/nn_functor.cpp
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 */
 
+#include "oneflow/core/common/optional.h"
 #include "oneflow/core/framework/attr_map.h"
 #include "oneflow/core/framework/op_builder.h"
 #include "oneflow/core/framework/op_expr.h"
@@ -47,6 +48,44 @@ class BiasAddFunctor {
   std::shared_ptr<OpExpr> op_;
 };
 
+class Conv2DFunctor {
+ public:
+  Conv2DFunctor() {
+    conv_op_ =
+        CHECK_JUST(one::OpBuilder("conv2d").Input("in").Input("weight").Output("out").Build());
+    bias_op_ = CHECK_JUST(one::OpBuilder("bias_add").Input("a").Input("b").Output("out").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
+                           const std::shared_ptr<one::Tensor>& weight,
+                           const Optional<one::Tensor>& bias, const std::vector<int32_t>& stride,
+                           const std::vector<int32_t>& padding,
+                           const std::vector<int32_t>& dilation, const int32_t& groups) const {
+    MutableAttrMap conv_attrs;
+    std::vector<int32_t> kernel_size_vec;
+    for (int i = 0; i < 2; i++) { kernel_size_vec.push_back((weight->shape())->At(i + 2)); }
+    JUST(conv_attrs.SetAttr<int32_t>("filters", (weight->shape())->At(0)));
+    JUST(conv_attrs.SetAttr<std::vector<int32_t>>("padding_before", padding));
+    JUST(conv_attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size_vec));
+    JUST(conv_attrs.SetAttr<std::vector<int32_t>>("strides", stride));
+    JUST(conv_attrs.SetAttr<std::vector<int32_t>>("dilation_rate", dilation));
+    JUST(conv_attrs.SetAttr<int32_t>("groups", groups));
+    JUST(conv_attrs.SetAttr<std::string>("data_format", std::string("channels_first")));
+    const std::shared_ptr<one::Tensor>& conv_out =
+        JUST(OpInterpUtil::Dispatch<Tensor>(*conv_op_, {x, weight}, conv_attrs));
+    if (bias) {
+      MutableAttrMap bias_attrs;
+      JUST(bias_attrs.SetAttr<int32_t>("axis", 1));
+      return OpInterpUtil::Dispatch<Tensor>(*bias_op_, {conv_out, JUST(bias.value())}, bias_attrs);
+    } else {
+      return conv_out;
+    }
+  }
+
+ private:
+  std::shared_ptr<OpExpr> conv_op_;
+  std::shared_ptr<OpExpr> bias_op_;
+};
+
 class MatMulBaseFunctor {
  public:
   MatMulBaseFunctor() = default;
@@ -217,6 +256,7 @@ class NormalizationFunctor {
 
 ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::BiasAddFunctor>("BiasAdd");
+  m.add_functor<impl::Conv2DFunctor>("Conv2D");
   m.add_functor<impl::MatMulFunctor>("MatMul");
   m.add_functor<impl::BatchMatMulFunctor>("BatchMatMul");
   m.add_functor<impl::BroadcastMatMulFunctor>("BroadcastMatMul");
diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6892ea4669ae5e02b9e2e8079748b0332e410e23
--- /dev/null
+++ b/oneflow/core/functional/impl/nn_grad_functor.cpp
@@ -0,0 +1,122 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+#include "oneflow/core/framework/attr_map.h"
+#include "oneflow/core/framework/op_builder.h"
+#include "oneflow/core/framework/op_expr.h"
+#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
+#include "oneflow/core/framework/tensor.h"
+#include "oneflow/core/framework/tensor_tuple.h"
+#include "oneflow/core/functional/function_library.h"
+#include "oneflow/core/functional/impl/common.h"
+#include "oneflow/core/functional/impl/unary_functor.h"
+#include "oneflow/core/functional/scalar.h"
+
+namespace oneflow {
+namespace one {
+namespace functional {
+
+namespace impl {
+
+class ConvBiasGradFunctor {
+ public:
+  ConvBiasGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("conv_bias_grad").Input("dy").Output("bias_diff").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy, const int32_t& num_spatial_dims,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<int32_t>("num_spatial_dims", num_spatial_dims));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class ConvFilterGradFunctor {
+ public:
+  ConvFilterGradFunctor() {
+    op_ = CHECK_JUST(
+        one::OpBuilder("conv_filter_grad").Input("dy").Input("x").Output("filter_diff").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& x, const int32_t& num_spatial_dims,
+                           const std::vector<int32_t>& kernel_size,
+                           const std::vector<int32_t>& strides,
+                           const std::vector<int32_t>& padding_before,
+                           const std::vector<int32_t>& dilation_rate, const int32_t& groups,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<int32_t>("num_spatial_dims", num_spatial_dims));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("strides", strides));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("dilation_rate", dilation_rate));
+    JUST(attrs.SetAttr<int32_t>("groups", groups));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class ConvDataGradFunctor {
+ public:
+  ConvDataGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("conv_data_grad")
+                         .Input("dy")
+                         .Input("filter")
+                         .Input("x_like")
+                         .Output("dx")
+                         .Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& weight,
+                           const std::shared_ptr<one::Tensor>& x, const int32_t& num_spatial_dims,
+                           const std::vector<int32_t>& kernel_size,
+                           const std::vector<int32_t>& strides,
+                           const std::vector<int32_t>& padding_before,
+                           const std::vector<int32_t>& dilation_rate, const int32_t& groups,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<int32_t>("num_spatial_dims", num_spatial_dims));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("kernel_size", kernel_size));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("strides", strides));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("padding_before", padding_before));
+    JUST(attrs.SetAttr<std::vector<int32_t>>("dilation_rate", dilation_rate));
+    JUST(attrs.SetAttr<int32_t>("groups", groups));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, weight, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+}  // namespace impl
+
+ONEFLOW_FUNCTION_LIBRARY(m) {
+  m.add_functor<impl::ConvBiasGradFunctor>("ConvBiasGrad");
+  m.add_functor<impl::ConvFilterGradFunctor>("ConvFilterGrad");
+  m.add_functor<impl::ConvDataGradFunctor>("ConvDataGrad");
+};
+
+}  // namespace functional
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/functional/value_types.h b/oneflow/core/functional/value_types.h
index cd8904face4017123eedb497a2c3560de6e94a10..f47035a8e1dc3c0ff76a8dad69428bae48264474 100644
--- a/oneflow/core/functional/value_types.h
+++ b/oneflow/core/functional/value_types.h
@@ -21,6 +21,7 @@ limitations under the License.
 
 #include "oneflow/core/common/data_type.pb.h"
 #include "oneflow/core/common/maybe.h"
+#include "oneflow/core/common/optional.h"
 
 namespace oneflow {
 class Shape;
@@ -79,6 +80,11 @@ enum ValueType {
 
 #define VALUE_TYPE_OF_IMPL(cpp_type, value_type)                                                 \
   template<typename T, typename std::enable_if<std::is_same<T, cpp_type>::value, int>::type = 0> \
+  inline ValueType ValueTypeOf() {                                                               \
+    return value_type;                                                                           \
+  }                                                                                              \
+  template<typename T,                                                                           \
+           typename std::enable_if<std::is_same<T, Optional<cpp_type>>::value, int>::type = 0>   \
   inline ValueType ValueTypeOf() {                                                               \
     return value_type;                                                                           \
   }
diff --git a/oneflow/python/nn/modules/conv.py b/oneflow/python/nn/modules/conv.py
index 069e94b6a61fd5ea8eb8d80dcc7e733cfb2b5751..c1ab45fc754fe4dc5aeaafaf9b60950f27a4a47f 100644
--- a/oneflow/python/nn/modules/conv.py
+++ b/oneflow/python/nn/modules/conv.py
@@ -209,60 +209,22 @@ class Conv2d(Module):
         super().__init__()
 
         assert padding_mode == "zeros"
-        kernel_size = _pair(kernel_size)
-        self.kernel_size = kernel_size
-        stride = _pair(stride)
-        padding = _pair(padding)
-        dilation = _pair(dilation)
+        self.kernel_size = _pair(kernel_size)
+        self.stride = _pair(stride)
+        self.padding = _pair(padding)
+        self.dilation = _pair(dilation)
         self.groups = groups
         assert in_channels % groups == 0
         assert out_channels % groups == 0
+        self.in_channels = in_channels
         self.out_channels = out_channels
         self.weight = flow.nn.Parameter(
-            flow.Tensor(out_channels, in_channels // groups, *kernel_size)
+            flow.Tensor(out_channels, in_channels // groups, *self.kernel_size)
         )
         self.out_channel_groups = out_channels // groups
         self.bias = None
-        self._bias_add_op = None
         if bias:
             self.bias = flow.nn.Parameter(flow.Tensor(out_channels))
-            self._bias_add_op = (
-                flow.builtin_op("bias_add")
-                .Input("a")
-                .Input("b")
-                .Output("out")
-                .Attr("axis", 1)
-                .Build()
-            )
-
-        self._op = (
-            flow.builtin_op("conv2d")
-            .Input("in")
-            .Input("weight")
-            .Attr("filters", out_channels)
-            .Attr("padding_before", padding)
-            .Attr("strides", stride)
-            .Attr("kernel_size", kernel_size)
-            .Attr("dilation_rate", dilation)
-            .Attr("groups", groups)
-            .Attr("data_format", "channels_first")
-            .Output("out")
-            .Build()
-        )
-        self._cpu_op = (
-            flow.builtin_op("conv2d")
-            .Input("in")
-            .Input("weight")
-            .Attr("filters", out_channels // groups)
-            .Attr("padding_before", padding)
-            .Attr("strides", stride)
-            .Attr("kernel_size", kernel_size)
-            .Attr("dilation_rate", dilation)
-            .Attr("groups", 1)
-            .Attr("data_format", "channels_first")
-            .Output("out")
-            .Build()
-        )
         self.reset_parameters()
 
     def reset_parameters(self) -> None:
@@ -273,6 +235,8 @@ class Conv2d(Module):
             init.uniform_(self.bias, -bound, bound)
 
     def forward(self, x):
+        if x.shape[1] != self.in_channels:
+            raise ValueError("The input channels should be equal to self.in_channels")
         if x.device.type == "cpu" and self.groups > 1:
             in_channel_axis = 1
             in_split_list = ConvUtil.split(
@@ -281,7 +245,7 @@ class Conv2d(Module):
             out_list = []
             for i in range(len(in_split_list)):
                 out_list.append(
-                    self._cpu_op(
+                    flow.F.conv2d(
                         in_split_list[i],
                         self.weight[
                             i
@@ -291,14 +255,30 @@ class Conv2d(Module):
                             :,
                             :,
                         ],
-                    )[0]
+                        self.bias[
+                            i
+                            * self.out_channel_groups : (i + 1)
+                            * self.out_channel_groups
+                        ]
+                        if self.bias
+                        else None,
+                        stride=self.stride,
+                        padding=self.padding,
+                        dilation=self.dilation,
+                        groups=1,
+                    )
                 )
             res = flow.experimental.cat(out_list, dim=in_channel_axis)
         else:
-            res = self._op(x, self.weight)[0]
-
-        if self._bias_add_op is not None:
-            res = self._bias_add_op(res, self.bias)[0]
+            res = flow.F.conv2d(
+                x,
+                self.weight,
+                self.bias,
+                stride=self.stride,
+                padding=self.padding,
+                dilation=self.dilation,
+                groups=self.groups,
+            )
         return res
 
 
diff --git a/oneflow/python/test/modules/test_conv.py b/oneflow/python/test/modules/test_conv.py
index 9057dc77d3c36b1f9a891645a9c569af330dfd51..4f2e3e73b6b88fbd050d5a9f54ce4429286aa896 100644
--- a/oneflow/python/test/modules/test_conv.py
+++ b/oneflow/python/test/modules/test_conv.py
@@ -1478,6 +1478,119 @@ def _test_conv2d_large_in_channel(test_case, device):
     test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6))
 
 
+def _test_conv2d_large_out_channel(test_case, device):
+    np_arr = np.array(
+        [
+            [
+                [
+                    [0.56573248, -0.19689320, -0.67875558, 0.34328273, 0.31964567],
+                    [-1.33715475, 0.33422229, -1.27643383, 0.37904647, 0.35891593],
+                    [0.84579802, 2.12729621, -0.51423287, 0.61297560, -1.31156564],
+                    [-0.71047139, 1.02679253, -0.76686019, -0.72969633, 0.73425150],
+                    [-0.13592879, -1.03207183, -0.22554775, 0.74148071, 0.96601510],
+                ],
+                [
+                    [0.51595992, 0.49624804, 0.91145641, 0.49247262, 0.41002217],
+                    [-1.08001196, 1.55497086, -0.81963140, -0.45511565, -0.60269165],
+                    [0.05563145, -0.94318372, -1.17058158, -0.73568577, 0.57810956],
+                    [-0.40260276, -0.10309298, 1.12378800, -0.23510537, -0.73893374],
+                    [-0.52712536, -0.00717016, -1.85051966, -1.50790560, 1.38335907],
+                ],
+            ]
+        ]
+    )
+    input = flow.Tensor(
+        np_arr, dtype=flow.float32, device=flow.device(device), requires_grad=True
+    )
+    weight = np.array(
+        [
+            [
+                [
+                    [-0.19489679, -0.32377058, 0.21736273],
+                    [0.04095296, -0.21552679, -0.14626531],
+                    [-0.19359522, -0.00742865, -0.19832158],
+                ]
+            ],
+            [
+                [
+                    [0.29926914, 0.00931164, 0.26197660],
+                    [0.27611443, -0.15439281, -0.19027126],
+                    [-0.28909120, 0.30367029, -0.05168664],
+                ]
+            ],
+            [
+                [
+                    [-0.03155736, 0.17610769, 0.22111714],
+                    [0.22790670, -0.32897446, -0.03260243],
+                    [-0.10274851, -0.06903386, -0.19438276],
+                ]
+            ],
+            [
+                [
+                    [-0.24573688, -0.06723209, -0.21363299],
+                    [-0.02136187, -0.24994437, -0.18691199],
+                    [0.12189507, 0.29469389, 0.03398871],
+                ]
+            ],
+        ]
+    )
+    m = flow.nn.Conv2d(2, 4, 3, groups=2, bias=False)
+    m.weight = flow.nn.Parameter(flow.Tensor(weight), requires_grad=True)
+    m = m.to(device)
+    output = m(input)
+    print(output)
+    np_out = np.array(
+        [
+            [
+                [
+                    [-0.21170563, 0.03652292, 0.25926736],
+                    [-0.19168918, 0.49044561, 0.25099146],
+                    [-1.02489340, 0.25361472, -0.51828313],
+                ],
+                [
+                    [0.23977707, -0.56090075, -0.19285655],
+                    [-0.17167747, 0.24558367, -0.30935860],
+                    [-0.33303234, 1.52472734, -0.49013454],
+                ],
+                [
+                    [-0.17137986, 1.21333742, 0.18988736],
+                    [0.31785482, -0.12121570, -0.18676008],
+                    [-0.10680684, -0.30298883, 0.41809759],
+                ],
+                [
+                    [-0.87821335, -0.51665992, -0.44061098],
+                    [0.74804580, 0.53107250, 0.50418228],
+                    [-0.00512899, -0.36455840, -0.23643512],
+                ],
+            ]
+        ]
+    )
+    test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-6, 1e-6))
+    output = output.sum()
+    output.backward()
+    np_grad = np.array(
+        [
+            [
+                [
+                    [0.10437235, -0.21008658, 0.26925275, 0.16488039, 0.47933933],
+                    [0.42143974, -0.26293880, -0.12013602, -0.54157579, 0.14280275],
+                    [-0.06124666, -0.44938356, -0.55658901, -0.49534237, -0.10720548],
+                    [-0.16561902, -0.23929697, -0.82584178, -0.66022277, -0.58654481],
+                    [-0.48268640, -0.18644476, -0.43645298, 0.04623342, -0.25000823],
+                ],
+                [
+                    [-0.27729425, -0.16841865, -0.16093449, 0.11635975, 0.00748415],
+                    [-0.07074942, -0.54079264, -0.75282294, -0.68207347, -0.21203026],
+                    [-0.05160286, -0.29598606, -0.66841042, -0.61680746, -0.37242430],
+                    [0.22569139, -0.12756741, -0.50747585, -0.73316729, -0.37990844],
+                    [0.01914656, 0.24480659, 0.08441254, 0.06526598, -0.16039404],
+                ],
+            ]
+        ]
+    )
+    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-6, 1e-6))
+
+
 @unittest.skipIf(
     not flow.unittest.env.eager_execution_enabled(),
     ".numpy() doesn't work in lazy mode",
@@ -1677,7 +1790,7 @@ class TestConv2d(flow.unittest.TestCase):
                 device=device,
             )
 
-    def test_large_channel_group_conv(test_case):
+    def test_large_in_channel_group_conv(test_case):
         arg_dict = OrderedDict()
         arg_dict["test_fun"] = [
             _test_conv2d_large_in_channel,
@@ -1686,6 +1799,15 @@ class TestConv2d(flow.unittest.TestCase):
         for arg in GenArgList(arg_dict):
             arg[0](test_case, *arg[1:])
 
+    def test_large_out_channel_group_conv(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_conv2d_large_out_channel,
+        ]
+        arg_dict["device"] = ["cuda", "cpu"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/oneflow/user/kernels/conv_cudnn_kernels.cpp b/oneflow/user/kernels/conv_cudnn_kernels.cpp
index a095e2404410c9b1fdd739c7dd3056509d35c339..c32975f92765a2acbfc7d24a60765378ba8e8f9f 100644
--- a/oneflow/user/kernels/conv_cudnn_kernels.cpp
+++ b/oneflow/user/kernels/conv_cudnn_kernels.cpp
@@ -149,8 +149,8 @@ class ConvGpuKernel final : public user_op::OpKernel {
 
   bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
 
-  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
-      user_op::KernelInitContext* ctx) const {
+  std::shared_ptr<ConvCudnnOpKernelState> CreateConvCudnnOpKernelState(
+      user_op::KernelComputeContext* ctx) const {
     const auto& data_format = ctx->Attr<std::string>("data_format");
     int32_t filters = ctx->Attr<int32_t>("filters");
 
@@ -166,7 +166,7 @@ class ConvGpuKernel final : public user_op::OpKernel {
   }
 
  private:
-  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
+  void Compute(user_op::KernelComputeContext* ctx) const override {
     const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
     const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0);
     user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
@@ -185,8 +185,8 @@ class ConvGpuKernel final : public user_op::OpKernel {
 
     const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0);
     if (bias != nullptr) {
-      ConvCudnnOpKernelState* conv_state = dynamic_cast<ConvCudnnOpKernelState*>(state);
-      CHECK_NOTNULL(conv_state);
+      const auto& conv_state = CreateConvCudnnOpKernelState(ctx);
+      CHECK_NOTNULL(conv_state.get());
       OF_CUDNN_CHECK(cudnnAddTensor(ctx->device_ctx()->cudnn_handle(), CudnnSPOnePtr<T>(),
                                     conv_state->bias_desc->Get(), bias->dptr<T>(),
                                     CudnnSPOnePtr<T>(), args.ydesc.Get(), out->mut_dptr<T>()));
@@ -352,8 +352,8 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel {
 
   bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
 
-  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
-      user_op::KernelInitContext* ctx) const {
+  std::shared_ptr<ConvBiasGradState> CreateConvBiasGradState(
+      user_op::KernelComputeContext* ctx) const {
     const auto* bias_diff = ctx->TensorDesc4ArgNameAndIndex("bias_diff", 0);
     const auto* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0);
     const auto& data_format = ctx->Attr<std::string>("data_format");
@@ -375,7 +375,7 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel {
   }
 
  private:
-  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
+  void Compute(user_op::KernelComputeContext* ctx) const override {
     const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
     user_op::Tensor* bias_diff = ctx->Tensor4ArgNameAndIndex("bias_diff", 0);
     CHECK_EQ(bias_diff->shape().NumAxes(), 1);
@@ -386,8 +386,8 @@ class ConvBiasGradGpuKernel final : public user_op::OpKernel {
 
     std::unique_ptr<CudnnTensorDesc> dy_desc;
     dy_desc.reset(new CudnnTensorDesc(dy->data_type(), dy->shape(), data_format));
-    auto* bias_grad_state = dynamic_cast<ConvBiasGradState*>(state);
-    CHECK_NOTNULL(bias_grad_state);
+    const auto& bias_grad_state = CreateConvBiasGradState(ctx);
+    CHECK_NOTNULL(bias_grad_state.get());
     OF_CUDNN_CHECK(cudnnConvolutionBackwardBias(
         ctx->device_ctx()->cudnn_handle(), CudnnSPOnePtr<T>(), dy_desc->Get(), dy->dptr<T>(),
         CudnnSPZeroPtr<T>(), bias_grad_state->bias_diff_desc->Get(), bias_diff->mut_dptr<T>()));
diff --git a/oneflow/user/kernels/conv_kernels.cpp b/oneflow/user/kernels/conv_kernels.cpp
index 6c5a857347499abd34262c63a4073fb1cd6d5667..bb01035f37773297375fc30c4b818ba9801a6721 100644
--- a/oneflow/user/kernels/conv_kernels.cpp
+++ b/oneflow/user/kernels/conv_kernels.cpp
@@ -326,10 +326,10 @@ struct ConvOpKernelState final : public user_op::OpKernelState {
 };
 
 template<typename T>
-std::shared_ptr<user_op::OpKernelState> CreateConvOpKernelState(user_op::KernelInitContext* ctx,
-                                                                const std::string& in_name,
-                                                                const std::string& out_name,
-                                                                const std::string& weight_name) {
+std::shared_ptr<ConvOpKernelState<T>> CreateConvOpKernelState(user_op::KernelComputeContext* ctx,
+                                                              const std::string& in_name,
+                                                              const std::string& out_name,
+                                                              const std::string& weight_name) {
   const auto& data_format = ctx->Attr<std::string>("data_format");
 
   std::shared_ptr<ConvOpKernelState<T>> state(new ConvOpKernelState<T>());
@@ -394,13 +394,11 @@ class ConvCpuKernel final : public user_op::OpKernel {
 
   bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
 
-  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
-      user_op::KernelInitContext* ctx) const {
-    return CreateConvOpKernelState<T>(ctx, "in", "out", "weight");
-  }
-
  private:
-  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const auto& conv_state = CreateConvOpKernelState<T>(ctx, "in", "out", "weight");
+    CHECK_NOTNULL(conv_state.get());
+
     const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
     const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0);
     user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
@@ -408,9 +406,6 @@ class ConvCpuKernel final : public user_op::OpKernel {
 
     T* col_buf_dptr = tmp_buffer->mut_dptr<T>();
 
-    auto* conv_state = dynamic_cast<ConvOpKernelState<T>*>(state);
-    conv_state->Update(in->shape(), out->shape());
-    CHECK_NOTNULL(conv_state);
     bool is_bias_mul_inited = false;
     for (int64_t i = 0; i < in->shape().At(0); ++i) {
       conv_state->im2col_func_(GetImgDptr<T>(in, i), ShapeView(conv_state->in_5d_shape_),
@@ -495,20 +490,16 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel {
 
   bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
 
-  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
-      user_op::KernelInitContext* ctx) const {
-    return CreateConvOpKernelState<T>(ctx, "dx", "dy", "filter");
-  }
-
  private:
-  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
-    auto* conv_state = dynamic_cast<ConvOpKernelState<T>*>(state);
-    CHECK_NOTNULL(conv_state);
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const auto& conv_state = CreateConvOpKernelState<T>(ctx, "dx", "dy", "filter");
+    CHECK_NOTNULL(conv_state.get());
+
     const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
     const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex("filter", 0);
     user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0);
     user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
-    conv_state->Update(dx->shape(), dy->shape());
+
     Memset<DeviceType::kCPU>(ctx->device_ctx(), dx->mut_dptr<T>(), 0,
                              dx->shape().elem_cnt() * sizeof(T));
 
@@ -571,21 +562,15 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel {
 
   bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
 
-  std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
-      user_op::KernelInitContext* ctx) const {
-    return CreateConvOpKernelState<T>(ctx, "x", "dy", "filter_diff");
-  }
-
  private:
-  void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
-    auto* conv_state = dynamic_cast<ConvOpKernelState<T>*>(state);
-    CHECK_NOTNULL(conv_state);
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const auto& conv_state = CreateConvOpKernelState<T>(ctx, "x", "dy", "filter_diff");
+    CHECK_NOTNULL(conv_state.get());
 
     const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0);
     const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0);
     user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex("filter_diff", 0);
     user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
-    conv_state->Update(x->shape(), dy->shape());
 
     Memset<DeviceType::kCPU>(ctx->device_ctx(), filter_diff->mut_dptr<T>(), 0,
                              filter_diff->shape().elem_cnt() * sizeof(T));
diff --git a/tools/generate_functional_api.py b/tools/generate_functional_api.py
index 4164dfc3b6f78196b7786d2d2cb0c474dd3d1340..409858934ad92125077a2a9d3836bf3c53db4233 100644
--- a/tools/generate_functional_api.py
+++ b/tools/generate_functional_api.py
@@ -61,6 +61,7 @@ header_fmt = (
 #ifndef ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_
 #define ONEFLOW_CORE_FUNCTIONAL_GENERATED_FUNCTIONAL_API_H_
 
+#include "oneflow/core/common/optional.h"
 #include "oneflow/core/framework/tensor.h"
 #include "oneflow/core/framework/tensor_tuple.h"
 #include "oneflow/core/functional/scalar.h"
@@ -104,6 +105,7 @@ pybind_fmt = (
 #include "oneflow/api/python/functional/function_def.h"
 #include "oneflow/api/python/functional/py_function.h"
 #include "oneflow/core/common/maybe.h"
+#include "oneflow/core/common/optional.h"
 #include "oneflow/core/functional/functional.h"
 
 namespace oneflow {{
@@ -174,6 +176,24 @@ argument_type_aliases = {
     **generic_type_aliases,
 }
 
+optional_argument_type_aliases = {
+    "Tensor": "const Optional<one::Tensor>&",
+    "TensorTuple": "const Optional<TensorTuple>&",
+    "Scalar": "const Optional<Scalar>&",
+    "ScalarList": "const Optional<std::vector<Scalar>>&",
+    "IntList": "const Optional<std::vector<int32_t>>&",
+    "Int32List": "const Optional<std::vector<int32_t>>&",
+    "Int64List": "const Optional<std::vector<int64_t>>&",
+    "FloatList": "const Optional<std::vector<float>>&",
+    "DoubleList": "const Optional<std::vector<double>>&",
+    "String": "const Optional<std::string>&",
+    "StringList": "const Optional<std::vector<std::string>>&",
+    "BoolList": "const Optional<std::vector<bool>>&",
+    "DataType": "const Optional<DataType>&",
+    "Shape": "const Optional<Shape>&",
+    **{k: "const Optional<{0}>".format(v) for k, v in generic_type_aliases.items()},
+}
+
 return_type_aliases = {
     "Void": "Maybe<void>",
     "Tensor": "Maybe<one::Tensor>",
@@ -254,20 +274,27 @@ class Argument:
         self._type = _normalize(fmt[0:sp])
         assert self._type in types_allowed, "Unknow type: " + self._type
 
-        if self._type in argument_type_aliases:
-            self._cpp_type = argument_type_aliases[self._type]
-        else:
-            self._cpp_type = self._type
+        optional = False
         self._name = _normalize(fmt[sp + 1 :])
         sp = self._name.find("=")
         if sp != -1:
             self._default_value = _normalize(self._name[sp + 1 :])
-            if self._default_value in value_aliases:
+            if self._default_value == "None":
+                optional = True
+                self._default_cpp_value = ""
+            elif self._default_value in value_aliases:
                 self._default_cpp_value = value_aliases[self._default_value]
             else:
                 self._default_cpp_value = self._default_value
             self._name = _normalize(self._name[0:sp])
 
+        if not optional and self._type in argument_type_aliases:
+            self._cpp_type = argument_type_aliases[self._type]
+        elif optional and self._type in optional_argument_type_aliases:
+            self._cpp_type = optional_argument_type_aliases[self._type]
+        else:
+            self._cpp_type = self._type
+
     @property
     def has_default_value(self):
         return self._default_value is not None