From 5b63e7691902cc46ed62d40a273b33f2e12e1381 Mon Sep 17 00:00:00 2001
From: Shijie <821898965@qq.com>
Date: Sat, 31 Jul 2021 03:46:19 +0800
Subject: [PATCH] Dev minimum maximum (#5576)

* add minimum and maximum

* add testcase

* fix docs

* move to python

* add elementwise min max grad func

* add autotest

* add broadcast min max grad func

* add broadcast min max testcase

* add broadcast_binary related functional

* convert gradient func to functional

* delete elementwise op_expr

* delete grad_op

* bind python

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 .../gradient_funcs/broadcast_binary_ops.cpp   | 196 ++++++++----------
 .../elementwise_minimum_maximum.cpp           |  83 ++++++++
 oneflow/core/framework/op_expr_helper.cpp     |   1 +
 oneflow/core/functional/functional_api.yaml   |  40 ++++
 .../core/functional/impl/array_functor.cpp    |  98 +++++++++
 .../core/functional/impl/binary_functor.cpp   |  43 ++++
 python/oneflow/__init__.py                    |   2 +
 python/oneflow/nn/modules/math_ops.py         |  86 ++++++++
 python/oneflow/test/modules/test_math_ops.py  | 155 ++++++++++++++
 9 files changed, 600 insertions(+), 104 deletions(-)
 create mode 100644 oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp

diff --git a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
index 24fdc1dc6..43a657b31 100644
--- a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
+++ b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
@@ -16,8 +16,7 @@ limitations under the License.
 #include "oneflow/core/framework/op_expr_grad_function.h"
 #include "oneflow/core/framework/op_builder.h"
 #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
-#include "oneflow/core/framework/op_expr.h"
-#include "oneflow/core/framework/op_expr_helper.h"
+#include "oneflow/core/functional/functional.h"
 
 namespace oneflow {
 namespace one {
@@ -26,40 +25,27 @@ namespace {
 
 class ReduceSumLikeModule {
  public:
-  ReduceSumLikeModule(const std::string& op_name) {
-    identity_op_ = CHECK_JUST(op_expr_helper::IdentityOp(op_name + "_identity"));
-    reshape_like_op_ = CHECK_JUST(op_expr_helper::ReshapeLikeOp(op_name + "_reshape_like"));
-    reduce_sum_like_op_ =
-        CHECK_JUST(op_expr_helper::ReduceSumLikeOp({-1}, op_name + "reduce_sum_like"));
-  }
+  ReduceSumLikeModule() = default;
+  ~ReduceSumLikeModule() = default;
 
-  Maybe<Tensor> forward(const std::shared_ptr<Tensor>& input,
-                        const std::shared_ptr<Tensor>& like) const {
+  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,
+                           const std::shared_ptr<Tensor>& like) const {
     const auto& in_shape = *(input->shape());
     const auto& like_shape = *(like->shape());
-    TensorTuple inputs{input};
-    MutableAttrMap attrs;
-    std::shared_ptr<OpExpr> op = identity_op_;
     if (in_shape != like_shape) {
       const Shape& left_extended_shape =
           CreateLeftExtendedShape(ShapeView(like_shape), in_shape.NumAxes());
       if (in_shape == left_extended_shape) {
-        op = reshape_like_op_;
+        return JUST(functional::ReshapeLike(input, like));
       } else {
-        op = reduce_sum_like_op_;
         const AxisVector& broadcast_axis_vec = left_extended_shape.Axes4BroadcastTo(in_shape);
-        JUST(attrs.SetAttr<std::vector<int32_t>>(
-            "axis", std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}));
+        return JUST(functional::ReduceSumLike(
+            input, like,
+            std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}));
       }
-      inputs.push_back(like);
     }
-    return JUST(OpInterpUtil::Dispatch<Tensor>(*op, inputs, attrs));
+    return JUST(functional::Identity(input));
   }
-
- private:
-  std::shared_ptr<OpExpr> identity_op_;
-  std::shared_ptr<OpExpr> reshape_like_op_;
-  std::shared_ptr<OpExpr> reduce_sum_like_op_;
 };
 
 }  // namespace
@@ -69,12 +55,7 @@ class BroadcastBinaryGrad : public OpExprGradFunction<OpExprInterpState> {
   BroadcastBinaryGrad() = default;
   virtual ~BroadcastBinaryGrad() = default;
 
-  virtual Maybe<void> Init(const OpExpr& op) {
-    const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
-    CHECK_NOTNULL_OR_RETURN(fw_op_expr);
-    op_name_ = fw_op_expr->op_name();
-    return Maybe<void>::Ok();
-  }
+  virtual Maybe<void> Init(const OpExpr& op) { return Maybe<void>::Ok(); }
 
   Maybe<void> Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
                       const AttrMap& attrs) const override {
@@ -85,120 +66,64 @@ class BroadcastBinaryGrad : public OpExprGradFunction<OpExprInterpState> {
     ctx->SaveTensorForBackward(outputs.at(0));
     return Maybe<void>::Ok();
   }
-
- protected:
-  std::string op_name_;
 };
 
 class BroadcastAdd : public BroadcastBinaryGrad {
  public:
-  Maybe<void> Init(const OpExpr& op) override {
-    JUST(BroadcastBinaryGrad::Init(op));
-    x_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_x");
-    y_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_y");
-    return Maybe<void>::Ok();
-  }
-
   Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
                     TensorTuple* in_grads) const override {
     const auto& x = ctx->SavedTensors().at(0);
     const auto& y = ctx->SavedTensors().at(1);
     in_grads->resize(2);
-    if (x->requires_grad()) { in_grads->at(0) = JUST(x_grad_op_->forward(out_grads.at(0), x)); }
-    if (y->requires_grad()) { in_grads->at(1) = JUST(y_grad_op_->forward(out_grads.at(0), y)); }
+    if (x->requires_grad()) { in_grads->at(0) = JUST(ReduceSumLikeModule()(out_grads.at(0), x)); }
+    if (y->requires_grad()) { in_grads->at(1) = JUST(ReduceSumLikeModule()(out_grads.at(0), y)); }
     return Maybe<void>::Ok();
   }
-
- private:
-  std::shared_ptr<ReduceSumLikeModule> x_grad_op_;
-  std::shared_ptr<ReduceSumLikeModule> y_grad_op_;
 };
 
 REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_add", BroadcastAdd);
 
 class BroadcastSub : public BroadcastBinaryGrad {
  public:
-  Maybe<void> Init(const OpExpr& op) override {
-    JUST(BroadcastBinaryGrad::Init(op));
-    x_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_x");
-    y_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_y");
-    y_grad_mul_op_ =
-        JUST(op_expr_helper::ScalarMulOp(-1.f, GradientOpName(op_name_ + "_y_scalar_mul")));
-    return Maybe<void>::Ok();
-  }
-
   Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
                     TensorTuple* in_grads) const override {
     const auto& x = ctx->SavedTensors().at(0);
     const auto& y = ctx->SavedTensors().at(1);
     in_grads->resize(2);
-    if (x->requires_grad()) { in_grads->at(0) = JUST(x_grad_op_->forward(out_grads.at(0), x)); }
+    if (x->requires_grad()) { in_grads->at(0) = JUST(ReduceSumLikeModule()(out_grads.at(0), x)); }
     if (y->requires_grad()) {
-      const auto& grad = JUST(OpInterpUtil::Dispatch<Tensor>(*y_grad_mul_op_, {out_grads.at(0)}));
-      in_grads->at(1) = JUST(y_grad_op_->forward(grad, y));
+      const auto& grad = JUST(functional::ScalarMul(out_grads.at(0), functional::Scalar(-1.f)));
+      in_grads->at(1) = JUST(ReduceSumLikeModule()(grad, y));
     }
     return Maybe<void>::Ok();
   }
-
- private:
-  std::shared_ptr<ReduceSumLikeModule> x_grad_op_;
-  std::shared_ptr<ReduceSumLikeModule> y_grad_op_;
-  std::shared_ptr<OpExpr> y_grad_mul_op_;
 };
 
 REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_sub", BroadcastSub);
 
 class BroadcastMul : public BroadcastBinaryGrad {
  public:
-  Maybe<void> Init(const OpExpr& op) override {
-    JUST(BroadcastBinaryGrad::Init(op));
-    x_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_x");
-    y_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_y");
-    x_grad_mul_op_ =
-        JUST(op_expr_helper::BroadcastMulOp(GradientOpName(op_name_ + "_x_broadcast_mul")));
-    y_grad_mul_op_ =
-        JUST(op_expr_helper::BroadcastMulOp(GradientOpName(op_name_ + "_y_broadcast_mul")));
-    return Maybe<void>::Ok();
-  }
-
   Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
                     TensorTuple* in_grads) const override {
     const auto& x = ctx->SavedTensors().at(0);
     const auto& y = ctx->SavedTensors().at(1);
     in_grads->resize(2);
     if (x->requires_grad()) {
-      const auto& x_grad =
-          JUST(OpInterpUtil::Dispatch<Tensor>(*x_grad_mul_op_, {out_grads.at(0), y}));
-      in_grads->at(0) = JUST(x_grad_op_->forward(x_grad, x));
+      const auto& x_grad = JUST(functional::BroadcastMul(out_grads.at(0), y));
+      in_grads->at(0) = JUST(ReduceSumLikeModule()(x_grad, x));
     }
     if (y->requires_grad()) {
-      const auto& y_grad =
-          JUST(OpInterpUtil::Dispatch<Tensor>(*y_grad_mul_op_, {out_grads.at(0), x}));
-      in_grads->at(1) = JUST(y_grad_op_->forward(y_grad, y));
+      const auto& y_grad = JUST(functional::BroadcastMul(out_grads.at(0), x));
+      in_grads->at(1) = JUST(ReduceSumLikeModule()(y_grad, y));
     }
     return Maybe<void>::Ok();
   }
-
- private:
-  std::shared_ptr<ReduceSumLikeModule> x_grad_op_;
-  std::shared_ptr<ReduceSumLikeModule> y_grad_op_;
-  std::shared_ptr<OpExpr> x_grad_mul_op_;
-  std::shared_ptr<OpExpr> y_grad_mul_op_;
 };
 
 REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_mul", BroadcastMul);
 
 class BroadcastDiv : public BroadcastBinaryGrad {
  public:
-  Maybe<void> Init(const OpExpr& op) override {
-    JUST(BroadcastBinaryGrad::Init(op));
-    x_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_x");
-    x_grad_div_op_ =
-        JUST(op_expr_helper::BroadcastDivOp(GradientOpName(op_name_ + "_x_broadcast_div")));
-    y_grad_op_ = JUST(op_expr_helper::BroadcastDivGradOp(GradientOpName(op_name_ + "_y")));
-    return Maybe<void>::Ok();
-  }
-
   Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
                     TensorTuple* in_grads) const override {
     const auto& x = ctx->SavedTensors().at(0);
@@ -206,23 +131,86 @@ class BroadcastDiv : public BroadcastBinaryGrad {
     const auto& z = ctx->SavedTensors().at(2);
     in_grads->resize(2);
     if (x->requires_grad()) {
-      const auto& x_grad =
-          JUST(OpInterpUtil::Dispatch<Tensor>(*x_grad_div_op_, {out_grads.at(0), y}));
-      in_grads->at(0) = JUST(x_grad_op_->forward(x_grad, x));
+      const auto& x_grad = JUST(functional::BroadcastDiv(out_grads.at(0), y));
+      in_grads->at(0) = JUST(ReduceSumLikeModule()(x_grad, x));
     }
     if (y->requires_grad()) {
-      in_grads->at(1) = JUST(OpInterpUtil::Dispatch<Tensor>(*y_grad_op_, {out_grads.at(0), z, y}));
+      in_grads->at(1) = JUST(functional::BroadcastDivGrad(out_grads.at(0), z, y));
     }
     return Maybe<void>::Ok();
   }
-
- private:
-  std::shared_ptr<ReduceSumLikeModule> x_grad_op_;
-  std::shared_ptr<OpExpr> x_grad_div_op_;
-  std::shared_ptr<OpExpr> y_grad_op_;
 };
 
 REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_div", BroadcastDiv);
 
+class BroadcastMinMax : public BroadcastBinaryGrad {
+ public:
+  Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    const auto& x = ctx->SavedTensors().at(0);
+    const auto& y = ctx->SavedTensors().at(1);
+    const auto& out = ctx->SavedTensors().at(2);
+    const auto& out_shape = *(out->shape());
+    in_grads->resize(2);
+    if (x->requires_grad() || y->requires_grad()) {
+      const auto& x_shape = *(x->shape());
+      const auto& y_shape = *(y->shape());
+      auto broad_x_ = x;
+      auto broad_y_ = y;
+      if (x_shape != out_shape) {
+        const Shape& left_extended_x_shape =
+            CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes());
+        const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);
+        const std::vector<int32_t> x_axis =
+            std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
+        broad_x_ = JUST(functional::BroadcastLike(x, out, x_axis));
+      }
+      if (y_shape != out_shape) {
+        const Shape& left_extended_y_shape =
+            CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes());
+        const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);
+        const std::vector<int32_t> y_axis =
+            std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
+        broad_y_ = JUST(functional::BroadcastLike(y, out, y_axis));
+      }
+      const auto& broad_grads =
+          JUST(elementwise_grad_functor_(out_grads.at(0), broad_x_, broad_y_));
+      if (x->requires_grad()) {
+        in_grads->at(0) = JUST(ReduceSumLikeModule()(broad_grads->at(0), x));
+      }
+      if (y->requires_grad()) {
+        in_grads->at(1) = JUST(ReduceSumLikeModule()(broad_grads->at(1), y));
+      }
+    }
+    return Maybe<void>::Ok();
+  }
+
+ protected:
+  std::function<Maybe<TensorTuple>(const std::shared_ptr<Tensor>&, const std::shared_ptr<Tensor>&,
+                                   const std::shared_ptr<Tensor>&)>
+      elementwise_grad_functor_;
+};
+
+class BroadcastMinimum : public BroadcastMinMax {
+ public:
+  Maybe<void> Init(const OpExpr& op) override {
+    JUST(BroadcastMinMax::Init(op));
+    elementwise_grad_functor_ = functional::ElementwiseMinGrad;
+    return Maybe<void>::Ok();
+  }
+};
+
+class BroadcastMaximum : public BroadcastMinMax {
+ public:
+  Maybe<void> Init(const OpExpr& op) override {
+    JUST(BroadcastMinMax::Init(op));
+    elementwise_grad_functor_ = functional::ElementwiseMaxGrad;
+    return Maybe<void>::Ok();
+  }
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_minimum", BroadcastMinimum);
+REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_maximum", BroadcastMaximum);
+
 }  // namespace one
 }  // namespace oneflow
diff --git a/oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp b/oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp
new file mode 100644
index 000000000..e155c6310
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/elementwise_minimum_maximum.cpp
@@ -0,0 +1,83 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/framework/op_builder.h"
+#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
+#include "oneflow/core/framework/attr_map.h"
+#include "oneflow/core/functional/functional.h"
+
+namespace oneflow {
+namespace one {
+
+struct ElementwiseXimumOpExprInterpState : public OpExprInterpState {
+  bool x_requires_grad;
+  bool y_requires_grad;
+};
+
+class ElementwiseXimumOp : public OpExprGradFunction<ElementwiseXimumOpExprInterpState> {
+ public:
+  Maybe<void> Capture(ElementwiseXimumOpExprInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override {
+    ctx->x_requires_grad = inputs.at(0)->requires_grad();
+    ctx->y_requires_grad = inputs.at(1)->requires_grad();
+    ctx->SaveTensorForBackward(inputs.at(0));
+    ctx->SaveTensorForBackward(inputs.at(1));
+    return Maybe<void>::Ok();
+  }
+
+  Maybe<void> Apply(const ElementwiseXimumOpExprInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override {
+    if (!(ctx->x_requires_grad || ctx->y_requires_grad)) { return Maybe<void>::Ok(); }
+
+    in_grads->resize(2);
+    const std::shared_ptr<one::Tensor>& x = ctx->SavedTensors().at(0);
+    const std::shared_ptr<one::Tensor>& y = ctx->SavedTensors().at(1);
+    if (ctx->x_requires_grad || ctx->y_requires_grad) {
+      const auto& grads = JUST(grad_functor(out_grads.at(0), x, y));
+      if (ctx->x_requires_grad) { in_grads->at(0) = grads->at(0); }
+      if (ctx->y_requires_grad) { in_grads->at(1) = grads->at(1); }
+    }
+
+    return Maybe<void>::Ok();
+  }
+
+ protected:
+  std::function<Maybe<TensorTuple>(const std::shared_ptr<Tensor>&, const std::shared_ptr<Tensor>&,
+                                   const std::shared_ptr<Tensor>&)>
+      grad_functor;
+};
+
+class ElementwiseMinimum : public ElementwiseXimumOp {
+ public:
+  Maybe<void> Init(const OpExpr& op) override {
+    grad_functor = functional::ElementwiseMinGrad;
+    return Maybe<void>::Ok();
+  }
+};
+
+class ElementwiseMaximum : public ElementwiseXimumOp {
+ public:
+  Maybe<void> Init(const OpExpr& op) override {
+    grad_functor = functional::ElementwiseMaxGrad;
+    return Maybe<void>::Ok();
+  }
+};
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("elementwise_minimum", ElementwiseMinimum);
+REGISTER_OP_EXPR_GRAD_FUNCTION("elementwise_maximum", ElementwiseMaximum);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp
index efd7f7dba..ccb56d84c 100644
--- a/oneflow/core/framework/op_expr_helper.cpp
+++ b/oneflow/core/framework/op_expr_helper.cpp
@@ -860,5 +860,6 @@ Maybe<one::UserOpExpr> SoftmaxGradOp() { return SoftmaxGradOp("softmax_grad"); }
 Maybe<one::UserOpExpr> SoftmaxGradOp(const std::string& name) {
   return one::OpBuilder("softmax_grad", name).Input("y").Input("dy").Output("dx").Build();
 }
+
 }  // namespace op_expr_helper
 }  // namespace oneflow
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index 1c529d99f..f69e203e2 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -77,6 +77,10 @@
   signature: "Tensor BroadcastDiv(Tensor x, Tensor y)"
   bind_python: True
 
+- name: "broadcast_div_grad"
+  signature: "Tensor BroadcastDivGrad(Tensor y, Tensor z, Tensor dz)"
+  bind_python: False
+
 - name: "broadcast_equal"
   signature: "Tensor BroadcastEqual(Tensor x, Tensor y)"
   bind_python: True
@@ -775,6 +779,42 @@
   signature: "Void TensorSetItem(Tensor x, *, TensorIndex index, Tensor value)"
   bind_python: True
 
+- name: "broadcast_min"
+  signature: "Tensor BroadcastMin(Tensor x, Tensor y)"
+  bind_python: True
+
+- name: "broadcast_max"
+  signature: "Tensor BroadcastMax(Tensor x, Tensor y)"
+  bind_python: True
+
+- name: "elementwise_min"
+  signature: "Tensor ElementwiseMin(Tensor x, Tensor y)"
+  bind_python: True
+
+- name: "elementwise_max"
+  signature: "Tensor ElementwiseMax(Tensor x, Tensor y)"
+  bind_python: True
+
+- name: "elementwise_min_grad"
+  signature: "TensorTuple ElementwiseMinGrad(Tensor dz, Tensor x, Tensor y)"
+  bind_python: False
+
+- name: "elementwise_max_grad"
+  signature: "TensorTuple ElementwiseMaxGrad(Tensor dz, Tensor x, Tensor y)"
+  bind_python: False
+
 - name: "stack"
   signature: "Tensor Stack(TensorTuple inputs, *, Int64 dim=0)"
   bind_python: True
+
+- name: "identity"
+  signature: "Tensor Identity(Tensor in)"
+  bind_python: True
+
+- name: "reshape_like"
+  signature: "Tensor ReshapeLike(Tensor in, Tensor like)"
+  bind_python: True
+
+- name: "reduce_sum_like"
+  signature: "Tensor ReduceSumLike(Tensor in, Tensor like, *,Int32List axis)"
+  bind_python: True
diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp
index f40ebc64e..730b7b868 100644
--- a/oneflow/core/functional/impl/array_functor.cpp
+++ b/oneflow/core/functional/impl/array_functor.cpp
@@ -1012,6 +1012,99 @@ class TensorSetItemFunctor {
   }
 };
 
+class ElementwiseMinimumGradFunctor {
+ public:
+  ElementwiseMinimumGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("elementwise_minimum_backward")
+                         .Input("dz")
+                         .Input("x")
+                         .Input("y")
+                         .Output("dx")
+                         .Output("dy")
+                         .Build());
+  }
+  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dz,
+                                const std::shared_ptr<one::Tensor>& x,
+                                const std::shared_ptr<one::Tensor>& y) const {
+    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dz, x, y});
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class ElementwiseMaximumGradFunctor {
+ public:
+  ElementwiseMaximumGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("elementwise_maximum_backward")
+                         .Input("dz")
+                         .Input("x")
+                         .Input("y")
+                         .Output("dx")
+                         .Output("dy")
+                         .Build());
+  }
+  Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dz,
+                                const std::shared_ptr<one::Tensor>& x,
+                                const std::shared_ptr<one::Tensor>& y) const {
+    return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dz, x, y});
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class BroadcastDivGradFunctor {
+ public:
+  BroadcastDivGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("broadcast_div_grad")
+                         .Input("dz")
+                         .Input("z")
+                         .Input("y")
+                         .Output("dy")
+                         .Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dz,
+                           const std::shared_ptr<one::Tensor>& z,
+                           const std::shared_ptr<one::Tensor>& y) const {
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dz, z, y});
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class IdentityFunctor {
+ public:
+  IdentityFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("identity").Input("in").Output("out").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in) const {
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {in});
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class ReduceSumLikeFunctor {
+ public:
+  ReduceSumLikeFunctor() {
+    op_ =
+        CHECK_JUST(one::OpBuilder("reduce_sum_like").Input("x").Input("like").Output("y").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
+                           const std::shared_ptr<one::Tensor>& like,
+                           const std::vector<int32_t>& axis) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<std::vector<int32_t>>("axis", axis));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x, like}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
 }  // namespace impl
 
 ONEFLOW_FUNCTION_LIBRARY(m) {
@@ -1061,6 +1154,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::DiagGradFunctor>("DiagGrad");
   m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
   m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem");
+  m.add_functor<impl::ElementwiseMinimumGradFunctor>("ElementwiseMinGrad");
+  m.add_functor<impl::ElementwiseMaximumGradFunctor>("ElementwiseMaxGrad");
+  m.add_functor<impl::BroadcastDivGradFunctor>("BroadcastDivGrad");
+  m.add_functor<impl::IdentityFunctor>("Identity");
+  m.add_functor<impl::ReduceSumLikeFunctor>("ReduceSumLike");
 };
 
 }  // namespace functional
diff --git a/oneflow/core/functional/impl/binary_functor.cpp b/oneflow/core/functional/impl/binary_functor.cpp
index fb8aaf9cf..924a26604 100644
--- a/oneflow/core/functional/impl/binary_functor.cpp
+++ b/oneflow/core/functional/impl/binary_functor.cpp
@@ -162,6 +162,44 @@ class ScalarDivByTensorFunctor : public BinaryFunctor {
   }
 };
 
+class BroadcastMinimumFunctor : public BinaryFunctor {
+ public:
+  BroadcastMinimumFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("broadcast_minimum").Input("x").Input("y").Output("z").Build());
+  }
+};
+
+class BroadcastMaximumFunctor : public BinaryFunctor {
+ public:
+  BroadcastMaximumFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("broadcast_maximum").Input("x").Input("y").Output("z").Build());
+  }
+};
+
+class ElementwiseMinimumFunctor : public BinaryFunctor {
+ public:
+  ElementwiseMinimumFunctor() {
+    op_ =
+        CHECK_JUST(one::OpBuilder("elementwise_minimum").Input("x").Input("y").Output("z").Build());
+  }
+};
+
+class ElementwiseMaximumFunctor : public BinaryFunctor {
+ public:
+  ElementwiseMaximumFunctor() {
+    op_ =
+        CHECK_JUST(one::OpBuilder("elementwise_maximum").Input("x").Input("y").Output("z").Build());
+  }
+};
+
+class ReshapeLikeFunctor : public BinaryFunctor {
+ public:
+  ReshapeLikeFunctor() {
+    op_ =
+        CHECK_JUST(one::OpBuilder("reshape_like").Input("in").Input("like").Output("out").Build());
+  }
+};
+
 }  // namespace impl
 
 ONEFLOW_FUNCTION_LIBRARY(m) {
@@ -172,6 +210,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::BroadcastSubFunctor>("BroadcastSub");
   m.add_functor<impl::BroadcastMulFunctor>("BroadcastMul");
   m.add_functor<impl::BroadcastDivFunctor>("BroadcastDiv");
+  m.add_functor<impl::BroadcastMinimumFunctor>("BroadcastMin");
+  m.add_functor<impl::BroadcastMaximumFunctor>("BroadcastMax");
   m.add_functor<impl::BroadcastEqualFunctor>("BroadcastEqual");
   m.add_functor<impl::BroadcastNotEqualFunctor>("BroadcastNotEqual");
   m.add_functor<impl::BroadcastGreaterFunctor>("BroadcastGreater");
@@ -182,7 +222,10 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::ScalarSubByTensorFunctor>("ScalarSubByTensor");
   m.add_functor<impl::ScalarMulByTensorFunctor>("ScalarMulByTensor");
   m.add_functor<impl::ScalarDivByTensorFunctor>("ScalarDivByTensor");
+  m.add_functor<impl::ElementwiseMinimumFunctor>("ElementwiseMin");
+  m.add_functor<impl::ElementwiseMaximumFunctor>("ElementwiseMax");
   m.add_functor<impl::BroadcastFModFunctor>("BroadcastFMod");
+  m.add_functor<impl::ReshapeLikeFunctor>("ReshapeLike");
 };
 
 }  // namespace functional
diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py
index 8e6f9d5cf..70e3eb198 100644
--- a/python/oneflow/__init__.py
+++ b/python/oneflow/__init__.py
@@ -284,6 +284,8 @@ from oneflow.nn.modules.math_ops import erfc_op as erfc
 from oneflow.nn.modules.math_ops import expm1_op as expm1
 from oneflow.nn.modules.math_ops import fmod_op as fmod
 from oneflow.nn.modules.math_ops import log_op as log
+from oneflow.nn.modules.math_ops import minimum as minimum
+from oneflow.nn.modules.math_ops import maximum as maximum
 from oneflow.nn.modules.math_ops import pow_op as pow
 from oneflow.nn.modules.math_ops import rsqrt_op as rsqrt
 from oneflow.nn.modules.math_ops import sin_op as sin
diff --git a/python/oneflow/nn/modules/math_ops.py b/python/oneflow/nn/modules/math_ops.py
index fd664de6b..053110f4a 100644
--- a/python/oneflow/nn/modules/math_ops.py
+++ b/python/oneflow/nn/modules/math_ops.py
@@ -1637,6 +1637,92 @@ def topk_op(input, k, dim: int = None, largest: bool = True, sorted: bool = True
     return Topk(k=k, dim=dim, largest=largest, sorted=sorted)(input)
 
 
+class ElementwiseMinimum(Module):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, x, y):
+        return flow.F.elementwise_min(x, y)
+
+
+class BroadcastMinimum(Module):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, x, y):
+        return flow.F.broadcast_min(x, y)
+
+
+@register_tensor_op("minimum")
+def minimum(x, y):
+    r"""Computes the element-wise minimum of x and y.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import numpy as np
+        >>> import oneflow as flow
+
+        >>> x = flow.tensor((1, 2, -1), dtype=flow.float)
+        >>> y = flow.tensor((3, 0, 4), dtype=flow.float)
+        >>> flow.minimum(x, y)
+        tensor([ 1.,  0., -1.], dtype=oneflow.float32)
+
+        >>> x = flow.tensor((1,), dtype=flow.float)
+        >>> y = flow.tensor((3, 0, 4), dtype=flow.float)
+        >>> flow.minimum(x, y)
+        tensor([1., 0., 1.], dtype=oneflow.float32)
+    """
+    if x.shape == y.shape:
+        return ElementwiseMinimum()(x, y)
+    else:
+        return BroadcastMinimum()(x, y)
+
+
+class ElementwiseMaximum(Module):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, x, y):
+        return flow.F.elementwise_max(x, y)
+
+
+class BroadcastMaximum(Module):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, x, y):
+        return flow.F.broadcast_max(x, y)
+
+
+@register_tensor_op("maximum")
+def maximum(x, y):
+    r"""Computes the element-wise maximum of x and y.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import numpy as np
+        >>> import oneflow as flow
+
+        >>> x = flow.tensor((1, 2, -1), dtype=flow.float)
+        >>> y = flow.tensor((3, 0, 4), dtype=flow.float)
+        >>> flow.maximum(x, y)
+        tensor([3., 2., 4.], dtype=oneflow.float32)
+
+        >>> x = flow.tensor((1,), dtype=flow.float)
+        >>> y = flow.tensor((3, 0, 4), dtype=flow.float)
+        >>> flow.maximum(x, y)
+        tensor([3., 1., 4.], dtype=oneflow.float32)
+    """
+    if x.shape == y.shape:
+        return ElementwiseMaximum()(x, y)
+    else:
+        return BroadcastMaximum()(x, y)
+
+
 if __name__ == "__main__":
     import doctest
 
diff --git a/python/oneflow/test/modules/test_math_ops.py b/python/oneflow/test/modules/test_math_ops.py
index b22b189cc..ca072fc22 100644
--- a/python/oneflow/test/modules/test_math_ops.py
+++ b/python/oneflow/test/modules/test_math_ops.py
@@ -602,5 +602,160 @@ class TestAtan2(flow.unittest.TestCase):
         return y
 
 
+def _test_elementwise_minimum(test_case, device):
+    arg_dict = OrderedDict()
+    arg_dict["shape"] = [(10, 10, 200), (3, 12), (12,)]
+    arg_dict["data_type"] = ["float32", "double"]
+    for (shape, data_type) in GenArgList(arg_dict):
+        input_x = flow.Tensor(
+            np.random.randn(*shape),
+            dtype=type_name_to_flow_type[data_type],
+            device=flow.device(device),
+        )
+        input_y = flow.Tensor(
+            np.random.randn(*shape),
+            dtype=type_name_to_flow_type[data_type],
+            device=flow.device(device),
+        )
+        of_values = flow.minimum(input_x, input_y)
+        np_values = np.minimum(input_x.numpy(), input_y.numpy())
+        test_case.assertTrue(
+            np.array_equal(of_values.numpy().flatten(), np_values.flatten())
+        )
+
+
+def _test_broadcast_minimum(test_case, device):
+    arg_dict = OrderedDict()
+    arg_dict["shape"] = [[(10, 10, 200), (10, 1, 1)], [(3, 12), (1, 12)]]
+    arg_dict["data_type"] = ["float32", "double"]
+    for (shape, data_type) in GenArgList(arg_dict):
+        input_x = flow.Tensor(
+            np.random.randn(*shape[0]),
+            dtype=type_name_to_flow_type[data_type],
+            device=flow.device(device),
+        )
+        input_y = flow.Tensor(
+            np.random.randn(*shape[1]),
+            dtype=type_name_to_flow_type[data_type],
+            device=flow.device(device),
+        )
+        of_values = flow.minimum(input_x, input_y)
+        np_values = np.minimum(input_x.numpy(), input_y.numpy())
+        test_case.assertTrue(
+            np.array_equal(of_values.numpy().flatten(), np_values.flatten())
+        )
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+@flow.unittest.skip_unless_1n1d()
+class TestMinimum(flow.unittest.TestCase):
+    def test_minimum(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_elementwise_minimum,
+            _test_broadcast_minimum,
+        ]
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
+    @autotest()
+    def test_flow_elementwise_minimum_with_random_data(test_case):
+        k1 = random(2, 6)
+        k2 = random(2, 6)
+        x = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2)
+        y = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2)
+        return torch.minimum(x, y)
+
+    @autotest()
+    def test_flow_broadcast_minimum_with_random_data(test_case):
+        k1 = random(2, 6)
+        k2 = random(2, 6)
+        k3 = random(2, 6)
+        x = random_pytorch_tensor(ndim=3, dim0=k1, dim1=1, dim2=1)
+        y = random_pytorch_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3)
+        return torch.minimum(x, y)
+
+
+def _test_elementwise_maximum(test_case, device):
+    arg_dict = OrderedDict()
+    arg_dict["shape"] = [(10, 10, 200), (3, 12), (12,)]
+    arg_dict["data_type"] = ["float32", "double"]
+    for (shape, data_type) in GenArgList(arg_dict):
+        input_x = flow.Tensor(
+            np.random.randn(*shape),
+            dtype=type_name_to_flow_type[data_type],
+            device=flow.device(device),
+        )
+        input_y = flow.Tensor(
+            np.random.randn(*shape),
+            dtype=type_name_to_flow_type[data_type],
+            device=flow.device(device),
+        )
+        of_values = flow.maximum(input_x, input_y)
+        np_values = np.maximum(input_x.numpy(), input_y.numpy())
+        test_case.assertTrue(
+            np.array_equal(of_values.numpy().flatten(), np_values.flatten())
+        )
+
+
+def _test_broadcast_maximum(test_case, device):
+    arg_dict = OrderedDict()
+    arg_dict["shape"] = [[(10, 10, 200), (10, 1, 1)], [(3, 12), (1, 12)]]
+    arg_dict["data_type"] = ["float32", "double"]
+    for (shape, data_type) in GenArgList(arg_dict):
+        input_x = flow.Tensor(
+            np.random.randn(*shape[0]),
+            dtype=type_name_to_flow_type[data_type],
+            device=flow.device(device),
+        )
+        input_y = flow.Tensor(
+            np.random.randn(*shape[1]),
+            dtype=type_name_to_flow_type[data_type],
+            device=flow.device(device),
+        )
+        of_values = flow.maximum(input_x, input_y)
+        np_values = np.maximum(input_x.numpy(), input_y.numpy())
+        test_case.assertTrue(
+            np.array_equal(of_values.numpy().flatten(), np_values.flatten())
+        )
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestMaximum(flow.unittest.TestCase):
+    def test_maximum(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_elementwise_maximum,
+            _test_broadcast_maximum,
+        ]
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
+    @autotest()
+    def test_flow_elementwise_mximum_with_random_data(test_case):
+        k1 = random(2, 6)
+        k2 = random(2, 6)
+        x = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2)
+        y = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2)
+        return torch.maximum(x, y)
+
+    @autotest()
+    def test_flow_broadcast_maximum_with_random_data(test_case):
+        k1 = random(2, 6)
+        k2 = random(2, 6)
+        k3 = random(2, 6)
+        x = random_pytorch_tensor(ndim=3, dim0=k1, dim1=1, dim2=1)
+        y = random_pytorch_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3)
+        return torch.maximum(x, y)
+
+
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab