From 543370b492a2ac58e3b1b202d6b8259c1483edc9 Mon Sep 17 00:00:00 2001
From: Houjiang Chen <chenhoujiangcug@gmail.com>
Date: Sun, 9 May 2021 16:13:28 +0800
Subject: [PATCH] Rewrite sparse_softmax_cross_entropy  and reduce ops gradient
 funcs (#4823)

* Rewrite sparse softmax cross entropy gradient func.

* Rewrite reduce ops gradient funcs.

* Add crossentropyloss grad unittest.

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: hjchen2 <hjchen2>
---
 .../autograd/gradient_funcs/reduce_ops.cpp    | 155 ++++++++++++++++++
 .../sparse_softmax_cross_entropy.cpp          |  85 ++++++++++
 oneflow/core/framework/op_expr_helper.cpp     |  52 ++++++
 oneflow/core/framework/op_expr_helper.h       |  16 ++
 .../modules/test_crossentropyloss_grad.py     | 153 +++++++++++++++++
 5 files changed, 461 insertions(+)
 create mode 100644 oneflow/core/autograd/gradient_funcs/reduce_ops.cpp
 create mode 100644 oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp
 create mode 100644 oneflow/python/test/modules/test_crossentropyloss_grad.py

diff --git a/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp b/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp
new file mode 100644
index 000000000..e5ec46bb9
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/reduce_ops.cpp
@@ -0,0 +1,155 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/attr_map.h"
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/framework/op_builder.h"
+#include "oneflow/core/framework/op_expr.h"
+#include "oneflow/core/framework/op_expr_helper.h"
+#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
+
+namespace oneflow {
+namespace one {
+
+struct ReduceSumOpInterpState : public OpExprInterpState {
+  std::vector<int32_t> axis;
+};
+
+class ReduceSumOp : public OpExprGradFunction<ReduceSumOpInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override;
+  Maybe<void> Capture(ReduceSumOpInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override;
+  Maybe<void> Apply(const ReduceSumOpInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override;
+
+ private:
+  AttrMap base_attrs_;
+  std::shared_ptr<OpExpr> grad_op_;
+};
+
+Maybe<void> ReduceSumOp::Init(const OpExpr& op) {
+  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
+  CHECK_NOTNULL_OR_RETURN(fw_op_expr);
+  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
+  const std::string& op_name = fw_op_expr->op_name();
+  grad_op_ = JUST(op_expr_helper::BroadcastLikeOp(/*axis=*/{-1}, GradientOpName(op_name)));
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> ReduceSumOp::Capture(ReduceSumOpInterpState* ctx, const TensorTuple& inputs,
+                                 const TensorTuple& outputs, const AttrMap& attrs) const {
+  ComposedAttrMap composed_attrs(attrs, base_attrs_);
+  ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("axis"));
+  ctx->SaveTensorForBackward(inputs.at(0));
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> ReduceSumOp::Apply(const ReduceSumOpInterpState* ctx, const TensorTuple& out_grads,
+                               TensorTuple* in_grads) const {
+  const auto& input = ctx->SavedTensors().at(0);
+  const auto& dy = out_grads.at(0);
+  MutableAttrMap attrs;
+  JUST(attrs.SetAttr<std::vector<int32_t>>("axis", ctx->axis));
+  in_grads->resize(1);
+  in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {dy, input}, attrs));
+  return Maybe<void>::Ok();
+}
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_sum", ReduceSumOp);
+
+struct ReduceMaxOrMinOpInterpState : public OpExprInterpState {
+  std::vector<int32_t> axis;
+  bool keepdims;
+};
+
+class ReduceMaxOrMinOp : public OpExprGradFunction<ReduceMaxOrMinOpInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override;
+  Maybe<void> Capture(ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override;
+  Maybe<void> Apply(const ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override;
+
+ private:
+  AttrMap base_attrs_;
+  std::shared_ptr<OpExpr> bcast_like_op_;
+  std::shared_ptr<OpExpr> bcast_equal_op_;
+  std::shared_ptr<OpExpr> cast_like_op_;
+  std::shared_ptr<OpExpr> reduce_sum_op_;
+  std::shared_ptr<OpExpr> bcast_div_op_;
+  std::shared_ptr<OpExpr> multiply_op_;
+};
+
+Maybe<void> ReduceMaxOrMinOp::Init(const OpExpr& op) {
+  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
+  CHECK_NOTNULL_OR_RETURN(fw_op_expr);
+  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
+  const std::string& op_name = fw_op_expr->op_name();
+  bcast_like_op_ =
+      JUST(op_expr_helper::BroadcastLikeOp(/*axis=*/{-1}, GradientOpName(op_name + "_bcast_like")));
+  bcast_equal_op_ = JUST(op_expr_helper::BroadcastEqualOp(GradientOpName(op_name + "_bcast_eq")));
+  cast_like_op_ = JUST(op_expr_helper::CastLikeOp(GradientOpName(op_name + "_cast_like")));
+  reduce_sum_op_ = JUST(op_expr_helper::ReduceSumOp(/*axis=*/{-1}, /*keepdims=*/false,
+                                                    GradientOpName(op_name + "_reduce_sum")));
+  bcast_div_op_ = JUST(op_expr_helper::BroadcastDivOp(GradientOpName(op_name + "_bcast_div")));
+  multiply_op_ = JUST(op_expr_helper::MultiplyOp(op_name + "_multiply"));
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> ReduceMaxOrMinOp::Capture(ReduceMaxOrMinOpInterpState* ctx, const TensorTuple& inputs,
+                                      const TensorTuple& outputs, const AttrMap& attrs) const {
+  ComposedAttrMap composed_attrs(attrs, base_attrs_);
+  ctx->axis = JUST(composed_attrs.GetAttr<std::vector<int32_t>>("axis"));
+  ctx->SaveTensorForBackward(inputs.at(0));
+  ctx->SaveTensorForBackward(outputs.at(0));
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> ReduceMaxOrMinOp::Apply(const ReduceMaxOrMinOpInterpState* ctx,
+                                    const TensorTuple& out_grads, TensorTuple* in_grads) const {
+  const auto& input = ctx->SavedTensors().at(0);
+  const auto& output = ctx->SavedTensors().at(1);
+  const auto& dy = out_grads.at(0);
+
+  MutableAttrMap bcast_attrs;
+  JUST(bcast_attrs.SetAttr<std::vector<int32_t>>("axis", ctx->axis));
+  const auto& bcast_like =
+      JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_like_op_, {output, input}, bcast_attrs));
+  const auto& bcast_eq =
+      JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_equal_op_, {input, bcast_like}));
+  const auto& cast_like = JUST(OpInterpUtil::Dispatch<Tensor>(*cast_like_op_, {bcast_eq, input}));
+
+  MutableAttrMap reduce_sum_attrs;
+  JUST(reduce_sum_attrs.SetAttr<std::vector<int32_t>>("axis", ctx->axis));
+  JUST(reduce_sum_attrs.SetAttr<bool>("keepdims", ctx->keepdims));
+  const auto& reduce_sum =
+      JUST(OpInterpUtil::Dispatch<Tensor>(*reduce_sum_op_, {cast_like}, reduce_sum_attrs));
+  const auto& broadcast_div =
+      JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_div_op_, {dy, reduce_sum}));
+  const auto& bcast_like_div =
+      JUST(OpInterpUtil::Dispatch<Tensor>(*bcast_like_op_, {broadcast_div, input}, bcast_attrs));
+
+  in_grads->resize(1);
+  in_grads->at(0) =
+      JUST(OpInterpUtil::Dispatch<Tensor>(*multiply_op_, {bcast_like_div, cast_like}));
+  return Maybe<void>::Ok();
+}
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_min", ReduceMaxOrMinOp);
+REGISTER_OP_EXPR_GRAD_FUNCTION("reduce_max", ReduceMaxOrMinOp);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp b/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp
new file mode 100644
index 000000000..859010ccb
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/sparse_softmax_cross_entropy.cpp
@@ -0,0 +1,85 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/attr_map.h"
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/framework/op_builder.h"
+#include "oneflow/core/framework/op_expr.h"
+#include "oneflow/core/framework/op_expr_helper.h"
+#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
+
+namespace oneflow {
+namespace one {
+
+struct SparseSoftmaxCrossEntropyInterpState : public OpExprInterpState {
+  int64_t depth;
+};
+
+class SparseSoftmaxCrossEntropy : public OpExprGradFunction<SparseSoftmaxCrossEntropyInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override;
+  Maybe<void> Capture(SparseSoftmaxCrossEntropyInterpState* ctx, const TensorTuple& inputs,
+                      const TensorTuple& outputs, const AttrMap& attrs) const override;
+  Maybe<void> Apply(const SparseSoftmaxCrossEntropyInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override;
+
+ private:
+  AttrMap base_attrs_;
+  std::shared_ptr<OpExpr> grad_op_;
+};
+
+Maybe<void> SparseSoftmaxCrossEntropy::Init(const OpExpr& op) {
+  const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
+  CHECK_NOTNULL_OR_RETURN(fw_op_expr);
+  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
+  const std::string& op_name = fw_op_expr->op_name();
+  grad_op_ =
+      JUST(op_expr_helper::SparseSoftmaxCrossEntropyGradOp(/*depth=*/-1, GradientOpName(op_name)));
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> SparseSoftmaxCrossEntropy::Capture(SparseSoftmaxCrossEntropyInterpState* ctx,
+                                               const TensorTuple& inputs,
+                                               const TensorTuple& outputs,
+                                               const AttrMap& attrs) const {
+  ComposedAttrMap composed_attrs(attrs, base_attrs_);
+  ctx->depth = JUST(composed_attrs.GetAttr<int64_t>("depth"));
+  CHECK_EQ_OR_RETURN(inputs.size(), 2);
+  CHECK_EQ_OR_RETURN(outputs.size(), 2);
+  ctx->SaveTensorForBackward(outputs.at(0));  // prob
+  ctx->SaveTensorForBackward(inputs.at(1));   // label
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> SparseSoftmaxCrossEntropy::Apply(const SparseSoftmaxCrossEntropyInterpState* ctx,
+                                             const TensorTuple& out_grads,
+                                             TensorTuple* in_grads) const {
+  CHECK_EQ_OR_RETURN(out_grads.size(), 2);
+  const auto& dy = out_grads.at(1);
+  const auto& prob = ctx->SavedTensors().at(0);
+  const auto& label = ctx->SavedTensors().at(1);
+  MutableAttrMap attrs;
+  JUST(attrs.SetAttr<int64_t>("depth", ctx->depth));
+  // SparseSoftmaxCrossEntropy has 2 inputs (prediction and label), and the second input does not
+  // require gradient.
+  in_grads->resize(2);
+  in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {prob, label, dy}, attrs));
+  return Maybe<void>::Ok();
+}
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("sparse_softmax_cross_entropy", SparseSoftmaxCrossEntropy);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp
index b0b8ee8c7..eb0c228bb 100644
--- a/oneflow/core/framework/op_expr_helper.cpp
+++ b/oneflow/core/framework/op_expr_helper.cpp
@@ -245,6 +245,25 @@ Maybe<one::UserOpExpr> BroadcastDivOp(const std::string& name) {
   return one::OpBuilder("broadcast_div", name).Input("x").Input("y").Output("z").Build();
 }
 
+Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis) {
+  return BroadcastLikeOp(axis, UniqueOpName("broadcast_like"));
+}
+Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis, const std::string& name) {
+  return one::OpBuilder("broadcast_like", name)
+      .Input("x")
+      .Input("like")
+      .Output("y")
+      .Attr<std::vector<int32_t>>("broadcast_axes", axis)
+      .Build();
+}
+
+Maybe<one::UserOpExpr> BroadcastEqualOp() {
+  return BroadcastEqualOp(UniqueOpName("broadcast_equal"));
+}
+Maybe<one::UserOpExpr> BroadcastEqualOp(const std::string& name) {
+  return one::OpBuilder("broadcast_equal", name).Input("x").Input("y").Output("z").Build();
+}
+
 Maybe<one::UserOpExpr> CastOp(const DataType& to_type) {
   return CastOp(to_type, UniqueOpName("cast"));
 }
@@ -256,6 +275,11 @@ Maybe<one::UserOpExpr> CastOp(const DataType& to_type, const std::string& name)
       .Build();
 }
 
+Maybe<one::UserOpExpr> CastLikeOp() { return CastLikeOp(UniqueOpName("cast_like")); }
+Maybe<one::UserOpExpr> CastLikeOp(const std::string& name) {
+  return one::OpBuilder("cast_like", name).Input("in").Input("dtype_like").Output("out").Build();
+}
+
 Maybe<one::UserOpExpr> NormalizationGradOp(const int32_t& axis, const float& epsilon) {
   return NormalizationGradOp(axis, epsilon, UniqueOpName("normalization_grad"));
 }
@@ -456,5 +480,33 @@ Maybe<one::UserOpExpr> ConvNdFilterGradOp(const std::vector<int32_t>& kernel_siz
       .Build();
 }
 
+Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth) {
+  return SparseSoftmaxCrossEntropyGradOp(depth, UniqueOpName("sparse_softmax_cross_entropy"));
+}
+Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth,
+                                                       const std::string& name) {
+  return one::OpBuilder("sparse_softmax_cross_entropy_grad", name)
+      .Input("prob")
+      .Input("label")
+      .Input("dy")
+      .Output("prediction_diff")
+      .Attr<int64_t>("depth", depth)
+      .Build();
+}
+
+Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth) {
+  return SparseSoftmaxCrossEntropyMsGradOp(depth, UniqueOpName("sparse_softmax_cross_entropy_ms"));
+}
+Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth,
+                                                         const std::string& name) {
+  return one::OpBuilder("sparse_softmax_cross_entropy_ms_grad", name)
+      .Input("prob")
+      .Input("label")
+      .Input("dy")
+      .Output("prediction_diff")
+      .Attr<int64_t>("depth", depth)
+      .Build();
+}
+
 }  // namespace op_expr_helper
 }  // namespace oneflow
diff --git a/oneflow/core/framework/op_expr_helper.h b/oneflow/core/framework/op_expr_helper.h
index 6d2f7268c..c3b282799 100644
--- a/oneflow/core/framework/op_expr_helper.h
+++ b/oneflow/core/framework/op_expr_helper.h
@@ -85,9 +85,18 @@ Maybe<one::UserOpExpr> BroadcastMulOp(const std::string& name);
 Maybe<one::UserOpExpr> BroadcastDivOp();
 Maybe<one::UserOpExpr> BroadcastDivOp(const std::string& name);
 
+Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis);
+Maybe<one::UserOpExpr> BroadcastLikeOp(const std::vector<int32_t>& axis, const std::string& name);
+
+Maybe<one::UserOpExpr> BroadcastEqualOp();
+Maybe<one::UserOpExpr> BroadcastEqualOp(const std::string& name);
+
 Maybe<one::UserOpExpr> CastOp(const DataType& to_type);
 Maybe<one::UserOpExpr> CastOp(const DataType& to_type, const std::string& name);
 
+Maybe<one::UserOpExpr> CastLikeOp();
+Maybe<one::UserOpExpr> CastLikeOp(const std::string& name);
+
 Maybe<one::UserOpExpr> NormalizationGradOp(const int32_t& axis, const float& epsilon);
 Maybe<one::UserOpExpr> NormalizationGradOp(const int32_t& axis, const float& epsilon,
                                            const std::string& name);
@@ -152,5 +161,12 @@ Maybe<one::UserOpExpr> ConvNdFilterGradOp(const std::vector<int32_t>& kernel_siz
                                           const int& groups, const std::string& data_format,
                                           const std::string& name);
 
+Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth);
+Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyGradOp(const int64_t& depth,
+                                                       const std::string& name);
+Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth);
+Maybe<one::UserOpExpr> SparseSoftmaxCrossEntropyMsGradOp(const int64_t& depth,
+                                                         const std::string& name);
+
 }  // namespace op_expr_helper
 }  // namespace oneflow
diff --git a/oneflow/python/test/modules/test_crossentropyloss_grad.py b/oneflow/python/test/modules/test_crossentropyloss_grad.py
new file mode 100644
index 000000000..faa31a0c9
--- /dev/null
+++ b/oneflow/python/test/modules/test_crossentropyloss_grad.py
@@ -0,0 +1,153 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+import unittest
+import numpy as np
+
+
+def gen_random_input():
+    return np.array(
+        [
+            [1.1909, -1.5726, 0.9973, -0.7698, -1.1273],
+            [1.1354, -1.1815, -1.0553, -0.6178, -2.1103],
+        ]
+    )
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestCrossEntropyLossModuleGrad(flow.unittest.TestCase):
+    def test_CrossEntropyLoss_mean(test_case):
+        label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
+        predict = flow.Tensor(np.ones([2, 5]), requires_grad=True)
+
+        CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="mean")
+        loss = CrossEntropyLoss(predict, label)
+        loss.backward()
+        target = np.array(
+            [
+                [-0.4000, 0.1000, 0.1000, 0.1000, 0.1000],
+                [0.1000, -0.4000, 0.1000, 0.1000, 0.1000],
+            ]
+        )
+
+        test_case.assertTrue(predict.grad is not None)
+        test_case.assertTrue(
+            np.allclose(predict.grad.numpy(), target, rtol=1e-4, atol=1e-8)
+        )
+
+    def test_CrossEntropyLoss_sum(test_case):
+        label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
+        predict = flow.Tensor(np.ones([2, 5]), requires_grad=True)
+
+        CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="sum")
+        loss = CrossEntropyLoss(predict, label)
+        loss.backward()
+        target = np.array(
+            [
+                [-0.8000, 0.2000, 0.2000, 0.2000, 0.2000],
+                [0.2000, -0.8000, 0.2000, 0.2000, 0.2000],
+            ]
+        )
+
+        test_case.assertTrue(predict.grad is not None)
+        test_case.assertTrue(
+            np.allclose(predict.grad.numpy(), target, rtol=1e-4, atol=1e-8)
+        )
+
+    def test_CrossEntropyLoss_none(test_case):
+        label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
+        predict = flow.Tensor(np.ones([2, 5]), requires_grad=True)
+
+        CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="none")
+        loss = CrossEntropyLoss(predict, label)
+        grad = flow.Tensor(np.ones([2]))
+        loss.backward(grad)
+        target = np.array(
+            [
+                [-0.8000, 0.2000, 0.2000, 0.2000, 0.2000],
+                [0.2000, -0.8000, 0.2000, 0.2000, 0.2000],
+            ]
+        )
+
+        test_case.assertTrue(predict.grad is not None)
+        test_case.assertTrue(
+            np.allclose(predict.grad.numpy(), target, rtol=1e-4, atol=1e-8)
+        )
+
+    def test_CrossEntropyLoss_mean_with_random_input(test_case):
+        label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
+        predict = flow.Tensor(gen_random_input(), requires_grad=True)
+
+        CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="mean")
+        loss = CrossEntropyLoss(predict, label)
+        loss.backward()
+        target = np.array(
+            [
+                [-0.2648, 0.0148, 0.1938, 0.0331, 0.0232],
+                [0.3515, -0.4654, 0.0393, 0.0609, 0.0137],
+            ]
+        )
+
+        test_case.assertTrue(predict.grad is not None)
+        test_case.assertTrue(
+            np.allclose(predict.grad.numpy(), target, rtol=1e-2, atol=1e-8)
+        )
+
+    def test_CrossEntropyLoss_sum_with_random_input(test_case):
+        label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
+        predict = flow.Tensor(gen_random_input(), requires_grad=True)
+
+        CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="sum")
+        loss = CrossEntropyLoss(predict, label)
+        loss.backward()
+        target = np.array(
+            [
+                [-0.5297, 0.0297, 0.3875, 0.0662, 0.0463],
+                [0.7029, -0.9307, 0.0786, 0.1218, 0.0274],
+            ]
+        )
+
+        test_case.assertTrue(predict.grad is not None)
+        test_case.assertTrue(
+            np.allclose(predict.grad.numpy(), target, rtol=1e-2, atol=1e-8)
+        )
+
+    def test_CrossEntropyLoss_none_with_random_input(test_case):
+        label = flow.Tensor(np.array([0, 1]), dtype=flow.int32)
+        predict = flow.Tensor(gen_random_input(), requires_grad=True)
+
+        CrossEntropyLoss = flow.nn.CrossEntropyLoss(reduction="none")
+        loss = CrossEntropyLoss(predict, label)
+        grad = flow.Tensor(np.ones([2]))
+        loss.backward(grad)
+        target = np.array(
+            [
+                [-0.5297, 0.0297, 0.3875, 0.0662, 0.0463],
+                [0.7029, -0.9307, 0.0786, 0.1218, 0.0274],
+            ]
+        )
+
+        test_case.assertTrue(predict.grad is not None)
+        test_case.assertTrue(
+            np.allclose(predict.grad.numpy(), target, rtol=1e-2, atol=1e-8)
+        )
+
+
+if __name__ == "__main__":
+    unittest.main()
-- 
GitLab