Skip to content
Snippets Groups Projects
Unverified Commit 5b63e769 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

Dev minimum maximum (#5576)


* add minimum and maximum

* add testcase

* fix docs

* move to python

* add elementwise min max grad func

* add autotest

* add broadcast min max grad func

* add broadcast min max testcase

* add broadcast_binary related functional

* convert gradient func to functional

* delete elementwise op_expr

* delete grad_op

* bind python

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent ee5f09d4
No related branches found
No related tags found
No related merge requests found
......@@ -16,8 +16,7 @@ limitations under the License.
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_expr_helper.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
......@@ -26,40 +25,27 @@ namespace {
class ReduceSumLikeModule {
public:
ReduceSumLikeModule(const std::string& op_name) {
identity_op_ = CHECK_JUST(op_expr_helper::IdentityOp(op_name + "_identity"));
reshape_like_op_ = CHECK_JUST(op_expr_helper::ReshapeLikeOp(op_name + "_reshape_like"));
reduce_sum_like_op_ =
CHECK_JUST(op_expr_helper::ReduceSumLikeOp({-1}, op_name + "reduce_sum_like"));
}
ReduceSumLikeModule() = default;
~ReduceSumLikeModule() = default;
Maybe<Tensor> forward(const std::shared_ptr<Tensor>& input,
const std::shared_ptr<Tensor>& like) const {
Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input,
const std::shared_ptr<Tensor>& like) const {
const auto& in_shape = *(input->shape());
const auto& like_shape = *(like->shape());
TensorTuple inputs{input};
MutableAttrMap attrs;
std::shared_ptr<OpExpr> op = identity_op_;
if (in_shape != like_shape) {
const Shape& left_extended_shape =
CreateLeftExtendedShape(ShapeView(like_shape), in_shape.NumAxes());
if (in_shape == left_extended_shape) {
op = reshape_like_op_;
return JUST(functional::ReshapeLike(input, like));
} else {
op = reduce_sum_like_op_;
const AxisVector& broadcast_axis_vec = left_extended_shape.Axes4BroadcastTo(in_shape);
JUST(attrs.SetAttr<std::vector<int32_t>>(
"axis", std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}));
return JUST(functional::ReduceSumLike(
input, like,
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()}));
}
inputs.push_back(like);
}
return JUST(OpInterpUtil::Dispatch<Tensor>(*op, inputs, attrs));
return JUST(functional::Identity(input));
}
private:
std::shared_ptr<OpExpr> identity_op_;
std::shared_ptr<OpExpr> reshape_like_op_;
std::shared_ptr<OpExpr> reduce_sum_like_op_;
};
} // namespace
......@@ -69,12 +55,7 @@ class BroadcastBinaryGrad : public OpExprGradFunction<OpExprInterpState> {
BroadcastBinaryGrad() = default;
virtual ~BroadcastBinaryGrad() = default;
virtual Maybe<void> Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
op_name_ = fw_op_expr->op_name();
return Maybe<void>::Ok();
}
virtual Maybe<void> Init(const OpExpr& op) { return Maybe<void>::Ok(); }
Maybe<void> Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override {
......@@ -85,120 +66,64 @@ class BroadcastBinaryGrad : public OpExprGradFunction<OpExprInterpState> {
ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok();
}
protected:
std::string op_name_;
};
class BroadcastAdd : public BroadcastBinaryGrad {
public:
Maybe<void> Init(const OpExpr& op) override {
JUST(BroadcastBinaryGrad::Init(op));
x_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_x");
y_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_y");
return Maybe<void>::Ok();
}
Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
in_grads->resize(2);
if (x->requires_grad()) { in_grads->at(0) = JUST(x_grad_op_->forward(out_grads.at(0), x)); }
if (y->requires_grad()) { in_grads->at(1) = JUST(y_grad_op_->forward(out_grads.at(0), y)); }
if (x->requires_grad()) { in_grads->at(0) = JUST(ReduceSumLikeModule()(out_grads.at(0), x)); }
if (y->requires_grad()) { in_grads->at(1) = JUST(ReduceSumLikeModule()(out_grads.at(0), y)); }
return Maybe<void>::Ok();
}
private:
std::shared_ptr<ReduceSumLikeModule> x_grad_op_;
std::shared_ptr<ReduceSumLikeModule> y_grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_add", BroadcastAdd);
class BroadcastSub : public BroadcastBinaryGrad {
public:
Maybe<void> Init(const OpExpr& op) override {
JUST(BroadcastBinaryGrad::Init(op));
x_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_x");
y_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_y");
y_grad_mul_op_ =
JUST(op_expr_helper::ScalarMulOp(-1.f, GradientOpName(op_name_ + "_y_scalar_mul")));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
in_grads->resize(2);
if (x->requires_grad()) { in_grads->at(0) = JUST(x_grad_op_->forward(out_grads.at(0), x)); }
if (x->requires_grad()) { in_grads->at(0) = JUST(ReduceSumLikeModule()(out_grads.at(0), x)); }
if (y->requires_grad()) {
const auto& grad = JUST(OpInterpUtil::Dispatch<Tensor>(*y_grad_mul_op_, {out_grads.at(0)}));
in_grads->at(1) = JUST(y_grad_op_->forward(grad, y));
const auto& grad = JUST(functional::ScalarMul(out_grads.at(0), functional::Scalar(-1.f)));
in_grads->at(1) = JUST(ReduceSumLikeModule()(grad, y));
}
return Maybe<void>::Ok();
}
private:
std::shared_ptr<ReduceSumLikeModule> x_grad_op_;
std::shared_ptr<ReduceSumLikeModule> y_grad_op_;
std::shared_ptr<OpExpr> y_grad_mul_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_sub", BroadcastSub);
class BroadcastMul : public BroadcastBinaryGrad {
public:
Maybe<void> Init(const OpExpr& op) override {
JUST(BroadcastBinaryGrad::Init(op));
x_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_x");
y_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_y");
x_grad_mul_op_ =
JUST(op_expr_helper::BroadcastMulOp(GradientOpName(op_name_ + "_x_broadcast_mul")));
y_grad_mul_op_ =
JUST(op_expr_helper::BroadcastMulOp(GradientOpName(op_name_ + "_y_broadcast_mul")));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
in_grads->resize(2);
if (x->requires_grad()) {
const auto& x_grad =
JUST(OpInterpUtil::Dispatch<Tensor>(*x_grad_mul_op_, {out_grads.at(0), y}));
in_grads->at(0) = JUST(x_grad_op_->forward(x_grad, x));
const auto& x_grad = JUST(functional::BroadcastMul(out_grads.at(0), y));
in_grads->at(0) = JUST(ReduceSumLikeModule()(x_grad, x));
}
if (y->requires_grad()) {
const auto& y_grad =
JUST(OpInterpUtil::Dispatch<Tensor>(*y_grad_mul_op_, {out_grads.at(0), x}));
in_grads->at(1) = JUST(y_grad_op_->forward(y_grad, y));
const auto& y_grad = JUST(functional::BroadcastMul(out_grads.at(0), x));
in_grads->at(1) = JUST(ReduceSumLikeModule()(y_grad, y));
}
return Maybe<void>::Ok();
}
private:
std::shared_ptr<ReduceSumLikeModule> x_grad_op_;
std::shared_ptr<ReduceSumLikeModule> y_grad_op_;
std::shared_ptr<OpExpr> x_grad_mul_op_;
std::shared_ptr<OpExpr> y_grad_mul_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_mul", BroadcastMul);
class BroadcastDiv : public BroadcastBinaryGrad {
public:
Maybe<void> Init(const OpExpr& op) override {
JUST(BroadcastBinaryGrad::Init(op));
x_grad_op_ = std::make_shared<ReduceSumLikeModule>(op_name_ + "_x");
x_grad_div_op_ =
JUST(op_expr_helper::BroadcastDivOp(GradientOpName(op_name_ + "_x_broadcast_div")));
y_grad_op_ = JUST(op_expr_helper::BroadcastDivGradOp(GradientOpName(op_name_ + "_y")));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
......@@ -206,23 +131,86 @@ class BroadcastDiv : public BroadcastBinaryGrad {
const auto& z = ctx->SavedTensors().at(2);
in_grads->resize(2);
if (x->requires_grad()) {
const auto& x_grad =
JUST(OpInterpUtil::Dispatch<Tensor>(*x_grad_div_op_, {out_grads.at(0), y}));
in_grads->at(0) = JUST(x_grad_op_->forward(x_grad, x));
const auto& x_grad = JUST(functional::BroadcastDiv(out_grads.at(0), y));
in_grads->at(0) = JUST(ReduceSumLikeModule()(x_grad, x));
}
if (y->requires_grad()) {
in_grads->at(1) = JUST(OpInterpUtil::Dispatch<Tensor>(*y_grad_op_, {out_grads.at(0), z, y}));
in_grads->at(1) = JUST(functional::BroadcastDivGrad(out_grads.at(0), z, y));
}
return Maybe<void>::Ok();
}
private:
std::shared_ptr<ReduceSumLikeModule> x_grad_op_;
std::shared_ptr<OpExpr> x_grad_div_op_;
std::shared_ptr<OpExpr> y_grad_op_;
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_div", BroadcastDiv);
class BroadcastMinMax : public BroadcastBinaryGrad {
public:
Maybe<void> Apply(const OpExprInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
const auto& y = ctx->SavedTensors().at(1);
const auto& out = ctx->SavedTensors().at(2);
const auto& out_shape = *(out->shape());
in_grads->resize(2);
if (x->requires_grad() || y->requires_grad()) {
const auto& x_shape = *(x->shape());
const auto& y_shape = *(y->shape());
auto broad_x_ = x;
auto broad_y_ = y;
if (x_shape != out_shape) {
const Shape& left_extended_x_shape =
CreateLeftExtendedShape(ShapeView(x_shape), out_shape.NumAxes());
const AxisVector& broadcast_axis_vec = left_extended_x_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> x_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_x_ = JUST(functional::BroadcastLike(x, out, x_axis));
}
if (y_shape != out_shape) {
const Shape& left_extended_y_shape =
CreateLeftExtendedShape(ShapeView(y_shape), out_shape.NumAxes());
const AxisVector& broadcast_axis_vec = left_extended_y_shape.Axes4BroadcastTo(out_shape);
const std::vector<int32_t> y_axis =
std::vector<int32_t>{broadcast_axis_vec.begin(), broadcast_axis_vec.end()};
broad_y_ = JUST(functional::BroadcastLike(y, out, y_axis));
}
const auto& broad_grads =
JUST(elementwise_grad_functor_(out_grads.at(0), broad_x_, broad_y_));
if (x->requires_grad()) {
in_grads->at(0) = JUST(ReduceSumLikeModule()(broad_grads->at(0), x));
}
if (y->requires_grad()) {
in_grads->at(1) = JUST(ReduceSumLikeModule()(broad_grads->at(1), y));
}
}
return Maybe<void>::Ok();
}
protected:
std::function<Maybe<TensorTuple>(const std::shared_ptr<Tensor>&, const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)>
elementwise_grad_functor_;
};
class BroadcastMinimum : public BroadcastMinMax {
public:
Maybe<void> Init(const OpExpr& op) override {
JUST(BroadcastMinMax::Init(op));
elementwise_grad_functor_ = functional::ElementwiseMinGrad;
return Maybe<void>::Ok();
}
};
class BroadcastMaximum : public BroadcastMinMax {
public:
Maybe<void> Init(const OpExpr& op) override {
JUST(BroadcastMinMax::Init(op));
elementwise_grad_functor_ = functional::ElementwiseMaxGrad;
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_minimum", BroadcastMinimum);
REGISTER_OP_EXPR_GRAD_FUNCTION("broadcast_maximum", BroadcastMaximum);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct ElementwiseXimumOpExprInterpState : public OpExprInterpState {
bool x_requires_grad;
bool y_requires_grad;
};
class ElementwiseXimumOp : public OpExprGradFunction<ElementwiseXimumOpExprInterpState> {
public:
Maybe<void> Capture(ElementwiseXimumOpExprInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->y_requires_grad = inputs.at(1)->requires_grad();
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(inputs.at(1));
return Maybe<void>::Ok();
}
Maybe<void> Apply(const ElementwiseXimumOpExprInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
if (!(ctx->x_requires_grad || ctx->y_requires_grad)) { return Maybe<void>::Ok(); }
in_grads->resize(2);
const std::shared_ptr<one::Tensor>& x = ctx->SavedTensors().at(0);
const std::shared_ptr<one::Tensor>& y = ctx->SavedTensors().at(1);
if (ctx->x_requires_grad || ctx->y_requires_grad) {
const auto& grads = JUST(grad_functor(out_grads.at(0), x, y));
if (ctx->x_requires_grad) { in_grads->at(0) = grads->at(0); }
if (ctx->y_requires_grad) { in_grads->at(1) = grads->at(1); }
}
return Maybe<void>::Ok();
}
protected:
std::function<Maybe<TensorTuple>(const std::shared_ptr<Tensor>&, const std::shared_ptr<Tensor>&,
const std::shared_ptr<Tensor>&)>
grad_functor;
};
class ElementwiseMinimum : public ElementwiseXimumOp {
public:
Maybe<void> Init(const OpExpr& op) override {
grad_functor = functional::ElementwiseMinGrad;
return Maybe<void>::Ok();
}
};
class ElementwiseMaximum : public ElementwiseXimumOp {
public:
Maybe<void> Init(const OpExpr& op) override {
grad_functor = functional::ElementwiseMaxGrad;
return Maybe<void>::Ok();
}
};
REGISTER_OP_EXPR_GRAD_FUNCTION("elementwise_minimum", ElementwiseMinimum);
REGISTER_OP_EXPR_GRAD_FUNCTION("elementwise_maximum", ElementwiseMaximum);
} // namespace one
} // namespace oneflow
......@@ -860,5 +860,6 @@ Maybe<one::UserOpExpr> SoftmaxGradOp() { return SoftmaxGradOp("softmax_grad"); }
Maybe<one::UserOpExpr> SoftmaxGradOp(const std::string& name) {
return one::OpBuilder("softmax_grad", name).Input("y").Input("dy").Output("dx").Build();
}
} // namespace op_expr_helper
} // namespace oneflow
......@@ -77,6 +77,10 @@
signature: "Tensor BroadcastDiv(Tensor x, Tensor y)"
bind_python: True
- name: "broadcast_div_grad"
signature: "Tensor BroadcastDivGrad(Tensor y, Tensor z, Tensor dz)"
bind_python: False
- name: "broadcast_equal"
signature: "Tensor BroadcastEqual(Tensor x, Tensor y)"
bind_python: True
......@@ -775,6 +779,42 @@
signature: "Void TensorSetItem(Tensor x, *, TensorIndex index, Tensor value)"
bind_python: True
- name: "broadcast_min"
signature: "Tensor BroadcastMin(Tensor x, Tensor y)"
bind_python: True
- name: "broadcast_max"
signature: "Tensor BroadcastMax(Tensor x, Tensor y)"
bind_python: True
- name: "elementwise_min"
signature: "Tensor ElementwiseMin(Tensor x, Tensor y)"
bind_python: True
- name: "elementwise_max"
signature: "Tensor ElementwiseMax(Tensor x, Tensor y)"
bind_python: True
- name: "elementwise_min_grad"
signature: "TensorTuple ElementwiseMinGrad(Tensor dz, Tensor x, Tensor y)"
bind_python: False
- name: "elementwise_max_grad"
signature: "TensorTuple ElementwiseMaxGrad(Tensor dz, Tensor x, Tensor y)"
bind_python: False
- name: "stack"
signature: "Tensor Stack(TensorTuple inputs, *, Int64 dim=0)"
bind_python: True
- name: "identity"
signature: "Tensor Identity(Tensor in)"
bind_python: True
- name: "reshape_like"
signature: "Tensor ReshapeLike(Tensor in, Tensor like)"
bind_python: True
- name: "reduce_sum_like"
signature: "Tensor ReduceSumLike(Tensor in, Tensor like, *,Int32List axis)"
bind_python: True
......@@ -1012,6 +1012,99 @@ class TensorSetItemFunctor {
}
};
class ElementwiseMinimumGradFunctor {
public:
ElementwiseMinimumGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("elementwise_minimum_backward")
.Input("dz")
.Input("x")
.Input("y")
.Output("dx")
.Output("dy")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dz,
const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& y) const {
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dz, x, y});
}
private:
std::shared_ptr<OpExpr> op_;
};
class ElementwiseMaximumGradFunctor {
public:
ElementwiseMaximumGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("elementwise_maximum_backward")
.Input("dz")
.Input("x")
.Input("y")
.Output("dx")
.Output("dy")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dz,
const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& y) const {
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dz, x, y});
}
private:
std::shared_ptr<OpExpr> op_;
};
class BroadcastDivGradFunctor {
public:
BroadcastDivGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("broadcast_div_grad")
.Input("dz")
.Input("z")
.Input("y")
.Output("dy")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dz,
const std::shared_ptr<one::Tensor>& z,
const std::shared_ptr<one::Tensor>& y) const {
return OpInterpUtil::Dispatch<Tensor>(*op_, {dz, z, y});
}
private:
std::shared_ptr<OpExpr> op_;
};
class IdentityFunctor {
public:
IdentityFunctor() {
op_ = CHECK_JUST(one::OpBuilder("identity").Input("in").Output("out").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& in) const {
return OpInterpUtil::Dispatch<Tensor>(*op_, {in});
}
private:
std::shared_ptr<OpExpr> op_;
};
class ReduceSumLikeFunctor {
public:
ReduceSumLikeFunctor() {
op_ =
CHECK_JUST(one::OpBuilder("reduce_sum_like").Input("x").Input("like").Output("y").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& like,
const std::vector<int32_t>& axis) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", axis));
return OpInterpUtil::Dispatch<Tensor>(*op_, {x, like}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) {
......@@ -1061,6 +1154,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::DiagGradFunctor>("DiagGrad");
m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem");
m.add_functor<impl::ElementwiseMinimumGradFunctor>("ElementwiseMinGrad");
m.add_functor<impl::ElementwiseMaximumGradFunctor>("ElementwiseMaxGrad");
m.add_functor<impl::BroadcastDivGradFunctor>("BroadcastDivGrad");
m.add_functor<impl::IdentityFunctor>("Identity");
m.add_functor<impl::ReduceSumLikeFunctor>("ReduceSumLike");
};
} // namespace functional
......
......@@ -162,6 +162,44 @@ class ScalarDivByTensorFunctor : public BinaryFunctor {
}
};
class BroadcastMinimumFunctor : public BinaryFunctor {
public:
BroadcastMinimumFunctor() {
op_ = CHECK_JUST(one::OpBuilder("broadcast_minimum").Input("x").Input("y").Output("z").Build());
}
};
class BroadcastMaximumFunctor : public BinaryFunctor {
public:
BroadcastMaximumFunctor() {
op_ = CHECK_JUST(one::OpBuilder("broadcast_maximum").Input("x").Input("y").Output("z").Build());
}
};
class ElementwiseMinimumFunctor : public BinaryFunctor {
public:
ElementwiseMinimumFunctor() {
op_ =
CHECK_JUST(one::OpBuilder("elementwise_minimum").Input("x").Input("y").Output("z").Build());
}
};
class ElementwiseMaximumFunctor : public BinaryFunctor {
public:
ElementwiseMaximumFunctor() {
op_ =
CHECK_JUST(one::OpBuilder("elementwise_maximum").Input("x").Input("y").Output("z").Build());
}
};
class ReshapeLikeFunctor : public BinaryFunctor {
public:
ReshapeLikeFunctor() {
op_ =
CHECK_JUST(one::OpBuilder("reshape_like").Input("in").Input("like").Output("out").Build());
}
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) {
......@@ -172,6 +210,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BroadcastSubFunctor>("BroadcastSub");
m.add_functor<impl::BroadcastMulFunctor>("BroadcastMul");
m.add_functor<impl::BroadcastDivFunctor>("BroadcastDiv");
m.add_functor<impl::BroadcastMinimumFunctor>("BroadcastMin");
m.add_functor<impl::BroadcastMaximumFunctor>("BroadcastMax");
m.add_functor<impl::BroadcastEqualFunctor>("BroadcastEqual");
m.add_functor<impl::BroadcastNotEqualFunctor>("BroadcastNotEqual");
m.add_functor<impl::BroadcastGreaterFunctor>("BroadcastGreater");
......@@ -182,7 +222,10 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::ScalarSubByTensorFunctor>("ScalarSubByTensor");
m.add_functor<impl::ScalarMulByTensorFunctor>("ScalarMulByTensor");
m.add_functor<impl::ScalarDivByTensorFunctor>("ScalarDivByTensor");
m.add_functor<impl::ElementwiseMinimumFunctor>("ElementwiseMin");
m.add_functor<impl::ElementwiseMaximumFunctor>("ElementwiseMax");
m.add_functor<impl::BroadcastFModFunctor>("BroadcastFMod");
m.add_functor<impl::ReshapeLikeFunctor>("ReshapeLike");
};
} // namespace functional
......
......@@ -284,6 +284,8 @@ from oneflow.nn.modules.math_ops import erfc_op as erfc
from oneflow.nn.modules.math_ops import expm1_op as expm1
from oneflow.nn.modules.math_ops import fmod_op as fmod
from oneflow.nn.modules.math_ops import log_op as log
from oneflow.nn.modules.math_ops import minimum as minimum
from oneflow.nn.modules.math_ops import maximum as maximum
from oneflow.nn.modules.math_ops import pow_op as pow
from oneflow.nn.modules.math_ops import rsqrt_op as rsqrt
from oneflow.nn.modules.math_ops import sin_op as sin
......
......@@ -1637,6 +1637,92 @@ def topk_op(input, k, dim: int = None, largest: bool = True, sorted: bool = True
return Topk(k=k, dim=dim, largest=largest, sorted=sorted)(input)
class ElementwiseMinimum(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return flow.F.elementwise_min(x, y)
class BroadcastMinimum(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return flow.F.broadcast_min(x, y)
@register_tensor_op("minimum")
def minimum(x, y):
r"""Computes the element-wise minimum of x and y.
For example:
.. code-block:: python
>>> import numpy as np
>>> import oneflow as flow
>>> x = flow.tensor((1, 2, -1), dtype=flow.float)
>>> y = flow.tensor((3, 0, 4), dtype=flow.float)
>>> flow.minimum(x, y)
tensor([ 1., 0., -1.], dtype=oneflow.float32)
>>> x = flow.tensor((1,), dtype=flow.float)
>>> y = flow.tensor((3, 0, 4), dtype=flow.float)
>>> flow.minimum(x, y)
tensor([1., 0., 1.], dtype=oneflow.float32)
"""
if x.shape == y.shape:
return ElementwiseMinimum()(x, y)
else:
return BroadcastMinimum()(x, y)
class ElementwiseMaximum(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return flow.F.elementwise_max(x, y)
class BroadcastMaximum(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return flow.F.broadcast_max(x, y)
@register_tensor_op("maximum")
def maximum(x, y):
r"""Computes the element-wise maximum of x and y.
For example:
.. code-block:: python
>>> import numpy as np
>>> import oneflow as flow
>>> x = flow.tensor((1, 2, -1), dtype=flow.float)
>>> y = flow.tensor((3, 0, 4), dtype=flow.float)
>>> flow.maximum(x, y)
tensor([3., 2., 4.], dtype=oneflow.float32)
>>> x = flow.tensor((1,), dtype=flow.float)
>>> y = flow.tensor((3, 0, 4), dtype=flow.float)
>>> flow.maximum(x, y)
tensor([3., 1., 4.], dtype=oneflow.float32)
"""
if x.shape == y.shape:
return ElementwiseMaximum()(x, y)
else:
return BroadcastMaximum()(x, y)
if __name__ == "__main__":
import doctest
......
......@@ -602,5 +602,160 @@ class TestAtan2(flow.unittest.TestCase):
return y
def _test_elementwise_minimum(test_case, device):
arg_dict = OrderedDict()
arg_dict["shape"] = [(10, 10, 200), (3, 12), (12,)]
arg_dict["data_type"] = ["float32", "double"]
for (shape, data_type) in GenArgList(arg_dict):
input_x = flow.Tensor(
np.random.randn(*shape),
dtype=type_name_to_flow_type[data_type],
device=flow.device(device),
)
input_y = flow.Tensor(
np.random.randn(*shape),
dtype=type_name_to_flow_type[data_type],
device=flow.device(device),
)
of_values = flow.minimum(input_x, input_y)
np_values = np.minimum(input_x.numpy(), input_y.numpy())
test_case.assertTrue(
np.array_equal(of_values.numpy().flatten(), np_values.flatten())
)
def _test_broadcast_minimum(test_case, device):
arg_dict = OrderedDict()
arg_dict["shape"] = [[(10, 10, 200), (10, 1, 1)], [(3, 12), (1, 12)]]
arg_dict["data_type"] = ["float32", "double"]
for (shape, data_type) in GenArgList(arg_dict):
input_x = flow.Tensor(
np.random.randn(*shape[0]),
dtype=type_name_to_flow_type[data_type],
device=flow.device(device),
)
input_y = flow.Tensor(
np.random.randn(*shape[1]),
dtype=type_name_to_flow_type[data_type],
device=flow.device(device),
)
of_values = flow.minimum(input_x, input_y)
np_values = np.minimum(input_x.numpy(), input_y.numpy())
test_case.assertTrue(
np.array_equal(of_values.numpy().flatten(), np_values.flatten())
)
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
@flow.unittest.skip_unless_1n1d()
class TestMinimum(flow.unittest.TestCase):
def test_minimum(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_elementwise_minimum,
_test_broadcast_minimum,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest()
def test_flow_elementwise_minimum_with_random_data(test_case):
k1 = random(2, 6)
k2 = random(2, 6)
x = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2)
y = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2)
return torch.minimum(x, y)
@autotest()
def test_flow_broadcast_minimum_with_random_data(test_case):
k1 = random(2, 6)
k2 = random(2, 6)
k3 = random(2, 6)
x = random_pytorch_tensor(ndim=3, dim0=k1, dim1=1, dim2=1)
y = random_pytorch_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3)
return torch.minimum(x, y)
def _test_elementwise_maximum(test_case, device):
arg_dict = OrderedDict()
arg_dict["shape"] = [(10, 10, 200), (3, 12), (12,)]
arg_dict["data_type"] = ["float32", "double"]
for (shape, data_type) in GenArgList(arg_dict):
input_x = flow.Tensor(
np.random.randn(*shape),
dtype=type_name_to_flow_type[data_type],
device=flow.device(device),
)
input_y = flow.Tensor(
np.random.randn(*shape),
dtype=type_name_to_flow_type[data_type],
device=flow.device(device),
)
of_values = flow.maximum(input_x, input_y)
np_values = np.maximum(input_x.numpy(), input_y.numpy())
test_case.assertTrue(
np.array_equal(of_values.numpy().flatten(), np_values.flatten())
)
def _test_broadcast_maximum(test_case, device):
arg_dict = OrderedDict()
arg_dict["shape"] = [[(10, 10, 200), (10, 1, 1)], [(3, 12), (1, 12)]]
arg_dict["data_type"] = ["float32", "double"]
for (shape, data_type) in GenArgList(arg_dict):
input_x = flow.Tensor(
np.random.randn(*shape[0]),
dtype=type_name_to_flow_type[data_type],
device=flow.device(device),
)
input_y = flow.Tensor(
np.random.randn(*shape[1]),
dtype=type_name_to_flow_type[data_type],
device=flow.device(device),
)
of_values = flow.maximum(input_x, input_y)
np_values = np.maximum(input_x.numpy(), input_y.numpy())
test_case.assertTrue(
np.array_equal(of_values.numpy().flatten(), np_values.flatten())
)
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestMaximum(flow.unittest.TestCase):
def test_maximum(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_elementwise_maximum,
_test_broadcast_maximum,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest()
def test_flow_elementwise_mximum_with_random_data(test_case):
k1 = random(2, 6)
k2 = random(2, 6)
x = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2)
y = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2)
return torch.maximum(x, y)
@autotest()
def test_flow_broadcast_maximum_with_random_data(test_case):
k1 = random(2, 6)
k2 = random(2, 6)
k3 = random(2, 6)
x = random_pytorch_tensor(ndim=3, dim0=k1, dim1=1, dim2=1)
y = random_pytorch_tensor(ndim=3, dim0=1, dim1=k2, dim2=k3)
return torch.maximum(x, y)
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment