From 543370b492a2ac58e3b1b202d6b8259c1483edc9 Mon Sep 17 00:00:00 2001 From: Houjiang Chen <chenhoujiangcug@gmail.com> Date: Sun, 9 May 2021 16:13:28 +0800 Subject: [PATCH] Rewrite sparse_softmax_cross_entropy and reduce ops gradient funcs (#4823) * Rewrite sparse softmax cross entropy gradient func. * Rewrite reduce ops gradient funcs. * Add crossentropyloss grad unittest. Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: hjchen2 <hjchen2> --- .../autograd/gradient_funcs/reduce_ops.cpp | 155 ++++++++++++++++++ .../sparse_softmax_cross_entropy.cpp | 85 ++++++++++ oneflow/core/framework/op_expr_helper.cpp | 52 ++++++ oneflow/core/framework/op_expr_helper.h | 16 ++ .../modules/test_crossentropyloss_grad.py | 153 +++++++++++++++++ 5 files changed, 461 insertions(+) create mode 100644 oneflow/core/autograd/gradient_funcs/reduce_ops.cpp create mode 100644 oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp create mode 100644 oneflow/python/test/modules/test_crossentropyloss_grad.py diff --git a/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp b/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp new file mode 100644 index 000000000..e5ec46bb9 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp @@ -0,0 +1,155 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_expr_helper.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" + +namespace oneflow { +namespace one { + +struct ReduceSumOpInterpState : public OpExprInterpState { + std::vector<int32_t> axis; +}; + +class ReduceSumOp : public OpExprGradFunction<ReduceSumOpInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override; + Maybe<void> Capture(ReduceSumOpInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe<void> Apply(const ReduceSumOpInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; + std::shared_ptr<OpExpr> grad_op_; +}; + +Maybe<void> ReduceSumOp::Init(const OpExpr& op) { + const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + const std::string& op_name = fw_op_expr->op_name(); + grad_op_ = JUST(op_expr_helper::BroadcastLikeOp(/*axis=*/{-1}, GradientOpName(op_name))); + return Maybe<void>::Ok(); +} + +Maybe<void> ReduceSumOp::Capture(ReduceSumOpInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("axis")); + ctx->SaveTensorForBackward(inputs.at(0)); + return Maybe<void>::Ok(); +} + +Maybe<void> ReduceSumOp::Apply(const ReduceSumOpInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { + const auto& input = ctx->SavedTensors().at(0); + const auto& dy = out_grads.at(0); + MutableAttrMap attrs; + JUST(attrs.SetAttr<std::vector<int32_t>>("axis", ctx->axis)); + in_grads->resize(1); + in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {dy, input}, attrs)); + return Maybe<void>::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum", ReduceSumOp); + +struct ReduceMaxOrMinOpInterpState : public OpExprInterpState { + std::vector<int32_t> axis; + bool keepdims; +}; + +class ReduceMaxOrMinOp : public OpExprGradFunction<ReduceMaxOrMinOpInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override; + Maybe<void> Capture(ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe<void> Apply(const ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; + std::shared_ptr<OpExpr> bcast_like_op_; + std::shared_ptr<OpExpr> bcast_equal_op_; + std::shared_ptr<OpExpr> cast_like_op_; + std::shared_ptr<OpExpr> reduce_sum_op_; + std::shared_ptr<OpExpr> bcast_div_op_; + std::shared_ptr<OpExpr> multiply_op_; +}; + +Maybe<void> ReduceMaxOrMinOp::Init(const OpExpr& op) { + const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + const std::string& op_name = fw_op_expr->op_name(); + bcast_like_op_ = + JUST(op_expr_helper::BroadcastLikeOp(/*axis=*/{-1}, GradientOpName(op_name + "_bcast_like"))); + bcast_equal_op_ = JUST(op_expr_helper::BroadcastEqualOp(GradientOpName(op_name + "_bcast_eq"))); + cast_like_op_ = JUST(op_expr_helper::CastLikeOp(GradientOpName(op_name + "_cast_like"))); + reduce_sum_op_ = JUST(op_expr_helper::ReduceSumOp(/*axis=*/{-1}, /*keepdims=*/false, + GradientOpName(op_name + "_reduce_sum"))); + bcast_div_op_ = JUST(op_expr_helper::BroadcastDivOp(GradientOpName(op_name + "_bcast_div"))); + multiply_op_ = JUST(op_expr_helper::MultiplyOp(op_name + "_multiply")); + return Maybe<void>::Ok(); +} + +Maybe<void> ReduceMaxOrMinOp::Capture(ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("axis")); + ctx->SaveTensorForBackward(inputs.at(0)); + ctx->SaveTensorForBackward(outputs.at(0)); + return Maybe<void>::Ok(); +} + +Maybe<void> ReduceMaxOrMinOp::Apply(const ReduceMaxOrMinOpInterpState* ctx, + const TensorTuple& out_grads, TensorTuple* in_grads) const { + const auto& input = ctx->SavedTensors().at(0); + const auto& output = ctx->SavedTensors().at(1); + const auto& dy = out_grads.at(0); + + MutableAttrMap bcast_attrs; + JUST(bcast_attrs.SetAttr<std::vector<int32_t>>("axis", ctx->axis)); + const auto& bcast_like = + JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_like_op_, {output, input}, bcast_attrs)); + const auto& bcast_eq = + JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_equal_op_, {input, bcast_like})); + const auto& cast_like = JUST(OpInterpUtil::Dispatch<Tensor>(*cast_like_op_, {bcast_eq, input})); + + MutableAttrMap reduce_sum_attrs; + JUST(reduce_sum_attrs.SetAttr<std::vector<int32_t>>("axis", ctx->axis)); + JUST(reduce_sum_attrs.SetAttr<bool>("keepdims", ctx->keepdims)); + const auto& reduce_sum = + JUST(OpInterpUtil::Dispatch<Tensor>(*reduce_sum_op_, {cast_like}, reduce_sum_attrs)); + const auto& broadcast_div = + JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_div_op_, {dy, reduce_sum})); + const auto& bcast_like_div = + JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_like_op_, {broadcast_div, input}, bcast_attrs)); + + in_grads->resize(1); + in_grads->at(0) = + JUST(OpInterpUtil::Dispatch<Tensor>(*multiply_op_, {bcast_like_div, cast_like})); + return Maybe<void>::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min", ReduceMaxOrMinOp); +REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max", ReduceMaxOrMinOp); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp b/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp new file mode 100644 index 000000000..859010ccb --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp @@ -0,0 +1,85 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_expr_helper.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" + +namespace oneflow { +namespace one { + +struct SparseSoftmaxCrossEntropyInterpState : public OpExprInterpState { + int64_t depth; +}; + +class SparseSoftmaxCrossEntropy : public OpExprGradFunction<SparseSoftmaxCrossEntropyInterpState> { + public: + Maybe<void> Init(const OpExpr& op) override; + Maybe<void> Capture(SparseSoftmaxCrossEntropyInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe<void> Apply(const SparseSoftmaxCrossEntropyInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; + std::shared_ptr<OpExpr> grad_op_; +}; + +Maybe<void> SparseSoftmaxCrossEntropy::Init(const OpExpr& op) { + const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + const std::string& op_name = fw_op_expr->op_name(); + grad_op_ = + JUST(op_expr_helper::SparseSoftmaxCrossEntropyGradOp(/*depth=*/-1, GradientOpName(op_name))); + return Maybe<void>::Ok(); +} + +Maybe<void> SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyInterpState* ctx, + const TensorTuple& inputs, + const TensorTuple& outputs, + const AttrMap& attrs) const { + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->depth = JUST(composed_attrs.GetAttr<int64_t>("depth")); + CHECK_EQ_OR_RETURN(inputs.size(), 2); + CHECK_EQ_OR_RETURN(outputs.size(), 2); + ctx->SaveTensorForBackward(outputs.at(0)); // prob + ctx->SaveTensorForBackward(inputs.at(1)); // label + return Maybe<void>::Ok(); +} + +Maybe<void> SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyInterpState* ctx, + const TensorTuple& out_grads, + TensorTuple* in_grads) const { + CHECK_EQ_OR_RETURN(out_grads.size(), 2); + const auto& dy = out_grads.at(1); + const auto& prob = ctx->SavedTensors().at(0); + const auto& label = ctx->SavedTensors().at(1); + MutableAttrMap attrs; + JUST(attrs.SetAttr<int64_t>("depth", ctx->depth)); + // SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not + // require gradient. + in_grads->resize(2); + in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {prob, label, dy}, attrs)); + return Maybe<void>::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_softmax_cross_entropy", SparseSoftmaxCrossEntropy); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp index b0b8ee8c7..eb0c228bb 100644 --- a/oneflow/core/framework/op_expr_helper.cpp +++ b/oneflow/core/framework/op_expr_helper.cpp @@ -245,6 +245,25 @@ Maybe<one::UserOpExpr> BroadcastDivOp(const std::string& name) { return one::OpBuilder("broadcast_div", name).Input("x").Input("y").Output("z").Build(); } +Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis) { + return BroadcastLikeOp(axis, UniqueOpName("broadcast_like")); +} +Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis, const std::string& name) { + return one::OpBuilder("broadcast_like", name) + .Input("x") + .Input("like") + .Output("y") + .Attr<std::vector<int32_t>>("broadcast_axes", axis) + .Build(); +} + +Maybe<one::UserOpExpr> BroadcastEqualOp() { + return BroadcastEqualOp(UniqueOpName("broadcast_equal")); +} +Maybe<one::UserOpExpr> BroadcastEqualOp(const std::string& name) { + return one::OpBuilder("broadcast_equal", name).Input("x").Input("y").Output("z").Build(); +} + Maybe<one::UserOpExpr> CastOp(const DataType& to_type) { return CastOp(to_type, UniqueOpName("cast")); } @@ -256,6 +275,11 @@ Maybe<one::UserOpExpr> CastOp(const DataType& to_type, const std::string& name) .Build(); } +Maybe<one::UserOpExpr> CastLikeOp() { return CastLikeOp(UniqueOpName("cast_like")); } +Maybe<one::UserOpExpr> CastLikeOp(const std::string& name) { + return one::OpBuilder("cast_like", name).Input("in").Input("dtype_like").Output("out").Build(); +} + Maybe<one::UserOpExpr> NormalizationGradOp(const int32_t& axis, const float& epsilon) { return NormalizationGradOp(axis, epsilon, UniqueOpName("normalization_grad")); } @@ -456,5 +480,33 @@ Maybe<one::UserOpExpr> ConvNdFilterGradOp(const std::vector<int32_t>& kernel_siz .Build(); } +Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth) { + return SparseSoftmaxCrossEntropyGradOp(depth, UniqueOpName("sparse_softmax_cross_entropy")); +} +Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth, + const std::string& name) { + return one::OpBuilder("sparse_softmax_cross_entropy_grad", name) + .Input("prob") + .Input("label") + .Input("dy") + .Output("prediction_diff") + .Attr<int64_t>("depth", depth) + .Build(); +} + +Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth) { + return SparseSoftmaxCrossEntropyMsGradOp(depth, UniqueOpName("sparse_softmax_cross_entropy_ms")); +} +Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth, + const std::string& name) { + return one::OpBuilder("sparse_softmax_cross_entropy_ms_grad", name) + .Input("prob") + .Input("label") + .Input("dy") + .Output("prediction_diff") + .Attr<int64_t>("depth", depth) + .Build(); +} + } // namespace op_expr_helper } // namespace oneflow diff --git a/oneflow/core/framework/op_expr_helper.h b/oneflow/core/framework/op_expr_helper.h index 6d2f7268c..c3b282799 100644 --- a/oneflow/core/framework/op_expr_helper.h +++ b/oneflow/core/framework/op_expr_helper.h @@ -85,9 +85,18 @@ Maybe<one::UserOpExpr> BroadcastMulOp(const std::string& name); Maybe<one::UserOpExpr> BroadcastDivOp(); Maybe<one::UserOpExpr> BroadcastDivOp(const std::string& name); +Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis); +Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis, const std::string& name); + +Maybe<one::UserOpExpr> BroadcastEqualOp(); +Maybe<one::UserOpExpr> BroadcastEqualOp(const std::string& name); + Maybe<one::UserOpExpr> CastOp(const DataType& to_type); Maybe<one::UserOpExpr> CastOp(const DataType& to_type, const std::string& name); +Maybe<one::UserOpExpr> CastLikeOp(); +Maybe<one::UserOpExpr> CastLikeOp(const std::string& name); + Maybe<one::UserOpExpr> NormalizationGradOp(const int32_t& axis, const float& epsilon); Maybe<one::UserOpExpr> NormalizationGradOp(const int32_t& axis, const float& epsilon, const std::string& name); @@ -152,5 +161,12 @@ Maybe<one::UserOpExpr> ConvNdFilterGradOp(const std::vector<int32_t>& kernel_siz const int& groups, const std::string& data_format, const std::string& name); +Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth); +Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth, + const std::string& name); +Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth); +Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth, + const std::string& name); + } // namespace op_expr_helper } // namespace oneflow diff --git a/oneflow/python/test/modules/test_crossentropyloss_grad.py b/oneflow/python/test/modules/test_crossentropyloss_grad.py new file mode 100644 index 000000000..faa31a0c9 --- /dev/null +++ b/oneflow/python/test/modules/test_crossentropyloss_grad.py @@ -0,0 +1,153 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import unittest +import numpy as np + + +def gen_random_input(): + return np.array( + [ + [1.1909, -1.5726, 0.9973, -0.7698, -1.1273], + [1.1354, -1.1815, -1.0553, -0.6178, -2.1103], + ] + ) + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestCrossEntropyLossModuleGrad(flow.unittest.TestCase): + def test_CrossEntropyLoss_mean(test_case): + label = flow.Tensor(np.array([0, 1]), dtype=flow.int32) + predict = flow.Tensor(np.ones([2, 5]), requires_grad=True) + + CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="mean") + loss = CrossEntropyLoss(predict, label) + loss.backward() + target = np.array( + [ + [-0.4000, 0.1000, 0.1000, 0.1000, 0.1000], + [0.1000, -0.4000, 0.1000, 0.1000, 0.1000], + ] + ) + + test_case.assertTrue(predict.grad is not None) + test_case.assertTrue( + np.allclose(predict.grad.numpy(), target, rtol=1e-4, atol=1e-8) + ) + + def test_CrossEntropyLoss_sum(test_case): + label = flow.Tensor(np.array([0, 1]), dtype=flow.int32) + predict = flow.Tensor(np.ones([2, 5]), requires_grad=True) + + CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="sum") + loss = CrossEntropyLoss(predict, label) + loss.backward() + target = np.array( + [ + [-0.8000, 0.2000, 0.2000, 0.2000, 0.2000], + [0.2000, -0.8000, 0.2000, 0.2000, 0.2000], + ] + ) + + test_case.assertTrue(predict.grad is not None) + test_case.assertTrue( + np.allclose(predict.grad.numpy(), target, rtol=1e-4, atol=1e-8) + ) + + def test_CrossEntropyLoss_none(test_case): + label = flow.Tensor(np.array([0, 1]), dtype=flow.int32) + predict = flow.Tensor(np.ones([2, 5]), requires_grad=True) + + CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="none") + loss = CrossEntropyLoss(predict, label) + grad = flow.Tensor(np.ones([2])) + loss.backward(grad) + target = np.array( + [ + [-0.8000, 0.2000, 0.2000, 0.2000, 0.2000], + [0.2000, -0.8000, 0.2000, 0.2000, 0.2000], + ] + ) + + test_case.assertTrue(predict.grad is not None) + test_case.assertTrue( + np.allclose(predict.grad.numpy(), target, rtol=1e-4, atol=1e-8) + ) + + def test_CrossEntropyLoss_mean_with_random_input(test_case): + label = flow.Tensor(np.array([0, 1]), dtype=flow.int32) + predict = flow.Tensor(gen_random_input(), requires_grad=True) + + CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="mean") + loss = CrossEntropyLoss(predict, label) + loss.backward() + target = np.array( + [ + [-0.2648, 0.0148, 0.1938, 0.0331, 0.0232], + [0.3515, -0.4654, 0.0393, 0.0609, 0.0137], + ] + ) + + test_case.assertTrue(predict.grad is not None) + test_case.assertTrue( + np.allclose(predict.grad.numpy(), target, rtol=1e-2, atol=1e-8) + ) + + def test_CrossEntropyLoss_sum_with_random_input(test_case): + label = flow.Tensor(np.array([0, 1]), dtype=flow.int32) + predict = flow.Tensor(gen_random_input(), requires_grad=True) + + CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="sum") + loss = CrossEntropyLoss(predict, label) + loss.backward() + target = np.array( + [ + [-0.5297, 0.0297, 0.3875, 0.0662, 0.0463], + [0.7029, -0.9307, 0.0786, 0.1218, 0.0274], + ] + ) + + test_case.assertTrue(predict.grad is not None) + test_case.assertTrue( + np.allclose(predict.grad.numpy(), target, rtol=1e-2, atol=1e-8) + ) + + def test_CrossEntropyLoss_none_with_random_input(test_case): + label = flow.Tensor(np.array([0, 1]), dtype=flow.int32) + predict = flow.Tensor(gen_random_input(), requires_grad=True) + + CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="none") + loss = CrossEntropyLoss(predict, label) + grad = flow.Tensor(np.ones([2])) + loss.backward(grad) + target = np.array( + [ + [-0.5297, 0.0297, 0.3875, 0.0662, 0.0463], + [0.7029, -0.9307, 0.0786, 0.1218, 0.0274], + ] + ) + + test_case.assertTrue(predict.grad is not None) + test_case.assertTrue( + np.allclose(predict.grad.numpy(), target, rtol=1e-2, atol=1e-8) + ) + + +if __name__ == "__main__": + unittest.main() -- GitLab