diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst
index ab607b227d27dd21abc69fdee40bb5055e902683..cd07e80cded51810d06bff41eba86ef9986fd9f7 100644
--- a/docs/source/experimental.rst
+++ b/docs/source/experimental.rst
@@ -204,6 +204,7 @@ Experimental features
 .. autofunction:: oneflow.experimental.nn.Hardswish
 .. autofunction:: oneflow.experimental.nn.PReLU
 .. autofunction:: oneflow.experimental.nn.Hardtanh
+.. autofunction:: oneflow.experimental.nn.functional.interpolate
 .. autofunction:: oneflow.experimental.nn.Upsample
 .. autofunction:: oneflow.experimental.nn.UpsamplingNearest2d
 .. autofunction:: oneflow.experimental.nn.UpsamplingBilinear2d
diff --git a/oneflow/core/autograd/gradient_funcs/upsample.cpp b/oneflow/core/autograd/gradient_funcs/upsample.cpp
index 545d245b37c8b25ac80e2f0cedf8b6d1ebe81017..51ab7d589370ef4474c427f6a1142c57d9cd5ee0 100644
--- a/oneflow/core/autograd/gradient_funcs/upsample.cpp
+++ b/oneflow/core/autograd/gradient_funcs/upsample.cpp
@@ -17,6 +17,7 @@ limitations under the License.
 #include "oneflow/core/framework/op_builder.h"
 #include "oneflow/core/framework/op_expr.h"
 #include "oneflow/core/framework/op_expr_helper.h"
+#include "oneflow/core/functional/functional.h"
 #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
 
 namespace oneflow {
@@ -93,5 +94,324 @@ Maybe<void> Upsample::Apply(const UpsampleInterpState* ctx, const TensorTuple& o
 
 REGISTER_OP_EXPR_GRAD_FUNCTION("upsample", Upsample);
 
+struct UpsampleNearest2DInterpState : public OpExprInterpState {
+  bool requires_grad;
+  float height_scale;
+  float width_scale;
+  std::string data_format;
+};
+
+class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(UpsampleNearest2DInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    ComposedAttrMap composed_attrs(attrs, base_attrs_);
+    ctx->height_scale = JUST(composed_attrs.GetAttr<float>("height_scale"));
+    ctx->width_scale = JUST(composed_attrs.GetAttr<float>("width_scale"));
+    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
+    ctx->SaveTensorForBackward(inputs.at(0));
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const UpsampleNearest2DInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    MutableAttrMap attrs;
+    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
+    in_grads->resize(1);
+    in_grads->at(0) = JUST(functional::UpsampleNearest2DGrad(out_grads.at(0), x, ctx->height_scale,
+                                                             ctx->width_scale, ctx->data_format));
+
+    return Maybe<void>::Ok();
+  }
+
+ private:
+  AttrMap base_attrs_;
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_nearest_2d", UpsampleNearest2D);
+
+struct UpsampleBilinear2DInterpState : public OpExprInterpState {
+  bool requires_grad;
+  float height_scale;
+  float width_scale;
+  bool align_corners;
+  std::string data_format;
+};
+
+class UpsampleBilinear2D : public OpExprGradFunction<UpsampleBilinear2DInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(UpsampleBilinear2DInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    ComposedAttrMap composed_attrs(attrs, base_attrs_);
+    ctx->height_scale = JUST(composed_attrs.GetAttr<float>("height_scale"));
+    ctx->width_scale = JUST(composed_attrs.GetAttr<float>("width_scale"));
+    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
+    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
+    ctx->SaveTensorForBackward(inputs.at(0));
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const UpsampleBilinear2DInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    MutableAttrMap attrs;
+    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
+    in_grads->resize(1);
+    in_grads->at(0) = JUST(functional::UpsampleBilinear2DGrad(out_grads.at(0), x, ctx->height_scale,
+                                                              ctx->width_scale, ctx->align_corners,
+                                                              ctx->data_format));
+
+    return Maybe<void>::Ok();
+  }
+
+ private:
+  AttrMap base_attrs_;
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_bilinear_2d", UpsampleBilinear2D);
+
+struct UpsampleLinear1DInterpState : public OpExprInterpState {
+  bool requires_grad;
+  float scale_factor;
+  bool align_corners;
+  std::string data_format;
+};
+
+class UpsampleLinear1D : public OpExprGradFunction<UpsampleLinear1DInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(UpsampleLinear1DInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    ComposedAttrMap composed_attrs(attrs, base_attrs_);
+    ctx->scale_factor = JUST(composed_attrs.GetAttr<float>("scale_factor"));
+    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
+    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
+    ctx->SaveTensorForBackward(inputs.at(0));
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const UpsampleLinear1DInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    MutableAttrMap attrs;
+    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
+    in_grads->resize(1);
+    in_grads->at(0) = JUST(functional::UpsampleLinear1DGrad(out_grads.at(0), x, ctx->scale_factor,
+                                                            ctx->align_corners, ctx->data_format));
+
+    return Maybe<void>::Ok();
+  }
+
+ private:
+  AttrMap base_attrs_;
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_linear_1d", UpsampleLinear1D);
+
+struct UpsampleNearest1DInterpState : public OpExprInterpState {
+  bool requires_grad;
+  float scale_factor;
+  std::string data_format;
+};
+
+class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(UpsampleNearest1DInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    ComposedAttrMap composed_attrs(attrs, base_attrs_);
+    ctx->scale_factor = JUST(composed_attrs.GetAttr<float>("scale_factor"));
+    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
+    ctx->SaveTensorForBackward(inputs.at(0));
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const UpsampleNearest1DInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    MutableAttrMap attrs;
+    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
+    in_grads->resize(1);
+    in_grads->at(0) = JUST(
+        functional::UpsampleNearest1DGrad(out_grads.at(0), x, ctx->scale_factor, ctx->data_format));
+
+    return Maybe<void>::Ok();
+  }
+
+ private:
+  AttrMap base_attrs_;
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_nearest_1d", UpsampleNearest1D);
+
+struct UpsampleBicubic2DInterpState : public OpExprInterpState {
+  bool requires_grad;
+  float height_scale;
+  float width_scale;
+  bool align_corners;
+  std::string data_format;
+};
+
+class UpsampleBicubic2D : public OpExprGradFunction<UpsampleBicubic2DInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(UpsampleBicubic2DInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    ComposedAttrMap composed_attrs(attrs, base_attrs_);
+    ctx->height_scale = JUST(composed_attrs.GetAttr<float>("height_scale"));
+    ctx->width_scale = JUST(composed_attrs.GetAttr<float>("width_scale"));
+    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
+    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
+    ctx->SaveTensorForBackward(inputs.at(0));
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const UpsampleBicubic2DInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    MutableAttrMap attrs;
+    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
+    in_grads->resize(1);
+    in_grads->at(0) = JUST(functional::UpsampleBicubic2DGrad(out_grads.at(0), x, ctx->height_scale,
+                                                             ctx->width_scale, ctx->align_corners,
+                                                             ctx->data_format));
+    return Maybe<void>::Ok();
+  }
+
+ private:
+  AttrMap base_attrs_;
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_bicubic_2d", UpsampleBicubic2D);
+
+struct UpsampleNearest3DInterpState : public OpExprInterpState {
+  bool requires_grad;
+  float depth_scale;
+  float height_scale;
+  float width_scale;
+  std::string data_format;
+};
+
+class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(UpsampleNearest3DInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    ComposedAttrMap composed_attrs(attrs, base_attrs_);
+    ctx->depth_scale = JUST(composed_attrs.GetAttr<float>("depth_scale"));
+    ctx->height_scale = JUST(composed_attrs.GetAttr<float>("height_scale"));
+    ctx->width_scale = JUST(composed_attrs.GetAttr<float>("width_scale"));
+    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
+    ctx->SaveTensorForBackward(inputs.at(0));
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const UpsampleNearest3DInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    MutableAttrMap attrs;
+    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
+    in_grads->resize(1);
+    in_grads->at(0) = JUST(functional::UpsampleNearest3DGrad(out_grads.at(0), x, ctx->depth_scale,
+                                                             ctx->height_scale, ctx->width_scale,
+                                                             ctx->data_format));
+
+    return Maybe<void>::Ok();
+  }
+
+ private:
+  AttrMap base_attrs_;
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_nearest_3d", UpsampleNearest3D);
+
+struct UpsampleTrilinear3DInterpState : public OpExprInterpState {
+  bool requires_grad;
+  float depth_scale;
+  float height_scale;
+  float width_scale;
+  bool align_corners;
+  std::string data_format;
+};
+
+class UpsampleTrilinear3D : public OpExprGradFunction<UpsampleTrilinear3DInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }
+
+  Maybe<void> Capture(UpsampleTrilinear3DInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    CHECK_EQ_OR_RETURN(inputs.size(), 1);
+    CHECK_EQ_OR_RETURN(outputs.size(), 1);
+    ctx->requires_grad = inputs.at(0)->requires_grad();
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    ComposedAttrMap composed_attrs(attrs, base_attrs_);
+    ctx->depth_scale = JUST(composed_attrs.GetAttr<float>("depth_scale"));
+    ctx->height_scale = JUST(composed_attrs.GetAttr<float>("height_scale"));
+    ctx->width_scale = JUST(composed_attrs.GetAttr<float>("width_scale"));
+    ctx->align_corners = JUST(composed_attrs.GetAttr<bool>("align_corners"));
+    ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
+    ctx->SaveTensorForBackward(inputs.at(0));
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const UpsampleTrilinear3DInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+    CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+    MutableAttrMap attrs;
+    const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
+    in_grads->resize(1);
+    in_grads->at(0) = JUST(functional::UpsampleTrilinear3DGrad(
+        out_grads.at(0), x, ctx->depth_scale, ctx->height_scale, ctx->width_scale,
+        ctx->align_corners, ctx->data_format));
+
+    return Maybe<void>::Ok();
+  }
+
+ private:
+  AttrMap base_attrs_;
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_trilinear_3d", UpsampleTrilinear3D);
+
 }  // namespace one
 }  // namespace oneflow
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index 5e33c253d161d68e621d5ca314e498dc09b8b156..9eaa451c9deaac2822aa75aa1bd4749ebd8770a1 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -476,6 +476,90 @@
                      String interpolation, String data_format=\"channels_first\")"
   bind_python: True
 
+- name: "upsample_linear_1d"
+  signature:
+    "Tensor UpsampleLinear1D(Tensor x, *, Float scale_factor, Bool align_corners,
+                     String data_format=\"channels_first\")"
+  bind_python: True
+
+- name: "upsample_linear_1d_grad"
+  signature:
+    "Tensor UpsampleLinear1DGrad(Tensor dy, Tensor x, *, Float scale_factor, Bool align_corners,
+                     String data_format=\"channels_first\")"
+  bind_python: False
+
+- name: "upsample_nearest_1d"
+  signature:
+    "Tensor UpsampleNearest1D(Tensor x, *, Float scale_factor,
+                     String data_format=\"channels_first\")"
+  bind_python: True  
+
+- name: "upsample_nearest_1d_grad"
+  signature:
+    "Tensor UpsampleNearest1DGrad(Tensor dy, Tensor x, *, Float scale_factor,
+                     String data_format=\"channels_first\")"
+  bind_python: False
+
+- name: "upsample_nearest_2d"
+  signature:
+    "Tensor UpsampleNearest2D(Tensor x, *, Float height_scale, Float width_scale, 
+                     String data_format=\"channels_first\")"
+  bind_python: True  
+
+- name: "upsample_nearest_2d_grad"
+  signature:
+    "Tensor UpsampleNearest2DGrad(Tensor dy, Tensor x, *, Float height_scale, Float width_scale, 
+                     String data_format=\"channels_first\")"
+  bind_python: False 
+
+- name: "upsample_bilinear_2d"
+  signature:
+    "Tensor UpsampleBilinear2D(Tensor x, *, Float height_scale, Float width_scale, Bool align_corners,
+                     String data_format=\"channels_first\")"
+  bind_python: True
+
+- name: "upsample_bilinear_2d_grad"
+  signature:
+    "Tensor UpsampleBilinear2DGrad(Tensor dy, Tensor x, *, Float height_scale, Float width_scale, Bool align_corners,
+                     String data_format=\"channels_first\")"
+  bind_python: False
+
+- name: "upsample_bicubic_2d"
+  signature:
+    "Tensor UpsampleBicubic2D(Tensor x, *, Float height_scale, Float width_scale, Bool align_corners,
+                     String data_format=\"channels_first\")"
+  bind_python: True
+
+- name: "upsample_bicubic_2d_grad"
+  signature:
+    "Tensor UpsampleBicubic2DGrad(Tensor dy, Tensor x, *, Float height_scale, Float width_scale, Bool align_corners,
+                     String data_format=\"channels_first\")"
+  bind_python: False
+
+- name: "upsample_nearest_3d"
+  signature:
+    "Tensor UpsampleNearest3D(Tensor x, *, Float depth_scale, Float height_scale, Float width_scale,
+                     String data_format=\"channels_first\")"
+  bind_python: True
+
+- name: "upsample_nearest_3d_grad"
+  signature:
+    "Tensor UpsampleNearest3DGrad(Tensor dy, Tensor x, *, Float depth_scale, Float height_scale, Float width_scale,
+                     String data_format=\"channels_first\")"
+  bind_python: False
+
+- name: "upsample_trilinear_3d"
+  signature:
+    "Tensor UpsampleTrilinear3D(Tensor x, *, Float depth_scale, Float height_scale, Float width_scale,
+                     Bool align_corners, String data_format=\"channels_first\")"
+  bind_python: True
+
+- name: "upsample_trilinear_3d_grad"
+  signature:
+    "Tensor UpsampleTrilinear3DGrad(Tensor dy, Tensor x, *, Float depth_scale, Float height_scale, Float width_scale,
+                     Bool align_corners, String data_format=\"channels_first\")"
+  bind_python: False
+
 - name: "abs"
   signature: "Tensor Abs(Tensor x)"
   bind_python: True
diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp
index 1653768c3e35777d4ff5c048439f50d772f4425f..53ddfa662b9b5cf2e167a6a04fdf742266962841 100644
--- a/oneflow/core/functional/impl/array_functor.cpp
+++ b/oneflow/core/functional/impl/array_functor.cpp
@@ -423,6 +423,288 @@ class UpsampleFunctor {
   std::shared_ptr<OpExpr> op_;
 };
 
+class UpsampleLinear1DFunctor {
+ public:
+  UpsampleLinear1DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("upsample_linear_1d").Input("x").Output("y").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& scale_factor,
+                           const bool& align_corners, const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("scale_factor", scale_factor));
+    JUST(attrs.SetAttr<bool>("align_corners", align_corners));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleLinear1DGradFunctor {
+ public:
+  UpsampleLinear1DGradFunctor() {
+    op_ = CHECK_JUST(
+        one::OpBuilder("upsample_linear_1d_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& x, const float& scale_factor,
+                           const bool& align_corners, const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("scale_factor", scale_factor));
+    JUST(attrs.SetAttr<bool>("align_corners", align_corners));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleNearest1DFunctor {
+ public:
+  UpsampleNearest1DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_1d").Input("x").Output("y").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& scale_factor,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("scale_factor", scale_factor));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleNearest1DGradFunctor {
+ public:
+  UpsampleNearest1DGradFunctor() {
+    op_ = CHECK_JUST(
+        one::OpBuilder("upsample_nearest_1d_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& x, const float& scale_factor,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("scale_factor", scale_factor));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleNearest2DFunctor {
+ public:
+  UpsampleNearest2DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_2d").Input("x").Output("y").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& height_scale,
+                           const float& width_scale, const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleNearest2DGradFunctor {
+ public:
+  UpsampleNearest2DGradFunctor() {
+    op_ = CHECK_JUST(
+        one::OpBuilder("upsample_nearest_2d_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& x, const float& height_scale,
+                           const float& width_scale, const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleBilinear2DFunctor {
+ public:
+  UpsampleBilinear2DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("upsample_bilinear_2d").Input("x").Output("y").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& height_scale,
+                           const float& width_scale, const bool& align_corners,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<bool>("align_corners", align_corners));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleBilinear2DGradFunctor {
+ public:
+  UpsampleBilinear2DGradFunctor() {
+    op_ = CHECK_JUST(
+        one::OpBuilder("upsample_bilinear_2d_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& x, const float& height_scale,
+                           const float& width_scale, const bool& align_corners,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<bool>("align_corners", align_corners));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleBicubic2DFunctor {
+ public:
+  UpsampleBicubic2DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("upsample_bicubic_2d").Input("x").Output("y").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& height_scale,
+                           const float& width_scale, const bool& align_corners,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<bool>("align_corners", align_corners));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleBicubic2DGradFunctor {
+ public:
+  UpsampleBicubic2DGradFunctor() {
+    op_ = CHECK_JUST(
+        one::OpBuilder("upsample_bicubic_2d_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& x, const float& height_scale,
+                           const float& width_scale, const bool& align_corners,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<bool>("align_corners", align_corners));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleNearest3DFunctor {
+ public:
+  UpsampleNearest3DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_3d").Input("x").Output("y").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& depth_scale,
+                           const float& height_scale, const float& width_scale,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("depth_scale", depth_scale));
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleNearest3DGradFunctor {
+ public:
+  UpsampleNearest3DGradFunctor() {
+    op_ = CHECK_JUST(
+        one::OpBuilder("upsample_nearest_3d_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& x, const float& depth_scale,
+                           const float& height_scale, const float& width_scale,
+                           const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("depth_scale", depth_scale));
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleTrilinear3DFunctor {
+ public:
+  UpsampleTrilinear3DFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("upsample_trilinear_3d").Input("x").Output("y").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const float& depth_scale,
+                           const float& height_scale, const float& width_scale,
+                           const bool& align_corners, const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("depth_scale", depth_scale));
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<bool>("align_corners", align_corners));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class UpsampleTrilinear3DGradFunctor {
+ public:
+  UpsampleTrilinear3DGradFunctor() {
+    op_ = CHECK_JUST(
+        one::OpBuilder("upsample_trilinear_3d_grad").Input("dy").Input("x").Output("dx").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& x, const float& depth_scale,
+                           const float& height_scale, const float& width_scale,
+                           const bool& align_corners, const std::string& data_format) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<float>("depth_scale", depth_scale));
+    JUST(attrs.SetAttr<float>("height_scale", height_scale));
+    JUST(attrs.SetAttr<float>("width_scale", width_scale));
+    JUST(attrs.SetAttr<bool>("align_corners", align_corners));
+    JUST(attrs.SetAttr<std::string>("data_format", data_format));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
 class UnsortedSegmentSumLikeFunctor {
  public:
   UnsortedSegmentSumLikeFunctor() {
@@ -572,6 +854,20 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::SqueezeFunctor>("Squeeze");
   m.add_functor<impl::CopyFunctor>("Copy");
   m.add_functor<impl::UpsampleFunctor>("Upsample");
+  m.add_functor<impl::UpsampleNearest2DFunctor>("UpsampleNearest2D");
+  m.add_functor<impl::UpsampleNearest2DGradFunctor>("UpsampleNearest2DGrad");
+  m.add_functor<impl::UpsampleBilinear2DFunctor>("UpsampleBilinear2D");
+  m.add_functor<impl::UpsampleBilinear2DGradFunctor>("UpsampleBilinear2DGrad");
+  m.add_functor<impl::UpsampleLinear1DFunctor>("UpsampleLinear1D");
+  m.add_functor<impl::UpsampleLinear1DGradFunctor>("UpsampleLinear1DGrad");
+  m.add_functor<impl::UpsampleNearest1DFunctor>("UpsampleNearest1D");
+  m.add_functor<impl::UpsampleNearest1DGradFunctor>("UpsampleNearest1DGrad");
+  m.add_functor<impl::UpsampleBicubic2DFunctor>("UpsampleBicubic2D");
+  m.add_functor<impl::UpsampleBicubic2DGradFunctor>("UpsampleBicubic2DGrad");
+  m.add_functor<impl::UpsampleNearest3DFunctor>("UpsampleNearest3D");
+  m.add_functor<impl::UpsampleNearest3DGradFunctor>("UpsampleNearest3DGrad");
+  m.add_functor<impl::UpsampleTrilinear3DFunctor>("UpsampleTrilinear3D");
+  m.add_functor<impl::UpsampleTrilinear3DGradFunctor>("UpsampleTrilinear3DGrad");
   m.add_functor<impl::UnsortedSegmentSumLikeFunctor>("UnsortedSegmentSumLike");
   m.add_functor<impl::TriuFunctor>("Triu");
   m.add_functor<impl::DiagFunctor>("Diag");
diff --git a/oneflow/python/nn/modules/interpolate.py b/oneflow/python/nn/modules/interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..64b07535a54cf6afcb67bf003e68942181b87d45
--- /dev/null
+++ b/oneflow/python/nn/modules/interpolate.py
@@ -0,0 +1,318 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import math
+import warnings
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export, experimental_api
+from oneflow.python.framework.tensor import register_tensor_op
+from typing import Optional, Union, Tuple
+
+
+class Interpolate(Module):
+    def __init__(
+        self,
+        size: Optional[Union[int, Tuple[int, ...]]] = None,
+        scale_factor: Optional[Union[float, Tuple[float, ...]]] = None,
+        mode: str = "nearest",
+        align_corners: Optional[bool] = None,
+        recompute_scale_factor: Optional[bool] = None,
+    ):
+        super().__init__()
+        self.size = size
+        if isinstance(scale_factor, tuple):
+            self.scale_factor = tuple(float(factor) for factor in scale_factor)
+        else:
+            self.scale_factor = float(scale_factor) if scale_factor else None
+
+        if mode in ("nearest", "area") and align_corners is not None:
+            raise ValueError(
+                "align_corners option can only be set with the "
+                "interpolating modes: linear | bilinear | bicubic | trilinear"
+            )
+
+        self.mode = mode
+        self.recompute_scale_factor = recompute_scale_factor
+        if align_corners == None:
+            align_corners = False
+
+        self.align_corners = align_corners
+        self.height_scale = None
+        self.width_scale = None
+
+        if isinstance(self.scale_factor, float):
+            self.height_scale = self.scale_factor
+            self.width_scale = self.scale_factor
+        elif isinstance(self.scale_factor, tuple):
+            self.height_scale = self.scale_factor[0]
+            self.width_scale = self.scale_factor[1]
+        else:
+            pass
+
+        if self.mode not in (
+            "nearest",
+            "bilinear",
+            "linear",
+            "area",
+            "bicubic",
+            "trilinear",
+        ):
+            raise ValueError(
+                'interpolation must be "nearest" or "bilinear" or "linear" or "area" or "bicubic" or "trilinear".'
+            )
+
+        if self.mode == "nearest" and self.align_corners:
+            raise ValueError('interpolation "nearest" does not support align_corners.')
+
+    def forward(self, x):
+        dim = len(x.shape) - 2
+        if self.size is not None and self.scale_factor is not None:
+            raise ValueError("only one of size or scale_factor should be defined")
+        elif self.size is not None:
+            assert self.scale_factor is None
+            scale_factors = None
+            if isinstance(self.size, (list, tuple)):
+                if len(self.size) != dim:
+                    raise ValueError(
+                        "size shape must match input shape. "
+                        "Input is {}D, size is {}".format(dim, len(self.size))
+                    )
+                output_size = self.size
+            else:
+                output_size = [self.size for _ in range(dim)]
+        elif self.scale_factor is not None:
+            assert self.size is None
+            output_size = None
+            if isinstance(self.scale_factor, (list, tuple)):
+                if len(self.scale_factor) != dim:
+                    raise ValueError(
+                        "scale_factor shape must match input shape. "
+                        "Input is {}D, scale_factor is {}".format(
+                            dim, len(self.scale_factor)
+                        )
+                    )
+                scale_factors = self.scale_factor
+            else:
+                scale_factors = [self.scale_factor for _ in range(dim)]
+        else:
+            raise ValueError("either size or scale_factor should be defined")
+
+        if self.recompute_scale_factor is None:
+            if scale_factors is not None:
+                for scale in scale_factors:
+                    if math.floor(scale) != scale:
+                        warnings.warn(
+                            "The default behavior for interpolate/upsample with float scale_factor changed "
+                            "in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, "
+                            "instead of relying on the computed output size. "
+                            "If you wish to restore the old behavior, please set recompute_scale_factor=True. "
+                            "See the documentation of nn.Upsample for details. "
+                        )
+                    break
+        elif self.recompute_scale_factor and self.size is not None:
+            raise ValueError(
+                "recompute_scale_factor is not meaningful with an explicit size."
+            )
+
+        # "area" mode always requires an explicit size rather than scale factor.
+        # Re-use the recompute_scale_factor code path.
+        if self.mode == "area" and output_size is None:
+            self.recompute_scale_factor = True
+
+        if self.recompute_scale_factor is not None and self.recompute_scale_factor:
+            assert scale_factors is not None
+            output_size = [
+                int(math.floor(float(input.size(i + 2)) * scale_factors[i]))
+                for i in range(dim)
+            ]
+            scale_factors = None
+
+        if len(x.shape) == 3 and self.mode == "nearest":
+            return flow.F.upsample_nearest_1d(
+                x, scale_factor=scale_factors[0], data_format="channels_first"
+            )
+
+        if len(x.shape) == 4 and self.mode == "nearest":
+            return flow.F.upsample_nearest_2d(
+                x,
+                height_scale=scale_factors[0],
+                width_scale=scale_factors[1],
+                data_format="channels_first",
+            )
+
+        if len(x.shape) == 5 and self.mode == "nearest":
+            return flow.F.upsample_nearest_3d(
+                x,
+                depth_scale=scale_factors[0],
+                height_scale=scale_factors[1],
+                width_scale=scale_factors[2],
+                data_format="channels_first",
+            )
+
+        # TODO(bbuf) Add adaptive_avg_pool op
+
+        if self.mode == "area":
+            raise NotImplementedError("adaptive_avg_pool1d not impleted now!")
+
+        if len(x.shape) == 3 and self.mode == "linear":
+            assert self.align_corners is not None
+            return flow.F.upsample_linear_1d(
+                x,
+                scale_factor=scale_factors[0],
+                align_corners=self.align_corners,
+                data_format="channels_first",
+            )
+
+        if len(x.shape) == 4 and self.mode == "bilinear":
+            assert self.align_corners is not None
+            return flow.F.upsample_bilinear_2d(
+                x,
+                height_scale=scale_factors[0],
+                width_scale=scale_factors[1],
+                align_corners=self.align_corners,
+                data_format="channels_first",
+            )
+
+        if len(x.shape) == 4 and self.mode == "bicubic":
+            assert self.align_corners is not None
+            return flow.F.upsample_bicubic_2d(
+                x,
+                height_scale=scale_factors[0],
+                width_scale=scale_factors[1],
+                align_corners=self.align_corners,
+                data_format="channels_first",
+            )
+
+        if len(x.shape) == 5 and self.mode == "trilinear":
+            assert self.align_corners is not None
+            return flow.F.upsample_trilinear_3d(
+                x,
+                depth_scale=scale_factors[0],
+                height_scale=scale_factors[1],
+                width_scale=scale_factors[2],
+                align_corners=self.align_corners,
+                data_format="channels_first",
+            )
+
+
+@oneflow_export("nn.functional.interpolate")
+@experimental_api
+def interpolate(
+    input,
+    size=None,
+    scale_factor=None,
+    mode="nearest",
+    align_corners=None,
+    recompute_scale_factor=None,
+):
+    r"""The interface is consistent with PyTorch.    
+    
+    The documentation is referenced from: https://pytorch.org/docs/1.9.0/_modules/torch/nn/functional.html#interpolate
+    
+
+    Down/up samples the input to either the given :attr:`size` or the given
+    :attr:`scale_factor`
+
+    The algorithm used for interpolation is determined by :attr:`mode`.
+
+    Currently temporal, spatial and volumetric sampling are supported, i.e.
+    expected inputs are 3-D, 4-D or 5-D in shape.
+
+    The input dimensions are interpreted in the form:
+    `mini-batch x channels x [optional depth] x [optional height] x width`.
+
+    The modes available for resizing are: `nearest`, `linear` (3D-only),
+    `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`
+
+    Args:
+        input (Tensor): the input tensor
+        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
+            output spatial size.
+        scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
+        mode (str): algorithm used for upsampling:
+            ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
+            ``'trilinear'`` | ``'area'``. Default: ``'nearest'``
+        align_corners (bool, optional): Geometrically, we consider the pixels of the
+            input and output as squares rather than points.
+            If set to ``True``, the input and output tensors are aligned by the
+            center points of their corner pixels, preserving the values at the corner pixels.
+            If set to ``False``, the input and output tensors are aligned by the corner
+            points of their corner pixels, and the interpolation uses edge value padding
+            for out-of-boundary values, making this operation *independent* of input size
+            when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
+            is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``.
+            Default: ``False``
+        recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
+            interpolation calculation.  When `scale_factor` is passed as a parameter, it is used
+            to compute the `output_size`.  If `recompute_scale_factor` is ``False`` or not specified,
+            the passed-in `scale_factor` will be used in the interpolation computation.
+            Otherwise, a new `scale_factor` will be computed based on the output and input sizes for
+            use in the interpolation computation (i.e. the computation will be identical to if the computed
+            `output_size` were passed-in explicitly).  Note that when `scale_factor` is floating-point,
+            the recomputed scale_factor may differ from the one passed in due to rounding and precision
+            issues.
+
+    .. note::
+        With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce
+        negative values or values greater than 255 for images.
+        Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot
+        when displaying the image.
+
+    .. warning::
+        With ``align_corners = True``, the linearly interpolating modes
+        (`linear`, `bilinear`, and `trilinear`) don't proportionally align the
+        output and input pixels, and thus the output values can depend on the
+        input size. This was the default behavior for these modes up to version
+        0.3.1. Since then, the default behavior is ``align_corners = False``.
+        See :class:`~torch.nn.Upsample` for concrete examples on how this
+        affects the outputs.
+
+    .. warning::
+        When scale_factor is specified, if recompute_scale_factor=True,
+        scale_factor is used to compute the output_size which will then
+        be used to infer new scales for the interpolation.
+        The default behavior for recompute_scale_factor changed to False
+        in 1.6.0, and scale_factor is used in the interpolation
+        calculation.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import oneflow.experimental as flow
+        >>> import numpy as np
+        >>> flow.enable_eager_execution()
+
+        >>> input = flow.Tensor(np.arange(1, 5).reshape((1, 1, 4)), dtype=flow.float32)
+        >>> output = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="linear")
+        >>> output
+        tensor([[[1.  , 1.25, 1.75, 2.25, 2.75, 3.25, 3.75, 4.  ]]],
+               dtype=oneflow.float32)
+
+    """
+    return Interpolate(
+        size=size,
+        scale_factor=scale_factor,
+        mode=mode,
+        align_corners=align_corners,
+        recompute_scale_factor=recompute_scale_factor,
+    )(input)
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod(raise_on_error=True)
diff --git a/oneflow/python/test/modules/test_interpolate.py b/oneflow/python/test/modules/test_interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..521b43c0602c972d58d1f1cd45c057003cae8999
--- /dev/null
+++ b/oneflow/python/test/modules/test_interpolate.py
@@ -0,0 +1,395 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+
+import oneflow.experimental as flow
+from test_util import GenArgList
+
+
+def _test_interpolate_linear_1d(test_case, device):
+    input = flow.Tensor(
+        np.arange(1, 5).reshape((1, 1, 4)),
+        device=flow.device(device),
+        dtype=flow.float32,
+        requires_grad=True,
+    )
+    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="linear")
+    np_out = [[[1.0, 1.25, 1.75, 2.25, 2.75, 3.25, 3.75, 4.0]]]
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [[[2.0, 2.0, 2.0, 2.0]]]
+    test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 1e-4, 1e-4))
+
+    input.grad = None
+    of_out = flow.nn.functional.interpolate(
+        input, scale_factor=2.0, mode="linear", align_corners=True
+    )
+    np_out = [
+        [
+            [
+                1.0,
+                1.4285714626312256,
+                1.8571429252624512,
+                2.2857141494750977,
+                2.7142856121063232,
+                3.142857074737549,
+                3.5714285373687744,
+                4.0,
+            ]
+        ]
+    ]
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [
+        [
+            [
+                1.7142856121063232,
+                2.2857141494750977,
+                2.2857143878936768,
+                1.7142856121063232,
+            ]
+        ]
+    ]
+    test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 1e-4, 1e-4))
+
+
+def _test_interpolate_nearest_1d(test_case, device):
+    input = flow.Tensor(
+        np.arange(1, 5).reshape((1, 1, 4)),
+        device=flow.device(device),
+        dtype=flow.float32,
+        requires_grad=True,
+    )
+    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest")
+    np_out = [[[1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]]]
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [[[2.0, 2.0, 2.0, 2.0]]]
+    test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 1e-4, 1e-4))
+
+
+def _test_interpolate_nearest_2d(test_case, device):
+    input = flow.Tensor(
+        np.arange(1, 5).reshape((1, 1, 2, 2)),
+        device=flow.device(device),
+        dtype=flow.float32,
+        requires_grad=True,
+    )
+    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest")
+    np_out = np.array(
+        [
+            [
+                [
+                    [1.0, 1.0, 2.0, 2.0],
+                    [1.0, 1.0, 2.0, 2.0],
+                    [3.0, 3.0, 4.0, 4.0],
+                    [3.0, 3.0, 4.0, 4.0],
+                ]
+            ]
+        ]
+    )
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]]
+    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+
+def _test_interpolate_nearest_3d(test_case, device):
+    input = flow.Tensor(
+        np.arange(1, 9).reshape((1, 1, 2, 2, 2)),
+        device=flow.device(device),
+        dtype=flow.float32,
+        requires_grad=True,
+    )
+    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest")
+    np_out = np.array(
+        [
+            [
+                [
+                    [
+                        [1.0, 1.0, 2.0, 2.0],
+                        [1.0, 1.0, 2.0, 2.0],
+                        [3.0, 3.0, 4.0, 4.0],
+                        [3.0, 3.0, 4.0, 4.0],
+                    ],
+                    [
+                        [1.0, 1.0, 2.0, 2.0],
+                        [1.0, 1.0, 2.0, 2.0],
+                        [3.0, 3.0, 4.0, 4.0],
+                        [3.0, 3.0, 4.0, 4.0],
+                    ],
+                    [
+                        [5.0, 5.0, 6.0, 6.0],
+                        [5.0, 5.0, 6.0, 6.0],
+                        [7.0, 7.0, 8.0, 8.0],
+                        [7.0, 7.0, 8.0, 8.0],
+                    ],
+                    [
+                        [5.0, 5.0, 6.0, 6.0],
+                        [5.0, 5.0, 6.0, 6.0],
+                        [7.0, 7.0, 8.0, 8.0],
+                        [7.0, 7.0, 8.0, 8.0],
+                    ],
+                ]
+            ]
+        ]
+    )
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [[[[[8.0, 8.0], [8.0, 8.0]], [[8.0, 8.0], [8.0, 8.0]]]]]
+    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+
+def _test_interpolate_bilinear_2d(test_case, device):
+    input = flow.Tensor(
+        np.arange(1, 5).reshape((1, 1, 2, 2)),
+        device=flow.device(device),
+        dtype=flow.float32,
+        requires_grad=True,
+    )
+    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="bilinear")
+    np_out = np.array(
+        [
+            [
+                [
+                    [1.0, 1.25, 1.75, 2.0],
+                    [1.5, 1.75, 2.25, 2.5],
+                    [2.5, 2.75, 3.25, 3.5],
+                    [3.0, 3.25, 3.75, 4.0],
+                ]
+            ]
+        ]
+    )
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]]
+    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+
+def _test_interpolate_bicubic_2d(test_case, device):
+    input = flow.Tensor(
+        np.arange(1, 5).reshape((1, 1, 2, 2)).astype(np.float32),
+        device=flow.device(device),
+        dtype=flow.float32,
+        requires_grad=True,
+    )
+    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="bicubic")
+    np_out = np.array(
+        [
+            [
+                [
+                    [0.68359375, 1.015625, 1.5625, 1.89453125],
+                    [1.34765625, 1.6796875, 2.2265625, 2.55859375],
+                    [2.44140625, 2.7734375, 3.3203125, 3.65234375],
+                    [3.10546875, 3.4375, 3.984375, 4.31640625],
+                ]
+            ]
+        ]
+    )
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [[[[4.0, 4.0], [4.0, 4.0]]]]
+    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+
+def _test_interpolate_bicubic_same_dim_2d(test_case, device):
+    input = flow.Tensor(
+        np.arange(1, 5).reshape((1, 1, 2, 2)).astype(np.float32),
+        device=flow.device(device),
+        dtype=flow.float32,
+        requires_grad=True,
+    )
+    of_out = flow.nn.functional.interpolate(input, scale_factor=1.0, mode="bicubic")
+    np_out = [[[[1.0, 2.0], [3.0, 4.0]]]]
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [[[[1.0, 1.0], [1.0, 1.0]]]]
+    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+
+def _test_interpolate_trilinear_3d(test_case, device):
+    input = flow.Tensor(
+        np.arange(1, 9).reshape((1, 1, 2, 2, 2)),
+        device=flow.device(device),
+        dtype=flow.float32,
+        requires_grad=True,
+    )
+    of_out = flow.nn.functional.interpolate(input, scale_factor=2.0, mode="trilinear")
+    np_out = np.array(
+        [
+            [
+                [
+                    [
+                        [1.0, 1.25, 1.75, 2.0],
+                        [1.5, 1.75, 2.25, 2.5],
+                        [2.5, 2.75, 3.25, 3.5],
+                        [3.0, 3.25, 3.75, 4.0],
+                    ],
+                    [
+                        [2.0, 2.25, 2.75, 3.0],
+                        [2.5, 2.75, 3.25, 3.5],
+                        [3.5, 3.75, 4.25, 4.5],
+                        [4.0, 4.25, 4.75, 5.0],
+                    ],
+                    [
+                        [4.0, 4.25, 4.75, 5.0],
+                        [4.5, 4.75, 5.25, 5.5],
+                        [5.5, 5.75, 6.25, 6.5],
+                        [6.0, 6.25, 6.75, 7.0],
+                    ],
+                    [
+                        [5.0, 5.25, 5.75, 6.0],
+                        [5.5, 5.75, 6.25, 6.5],
+                        [6.5, 6.75, 7.25, 7.5],
+                        [7.0, 7.25, 7.75, 8.0],
+                    ],
+                ]
+            ]
+        ]
+    )
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [[[[[8.0, 8.0], [8.0, 8.0]], [[8.0, 8.0], [8.0, 8.0]]]]]
+    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+
+def _test_interpolate_trilinear_3d_align_corners(test_case, device):
+    input = flow.Tensor(
+        np.arange(1, 9).reshape((1, 1, 2, 2, 2)),
+        device=flow.device(device),
+        dtype=flow.float32,
+        requires_grad=True,
+    )
+    of_out = flow.nn.functional.interpolate(
+        input, scale_factor=2.0, mode="trilinear", align_corners=True
+    )
+    np_out = np.array(
+        [
+            [
+                [
+                    [
+                        [1.0, 1.3333332538604736, 1.6666667461395264, 2.0],
+                        [
+                            1.6666666269302368,
+                            2.0,
+                            2.3333334922790527,
+                            2.6666665077209473,
+                        ],
+                        [
+                            2.3333332538604736,
+                            2.6666665077209473,
+                            3.0,
+                            3.3333334922790527,
+                        ],
+                        [3.0, 3.3333332538604736, 3.6666667461395264, 4.0],
+                    ],
+                    [
+                        [
+                            2.3333334922790527,
+                            2.6666665077209473,
+                            3.0,
+                            3.3333332538604736,
+                        ],
+                        [3.0, 3.3333330154418945, 3.6666665077209473, 4.0],
+                        [
+                            3.6666665077209473,
+                            4.0,
+                            4.333333492279053,
+                            4.6666669845581055,
+                        ],
+                        [4.333333492279053, 4.666666030883789, 5.0, 5.3333330154418945],
+                    ],
+                    [
+                        [3.6666667461395264, 4.0, 4.333333492279053, 4.666666507720947],
+                        [4.333333492279053, 4.666666507720947, 5.0, 5.3333330154418945],
+                        [5.0, 5.333333492279053, 5.6666669845581055, 6.0],
+                        [
+                            5.6666669845581055,
+                            6.0,
+                            6.333333492279053,
+                            6.6666669845581055,
+                        ],
+                    ],
+                    [
+                        [5.0, 5.3333330154418945, 5.666666507720947, 6.0],
+                        [
+                            5.666666507720947,
+                            5.999999523162842,
+                            6.3333330154418945,
+                            6.666666507720947,
+                        ],
+                        [6.333333492279053, 6.666666030883789, 7.0, 7.333333492279053],
+                        [7.0, 7.3333330154418945, 7.6666669845581055, 8.0],
+                    ],
+                ]
+            ]
+        ]
+    )
+    test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4))
+    of_out = of_out.sum()
+    of_out.backward()
+    np_grad = [
+        [
+            [
+                [[7.999999523162842, 8.0], [7.999999523162842, 8.0]],
+                [[8.0, 8.0], [8.0, 8.0]],
+            ]
+        ]
+    ]
+    test_case.assertTrue(np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5))
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestUpsample2d(flow.unittest.TestCase):
+    def test_upsample2d(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_interpolate_linear_1d,
+            _test_interpolate_nearest_1d,
+            _test_interpolate_nearest_2d,
+            _test_interpolate_nearest_3d,
+            _test_interpolate_bilinear_2d,
+            _test_interpolate_bicubic_2d,
+            _test_interpolate_bicubic_same_dim_2d,
+            _test_interpolate_trilinear_3d,
+            _test_interpolate_trilinear_3d_align_corners,
+        ]
+        arg_dict["device"] = [
+            "cpu",
+            "cuda",
+        ]
+        for arg in GenArgList(arg_dict):
+            for i in range(100):
+                arg[0](test_case, *arg[1:])
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/oneflow/user/kernels/upsample_bicubic2d_kernel.cpp b/oneflow/user/kernels/upsample_bicubic2d_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d21efb9919da07385f2256e4da908e9c3b49499f
--- /dev/null
+++ b/oneflow/user/kernels/upsample_bicubic2d_kernel.cpp
@@ -0,0 +1,183 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+template<typename T>
+class UpsampleBicubic2dCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleBicubic2dCPUKernel() = default;
+  ~UpsampleBicubic2dCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const T* in_ptr = x_tensor->dptr<T>();
+    T* out_ptr = y_tensor->mut_dptr<T>();
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+
+    const int nbatch = x_tensor->shape().At(0);
+    const int channels = x_tensor->shape().At(1);
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t in_width = x_tensor->shape().At(3);
+    const int64_t out_height = y_tensor->shape().At(2);
+    const int64_t out_width = y_tensor->shape().At(3);
+
+    if (in_height == out_height && in_width == out_width) {
+      memcpy(out_ptr, in_ptr, sizeof(T) * nbatch * channels * in_height * in_width);
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+
+      for (int64_t output_y = 0; output_y < out_height; output_y++) {
+        for (int64_t output_x = 0; output_x < out_width; output_x++) {
+          const T* in = in_ptr;
+          T* out = out_ptr;
+
+          const T real_x = GetAreaPixel(scale_width, output_x, align_corners, /*cubic=*/true);
+          int64_t input_x = std::floor(real_x);
+          const T t_x = real_x - input_x;
+
+          const T real_y = GetAreaPixel(scale_height, output_y, align_corners, /*cubic=*/true);
+          int64_t input_y = std::floor(real_y);
+          const T t_y = real_y - input_y;
+
+          for (int64_t c = 0; c < channels * nbatch; c++) {
+            T coefficients[4];
+
+            // Interpolate 4 times in the x direction
+            for (int64_t i = 0; i < 4; i++) {
+              coefficients[i] =
+                  cubic_interp1d<T>(upsample_get_value_bounded<T>(in, in_width, in_height,
+                                                                  input_x - 1, input_y - 1 + i),
+                                    upsample_get_value_bounded<T>(in, in_width, in_height,
+                                                                  input_x + 0, input_y - 1 + i),
+                                    upsample_get_value_bounded<T>(in, in_width, in_height,
+                                                                  input_x + 1, input_y - 1 + i),
+                                    upsample_get_value_bounded<T>(in, in_width, in_height,
+                                                                  input_x + 2, input_y - 1 + i),
+                                    t_x);
+            }
+
+            // Interpolate in the y direction using x interpolations
+            out[output_y * out_width + output_x] = cubic_interp1d<T>(
+                coefficients[0], coefficients[1], coefficients[2], coefficients[3], t_y);
+
+            // Move to next channel
+            in += in_width * in_height;
+            out += out_width * out_height;
+          }
+        }
+      }
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleBicubic2dGradCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleBicubic2dGradCPUKernel() = default;
+  ~UpsampleBicubic2dGradCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    T* in_ptr = dx_tensor->mut_dptr<T>();
+    const T* out_ptr = dy_tensor->dptr<T>();
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+
+    const int nbatch = dx_tensor->shape().At(0);
+    int channels = dx_tensor->shape().At(1);
+    channels = channels * nbatch;
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t in_width = dx_tensor->shape().At(3);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    const int64_t out_width = dy_tensor->shape().At(3);
+
+    if (in_height == out_height && in_width == out_width) {
+      memcpy(in_ptr, out_ptr, sizeof(T) * nbatch * channels * in_height * in_width);
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+
+      for (int64_t output_y = 0; output_y < out_height; output_y++) {
+        for (int64_t output_x = 0; output_x < out_width; output_x++) {
+          T* in = in_ptr;
+          const T* out = out_ptr;
+
+          T real_x = GetAreaPixel(scale_width, output_x, align_corners, true);
+          int64_t input_x = std::floor(real_x);
+          T t_x = real_x - input_x;
+
+          T real_y = GetAreaPixel(scale_height, output_y, align_corners, true);
+          int64_t input_y = std::floor(real_y);
+          T t_y = real_y - input_y;
+
+          T x_coeffs[4];
+          T y_coeffs[4];
+
+          get_cubic_upsample_coefficients<T>(x_coeffs, t_x);
+          get_cubic_upsample_coefficients<T>(y_coeffs, t_y);
+
+          for (int64_t c = 0; c < channels; c++) {
+            T out_value = out[output_y * out_width + output_x];
+
+            for (int64_t i = 0; i < 4; i++) {
+              for (int64_t j = 0; j < 4; j++) {
+                upsample_increment_value_bounded<T>(in, in_width, in_height, input_x - 1 + i,
+                                                    input_y - 1 + j,
+                                                    out_value * y_coeffs[j] * x_coeffs[i]);
+              }
+            }
+
+            in += in_width * in_height;
+            out += out_width * out_height;
+          }
+        }
+      }
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(dtype)                                    \
+  REGISTER_USER_KERNEL("upsample_bicubic_2d")                                          \
+      .SetCreateFn<UpsampleBicubic2dCPUKernel<dtype>>()                                \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_bicubic_2d_grad")                                     \
+      .SetCreateFn<UpsampleBicubic2dGradCPUKernel<dtype>>()                            \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(float)
+REGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(double)
+REGISTER_UPSAMPLE_BICUBIC_CPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/upsample_bicubic2d_kernel.cu b/oneflow/user/kernels/upsample_bicubic2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..2fa1baca6a362072967afaaa509c0c4121ef4356
--- /dev/null
+++ b/oneflow/user/kernels/upsample_bicubic2d_kernel.cu
@@ -0,0 +1,222 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/cuda/atomic.cuh"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+__device__ void upsample_increment_value_bounded_cuda(T* data, int64_t width, int64_t height,
+                                                      int64_t x, int64_t y, T value) {
+  int64_t access_x = max(min(x, width - 1), static_cast<int64_t>(0));
+  int64_t access_y = max(min(y, height - 1), static_cast<int64_t>(0));
+  cuda::atomic::Add(data + access_y * width + access_x, value);
+}
+
+template<typename T>
+__global__ void UpsampleBicubic2dForward(const int64_t elem_cnt, const T* in_dptr,
+                                         const int64_t nbatch, const int64_t channels,
+                                         const int64_t in_height, const int64_t in_width,
+                                         const int64_t out_height, const int64_t out_width,
+                                         const float scale_height, const float scale_width,
+                                         bool align_corners, T* out_dptr) {
+  CUDA_1D_KERNEL_LOOP(idx, elem_cnt) {
+    const int output_x = idx % out_width;
+    const int output_y = idx / out_width;
+
+    const T* in = in_dptr;
+    T* out = out_dptr;
+
+    const T real_x = GetAreaPixel(scale_width, output_x, align_corners, /*cubic=*/true);
+    int64_t input_x = std::floor(1.0 * real_x);
+    const T t_x = real_x - input_x;
+
+    const T real_y = GetAreaPixel(scale_height, output_y, align_corners, /*cubic=*/true);
+    int64_t input_y = std::floor(1.0 * real_y);
+    const T t_y = real_y - input_y;
+
+    for (int64_t c = 0; c < channels * nbatch; c++) {
+      T coefficients[4];
+
+      // Interpolate 4 times in the x direction
+      for (int64_t i = 0; i < 4; i++) {
+        coefficients[i] = cubic_interp1d<T>(
+            upsample_get_value_bounded<T>(in, in_width, in_height, input_x - 1, input_y - 1 + i),
+            upsample_get_value_bounded<T>(in, in_width, in_height, input_x + 0, input_y - 1 + i),
+            upsample_get_value_bounded<T>(in, in_width, in_height, input_x + 1, input_y - 1 + i),
+            upsample_get_value_bounded<T>(in, in_width, in_height, input_x + 2, input_y - 1 + i),
+            t_x);
+      }
+
+      // Interpolate in the y direction using x interpolations
+      out[output_y * out_width + output_x] = cubic_interp1d<T>(
+          coefficients[0], coefficients[1], coefficients[2], coefficients[3], t_y);
+
+      // Move to next channel
+      in += in_width * in_height;
+      out += out_width * out_height;
+    }
+  }
+}
+
+template<typename T>
+__global__ void UpsampleBicubic2dBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                          const int64_t nbatch, const int64_t channels,
+                                          const int64_t in_height, const int64_t in_width,
+                                          const int64_t out_height, const int64_t out_width,
+                                          const float scale_height, const float scale_width,
+                                          bool align_corners, T* dx_dptr) {
+  CUDA_1D_KERNEL_LOOP(idx, elem_cnt) {
+    const int output_x = idx % out_width;
+    const int output_y = idx / out_width;
+
+    T* in = dx_dptr;
+    const T* out = dy_dptr;
+
+    T real_x = GetAreaPixel(scale_width, output_x, align_corners, true);
+    int64_t input_x = std::floor(1.0 * real_x);
+    T t_x = real_x - input_x;
+
+    T real_y = GetAreaPixel(scale_height, output_y, align_corners, true);
+    int64_t input_y = std::floor(1.0 * real_y);
+    T t_y = real_y - input_y;
+
+    T x_coeffs[4];
+    T y_coeffs[4];
+
+    get_cubic_upsample_coefficients<T>(x_coeffs, t_x);
+    get_cubic_upsample_coefficients<T>(y_coeffs, t_y);
+
+    for (int64_t c = 0; c < channels; c++) {
+      T out_value = out[output_y * out_width + output_x];
+
+      for (int64_t i = 0; i < 4; i++) {
+        for (int64_t j = 0; j < 4; j++) {
+          upsample_increment_value_bounded_cuda<T>(in, in_width, in_height, input_x - 1 + i,
+                                                   input_y - 1 + j,
+                                                   out_value * y_coeffs[j] * x_coeffs[i]);
+        }
+      }
+
+      in += in_width * in_height;
+      out += out_width * out_height;
+    }
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class UpsampleBicubic2dGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleBicubic2dGPUKernel() = default;
+  ~UpsampleBicubic2dGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const T* in_ptr = x_tensor->dptr<T>();
+    T* out_ptr = y_tensor->mut_dptr<T>();
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+
+    const int nbatch = x_tensor->shape().At(0);
+    const int channels = x_tensor->shape().At(1);
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t in_width = x_tensor->shape().At(3);
+    const int64_t out_height = y_tensor->shape().At(2);
+    const int64_t out_width = y_tensor->shape().At(3);
+    const int64_t elem_cnt = out_height * out_width;
+
+    if (in_height == out_height && in_width == out_width) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),
+          x_tensor->shape().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+
+      RUN_CUDA_KERNEL((UpsampleBicubic2dForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      x_tensor->dptr<T>(), nbatch, channels, in_height, in_width, out_height,
+                      out_width, scale_height, scale_width, align_corners, y_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleBicubic2dGradGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleBicubic2dGradGPUKernel() = default;
+  ~UpsampleBicubic2dGradGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+
+    const int nbatch = dx_tensor->shape().At(0);
+    const int channels = dx_tensor->shape().At(1);
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t in_width = dx_tensor->shape().At(3);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    const int64_t out_width = dy_tensor->shape().At(3);
+    const int64_t elem_cnt = out_height * out_width;
+
+    if (in_height == out_height && in_width == out_width) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+          dy_tensor->shape().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+
+      RUN_CUDA_KERNEL((UpsampleBicubic2dBackward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      dy_tensor->dptr<T>(), nbatch, channels, in_height, in_width, out_height,
+                      out_width, scale_height, scale_width, align_corners,
+                      dx_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPLE_BICUBIC_GPU_KERNEL(dtype)                                    \
+  REGISTER_USER_KERNEL("upsample_bicubic_2d")                                          \
+      .SetCreateFn<UpsampleBicubic2dGPUKernel<dtype>>()                                \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_bicubic_2d_grad")                                     \
+      .SetCreateFn<UpsampleBicubic2dGradGPUKernel<dtype>>()                            \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPLE_BICUBIC_GPU_KERNEL(float)
+REGISTER_UPSAMPLE_BICUBIC_GPU_KERNEL(double)
+REGISTER_UPSAMPLE_BICUBIC_GPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/upsample_bilinear_2d_kernel.cpp b/oneflow/user/kernels/upsample_bilinear_2d_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..978e8c5092dc432a5432b61385315ea8bdf316c5
--- /dev/null
+++ b/oneflow/user/kernels/upsample_bilinear_2d_kernel.cpp
@@ -0,0 +1,172 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+static void UpsampleBilinear2DForward(const int64_t elem_cnt, const T* in_dptr,
+                                      NdIndexOffsetHelper<int64_t, 4> in_helper,
+                                      NdIndexOffsetHelper<int64_t, 4> out_helper,
+                                      const int64_t in_height, const int64_t in_width,
+                                      const T scale_h, const T scale_w, const bool align_corners,
+                                      T* out_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, h, w;
+    out_helper.OffsetToNdIndex(index, n, c, h, w);
+    BilinearParam<T> params;
+    GetBilinearParam(align_corners, h, w, in_height, in_width, scale_h, scale_w, &params);
+    const int64_t top_offset = in_helper.NdIndexToOffset(n, c, params.top_h_index, 0);
+    const int64_t bottom_offset = in_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);
+    const T top_left = in_dptr[top_offset + params.left_w_index];
+    const T top_right = in_dptr[top_offset + params.right_w_index];
+    const T bottom_left = in_dptr[bottom_offset + params.left_w_index];
+    const T bottom_right = in_dptr[bottom_offset + params.right_w_index];
+    const T top = top_left + (top_right - top_left) * params.w_lerp;
+    const T bottom = bottom_left + (bottom_right - bottom_left) * params.w_lerp;
+    out_dptr[index] = top + (bottom - top) * params.h_lerp;
+  }
+}
+
+template<typename T>
+static void UpsampleBilinearBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                     NdIndexOffsetHelper<int64_t, 4> dy_helper,
+                                     NdIndexOffsetHelper<int64_t, 4> dx_helper,
+                                     const int64_t dx_height, const int64_t dx_width,
+                                     const T scale_h, const T scale_w, const bool align_corners,
+                                     T* dx_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, h, w;
+    dy_helper.OffsetToNdIndex(index, n, c, h, w);
+    BilinearParam<T> params;
+    GetBilinearParam(align_corners, h, w, dx_height, dx_width, scale_h, scale_w, &params);
+    const int64_t top_offset = dx_helper.NdIndexToOffset(n, c, params.top_h_index, 0);
+    const int64_t bottom_offset = dx_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);
+    const T dy = dy_dptr[index];
+    const T dbottom = params.h_lerp * dy;
+    T* dx_dptr_bottom_offset = dx_dptr + bottom_offset;
+    *(dx_dptr_bottom_offset + params.left_w_index) += static_cast<T>((1 - params.w_lerp) * dbottom);
+    *(dx_dptr_bottom_offset + params.right_w_index) += static_cast<T>(params.w_lerp * dbottom);
+    const T dtop = dy - dbottom;
+    T* dx_dptr_top_offset = dx_dptr + top_offset;
+    *(dx_dptr_top_offset + params.left_w_index) += static_cast<T>((1 - params.w_lerp) * dtop);
+    *(dx_dptr_top_offset + params.right_w_index) += static_cast<T>(params.w_lerp * dtop);
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class UpsampleBilinear2DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleBilinear2DCPUKernel() = default;
+  ~UpsampleBilinear2DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 4> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                              x_tensor->shape().At(2), x_tensor->shape().At(3));
+    NdIndexOffsetHelper<int64_t, 4> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                               y_tensor->shape().At(2), y_tensor->shape().At(3));
+
+    const int64_t nbatch = x_tensor->shape().At(0);
+    const int64_t channels = x_tensor->shape().At(1);
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t in_width = x_tensor->shape().At(3);
+    const int64_t out_height = y_tensor->shape().At(2);
+    const int64_t out_width = y_tensor->shape().At(3);
+
+    if (in_height == out_height && in_width == out_width) {
+      memcpy(y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),
+             sizeof(T) * nbatch * channels * in_height * in_width);
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+      UpsampleBilinear2DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper, in_height,
+                                   in_width, scale_height, scale_width, align_corners,
+                                   y_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleBilinear2DGradCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleBilinear2DGradCPUKernel() = default;
+  ~UpsampleBilinear2DGradCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 4> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                              dy_tensor->shape().At(2), dy_tensor->shape().At(3));
+    NdIndexOffsetHelper<int64_t, 4> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                              dx_tensor->shape().At(2), dx_tensor->shape().At(3));
+
+    const int64_t nbatch = dx_tensor->shape().At(0);
+    const int64_t channels = dx_tensor->shape().At(1);
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t in_width = dx_tensor->shape().At(3);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    const int64_t out_width = dy_tensor->shape().At(3);
+    if (in_height == out_height && in_width == out_width) {
+      memcpy(dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+             sizeof(T) * nbatch * channels * in_height * in_width);
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+      UpsampleBilinearBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper, in_height,
+                                  in_width, scale_height, scale_width, align_corners,
+                                  dx_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(dtype)                                \
+  REGISTER_USER_KERNEL("upsample_bilinear_2d")                                         \
+      .SetCreateFn<UpsampleBilinear2DCPUKernel<dtype>>()                               \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_bilinear_2d_grad")                                    \
+      .SetCreateFn<UpsampleBilinear2DGradCPUKernel<dtype>>()                           \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(float)
+REGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(double)
+REGISTER_UPSAMPLE_BILINEAR_2D_CPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/upsample_bilinear_2d_kernel.cu b/oneflow/user/kernels/upsample_bilinear_2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..61f3ccf0d28f4e4c82a9f125668d6f15b11e426b
--- /dev/null
+++ b/oneflow/user/kernels/upsample_bilinear_2d_kernel.cu
@@ -0,0 +1,174 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/cuda/atomic.cuh"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+__global__ void UpsampleBilinear2DForward(const int64_t elem_cnt, const T* in_dptr,
+                                          NdIndexOffsetHelper<int64_t, 4> in_helper,
+                                          NdIndexOffsetHelper<int64_t, 4> out_helper,
+                                          const int64_t in_height, const int64_t in_width,
+                                          const T scale_h, const T scale_w,
+                                          const bool align_corners, T* out_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, h, w;
+    out_helper.OffsetToNdIndex(index, n, c, h, w);
+    BilinearParam<T> params;
+    GetBilinearParam(align_corners, h, w, in_height, in_width, scale_h, scale_w, &params);
+    const int64_t top_offset = in_helper.NdIndexToOffset(n, c, params.top_h_index, 0);
+    const int64_t bottom_offset = in_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);
+    const T top_left = in_dptr[top_offset + params.left_w_index];
+    const T top_right = in_dptr[top_offset + params.right_w_index];
+    const T bottom_left = in_dptr[bottom_offset + params.left_w_index];
+    const T bottom_right = in_dptr[bottom_offset + params.right_w_index];
+    const T top = top_left + (top_right - top_left) * params.w_lerp;
+    const T bottom = bottom_left + (bottom_right - bottom_left) * params.w_lerp;
+    out_dptr[index] = top + (bottom - top) * params.h_lerp;
+  }
+}
+
+template<typename T>
+__global__ void UpsampleBilinearBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                         NdIndexOffsetHelper<int64_t, 4> dy_helper,
+                                         NdIndexOffsetHelper<int64_t, 4> dx_helper,
+                                         const int64_t dx_height, const int64_t dx_width,
+                                         const T scale_h, const T scale_w, const bool align_corners,
+                                         T* dx_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, h, w;
+    dy_helper.OffsetToNdIndex(index, n, c, h, w);
+    BilinearParam<T> params;
+    GetBilinearParam(align_corners, h, w, dx_height, dx_width, scale_h, scale_w, &params);
+    const int64_t top_offset = dx_helper.NdIndexToOffset(n, c, params.top_h_index, 0);
+    const int64_t bottom_offset = dx_helper.NdIndexToOffset(n, c, params.bottom_h_index, 0);
+    const T dy = dy_dptr[index];
+    const T dbottom = params.h_lerp * dy;
+    T* dx_dptr_bottom_offset = dx_dptr + bottom_offset;
+    cuda::atomic::Add(dx_dptr_bottom_offset + params.left_w_index,
+                      static_cast<T>((1 - params.w_lerp) * dbottom));
+    cuda::atomic::Add(dx_dptr_bottom_offset + params.right_w_index,
+                      static_cast<T>(params.w_lerp * dbottom));
+    const T dtop = dy - dbottom;
+    T* dx_dptr_top_offset = dx_dptr + top_offset;
+    cuda::atomic::Add(dx_dptr_top_offset + params.left_w_index,
+                      static_cast<T>((1 - params.w_lerp) * dtop));
+    cuda::atomic::Add(dx_dptr_top_offset + params.right_w_index,
+                      static_cast<T>(params.w_lerp * dtop));
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class UpsampleBilinear2DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleBilinear2DGPUKernel() = default;
+  ~UpsampleBilinear2DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 4> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                              x_tensor->shape().At(2), x_tensor->shape().At(3));
+    NdIndexOffsetHelper<int64_t, 4> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                               y_tensor->shape().At(2), y_tensor->shape().At(3));
+
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t in_width = x_tensor->shape().At(3);
+    const int64_t out_height = y_tensor->shape().At(2);
+    const int64_t out_width = y_tensor->shape().At(3);
+    if (in_height == out_height && in_width == out_width) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),
+          x_tensor->shape().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+      RUN_CUDA_KERNEL((UpsampleBilinear2DForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      x_tensor->dptr<T>(), in_helper, out_helper, in_height, in_width, scale_height,
+                      scale_width, align_corners, y_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleBilinear2DGradGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleBilinear2DGradGPUKernel() = default;
+  ~UpsampleBilinear2DGradGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 4> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                              dy_tensor->shape().At(2), dy_tensor->shape().At(3));
+    NdIndexOffsetHelper<int64_t, 4> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                              dx_tensor->shape().At(2), dx_tensor->shape().At(3));
+
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t in_width = dx_tensor->shape().At(3);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    const int64_t out_width = dy_tensor->shape().At(3);
+    if (in_height == out_height && in_width == out_width) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+          dy_tensor->shape().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+      RUN_CUDA_KERNEL((UpsampleBilinearBackward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      dy_tensor->dptr<T>(), dy_helper, dx_helper, in_height, in_width, scale_height,
+                      scale_width, align_corners, dx_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPLE_BILINEAR_2D_GPU_KERNEL(dtype)                                \
+  REGISTER_USER_KERNEL("upsample_bilinear_2d")                                         \
+      .SetCreateFn<UpsampleBilinear2DGPUKernel<dtype>>()                               \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_bilinear_2d_grad")                                    \
+      .SetCreateFn<UpsampleBilinear2DGradGPUKernel<dtype>>()                           \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPLE_BILINEAR_2D_GPU_KERNEL(float)
+REGISTER_UPSAMPLE_BILINEAR_2D_GPU_KERNEL(double)
+REGISTER_UPSAMPLE_BILINEAR_2D_GPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/upsample_kernel.h b/oneflow/user/kernels/upsample_kernel.h
index b2103d6411b711c4a37814a6b47a56efb9ad3246..d783413a57e44533119c9c85117b135e988d0f57 100644
--- a/oneflow/user/kernels/upsample_kernel.h
+++ b/oneflow/user/kernels/upsample_kernel.h
@@ -15,6 +15,16 @@ limitations under the License.
 */
 #include "oneflow/core/common/nd_index_offset_helper.h"
 
+template<typename T>
+OF_DEVICE_FUNC T GetLinearInputIndex(const int64_t out_dim_idx, const T scale, bool align_corners) {
+  if (align_corners) {
+    return static_cast<T>(scale * out_dim_idx);
+  } else {
+    T src_idx = scale * (out_dim_idx + 0.5) - 0.5;
+    return static_cast<T>(src_idx < 0 ? 0 : src_idx);
+  }
+}
+
 OF_DEVICE_FUNC static int64_t GetNearestInputIndex(const int64_t out_dim_idx, const float scale,
                                                    const int64_t in_dim_size) {
   int64_t index = static_cast<int64_t>(std::floor((static_cast<float>(out_dim_idx) * scale)));
@@ -37,6 +47,17 @@ OF_DEVICE_FUNC T GetAreaPixelScale(const int64_t input_size, const int64_t outpu
   }
 }
 
+template<typename T>
+OF_DEVICE_FUNC T GetAreaPixel(const T scale, const int64_t dst_index, bool align_corners,
+                              bool cubic = false) {
+  if (align_corners) {
+    return scale * dst_index;
+  } else {
+    T src_idx = scale * (dst_index + 0.5) - 0.5;
+    return (!cubic && src_idx < 0) ? static_cast<T>(0) : src_idx;
+  }
+}
+
 template<typename T>
 struct BilinearParam {
   int64_t top_h_index;
@@ -78,3 +99,60 @@ OF_DEVICE_FUNC void GetBilinearParam(const bool align_corners, const int64_t h,
   params->right_w_index = w1 + w1p;
   params->w_lerp = w1r - w1;
 }
+
+template<typename T>
+OF_DEVICE_FUNC void upsample_increment_value_bounded(T* data, int64_t width, int64_t height,
+                                                     int64_t x, int64_t y, T value) {
+  int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
+  int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
+  data[access_y * width + access_x] += value;
+}
+
+template<typename T>
+OF_DEVICE_FUNC T upsample_get_value_bounded(const T* data, const int64_t width,
+                                            const int64_t height, const int64_t x,
+                                            const int64_t y) {
+  int64_t access_x = x;
+  access_x = access_x > width - 1 ? width - 1 : access_x;
+  access_x = access_x < 0 ? 0 : access_x;
+
+  int64_t access_y = y;
+  access_y = access_y > height - 1 ? height - 1 : access_y;
+  access_y = access_y < 0 ? 0 : access_y;
+
+  return data[access_y * width + access_x];
+}
+
+// Based on
+// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
+
+template<typename T>
+OF_DEVICE_FUNC T cubic_convolution1(const T x, const T A) {
+  return ((A + 2.0) * x - (A + 3.0)) * x * x + 1.0;
+}
+
+template<typename T>
+OF_DEVICE_FUNC T cubic_convolution2(const T x, const T A) {
+  return ((A * x - 5.0 * A) * x + 8.0 * A) * x - 4.0 * A;
+}
+
+template<typename T>
+OF_DEVICE_FUNC void get_cubic_upsample_coefficients(T coeffs[4], const T t) {
+  T A = -0.75;
+
+  T x1 = t;
+  coeffs[0] = cubic_convolution2<T>(x1 + 1.0, A);
+  coeffs[1] = cubic_convolution1<T>(x1, A);
+
+  // opposite coefficients
+  T x2 = 1.0 - t;
+  coeffs[2] = cubic_convolution1<T>(x2, A);
+  coeffs[3] = cubic_convolution2<T>(x2 + 1.0, A);
+}
+
+template<typename T>
+OF_DEVICE_FUNC T cubic_interp1d(const T x0, const T x1, const T x2, const T x3, const T t) {
+  T coeffs[4];
+  get_cubic_upsample_coefficients<T>(coeffs, t);
+  return x0 * coeffs[0] * 1.0 + x1 * coeffs[1] * 1.0 + x2 * coeffs[2] * 1.0 + x3 * coeffs[3] * 1.0;
+}
diff --git a/oneflow/user/kernels/upsample_linear_1d_kernel.cpp b/oneflow/user/kernels/upsample_linear_1d_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9a79d7dccee4802a7f9eb0b2b1d48e6f9c9988ce
--- /dev/null
+++ b/oneflow/user/kernels/upsample_linear_1d_kernel.cpp
@@ -0,0 +1,148 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+static void UpsampleLinear1DForward(const int64_t elem_cnt, const T* in_dptr,
+                                    NdIndexOffsetHelper<int64_t, 3> in_helper,
+                                    NdIndexOffsetHelper<int64_t, 3> out_helper, const int in_height,
+                                    const float scale_factor, bool align_corners, T* out_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, h;
+    out_helper.OffsetToNdIndex(index, n, c, h);
+    const T h1r = GetLinearInputIndex(h, scale_factor, align_corners);
+    const int64_t h1 = h1r;
+    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;
+    const T h1lambda = h1r - h1;
+    const T h0lambda = static_cast<T>(1.) - h1lambda;
+    out_dptr[index] = h0lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1)]
+                      + h1lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1 + h1p)];
+  }
+}
+
+template<typename T>
+static void UpsampleLinear1DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                     NdIndexOffsetHelper<int64_t, 3> dy_helper,
+                                     NdIndexOffsetHelper<int64_t, 3> dx_helper, const int in_height,
+                                     const float scale_factor, bool align_corners, T* dx_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, h;
+    dy_helper.OffsetToNdIndex(index, n, c, h);
+    const T h1r = GetLinearInputIndex(h, scale_factor, align_corners);
+    const int64_t h1 = h1r;
+    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;
+    const T h1lambda = h1r - h1;
+    const T h0lambda = static_cast<T>(1.) - h1lambda;
+
+    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1)) += h0lambda * dy_dptr[index];
+    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1 + h1p)) += h1lambda * dy_dptr[index];
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class UpsampleLinear1DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleLinear1DCPUKernel() = default;
+  ~UpsampleLinear1DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float height_scale = ctx->Attr<float>("scale_factor");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 3> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                              x_tensor->shape().At(2));
+    NdIndexOffsetHelper<int64_t, 3> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                               y_tensor->shape().At(2));
+    const int64_t nbatch = x_tensor->shape().At(0);
+    const int64_t channels = x_tensor->shape().At(1);
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t out_height = y_tensor->shape().At(2);
+    if (in_height == out_height) {
+      memcpy(y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),
+             sizeof(T) * nbatch * channels * in_height);
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      UpsampleLinear1DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper, in_height,
+                                 scale_height, align_corners, y_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleLinearGrad1DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleLinearGrad1DCPUKernel() = default;
+  ~UpsampleLinearGrad1DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float height_scale = ctx->Attr<float>("scale_factor");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+
+    NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                              dy_tensor->shape().At(2));
+    NdIndexOffsetHelper<int64_t, 3> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                              dx_tensor->shape().At(2));
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+
+    const int64_t nbatch = dx_tensor->shape().At(0);
+    const int64_t channels = dx_tensor->shape().At(1);
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    if (in_height == out_height) {
+      memcpy(dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+             sizeof(T) * nbatch * channels * in_height);
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      UpsampleLinear1DBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper, in_height,
+                                  scale_height, align_corners, dx_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(dtype)                                    \
+  REGISTER_USER_KERNEL("upsample_linear_1d")                                           \
+      .SetCreateFn<UpsampleLinear1DCPUKernel<dtype>>()                                 \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_linear_1d_grad")                                      \
+      .SetCreateFn<UpsampleLinearGrad1DCPUKernel<dtype>>()                             \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(float)
+REGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(double)
+REGISTER_UPSAMPLELINEAR1D_CPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/upsample_linear_1d_kernel.cu b/oneflow/user/kernels/upsample_linear_1d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..67088f84903072c874fc4708e94eb1da217e74d0
--- /dev/null
+++ b/oneflow/user/kernels/upsample_linear_1d_kernel.cu
@@ -0,0 +1,151 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/cuda/atomic.cuh"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+__global__ void UpsampleLinear1DForward(const int64_t elem_cnt, const T* in_dptr,
+                                        NdIndexOffsetHelper<int64_t, 3> in_helper,
+                                        NdIndexOffsetHelper<int64_t, 3> out_helper,
+                                        const int in_height, const float scale_factor,
+                                        bool align_corners, T* out_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, h;
+    out_helper.OffsetToNdIndex(index, n, c, h);
+    const T h1r = GetLinearInputIndex(h, scale_factor, align_corners);
+    const int64_t h1 = h1r;
+    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;
+    const T h1lambda = h1r - h1;
+    const T h0lambda = static_cast<T>(1.) - h1lambda;
+    out_dptr[index] = h0lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1)]
+                      + h1lambda * in_dptr[in_helper.NdIndexToOffset(n, c, h1 + h1p)];
+  }
+}
+
+template<typename T>
+__global__ void UpsampleLinear1DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                         NdIndexOffsetHelper<int64_t, 3> dy_helper,
+                                         NdIndexOffsetHelper<int64_t, 3> dx_helper,
+                                         const int in_height, const float scale_factor,
+                                         bool align_corners, T* dx_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, h;
+    dy_helper.OffsetToNdIndex(index, n, c, h);
+    const T h1r = GetLinearInputIndex(h, scale_factor, align_corners);
+    const int64_t h1 = h1r;
+    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;
+    const T h1lambda = h1r - h1;
+    const T h0lambda = static_cast<T>(1.) - h1lambda;
+
+    cuda::atomic::Add(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1), h0lambda * dy_dptr[index]);
+    cuda::atomic::Add(dx_dptr + dx_helper.NdIndexToOffset(n, c, h1 + h1p),
+                      h1lambda * dy_dptr[index]);
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class UpsampleLinear1DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleLinear1DGPUKernel() = default;
+  ~UpsampleLinear1DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float height_scale = ctx->Attr<float>("scale_factor");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 3> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                              x_tensor->shape().At(2));
+    NdIndexOffsetHelper<int64_t, 3> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                               y_tensor->shape().At(2));
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t out_height = y_tensor->shape().At(2);
+    if (in_height == out_height) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),
+          x_tensor->shape().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      RUN_CUDA_KERNEL((UpsampleLinear1DForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      x_tensor->dptr<T>(), in_helper, out_helper, in_height, scale_height,
+                      align_corners, y_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleLinearGrad1DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleLinearGrad1DGPUKernel() = default;
+  ~UpsampleLinearGrad1DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float height_scale = ctx->Attr<float>("scale_factor");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+
+    NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                              dy_tensor->shape().At(2));
+    NdIndexOffsetHelper<int64_t, 3> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                              dx_tensor->shape().At(2));
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    if (in_height == out_height) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+          dy_tensor->shape().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
+    } else {
+      const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+      RUN_CUDA_KERNEL((UpsampleLinear1DBackward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      dy_tensor->dptr<T>(), dy_helper, dx_helper, in_height, scale_height,
+                      align_corners, dx_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPLELINEAR1D_GPU_KERNEL(dtype)                                    \
+  REGISTER_USER_KERNEL("upsample_linear_1d")                                           \
+      .SetCreateFn<UpsampleLinear1DGPUKernel<dtype>>()                                 \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_linear_1d_grad")                                      \
+      .SetCreateFn<UpsampleLinearGrad1DGPUKernel<dtype>>()                             \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPLELINEAR1D_GPU_KERNEL(float)
+REGISTER_UPSAMPLELINEAR1D_GPU_KERNEL(double)
+REGISTER_UPSAMPLELINEAR1D_GPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/upsample_nearest_kernel.cpp b/oneflow/user/kernels/upsample_nearest_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b627005e8810c5d3b041e4699e27330ee3da92b3
--- /dev/null
+++ b/oneflow/user/kernels/upsample_nearest_kernel.cpp
@@ -0,0 +1,369 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+static void UpsampleNearest1DForward(const int64_t elem_cnt, const T* in_dptr,
+                                     NdIndexOffsetHelper<int64_t, 3> in_helper,
+                                     NdIndexOffsetHelper<int64_t, 3> out_helper,
+                                     const int64_t in_height, const float scale_factor,
+                                     T* out_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, h;
+    out_helper.OffsetToNdIndex(index, n, c, h);
+    const int64_t in_h = GetNearestInputIndex(h, scale_factor, in_height);
+    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h)];
+  }
+}
+
+template<typename T>
+static void UpsampleNearest1DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                      NdIndexOffsetHelper<int64_t, 3> dy_helper,
+                                      NdIndexOffsetHelper<int64_t, 3> dx_helper,
+                                      const int64_t in_height, const float scale_factor,
+                                      T* dx_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, h;
+    dy_helper.OffsetToNdIndex(index, n, c, h);
+    const int64_t dx_h = GetNearestInputIndex(h, scale_factor, in_height);
+    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, dx_h)) += dy_dptr[index];
+  }
+}
+
+template<typename T>
+static void UpsampleNearest2DForward(const int64_t elem_cnt, const T* in_dptr,
+                                     NdIndexOffsetHelper<int64_t, 4> in_helper,
+                                     NdIndexOffsetHelper<int64_t, 4> out_helper,
+                                     const int64_t in_height, const int64_t in_width,
+                                     const float scale_h, const float scale_w, T* out_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, h, w;
+    out_helper.OffsetToNdIndex(index, n, c, h, w);
+    const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height);
+    const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width);
+    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h, in_w)];
+  }
+}
+
+template<typename T>
+static void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                      NdIndexOffsetHelper<int64_t, 4> dy_helper,
+                                      NdIndexOffsetHelper<int64_t, 4> dx_helper,
+                                      const int64_t dx_height, const int64_t dx_width,
+                                      const float scale_h, const float scale_w, T* dx_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, h, w;
+    dy_helper.OffsetToNdIndex(index, n, c, h, w);
+    const int64_t dx_h = GetNearestInputIndex(h, scale_h, dx_height);
+    const int64_t dx_w = GetNearestInputIndex(w, scale_w, dx_width);
+    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, dx_h, dx_w)) += dy_dptr[index];
+  }
+}
+
+template<typename T>
+static void UpsampleNearest3DForward(const int64_t elem_cnt, const T* in_dptr,
+                                     NdIndexOffsetHelper<int64_t, 5> in_helper,
+                                     NdIndexOffsetHelper<int64_t, 5> out_helper,
+                                     const int64_t in_depth, const int64_t in_height,
+                                     const int64_t in_width, const float scale_d,
+                                     const float scale_h, const float scale_w, T* out_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, d, h, w;
+    out_helper.OffsetToNdIndex(index, n, c, d, h, w);
+    const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height);
+    const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width);
+    const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth);
+    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_d, in_h, in_w)];
+  }
+}
+
+template<typename T>
+static void UpsampleNearest3DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                      NdIndexOffsetHelper<int64_t, 5> dy_helper,
+                                      NdIndexOffsetHelper<int64_t, 5> dx_helper,
+                                      const int64_t in_depth, const int64_t in_height,
+                                      const int64_t in_width, const float scale_d,
+                                      const float scale_h, const float scale_w, T* dx_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, d, h, w;
+    dy_helper.OffsetToNdIndex(index, n, c, d, h, w);
+    const int64_t dx_h = GetNearestInputIndex(h, scale_h, in_height);
+    const int64_t dx_w = GetNearestInputIndex(w, scale_w, in_width);
+    const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth);
+    *(dx_dptr + dx_helper.NdIndexToOffset(n, c, in_d, dx_h, dx_w)) += dy_dptr[index];
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class UpsampleNearest1DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearest1DCPUKernel() = default;
+  ~UpsampleNearest1DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float height_scale = ctx->Attr<float>("scale_factor");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+
+    const int64_t nbatch = x_tensor->shape().At(0);
+    const int64_t channels = x_tensor->shape().At(1);
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t out_height = y_tensor->shape().At(2);
+
+    if (in_height == out_height) {
+      memcpy(y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),
+             sizeof(T) * nbatch * channels * in_height);
+    } else {
+      NdIndexOffsetHelper<int64_t, 3> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                                x_tensor->shape().At(2));
+      NdIndexOffsetHelper<int64_t, 3> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                                 y_tensor->shape().At(2));
+      UpsampleNearest1DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper,
+                                  x_tensor->shape().At(2), 1.f / height_scale,
+                                  y_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleNearestGrad1DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearestGrad1DCPUKernel() = default;
+  ~UpsampleNearestGrad1DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float height_scale = ctx->Attr<float>("scale_factor");
+
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+    const int64_t nbatch = dx_tensor->shape().At(0);
+    const int64_t channels = dx_tensor->shape().At(1);
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    if (in_height == out_height) {
+      memcpy(dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+             sizeof(T) * nbatch * channels * in_height);
+    } else {
+      NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                                dy_tensor->shape().At(2));
+      NdIndexOffsetHelper<int64_t, 3> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                                dx_tensor->shape().At(2));
+      UpsampleNearest1DBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper,
+                                   dx_tensor->shape().At(2), 1.f / height_scale,
+                                   dx_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPNEAREST1D_CPU_KERNEL(dtype)                                     \
+  REGISTER_USER_KERNEL("upsample_nearest_1d")                                          \
+      .SetCreateFn<UpsampleNearest1DCPUKernel<dtype>>()                                \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_nearest_1d_grad")                                     \
+      .SetCreateFn<UpsampleNearestGrad1DCPUKernel<dtype>>()                            \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPNEAREST1D_CPU_KERNEL(float)
+REGISTER_UPSAMPNEAREST1D_CPU_KERNEL(double)
+REGISTER_UPSAMPNEAREST1D_CPU_KERNEL(int)
+
+template<typename T>
+class UpsampleNearest2DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearest2DCPUKernel() = default;
+  ~UpsampleNearest2DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+
+    const int64_t nbatch = x_tensor->shape().At(0);
+    const int64_t channels = x_tensor->shape().At(1);
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t in_width = x_tensor->shape().At(3);
+    const int64_t out_height = y_tensor->shape().At(2);
+    const int64_t out_width = y_tensor->shape().At(3);
+
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+
+    if (in_height == out_height && in_width == out_width) {
+      memcpy(y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),
+             sizeof(T) * nbatch * channels * in_height * in_width);
+    } else {
+      NdIndexOffsetHelper<int64_t, 4> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                                x_tensor->shape().At(2), x_tensor->shape().At(3));
+      NdIndexOffsetHelper<int64_t, 4> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                                 y_tensor->shape().At(2), y_tensor->shape().At(3));
+      UpsampleNearest2DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper,
+                                  x_tensor->shape().At(2), x_tensor->shape().At(3),
+                                  1.f / height_scale, 1.f / width_scale, y_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleNearest2DGradCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearest2DGradCPUKernel() = default;
+  ~UpsampleNearest2DGradCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+
+    const int64_t nbatch = dx_tensor->shape().At(0);
+    const int64_t channels = dx_tensor->shape().At(1);
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t in_width = dx_tensor->shape().At(3);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    const int64_t out_width = dy_tensor->shape().At(3);
+
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+
+    if (in_height == out_height && in_width == out_width) {
+      memcpy(dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+             sizeof(T) * nbatch * channels * in_height * in_width);
+    } else {
+      NdIndexOffsetHelper<int64_t, 4> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                                dy_tensor->shape().At(2), dy_tensor->shape().At(3));
+      NdIndexOffsetHelper<int64_t, 4> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                                dx_tensor->shape().At(2), dx_tensor->shape().At(3));
+      UpsampleNearest2DBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper,
+                                   dx_tensor->shape().At(2), dx_tensor->shape().At(3),
+                                   1.f / height_scale, 1.f / width_scale, dx_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(dtype)                                 \
+  REGISTER_USER_KERNEL("upsample_nearest_2d")                                          \
+      .SetCreateFn<UpsampleNearest2DCPUKernel<dtype>>()                                \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_nearest_2d_grad")                                     \
+      .SetCreateFn<UpsampleNearest2DGradCPUKernel<dtype>>()                            \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(float)
+REGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(double)
+REGISTER_UPSAMPLE_NEAREST_2D_CPU_KERNEL(int)
+
+template<typename T>
+class UpsampleNearest3DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearest3DCPUKernel() = default;
+  ~UpsampleNearest3DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_blob = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_blob = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float depth_scale = ctx->Attr<float>("depth_scale");
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const int64_t elem_cnt = y_blob->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 5> in_helper(x_blob->shape().At(0), x_blob->shape().At(1),
+                                              x_blob->shape().At(2), x_blob->shape().At(3),
+                                              x_blob->shape().At(4));
+    NdIndexOffsetHelper<int64_t, 5> out_helper(y_blob->shape().At(0), y_blob->shape().At(1),
+                                               y_blob->shape().At(2), y_blob->shape().At(3),
+                                               y_blob->shape().At(4));
+    UpsampleNearest3DForward<T>(elem_cnt, x_blob->dptr<T>(), in_helper, out_helper,
+                                x_blob->shape().At(2), x_blob->shape().At(3), x_blob->shape().At(4),
+                                1.f / depth_scale, 1.f / height_scale, 1.f / width_scale,
+                                y_blob->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleNearestGrad3DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearestGrad3DCPUKernel() = default;
+  ~UpsampleNearestGrad3DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_blob = ctx->Tensor4ArgNameAndIndex("dx", 0);
+    if (dx_blob == nullptr) { return; }
+    Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_blob->mut_dptr<T>(), 0,
+                             dx_blob->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_blob = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float depth_scale = ctx->Attr<float>("depth_scale");
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const int64_t elem_cnt = dy_blob->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 5> dy_helper(dy_blob->shape().At(0), dy_blob->shape().At(1),
+                                              dy_blob->shape().At(2), dy_blob->shape().At(3),
+                                              dy_blob->shape().At(4));
+    NdIndexOffsetHelper<int64_t, 5> dx_helper(dx_blob->shape().At(0), dx_blob->shape().At(1),
+                                              dx_blob->shape().At(2), dx_blob->shape().At(3),
+                                              dx_blob->shape().At(4));
+    UpsampleNearest3DBackward<T>(elem_cnt, dy_blob->dptr<T>(), dy_helper, dx_helper,
+                                 dx_blob->shape().At(2), dx_blob->shape().At(3),
+                                 dx_blob->shape().At(4), 1.f / depth_scale, 1.f / height_scale,
+                                 1.f / width_scale, dx_blob->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPNEAREST3D_CPU_KERNEL(dtype)                                     \
+  REGISTER_USER_KERNEL("upsample_nearest_3d")                                          \
+      .SetCreateFn<UpsampleNearest3DCPUKernel<dtype>>()                                \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_nearest_3d_grad")                                     \
+      .SetCreateFn<UpsampleNearestGrad3DCPUKernel<dtype>>()                            \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPNEAREST3D_CPU_KERNEL(float)
+REGISTER_UPSAMPNEAREST3D_CPU_KERNEL(double)
+REGISTER_UPSAMPNEAREST3D_CPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/upsample_nearest_kernel.cu b/oneflow/user/kernels/upsample_nearest_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8ade66a771f95e3bb5936c42325a526ae582bb58
--- /dev/null
+++ b/oneflow/user/kernels/upsample_nearest_kernel.cu
@@ -0,0 +1,360 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/cuda/atomic.cuh"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+__global__ void UpsampleNearest1DForward(const int64_t elem_cnt, const T* in_dptr,
+                                         NdIndexOffsetHelper<int64_t, 3> in_helper,
+                                         NdIndexOffsetHelper<int64_t, 3> out_helper,
+                                         const int64_t in_height, const float scale_factor,
+                                         T* out_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, h;
+    out_helper.OffsetToNdIndex(index, n, c, h);
+    const int64_t in_h = GetNearestInputIndex(h, scale_factor, in_height);
+    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h)];
+  }
+}
+
+template<typename T>
+__global__ void UpsampleNearest1DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                          NdIndexOffsetHelper<int64_t, 3> dy_helper,
+                                          NdIndexOffsetHelper<int64_t, 3> dx_helper,
+                                          const int64_t in_height, const float scale_factor,
+                                          T* dx_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, h;
+    dy_helper.OffsetToNdIndex(index, n, c, h);
+    const int64_t dx_h = GetNearestInputIndex(h, scale_factor, in_height);
+    cuda::atomic::Add(dx_dptr + dx_helper.NdIndexToOffset(n, c, dx_h), dy_dptr[index]);
+  }
+}
+
+template<typename T>
+__global__ void UpsampleNearest2DForward(const int64_t elem_cnt, const T* in_dptr,
+                                         NdIndexOffsetHelper<int64_t, 4> in_helper,
+                                         NdIndexOffsetHelper<int64_t, 4> out_helper,
+                                         const int64_t in_height, const int64_t in_width,
+                                         const float scale_h, const float scale_w, T* out_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, h, w;
+    out_helper.OffsetToNdIndex(index, n, c, h, w);
+    const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height);
+    const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width);
+    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_h, in_w)];
+  }
+}
+
+template<typename T>
+__global__ void UpsampleNearest2DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                          NdIndexOffsetHelper<int64_t, 4> dy_helper,
+                                          NdIndexOffsetHelper<int64_t, 4> dx_helper,
+                                          const int64_t dx_height, const int64_t dx_width,
+                                          const float scale_h, const float scale_w, T* dx_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, h, w;
+    dy_helper.OffsetToNdIndex(index, n, c, h, w);
+    const int64_t dx_h = GetNearestInputIndex(h, scale_h, dx_height);
+    const int64_t dx_w = GetNearestInputIndex(w, scale_w, dx_width);
+    cuda::atomic::Add(dx_dptr + dx_helper.NdIndexToOffset(n, c, dx_h, dx_w), dy_dptr[index]);
+  }
+}
+
+template<typename T>
+__global__ void UpsampleNearest3DForward(const int64_t elem_cnt, const T* in_dptr,
+                                         NdIndexOffsetHelper<int64_t, 5> in_helper,
+                                         NdIndexOffsetHelper<int64_t, 5> out_helper,
+                                         const int64_t in_depth, const int64_t in_height,
+                                         const int64_t in_width, const float scale_d,
+                                         const float scale_h, const float scale_w, T* out_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, d, h, w;
+    out_helper.OffsetToNdIndex(index, n, c, d, h, w);
+    const int64_t in_h = GetNearestInputIndex(h, scale_h, in_height);
+    const int64_t in_w = GetNearestInputIndex(w, scale_w, in_width);
+    const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth);
+    out_dptr[index] = in_dptr[in_helper.NdIndexToOffset(n, c, in_d, in_h, in_w)];
+  }
+}
+
+template<typename T>
+__global__ void UpsampleNearest3DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                          NdIndexOffsetHelper<int64_t, 5> dy_helper,
+                                          NdIndexOffsetHelper<int64_t, 5> dx_helper,
+                                          const int64_t in_depth, const int64_t in_height,
+                                          const int64_t in_width, const float scale_d,
+                                          const float scale_h, const float scale_w, T* dx_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, d, h, w;
+    dy_helper.OffsetToNdIndex(index, n, c, d, h, w);
+    const int64_t dx_h = GetNearestInputIndex(h, scale_h, in_height);
+    const int64_t dx_w = GetNearestInputIndex(w, scale_w, in_width);
+    const int64_t in_d = GetNearestInputIndex(d, scale_d, in_depth);
+    cuda::atomic::Add(dx_dptr + dx_helper.NdIndexToOffset(n, c, in_d, dx_h, dx_w), dy_dptr[index]);
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class UpsampleNearest1DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearest1DGPUKernel() = default;
+  ~UpsampleNearest1DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float height_scale = ctx->Attr<float>("scale_factor");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t out_height = y_tensor->shape().At(2);
+    if (in_height == out_height) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),
+          x_tensor->shape().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));
+    } else {
+      NdIndexOffsetHelper<int64_t, 3> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                                x_tensor->shape().At(2));
+      NdIndexOffsetHelper<int64_t, 3> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                                 y_tensor->shape().At(2));
+      RUN_CUDA_KERNEL((UpsampleNearest1DForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      x_tensor->dptr<T>(), in_helper, out_helper, x_tensor->shape().At(2),
+                      1.f / height_scale, y_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleNearestGrad1DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearestGrad1DGPUKernel() = default;
+  ~UpsampleNearestGrad1DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float height_scale = ctx->Attr<float>("scale_factor");
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    if (in_height == out_height) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+          dy_tensor->shape().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
+    } else {
+      NdIndexOffsetHelper<int64_t, 3> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                                dy_tensor->shape().At(2));
+      NdIndexOffsetHelper<int64_t, 3> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                                dx_tensor->shape().At(2));
+      RUN_CUDA_KERNEL((UpsampleNearest1DBackward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape().At(2),
+                      1.f / height_scale, dx_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPNEAREST1D_GPU_KERNEL(dtype)                                     \
+  REGISTER_USER_KERNEL("upsample_nearest_1d")                                          \
+      .SetCreateFn<UpsampleNearest1DGPUKernel<dtype>>()                                \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_nearest_1d_grad")                                     \
+      .SetCreateFn<UpsampleNearestGrad1DGPUKernel<dtype>>()                            \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPNEAREST1D_GPU_KERNEL(float)
+REGISTER_UPSAMPNEAREST1D_GPU_KERNEL(double)
+REGISTER_UPSAMPNEAREST1D_GPU_KERNEL(int)
+
+template<typename T>
+class UpsampleNearest2DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearest2DGPUKernel() = default;
+  ~UpsampleNearest2DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+
+    const int64_t in_height = x_tensor->shape().At(2);
+    const int64_t in_width = x_tensor->shape().At(3);
+    const int64_t out_height = y_tensor->shape().At(2);
+    const int64_t out_width = y_tensor->shape().At(3);
+    if (in_height == out_height && in_width == out_width) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), y_tensor->mut_dptr<void>(), x_tensor->dptr<void>(),
+          x_tensor->shape().elem_cnt() * GetSizeOfDataType(x_tensor->data_type()));
+    } else {
+      NdIndexOffsetHelper<int64_t, 4> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                                x_tensor->shape().At(2), x_tensor->shape().At(3));
+      NdIndexOffsetHelper<int64_t, 4> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                                 y_tensor->shape().At(2), y_tensor->shape().At(3));
+      RUN_CUDA_KERNEL((UpsampleNearest2DForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      x_tensor->dptr<T>(), in_helper, out_helper, x_tensor->shape().At(2),
+                      x_tensor->shape().At(3), 1.f / height_scale, 1.f / width_scale,
+                      y_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleNearest2DGradGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearest2DGradGPUKernel() = default;
+  ~UpsampleNearest2DGradGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+    const int64_t in_height = dx_tensor->shape().At(2);
+    const int64_t in_width = dx_tensor->shape().At(3);
+    const int64_t out_height = dy_tensor->shape().At(2);
+    const int64_t out_width = dy_tensor->shape().At(3);
+    if (in_height == out_height && in_width == out_width) {
+      Memcpy<DeviceType::kGPU>(
+          ctx->device_ctx(), dx_tensor->mut_dptr<void>(), dy_tensor->dptr<void>(),
+          dy_tensor->shape().elem_cnt() * GetSizeOfDataType(dy_tensor->data_type()));
+    } else {
+      NdIndexOffsetHelper<int64_t, 4> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                                dy_tensor->shape().At(2), dy_tensor->shape().At(3));
+      NdIndexOffsetHelper<int64_t, 4> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                                dx_tensor->shape().At(2), dx_tensor->shape().At(3));
+      RUN_CUDA_KERNEL((UpsampleNearest2DBackward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                      dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape().At(2),
+                      dx_tensor->shape().At(3), 1.f / height_scale, 1.f / width_scale,
+                      dx_tensor->mut_dptr<T>());
+    }
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPLE_NEAREST_2D_GPU_KERNEL(dtype)                                 \
+  REGISTER_USER_KERNEL("upsample_nearest_2d")                                          \
+      .SetCreateFn<UpsampleNearest2DGPUKernel<dtype>>()                                \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_nearest_2d_grad")                                     \
+      .SetCreateFn<UpsampleNearest2DGradGPUKernel<dtype>>()                            \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPLE_NEAREST_2D_GPU_KERNEL(float)
+REGISTER_UPSAMPLE_NEAREST_2D_GPU_KERNEL(double)
+REGISTER_UPSAMPLE_NEAREST_2D_GPU_KERNEL(int)
+
+template<typename T>
+class UpsampleNearest3DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearest3DGPUKernel() = default;
+  ~UpsampleNearest3DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const float depth_scale = ctx->Attr<float>("depth_scale");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 5> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                              x_tensor->shape().At(2), x_tensor->shape().At(3),
+                                              x_tensor->shape().At(4));
+    NdIndexOffsetHelper<int64_t, 5> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                               y_tensor->shape().At(2), y_tensor->shape().At(3),
+                                               y_tensor->shape().At(4));
+    RUN_CUDA_KERNEL((UpsampleNearest3DForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                    x_tensor->dptr<T>(), in_helper, out_helper, x_tensor->shape().At(2),
+                    x_tensor->shape().At(3), x_tensor->shape().At(4), 1.f / depth_scale,
+                    1.f / height_scale, 1.f / width_scale, y_tensor->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleNearestGrad3DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleNearestGrad3DGPUKernel() = default;
+  ~UpsampleNearestGrad3DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const float depth_scale = ctx->Attr<float>("depth_scale");
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 5> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                              dy_tensor->shape().At(2), dy_tensor->shape().At(3),
+                                              dy_tensor->shape().At(4));
+    NdIndexOffsetHelper<int64_t, 5> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                              dx_tensor->shape().At(2), dx_tensor->shape().At(3),
+                                              dx_tensor->shape().At(4));
+    RUN_CUDA_KERNEL((UpsampleNearest3DBackward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                    dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape().At(2),
+                    dx_tensor->shape().At(3), dx_tensor->shape().At(4), 1.f / depth_scale,
+                    1.f / height_scale, 1.f / width_scale, dx_tensor->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPNEAREST3D_GPU_KERNEL(dtype)                                     \
+  REGISTER_USER_KERNEL("upsample_nearest_3d")                                          \
+      .SetCreateFn<UpsampleNearest3DGPUKernel<dtype>>()                                \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_nearest_3d_grad")                                     \
+      .SetCreateFn<UpsampleNearestGrad3DGPUKernel<dtype>>()                            \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPNEAREST3D_GPU_KERNEL(float)
+REGISTER_UPSAMPNEAREST3D_GPU_KERNEL(double)
+REGISTER_UPSAMPNEAREST3D_GPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/upsample_trilinear_3d_kernel.cpp b/oneflow/user/kernels/upsample_trilinear_3d_kernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b0952f7d5e7bc3765e530b60f9716ff0b16448d3
--- /dev/null
+++ b/oneflow/user/kernels/upsample_trilinear_3d_kernel.cpp
@@ -0,0 +1,218 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+static void UpsampleTrilinear3DForward(const int64_t elem_cnt, const T* in_dptr,
+                                       NdIndexOffsetHelper<int64_t, 5> in_helper,
+                                       NdIndexOffsetHelper<int64_t, 5> out_helper,
+                                       const int64_t in_depth, const int64_t in_height,
+                                       const int64_t in_width, const T rdepth, const T rheight,
+                                       const T rwidth, const bool align_corners, T* out_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, d, h, w;
+    out_helper.OffsetToNdIndex(index, n, c, d, h, w);
+
+    const T t1r = GetAreaPixel(rdepth, d, align_corners);
+    const int64_t t1 = t1r;
+    const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0;
+    const T t1lambda = t1r - t1;
+    const T t0lambda = static_cast<T>(1.) - t1lambda;
+
+    const T h1r = GetAreaPixel(rheight, h, align_corners);
+    const int64_t h1 = h1r;
+    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;
+    const T h1lambda = h1r - h1;
+    const T h0lambda = static_cast<T>(1.) - h1lambda;
+
+    const T w1r = GetAreaPixel(rwidth, w, align_corners);
+    const int64_t w1 = w1r;
+    const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;
+    const T w1lambda = w1r - w1;
+    const T w0lambda = static_cast<T>(1.) - w1lambda;
+
+    const T* pos1 = &in_dptr[in_helper.NdIndexToOffset(n, c, t1, h1, w1)];
+
+    out_dptr[index] =
+        t0lambda
+            * (h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p])
+               + h1lambda
+                     * (w0lambda * pos1[h1p * in_width] + w1lambda * pos1[h1p * in_width + w1p]))
+        + t1lambda
+              * (h0lambda
+                     * (w0lambda * pos1[t1p * in_height * in_width]
+                        + w1lambda * pos1[t1p * in_height * in_width + w1p])
+                 + h1lambda
+                       * (w0lambda * pos1[t1p * in_height * in_width + h1p * in_width]
+                          + w1lambda * pos1[t1p * in_height * in_width + h1p * in_width + w1p]));
+  }
+}
+
+template<typename T>
+static void UpsampleTrilinear3DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                        NdIndexOffsetHelper<int64_t, 5> dy_helper,
+                                        NdIndexOffsetHelper<int64_t, 5> dx_helper,
+                                        const int64_t in_depth, const int64_t in_height,
+                                        const int64_t in_width, const T rdepth, const T rheight,
+                                        const T rwidth, const bool align_corners, T* dx_dptr) {
+  for (int64_t index = 0; index < elem_cnt; ++index) {
+    int64_t n, c, d, h, w;
+    dy_helper.OffsetToNdIndex(index, n, c, d, h, w);
+
+    const T t1r = GetAreaPixel(rdepth, d, align_corners);
+    const int64_t t1 = t1r;
+    const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0;
+    const T t1lambda = t1r - t1;
+    const T t0lambda = static_cast<T>(1.) - t1lambda;
+
+    const T h1r = GetAreaPixel(rheight, h, align_corners);
+    const int64_t h1 = h1r;
+    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;
+    const T h1lambda = h1r - h1;
+    const T h0lambda = static_cast<T>(1.) - h1lambda;
+
+    const T w1r = GetAreaPixel(rwidth, w, align_corners);
+    const int64_t w1 = w1r;
+    const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;
+    const T w1lambda = w1r - w1;
+    const T w0lambda = static_cast<T>(1.) - w1lambda;
+
+    T* pos1 = &dx_dptr[dx_helper.NdIndexToOffset(n, c, t1, h1, w1)];
+    const T* pos2 = &dy_dptr[index];
+
+    pos1[0] += t0lambda * h0lambda * w0lambda * pos2[0];
+    pos1[w1p] += t0lambda * h0lambda * w1lambda * pos2[0];
+    pos1[h1p * in_width] += t0lambda * h1lambda * w0lambda * pos2[0];
+    pos1[h1p * in_width + w1p] += t0lambda * h1lambda * w1lambda * pos2[0];
+    pos1[t1p * in_height * in_width] += t1lambda * h0lambda * w0lambda * pos2[0];
+    pos1[t1p * in_height * in_width + w1p] += t1lambda * h0lambda * w1lambda * pos2[0];
+    pos1[t1p * in_height * in_width + h1p * in_width] += t1lambda * h1lambda * w0lambda * pos2[0];
+    pos1[t1p * in_height * in_width + h1p * in_width + w1p] +=
+        t1lambda * h1lambda * w1lambda * pos2[0];
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class UpsampleTrilinear3DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleTrilinear3DCPUKernel() = default;
+  ~UpsampleTrilinear3DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float depth_scale = ctx->Attr<float>("depth_scale");
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 5> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                              x_tensor->shape().At(2), x_tensor->shape().At(3),
+                                              x_tensor->shape().At(4));
+    NdIndexOffsetHelper<int64_t, 5> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                               y_tensor->shape().At(2), y_tensor->shape().At(3),
+                                               y_tensor->shape().At(4));
+
+    const int64_t in_depth = x_tensor->shape().At(2);
+    const int64_t in_height = x_tensor->shape().At(3);
+    const int64_t in_width = x_tensor->shape().At(4);
+
+    const int64_t out_depth = y_tensor->shape().At(2);
+    const int64_t out_height = y_tensor->shape().At(3);
+    const int64_t out_width = y_tensor->shape().At(4);
+
+    const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale);
+    const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+    const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+
+    UpsampleTrilinear3DForward<T>(elem_cnt, x_tensor->dptr<T>(), in_helper, out_helper,
+                                  x_tensor->shape().At(2), x_tensor->shape().At(3),
+                                  x_tensor->shape().At(4), scale_depth, scale_height, scale_width,
+                                  align_corners, y_tensor->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleTrilinearGrad3DCPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleTrilinearGrad3DCPUKernel() = default;
+  ~UpsampleTrilinearGrad3DCPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    Memset<DeviceType::kCPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float depth_scale = ctx->Attr<float>("depth_scale");
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 5> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                              dy_tensor->shape().At(2), dy_tensor->shape().At(3),
+                                              dy_tensor->shape().At(4));
+    NdIndexOffsetHelper<int64_t, 5> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                              dx_tensor->shape().At(2), dx_tensor->shape().At(3),
+                                              dx_tensor->shape().At(4));
+
+    const int64_t in_depth = dx_tensor->shape().At(2);
+    const int64_t in_height = dx_tensor->shape().At(3);
+    const int64_t in_width = dx_tensor->shape().At(4);
+
+    const int64_t out_depth = dy_tensor->shape().At(2);
+    const int64_t out_height = dy_tensor->shape().At(3);
+    const int64_t out_width = dy_tensor->shape().At(4);
+
+    const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale);
+    const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+    const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+
+    UpsampleTrilinear3DBackward<T>(elem_cnt, dy_tensor->dptr<T>(), dy_helper, dx_helper,
+                                   dx_tensor->shape().At(2), dx_tensor->shape().At(3),
+                                   dx_tensor->shape().At(4), scale_depth, scale_height, scale_width,
+                                   align_corners, dx_tensor->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(dtype)                                   \
+  REGISTER_USER_KERNEL("upsample_trilinear_3d")                                        \
+      .SetCreateFn<UpsampleTrilinear3DCPUKernel<dtype>>()                              \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_trilinear_3d_grad")                                   \
+      .SetCreateFn<UpsampleTrilinearGrad3DCPUKernel<dtype>>()                          \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(float)
+REGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(double)
+REGISTER_UPSAMPTRILINEAR3D_CPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/kernels/upsample_trilinear_3d_kernel.cu b/oneflow/user/kernels/upsample_trilinear_3d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8cdfabc212438481f7fa5342fb0bb995ca288338
--- /dev/null
+++ b/oneflow/user/kernels/upsample_trilinear_3d_kernel.cu
@@ -0,0 +1,221 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/framework.h"
+#include "oneflow/core/kernel/new_kernel_util.h"
+#include "oneflow/core/common/nd_index_offset_helper.h"
+#include "oneflow/core/cuda/atomic.cuh"
+#include "oneflow/user/kernels/upsample_kernel.h"
+
+namespace oneflow {
+
+namespace {
+
+template<typename T>
+__global__ void UpsampleTrilinear3DForward(const int64_t elem_cnt, const T* in_dptr,
+                                           NdIndexOffsetHelper<int64_t, 5> in_helper,
+                                           NdIndexOffsetHelper<int64_t, 5> out_helper,
+                                           const int64_t in_depth, const int64_t in_height,
+                                           const int64_t in_width, const T rdepth, const T rheight,
+                                           const T rwidth, const bool align_corners, T* out_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, d, h, w;
+    out_helper.OffsetToNdIndex(index, n, c, d, h, w);
+
+    const T t1r = GetAreaPixel(rdepth, d, align_corners);
+    const int64_t t1 = t1r;
+    const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0;
+    const T t1lambda = t1r - t1;
+    const T t0lambda = static_cast<T>(1.) - t1lambda;
+
+    const T h1r = GetAreaPixel(rheight, h, align_corners);
+    const int64_t h1 = h1r;
+    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;
+    const T h1lambda = h1r - h1;
+    const T h0lambda = static_cast<T>(1.) - h1lambda;
+
+    const T w1r = GetAreaPixel(rwidth, w, align_corners);
+    const int64_t w1 = w1r;
+    const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;
+    const T w1lambda = w1r - w1;
+    const T w0lambda = static_cast<T>(1.) - w1lambda;
+
+    const T* pos1 = &in_dptr[in_helper.NdIndexToOffset(n, c, t1, h1, w1)];
+
+    out_dptr[index] =
+        t0lambda
+            * (h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p])
+               + h1lambda
+                     * (w0lambda * pos1[h1p * in_width] + w1lambda * pos1[h1p * in_width + w1p]))
+        + t1lambda
+              * (h0lambda
+                     * (w0lambda * pos1[t1p * in_height * in_width]
+                        + w1lambda * pos1[t1p * in_height * in_width + w1p])
+                 + h1lambda
+                       * (w0lambda * pos1[t1p * in_height * in_width + h1p * in_width]
+                          + w1lambda * pos1[t1p * in_height * in_width + h1p * in_width + w1p]));
+  }
+}
+
+template<typename T>
+__global__ void UpsampleTrilinear3DBackward(const int64_t elem_cnt, const T* dy_dptr,
+                                            NdIndexOffsetHelper<int64_t, 5> dy_helper,
+                                            NdIndexOffsetHelper<int64_t, 5> dx_helper,
+                                            const int64_t in_depth, const int64_t in_height,
+                                            const int64_t in_width, const T rdepth, const T rheight,
+                                            const T rwidth, const bool align_corners, T* dx_dptr) {
+  CUDA_1D_KERNEL_LOOP(index, elem_cnt) {
+    int64_t n, c, d, h, w;
+    dy_helper.OffsetToNdIndex(index, n, c, d, h, w);
+
+    const T t1r = GetAreaPixel(rdepth, d, align_corners);
+    const int64_t t1 = t1r;
+    const int64_t t1p = (t1 < in_depth - 1) ? 1 : 0;
+    const T t1lambda = t1r - t1;
+    const T t0lambda = static_cast<T>(1.) - t1lambda;
+
+    const T h1r = GetAreaPixel(rheight, h, align_corners);
+    const int64_t h1 = h1r;
+    const int64_t h1p = (h1 < in_height - 1) ? 1 : 0;
+    const T h1lambda = h1r - h1;
+    const T h0lambda = static_cast<T>(1.) - h1lambda;
+
+    const T w1r = GetAreaPixel(rwidth, w, align_corners);
+    const int64_t w1 = w1r;
+    const int64_t w1p = (w1 < in_width - 1) ? 1 : 0;
+    const T w1lambda = w1r - w1;
+    const T w0lambda = static_cast<T>(1.) - w1lambda;
+
+    T* pos1 = &dx_dptr[dx_helper.NdIndexToOffset(n, c, t1, h1, w1)];
+    const T* pos2 = &dy_dptr[index];
+
+    cuda::atomic::Add(pos1 + 0, t0lambda * h0lambda * w0lambda * pos2[0]);
+    cuda::atomic::Add(pos1 + w1p, t0lambda * h0lambda * w1lambda * pos2[0]);
+    cuda::atomic::Add(pos1 + h1p * in_width, t0lambda * h1lambda * w0lambda * pos2[0]);
+    cuda::atomic::Add(pos1 + h1p * in_width + w1p, t0lambda * h1lambda * w1lambda * pos2[0]);
+    cuda::atomic::Add(pos1 + t1p * in_height * in_width, t1lambda * h0lambda * w0lambda * pos2[0]);
+    cuda::atomic::Add(pos1 + t1p * in_height * in_width + w1p,
+                      t1lambda * h0lambda * w1lambda * pos2[0]);
+    cuda::atomic::Add(pos1 + t1p * in_height * in_width + h1p * in_width,
+                      t1lambda * h1lambda * w0lambda * pos2[0]);
+    cuda::atomic::Add(pos1 + t1p * in_height * in_width + h1p * in_width + w1p,
+                      t1lambda * h1lambda * w1lambda * pos2[0]);
+  }
+}
+
+}  // namespace
+
+template<typename T>
+class UpsampleTrilinear3DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleTrilinear3DGPUKernel() = default;
+  ~UpsampleTrilinear3DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0);
+    user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0);
+    const float depth_scale = ctx->Attr<float>("depth_scale");
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = y_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 5> in_helper(x_tensor->shape().At(0), x_tensor->shape().At(1),
+                                              x_tensor->shape().At(2), x_tensor->shape().At(3),
+                                              x_tensor->shape().At(4));
+    NdIndexOffsetHelper<int64_t, 5> out_helper(y_tensor->shape().At(0), y_tensor->shape().At(1),
+                                               y_tensor->shape().At(2), y_tensor->shape().At(3),
+                                               y_tensor->shape().At(4));
+
+    const int64_t in_depth = x_tensor->shape().At(2);
+    const int64_t in_height = x_tensor->shape().At(3);
+    const int64_t in_width = x_tensor->shape().At(4);
+
+    const int64_t out_depth = y_tensor->shape().At(2);
+    const int64_t out_height = y_tensor->shape().At(3);
+    const int64_t out_width = y_tensor->shape().At(4);
+
+    const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale);
+    const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+    const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+
+    RUN_CUDA_KERNEL((UpsampleTrilinear3DForward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                    x_tensor->dptr<T>(), in_helper, out_helper, x_tensor->shape().At(2),
+                    x_tensor->shape().At(3), x_tensor->shape().At(4), scale_depth, scale_height,
+                    scale_width, align_corners, y_tensor->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+template<typename T>
+class UpsampleTrilinearGrad3DGPUKernel final : public user_op::OpKernel {
+ public:
+  UpsampleTrilinearGrad3DGPUKernel() = default;
+  ~UpsampleTrilinearGrad3DGPUKernel() = default;
+
+ private:
+  void Compute(user_op::KernelComputeContext* ctx) const override {
+    user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0);
+
+    Memset<DeviceType::kGPU>(ctx->device_ctx(), dx_tensor->mut_dptr<T>(), 0,
+                             dx_tensor->shape().elem_cnt() * sizeof(T));
+    const user_op::Tensor* dy_tensor = ctx->Tensor4ArgNameAndIndex("dy", 0);
+    const float depth_scale = ctx->Attr<float>("depth_scale");
+    const float height_scale = ctx->Attr<float>("height_scale");
+    const float width_scale = ctx->Attr<float>("width_scale");
+    const bool align_corners = ctx->Attr<bool>("align_corners");
+    const int64_t elem_cnt = dy_tensor->shape().elem_cnt();
+    NdIndexOffsetHelper<int64_t, 5> dy_helper(dy_tensor->shape().At(0), dy_tensor->shape().At(1),
+                                              dy_tensor->shape().At(2), dy_tensor->shape().At(3),
+                                              dy_tensor->shape().At(4));
+    NdIndexOffsetHelper<int64_t, 5> dx_helper(dx_tensor->shape().At(0), dx_tensor->shape().At(1),
+                                              dx_tensor->shape().At(2), dx_tensor->shape().At(3),
+                                              dx_tensor->shape().At(4));
+
+    const int64_t in_depth = dx_tensor->shape().At(2);
+    const int64_t in_height = dx_tensor->shape().At(3);
+    const int64_t in_width = dx_tensor->shape().At(4);
+
+    const int64_t out_depth = dy_tensor->shape().At(2);
+    const int64_t out_height = dy_tensor->shape().At(3);
+    const int64_t out_width = dy_tensor->shape().At(4);
+
+    const T scale_depth = GetAreaPixelScale(in_depth, out_depth, align_corners, depth_scale);
+    const T scale_height = GetAreaPixelScale(in_height, out_height, align_corners, height_scale);
+    const T scale_width = GetAreaPixelScale(in_width, out_width, align_corners, width_scale);
+
+    RUN_CUDA_KERNEL((UpsampleTrilinear3DBackward<T>), ctx->device_ctx(), elem_cnt, elem_cnt,
+                    dy_tensor->dptr<T>(), dy_helper, dx_helper, dx_tensor->shape().At(2),
+                    dx_tensor->shape().At(3), dx_tensor->shape().At(4), scale_depth, scale_height,
+                    scale_width, align_corners, dx_tensor->mut_dptr<T>());
+  }
+  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
+};
+
+#define REGISTER_UPSAMPTRILINEAR3D_GPU_KERNEL(dtype)                                   \
+  REGISTER_USER_KERNEL("upsample_trilinear_3d")                                        \
+      .SetCreateFn<UpsampleTrilinear3DGPUKernel<dtype>>()                              \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("y", 0) == GetDataType<dtype>::value)); \
+  REGISTER_USER_KERNEL("upsample_trilinear_3d_grad")                                   \
+      .SetCreateFn<UpsampleTrilinearGrad3DGPUKernel<dtype>>()                          \
+      .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu")                              \
+                       & (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value));
+
+REGISTER_UPSAMPTRILINEAR3D_GPU_KERNEL(float)
+REGISTER_UPSAMPTRILINEAR3D_GPU_KERNEL(double)
+REGISTER_UPSAMPTRILINEAR3D_GPU_KERNEL(int)
+
+}  // namespace oneflow
diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp
index 6a539c6690743bb370c63a019924cf5bfe745f86..f3f865e9b183e4c35084115bb383ba63916c01cb 100644
--- a/oneflow/user/ops/upsample_op.cpp
+++ b/oneflow/user/ops/upsample_op.cpp
@@ -17,6 +17,144 @@ limitations under the License.
 
 namespace oneflow {
 
+REGISTER_USER_OP("upsample_linear_1d")
+    .Input("x")
+    .Output("y")
+    .Attr<float>("scale_factor")
+    .Attr<bool>("align_corners")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
+      user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
+      const float scale_factor = ctx->Attr<float>("scale_factor");
+
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && x_desc->shape().NumAxes() == 3)
+          << "upsample_linear_1d only supports NCH";
+      *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1),
+                                    static_cast<int32_t>(scale_factor * x_desc->shape().At(2))});
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_nearest_1d")
+    .Input("x")
+    .Output("y")
+    .Attr<float>("scale_factor")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
+      user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
+      const float scale_factor = ctx->Attr<float>("scale_factor");
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && x_desc->shape().NumAxes() == 3)
+          << "upsample_nearest_1d only supports NCH";
+      *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1),
+                                    static_cast<int32_t>(scale_factor * x_desc->shape().At(2))});
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_nearest_2d")
+    .Input("x")
+    .Output("y")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
+      user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
+      const float height_scale = ctx->Attr<float>("height_scale");
+      const float width_scale = ctx->Attr<float>("width_scale");
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && x_desc->shape().NumAxes() == 4)
+          << "upsample_nearest_2d only supports NCHW";
+      *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1),
+                                    static_cast<int32_t>(height_scale * x_desc->shape().At(2)),
+                                    static_cast<int32_t>(width_scale * x_desc->shape().At(3))});
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_bilinear_2d")
+    .Input("x")
+    .Output("y")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<bool>("align_corners")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
+      user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
+      const float height_scale = ctx->Attr<float>("height_scale");
+      const float width_scale = ctx->Attr<float>("width_scale");
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && x_desc->shape().NumAxes() == 4)
+          << "upsample_bilinear_2d only supports NCHW";
+      *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1),
+                                    static_cast<int32_t>(height_scale * x_desc->shape().At(2)),
+                                    static_cast<int32_t>(width_scale * x_desc->shape().At(3))});
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_bicubic_2d")
+    .Input("x")
+    .Output("y")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<bool>("align_corners")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
+      user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
+      const float height_scale = ctx->Attr<float>("height_scale");
+      const float width_scale = ctx->Attr<float>("width_scale");
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && x_desc->shape().NumAxes() == 4)
+          << "upsample_bicubic_2d only supports NCHW";
+      *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1),
+                                    static_cast<int32_t>(height_scale * x_desc->shape().At(2)),
+                                    static_cast<int32_t>(width_scale * x_desc->shape().At(3))});
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
 REGISTER_USER_OP("upsample")
     .Input("x")
     .Output("y")
@@ -48,6 +186,195 @@ REGISTER_USER_OP("upsample")
       return Maybe<void>::Ok();
     });
 
+REGISTER_USER_OP("upsample_nearest_3d")
+    .Input("x")
+    .Output("y")
+    .Attr<float>("depth_scale")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
+      user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
+      const float depth_scale = ctx->Attr<float>("depth_scale");
+      const float height_scale = ctx->Attr<float>("height_scale");
+      const float width_scale = ctx->Attr<float>("width_scale");
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && x_desc->shape().NumAxes() == 5)
+          << "upsample_nearest_3d only supports NCDHW";
+      *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1),
+                                    static_cast<int32_t>(depth_scale * x_desc->shape().At(2)),
+                                    static_cast<int32_t>(height_scale * x_desc->shape().At(3)),
+                                    static_cast<int32_t>(width_scale * x_desc->shape().At(4))});
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_trilinear_3d")
+    .Input("x")
+    .Output("y")
+    .Attr<float>("depth_scale")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<bool>("align_corners")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0);
+      user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0);
+      const float depth_scale = ctx->Attr<float>("depth_scale");
+      const float height_scale = ctx->Attr<float>("height_scale");
+      const float width_scale = ctx->Attr<float>("width_scale");
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && x_desc->shape().NumAxes() == 5)
+          << "upsample_trilinear_3d only supports NCDHW";
+      *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1),
+                                    static_cast<int32_t>(depth_scale * x_desc->shape().At(2)),
+                                    static_cast<int32_t>(height_scale * x_desc->shape().At(3)),
+                                    static_cast<int32_t>(width_scale * x_desc->shape().At(4))});
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_linear_1d_grad")
+    .Input("dy")
+    .Input("x")
+    .Output("dx")
+    .Attr<float>("scale_factor")
+    .Attr<bool>("align_corners")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && dy_shape.NumAxes() == 3)
+          << "upsample_linear_1d_grad only supports NCH";
+      *dx_shape = ctx->InputShape("x", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_nearest_1d_grad")
+    .Input("dy")
+    .Input("x")
+    .Output("dx")
+    .Attr<float>("scale_factor")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && dy_shape.NumAxes() == 3)
+          << "upsample_nearest_1d_grad only supports NCH";
+      *dx_shape = ctx->InputShape("x", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_nearest_2d_grad")
+    .Input("dy")
+    .Input("x")
+    .Output("dx")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && dy_shape.NumAxes() == 4)
+          << "upsample_nearest_2d_grad only supports NCHW";
+      *dx_shape = ctx->InputShape("x", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_bilinear_2d_grad")
+    .Input("dy")
+    .Input("x")
+    .Output("dx")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<bool>("align_corners")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && dy_shape.NumAxes() == 4)
+          << "upsample_bilinear_2d_grad only supports NCHW";
+      *dx_shape = ctx->InputShape("x", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_bicubic_2d_grad")
+    .Input("dy")
+    .Input("x")
+    .Output("dx")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<bool>("align_corners")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && dy_shape.NumAxes() == 4)
+          << "upsample_bicubic_2d_grad only supports NCHW";
+      *dx_shape = ctx->InputShape("x", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    });
+
 REGISTER_USER_OP("upsample_grad")
     .Input("dy")
     .Input("x")
@@ -75,6 +402,150 @@ REGISTER_USER_OP("upsample_grad")
       return Maybe<void>::Ok();
     });
 
+REGISTER_USER_OP("upsample_nearest_3d_grad")
+    .Input("dy")
+    .Input("x")
+    .Output("dx")
+    .Attr<float>("depth_scale")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && dy_shape.NumAxes() == 5)
+          << "upsample_nearest_3d_grad only supports NCDHW";
+      *dx_shape = ctx->InputShape("x", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP("upsample_trilinear_3d_grad")
+    .Input("dy")
+    .Input("x")
+    .Output("dx")
+    .Attr<float>("depth_scale")
+    .Attr<float>("height_scale")
+    .Attr<float>("width_scale")
+    .Attr<bool>("align_corners")
+    .Attr<std::string>("data_format")
+    .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      const Shape& dy_shape = ctx->InputShape("dy", 0);
+      Shape* dx_shape = ctx->OutputShape("dx", 0);
+      CHECK_OR_RETURN(ctx->Attr<std::string>("data_format") == "channels_first"
+                      && dy_shape.NumAxes() == 5)
+          << "upsample_trilinear_3d_grad only supports NCDHW";
+      *dx_shape = ctx->InputShape("x", 0);
+      return Maybe<void>::Ok();
+    })
+    .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
+      ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build();
+      return Maybe<void>::Ok();
+    })
+    .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
+      *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0);
+      return Maybe<void>::Ok();
+    });
+
+REGISTER_USER_OP_GRAD("upsample_linear_1d")
+    .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
+      if (op.NeedGenGradTensor4OpInput("x", 0)) {
+        user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
+        user_op::UserOpConfWrapper grad_op =
+            builder.Op("upsample_linear_1d_grad")
+                .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
+                .Input("x", op.input("x", 0))
+                .Output("dx")
+                .Attr("scale_factor", op.attr<float>("scale_factor"))
+                .Attr("align_corners", op.attr<bool>("align_corners"))
+                .Attr("data_format", op.attr<std::string>("data_format"))
+                .Build();
+        op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
+        AddOp(grad_op);
+      }
+    });
+
+REGISTER_USER_OP_GRAD("upsample_nearest_1d")
+    .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
+      if (op.NeedGenGradTensor4OpInput("x", 0)) {
+        user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
+        user_op::UserOpConfWrapper grad_op =
+            builder.Op("upsample_nearest_1d_grad")
+                .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
+                .Input("x", op.input("x", 0))
+                .Output("dx")
+                .Attr("scale_factor", op.attr<float>("scale_factor"))
+                .Attr("data_format", op.attr<std::string>("data_format"))
+                .Build();
+        op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
+        AddOp(grad_op);
+      }
+    });
+
+REGISTER_USER_OP_GRAD("upsample_nearest_2d")
+    .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
+      if (op.NeedGenGradTensor4OpInput("x", 0)) {
+        user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
+        user_op::UserOpConfWrapper grad_op =
+            builder.Op("upsample_nearest_2d_grad")
+                .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
+                .Input("x", op.input("x", 0))
+                .Output("dx")
+                .Attr("height_scale", op.attr<float>("height_scale"))
+                .Attr("width_scale", op.attr<float>("width_scale"))
+                .Attr("data_format", op.attr<std::string>("data_format"))
+                .Build();
+        op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
+        AddOp(grad_op);
+      }
+    });
+
+REGISTER_USER_OP_GRAD("upsample_bilinear_2d")
+    .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
+      if (op.NeedGenGradTensor4OpInput("x", 0)) {
+        user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
+        user_op::UserOpConfWrapper grad_op =
+            builder.Op("upsample_bilinear_2d_grad")
+                .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
+                .Input("x", op.input("x", 0))
+                .Output("dx")
+                .Attr("height_scale", op.attr<float>("height_scale"))
+                .Attr("width_scale", op.attr<float>("width_scale"))
+                .Attr("align_corners", op.attr<bool>("align_corners"))
+                .Attr("data_format", op.attr<std::string>("data_format"))
+                .Build();
+        op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
+        AddOp(grad_op);
+      }
+    });
+
+REGISTER_USER_OP_GRAD("upsample_bicubic_2d")
+    .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
+      if (op.NeedGenGradTensor4OpInput("x", 0)) {
+        user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
+        user_op::UserOpConfWrapper grad_op =
+            builder.Op("upsample_bicubic_2d_grad")
+                .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
+                .Input("x", op.input("x", 0))
+                .Output("dx")
+                .Attr("height_scale", op.attr<float>("height_scale"))
+                .Attr("width_scale", op.attr<float>("width_scale"))
+                .Attr("align_corners", op.attr<bool>("align_corners"))
+                .Attr("data_format", op.attr<std::string>("data_format"))
+                .Build();
+        op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
+        AddOp(grad_op);
+      }
+    });
+
 REGISTER_USER_OP_GRAD("upsample")
     .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
       if (op.NeedGenGradTensor4OpInput("x", 0)) {
@@ -95,4 +566,43 @@ REGISTER_USER_OP_GRAD("upsample")
       }
     });
 
+REGISTER_USER_OP_GRAD("upsample_nearest_3d")
+    .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
+      if (op.NeedGenGradTensor4OpInput("x", 0)) {
+        user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
+        user_op::UserOpConfWrapper grad_op =
+            builder.Op("upsample_nearest_3d_grad")
+                .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
+                .Input("x", op.input("x", 0))
+                .Output("dx")
+                .Attr("depth_scale", op.attr<float>("depth_scale"))
+                .Attr("height_scale", op.attr<float>("height_scale"))
+                .Attr("width_scale", op.attr<float>("width_scale"))
+                .Attr("data_format", op.attr<std::string>("data_format"))
+                .Build();
+        op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
+        AddOp(grad_op);
+      }
+    });
+
+REGISTER_USER_OP_GRAD("upsample_trilinear_3d")
+    .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) {
+      if (op.NeedGenGradTensor4OpInput("x", 0)) {
+        user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad");
+        user_op::UserOpConfWrapper grad_op =
+            builder.Op("upsample_trilinear_3d_grad")
+                .Input("dy", op.GetGradTensorWithOpOutput("y", 0))
+                .Input("x", op.input("x", 0))
+                .Output("dx")
+                .Attr("depth_scale", op.attr<float>("depth_scale"))
+                .Attr("height_scale", op.attr<float>("height_scale"))
+                .Attr("width_scale", op.attr<float>("width_scale"))
+                .Attr("align_corners", op.attr<bool>("align_corners"))
+                .Attr("data_format", op.attr<std::string>("data_format"))
+                .Build();
+        op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0);
+        AddOp(grad_op);
+      }
+    });
+
 }  // namespace oneflow