Skip to content
Snippets Groups Projects
Unverified Commit 543370b4 authored by Houjiang Chen's avatar Houjiang Chen Committed by GitHub
Browse files

Rewrite sparse_softmax_cross_entropy and reduce ops gradient funcs (#4823)


* Rewrite sparse softmax cross entropy gradient func.

* Rewrite reduce ops gradient funcs.

* Add crossentropyloss grad unittest.

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: hjchen2 <hjchen2>
parent 0a0c550a
No related branches found
No related tags found
No related merge requests found
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_expr_helper.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
namespace oneflow {
namespace one {
struct ReduceSumOpInterpState : public OpExprInterpState {
std::vector<int32_t> axis;
};
class ReduceSumOp : public OpExprGradFunction<ReduceSumOpInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(ReduceSumOpInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ReduceSumOpInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::shared_ptr<OpExpr> grad_op_;
};
Maybe<void> ReduceSumOp::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
const std::string& op_name = fw_op_expr->op_name();
grad_op_ = JUST(op_expr_helper::BroadcastLikeOp(/*axis=*/{-1}, GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> ReduceSumOp::Capture(ReduceSumOpInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("axis"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> ReduceSumOp::Apply(const ReduceSumOpInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
const auto& input = ctx->SavedTensors().at(0);
const auto& dy = out_grads.at(0);
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", ctx->axis));
in_grads->resize(1);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {dy, input}, attrs));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum", ReduceSumOp);
struct ReduceMaxOrMinOpInterpState : public OpExprInterpState {
std::vector<int32_t> axis;
bool keepdims;
};
class ReduceMaxOrMinOp : public OpExprGradFunction<ReduceMaxOrMinOpInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::shared_ptr<OpExpr> bcast_like_op_;
std::shared_ptr<OpExpr> bcast_equal_op_;
std::shared_ptr<OpExpr> cast_like_op_;
std::shared_ptr<OpExpr> reduce_sum_op_;
std::shared_ptr<OpExpr> bcast_div_op_;
std::shared_ptr<OpExpr> multiply_op_;
};
Maybe<void> ReduceMaxOrMinOp::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
const std::string& op_name = fw_op_expr->op_name();
bcast_like_op_ =
JUST(op_expr_helper::BroadcastLikeOp(/*axis=*/{-1}, GradientOpName(op_name + "_bcast_like")));
bcast_equal_op_ = JUST(op_expr_helper::BroadcastEqualOp(GradientOpName(op_name + "_bcast_eq")));
cast_like_op_ = JUST(op_expr_helper::CastLikeOp(GradientOpName(op_name + "_cast_like")));
reduce_sum_op_ = JUST(op_expr_helper::ReduceSumOp(/*axis=*/{-1}, /*keepdims=*/false,
GradientOpName(op_name + "_reduce_sum")));
bcast_div_op_ = JUST(op_expr_helper::BroadcastDivOp(GradientOpName(op_name + "_bcast_div")));
multiply_op_ = JUST(op_expr_helper::MultiplyOp(op_name + "_multiply"));
return Maybe<void>::Ok();
}
Maybe<void> ReduceMaxOrMinOp::Capture(ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("axis"));
ctx->SaveTensorForBackward(inputs.at(0));
ctx->SaveTensorForBackward(outputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> ReduceMaxOrMinOp::Apply(const ReduceMaxOrMinOpInterpState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
const auto& input = ctx->SavedTensors().at(0);
const auto& output = ctx->SavedTensors().at(1);
const auto& dy = out_grads.at(0);
MutableAttrMap bcast_attrs;
JUST(bcast_attrs.SetAttr<std::vector<int32_t>>("axis", ctx->axis));
const auto& bcast_like =
JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_like_op_, {output, input}, bcast_attrs));
const auto& bcast_eq =
JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_equal_op_, {input, bcast_like}));
const auto& cast_like = JUST(OpInterpUtil::Dispatch<Tensor>(*cast_like_op_, {bcast_eq, input}));
MutableAttrMap reduce_sum_attrs;
JUST(reduce_sum_attrs.SetAttr<std::vector<int32_t>>("axis", ctx->axis));
JUST(reduce_sum_attrs.SetAttr<bool>("keepdims", ctx->keepdims));
const auto& reduce_sum =
JUST(OpInterpUtil::Dispatch<Tensor>(*reduce_sum_op_, {cast_like}, reduce_sum_attrs));
const auto& broadcast_div =
JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_div_op_, {dy, reduce_sum}));
const auto& bcast_like_div =
JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_like_op_, {broadcast_div, input}, bcast_attrs));
in_grads->resize(1);
in_grads->at(0) =
JUST(OpInterpUtil::Dispatch<Tensor>(*multiply_op_, {bcast_like_div, cast_like}));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min", ReduceMaxOrMinOp);
REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max", ReduceMaxOrMinOp);
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_expr_helper.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
namespace oneflow {
namespace one {
struct SparseSoftmaxCrossEntropyInterpState : public OpExprInterpState {
int64_t depth;
};
class SparseSoftmaxCrossEntropy : public OpExprGradFunction<SparseSoftmaxCrossEntropyInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(SparseSoftmaxCrossEntropyInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const SparseSoftmaxCrossEntropyInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
std::shared_ptr<OpExpr> grad_op_;
};
Maybe<void> SparseSoftmaxCrossEntropy::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
const std::string& op_name = fw_op_expr->op_name();
grad_op_ =
JUST(op_expr_helper::SparseSoftmaxCrossEntropyGradOp(/*depth=*/-1, GradientOpName(op_name)));
return Maybe<void>::Ok();
}
Maybe<void> SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyInterpState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->depth = JUST(composed_attrs.GetAttr<int64_t>("depth"));
CHECK_EQ_OR_RETURN(inputs.size(), 2);
CHECK_EQ_OR_RETURN(outputs.size(), 2);
ctx->SaveTensorForBackward(outputs.at(0)); // prob
ctx->SaveTensorForBackward(inputs.at(1)); // label
return Maybe<void>::Ok();
}
Maybe<void> SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyInterpState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 2);
const auto& dy = out_grads.at(1);
const auto& prob = ctx->SavedTensors().at(0);
const auto& label = ctx->SavedTensors().at(1);
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("depth", ctx->depth));
// SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not
// require gradient.
in_grads->resize(2);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {prob, label, dy}, attrs));
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_softmax_cross_entropy", SparseSoftmaxCrossEntropy);
} // namespace one
} // namespace oneflow
......@@ -245,6 +245,25 @@ Maybe<one::UserOpExpr> BroadcastDivOp(const std::string& name) {
return one::OpBuilder("broadcast_div", name).Input("x").Input("y").Output("z").Build();
}
Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis) {
return BroadcastLikeOp(axis, UniqueOpName("broadcast_like"));
}
Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis, const std::string& name) {
return one::OpBuilder("broadcast_like", name)
.Input("x")
.Input("like")
.Output("y")
.Attr<std::vector<int32_t>>("broadcast_axes", axis)
.Build();
}
Maybe<one::UserOpExpr> BroadcastEqualOp() {
return BroadcastEqualOp(UniqueOpName("broadcast_equal"));
}
Maybe<one::UserOpExpr> BroadcastEqualOp(const std::string& name) {
return one::OpBuilder("broadcast_equal", name).Input("x").Input("y").Output("z").Build();
}
Maybe<one::UserOpExpr> CastOp(const DataType& to_type) {
return CastOp(to_type, UniqueOpName("cast"));
}
......@@ -256,6 +275,11 @@ Maybe<one::UserOpExpr> CastOp(const DataType& to_type, const std::string& name)
.Build();
}
Maybe<one::UserOpExpr> CastLikeOp() { return CastLikeOp(UniqueOpName("cast_like")); }
Maybe<one::UserOpExpr> CastLikeOp(const std::string& name) {
return one::OpBuilder("cast_like", name).Input("in").Input("dtype_like").Output("out").Build();
}
Maybe<one::UserOpExpr> NormalizationGradOp(const int32_t& axis, const float& epsilon) {
return NormalizationGradOp(axis, epsilon, UniqueOpName("normalization_grad"));
}
......@@ -456,5 +480,33 @@ Maybe<one::UserOpExpr> ConvNdFilterGradOp(const std::vector<int32_t>& kernel_siz
.Build();
}
Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth) {
return SparseSoftmaxCrossEntropyGradOp(depth, UniqueOpName("sparse_softmax_cross_entropy"));
}
Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth,
const std::string& name) {
return one::OpBuilder("sparse_softmax_cross_entropy_grad", name)
.Input("prob")
.Input("label")
.Input("dy")
.Output("prediction_diff")
.Attr<int64_t>("depth", depth)
.Build();
}
Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth) {
return SparseSoftmaxCrossEntropyMsGradOp(depth, UniqueOpName("sparse_softmax_cross_entropy_ms"));
}
Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth,
const std::string& name) {
return one::OpBuilder("sparse_softmax_cross_entropy_ms_grad", name)
.Input("prob")
.Input("label")
.Input("dy")
.Output("prediction_diff")
.Attr<int64_t>("depth", depth)
.Build();
}
} // namespace op_expr_helper
} // namespace oneflow
......@@ -85,9 +85,18 @@ Maybe<one::UserOpExpr> BroadcastMulOp(const std::string& name);
Maybe<one::UserOpExpr> BroadcastDivOp();
Maybe<one::UserOpExpr> BroadcastDivOp(const std::string& name);
Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis);
Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis, const std::string& name);
Maybe<one::UserOpExpr> BroadcastEqualOp();
Maybe<one::UserOpExpr> BroadcastEqualOp(const std::string& name);
Maybe<one::UserOpExpr> CastOp(const DataType& to_type);
Maybe<one::UserOpExpr> CastOp(const DataType& to_type, const std::string& name);
Maybe<one::UserOpExpr> CastLikeOp();
Maybe<one::UserOpExpr> CastLikeOp(const std::string& name);
Maybe<one::UserOpExpr> NormalizationGradOp(const int32_t& axis, const float& epsilon);
Maybe<one::UserOpExpr> NormalizationGradOp(const int32_t& axis, const float& epsilon,
const std::string& name);
......@@ -152,5 +161,12 @@ Maybe<one::UserOpExpr> ConvNdFilterGradOp(const std::vector<int32_t>& kernel_siz
const int& groups, const std::string& data_format,
const std::string& name);
Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth);
Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth,
const std::string& name);
Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth);
Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth,
const std::string& name);
} // namespace op_expr_helper
} // namespace oneflow
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
import unittest
import numpy as np
def gen_random_input():
return np.array(
[
[1.1909, -1.5726, 0.9973, -0.7698, -1.1273],
[1.1354, -1.1815, -1.0553, -0.6178, -2.1103],
]
)
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestCrossEntropyLossModuleGrad(flow.unittest.TestCase):
def test_CrossEntropyLoss_mean(test_case):
label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
predict = flow.Tensor(np.ones([2, 5]), requires_grad=True)
CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="mean")
loss = CrossEntropyLoss(predict, label)
loss.backward()
target = np.array(
[
[-0.4000, 0.1000, 0.1000, 0.1000, 0.1000],
[0.1000, -0.4000, 0.1000, 0.1000, 0.1000],
]
)
test_case.assertTrue(predict.grad is not None)
test_case.assertTrue(
np.allclose(predict.grad.numpy(), target, rtol=1e-4, atol=1e-8)
)
def test_CrossEntropyLoss_sum(test_case):
label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
predict = flow.Tensor(np.ones([2, 5]), requires_grad=True)
CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="sum")
loss = CrossEntropyLoss(predict, label)
loss.backward()
target = np.array(
[
[-0.8000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, -0.8000, 0.2000, 0.2000, 0.2000],
]
)
test_case.assertTrue(predict.grad is not None)
test_case.assertTrue(
np.allclose(predict.grad.numpy(), target, rtol=1e-4, atol=1e-8)
)
def test_CrossEntropyLoss_none(test_case):
label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
predict = flow.Tensor(np.ones([2, 5]), requires_grad=True)
CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="none")
loss = CrossEntropyLoss(predict, label)
grad = flow.Tensor(np.ones([2]))
loss.backward(grad)
target = np.array(
[
[-0.8000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, -0.8000, 0.2000, 0.2000, 0.2000],
]
)
test_case.assertTrue(predict.grad is not None)
test_case.assertTrue(
np.allclose(predict.grad.numpy(), target, rtol=1e-4, atol=1e-8)
)
def test_CrossEntropyLoss_mean_with_random_input(test_case):
label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
predict = flow.Tensor(gen_random_input(), requires_grad=True)
CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="mean")
loss = CrossEntropyLoss(predict, label)
loss.backward()
target = np.array(
[
[-0.2648, 0.0148, 0.1938, 0.0331, 0.0232],
[0.3515, -0.4654, 0.0393, 0.0609, 0.0137],
]
)
test_case.assertTrue(predict.grad is not None)
test_case.assertTrue(
np.allclose(predict.grad.numpy(), target, rtol=1e-2, atol=1e-8)
)
def test_CrossEntropyLoss_sum_with_random_input(test_case):
label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
predict = flow.Tensor(gen_random_input(), requires_grad=True)
CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="sum")
loss = CrossEntropyLoss(predict, label)
loss.backward()
target = np.array(
[
[-0.5297, 0.0297, 0.3875, 0.0662, 0.0463],
[0.7029, -0.9307, 0.0786, 0.1218, 0.0274],
]
)
test_case.assertTrue(predict.grad is not None)
test_case.assertTrue(
np.allclose(predict.grad.numpy(), target, rtol=1e-2, atol=1e-8)
)
def test_CrossEntropyLoss_none_with_random_input(test_case):
label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
predict = flow.Tensor(gen_random_input(), requires_grad=True)
CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="none")
loss = CrossEntropyLoss(predict, label)
grad = flow.Tensor(np.ones([2]))
loss.backward(grad)
target = np.array(
[
[-0.5297, 0.0297, 0.3875, 0.0662, 0.0463],
[0.7029, -0.9307, 0.0786, 0.1218, 0.0274],
]
)
test_case.assertTrue(predict.grad is not None)
test_case.assertTrue(
np.allclose(predict.grad.numpy(), target, rtol=1e-2, atol=1e-8)
)
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment