From 6caa52f3c5ba11a1d67f183bac4c1559b2a58ef5 Mon Sep 17 00:00:00 2001
From: tangnana925 <85614052+tangnana925@users.noreply.github.com>
Date: Tue, 13 Jul 2021 13:06:44 +0800
Subject: [PATCH] init of op diag (#5298)

* init of op diag

* init of op diag

* modify  op diag

* merge testcase

* delete no use tensor diag

* amend diag api docs

* resolve conficts

* resolve confict

* resolve confict

* resolve confict

* auto format by CI

* add tensor.diag

* auto format by CI

* amend API description

* add non_square test

* amend diag API description

* amend diag API description

* compact test code

* compact code

* auto format by CI

* amend API description

* amend API description

* motify functional API

* motify test_diag

* amend diag functional API

* amend diag docstring

* auto format by CI

* amend code standards

* auto format by CI

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 docs/source/experimental.rst                  |   2 +
 oneflow/core/autograd/gradient_funcs/diag.cpp |  72 ++++++++
 oneflow/core/functional/functional_api.yaml   |   9 +
 .../core/functional/impl/array_functor.cpp    |  31 ++++
 oneflow/python/nn/modules/diag.py             |  86 ++++++++++
 oneflow/python/ops/diag_ops.py                |   6 +-
 oneflow/python/test/modules/test_diag.py      | 162 ++++++++++++++++++
 7 files changed, 366 insertions(+), 2 deletions(-)
 create mode 100644 oneflow/core/autograd/gradient_funcs/diag.cpp
 create mode 100644 oneflow/python/nn/modules/diag.py
 create mode 100644 oneflow/python/test/modules/test_diag.py

diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst
index 370c41f89..9cafbe695 100644
--- a/docs/source/experimental.rst
+++ b/docs/source/experimental.rst
@@ -233,6 +233,8 @@ Experimental features
 .. autofunction:: oneflow.experimental.meshgrid
 .. autofunction:: oneflow.experimental.topk
 .. autofunction:: oneflow.experimental.Tensor.topk
+.. autofunction:: oneflow.experimental.diag
+.. autofunction:: oneflow.experimental.Tensor.diag
 .. autofunction:: oneflow.experimental.nn.GroupNorm
 .. autofunction:: oneflow.experimental.nn.ZeroPad2d
 .. autofunction:: oneflow.experimental.tensor_buffer_to_tensor
diff --git a/oneflow/core/autograd/gradient_funcs/diag.cpp b/oneflow/core/autograd/gradient_funcs/diag.cpp
new file mode 100644
index 000000000..cfd0aee9d
--- /dev/null
+++ b/oneflow/core/autograd/gradient_funcs/diag.cpp
@@ -0,0 +1,72 @@
+/*
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+#include "oneflow/core/framework/attr_map.h"
+#include "oneflow/core/framework/op_expr_grad_function.h"
+#include "oneflow/core/functional/functional.h"
+
+namespace oneflow {
+namespace one {
+
+struct DiagInterpState : public OpExprInterpState {
+  bool requires_grad;
+  int32_t diagonal;
+};
+
+class Diag : public OpExprGradFunction<DiagInterpState> {
+ public:
+  Maybe<void> Init(const OpExpr& op) override;
+  Maybe<void> Capture(DiagInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
+                      const AttrMap& attrs) const override;
+  Maybe<void> Apply(const DiagInterpState* ctx, const TensorTuple& out_grads,
+                    TensorTuple* in_grads) const override;
+
+ private:
+  AttrMap base_attrs_;
+};
+
+Maybe<void> Diag::Init(const OpExpr& op) {
+  const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
+  CHECK_NOTNULL_OR_RETURN(fw_op_expr);
+  base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> Diag::Capture(DiagInterpState* ctx, const TensorTuple& inputs,
+                          const TensorTuple& outputs, const AttrMap& attrs) const {
+  CHECK_EQ_OR_RETURN(outputs.size(), 1);
+  ctx->requires_grad = inputs.at(0)->requires_grad();
+  if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
+  ComposedAttrMap composed_attrs(attrs, base_attrs_);
+  ctx->diagonal = JUST(composed_attrs.GetAttr<int32_t>("diagonal"));
+  ctx->SaveTensorForBackward(inputs.at(0));
+  return Maybe<void>::Ok();
+}
+
+Maybe<void> Diag::Apply(const DiagInterpState* ctx, const TensorTuple& out_grads,
+                        TensorTuple* in_grads) const {
+  CHECK_EQ_OR_RETURN(out_grads.size(), 1);
+  in_grads->resize(2);
+  if (ctx->requires_grad) {
+    const auto& x = ctx->SavedTensors().at(0);
+    in_grads->at(0) = JUST(functional::DiagGrad(out_grads.at(0), x, ctx->diagonal));
+  }
+  return Maybe<void>::Ok();
+}
+
+REGISTER_OP_EXPR_GRAD_FUNCTION("diag", Diag);
+
+}  // namespace one
+}  // namespace oneflow
diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml
index 1299e6fd6..eff0d32cb 100644
--- a/oneflow/core/functional/functional_api.yaml
+++ b/oneflow/core/functional/functional_api.yaml
@@ -594,6 +594,15 @@
   signature: "Tensor PadGrad(Tensor dy, *, Int64List pad, String mode=\"constant\", Scalar value=0)"
   bind_python: False
 
+- name: "diag"
+  signature: "Tensor Diag(Tensor x, *, Int32 diagonal=0)"
+  bind_python: True
+
+- name: "diag_grad"
+  signature: "Tensor DiagGrad(Tensor dy, Tensor in, *, Int32 diagonal=0)"
+  bind_python: False
+
 - name: "tensor_getitem"
   signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)"
   bind_python: True
+
diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp
index 6a9dc89fc..1653768c3 100644
--- a/oneflow/core/functional/impl/array_functor.cpp
+++ b/oneflow/core/functional/impl/array_functor.cpp
@@ -458,6 +458,35 @@ class TriuFunctor {
   std::shared_ptr<OpExpr> op_;
 };
 
+class DiagFunctor {
+ public:
+  DiagFunctor() { op_ = CHECK_JUST(one::OpBuilder("diag").Input("in").Output("out").Build()); }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& diagonal) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<int32_t>("diagonal", diagonal));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
+class DiagGradFunctor {
+ public:
+  DiagGradFunctor() {
+    op_ = CHECK_JUST(one::OpBuilder("diag_grad").Input("dy").Input("in").Output("dx").Build());
+  }
+  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
+                           const std::shared_ptr<one::Tensor>& x, const int32_t& diagonal) const {
+    MutableAttrMap attrs;
+    JUST(attrs.SetAttr<int32_t>("diagonal", diagonal));
+    return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
+  }
+
+ private:
+  std::shared_ptr<OpExpr> op_;
+};
+
 class TensorGetItemFunctor {
  public:
   TensorGetItemFunctor() {}
@@ -545,6 +574,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
   m.add_functor<impl::UpsampleFunctor>("Upsample");
   m.add_functor<impl::UnsortedSegmentSumLikeFunctor>("UnsortedSegmentSumLike");
   m.add_functor<impl::TriuFunctor>("Triu");
+  m.add_functor<impl::DiagFunctor>("Diag");
+  m.add_functor<impl::DiagGradFunctor>("DiagGrad");
   m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
 };
 
diff --git a/oneflow/python/nn/modules/diag.py b/oneflow/python/nn/modules/diag.py
new file mode 100644
index 000000000..49ed84fa5
--- /dev/null
+++ b/oneflow/python/nn/modules/diag.py
@@ -0,0 +1,86 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import oneflow as flow
+from oneflow.python.oneflow_export import oneflow_export, experimental_api
+from oneflow.python.nn.module import Module
+from oneflow.python.framework.tensor import register_tensor_op
+
+
+class Diag(Module):
+    def __init__(self, diagonal=0):
+        super().__init__()
+        self.diagonal = diagonal
+
+    def forward(self, input):
+        return flow.F.diag(input, self.diagonal)
+
+
+@oneflow_export("diag")
+@experimental_api
+def diag_op(input, diagonal=0):
+    r"""
+    If input is a vector (1-D tensor), then returns a 2-D square tensor with the elements of input as the diagonal.
+    If input is a matrix (2-D tensor), then returns a 1-D tensor with diagonal elements of input.
+
+    Args:
+        input (Tensor): the input tensor.
+        diagonal (Optional[int], 0): The diagonal to consider. 
+            If diagonal = 0, it is the main diagonal. If diagonal > 0, it is above the main diagonal. If diagonal < 0, it is below the main diagonal. Defaults to 0.
+    
+    Returns:
+        oneflow.Tensor: the output Tensor.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import oneflow.experimental as flow
+        >>> import numpy as np
+        >>> flow.enable_eager_execution()
+
+        >>> arr = np.array(
+        ...     [
+        ...        [1.0, 2.0, 3.0],
+        ...        [4.0, 5.0, 6.0],
+        ...        [7.0, 8.0, 9.0],
+        ...     ]
+        ... )
+
+        >>> input = flow.Tensor(arr, dtype=flow.float32)
+        >>> flow.diag(input)
+        tensor([1., 5., 9.], dtype=oneflow.float32)
+    """
+
+    return Diag(diagonal)(input)
+
+
+@register_tensor_op("diag")
+@experimental_api
+def diag_op_tensor(input, diagonal=0):
+    r"""
+    diag() -> Tensor
+    See :func:`oneflow.experimental.diag`
+    
+    """
+
+    return Diag(diagonal)(input)
+
+
+if __name__ == "__main__":
+    import doctest
+
+    doctest.testmod(raise_on_error=True)
diff --git a/oneflow/python/ops/diag_ops.py b/oneflow/python/ops/diag_ops.py
index 9e12f004e..2c75e748f 100644
--- a/oneflow/python/ops/diag_ops.py
+++ b/oneflow/python/ops/diag_ops.py
@@ -16,13 +16,14 @@ limitations under the License.
 
 import oneflow as flow
 import oneflow.python.framework.id_util as id_util
-from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.oneflow_export import oneflow_export, stable_api
 import oneflow.python.framework.remote_blob as remote_blob_util
 from typing import Optional
 import oneflow._oneflow_internal
 
 
 @oneflow_export("diag")
+@stable_api
 def diag(
     input: oneflow._oneflow_internal.BlobDesc,
     diagonal: Optional[int] = 0,
@@ -32,8 +33,9 @@ def diag(
 
     If input is a vector, then returns a square matrix with the elements of input as the diagonal.
     If input is a matrix, then returns a vector with the diagonal elements of input.
+    
     Args:
-        input (remote_blob_util.BlobDef): The input Blob.
+        input (remote_blob_util.BlobDef): The input Blob
         diagonal (Optional[int], 0): The diagonal to consider. If diagonal = 0, it is the main diagonal. If diagonal > 0, it is above the main diagonal. If diagonal < 0, it is below the main diagonal. Defaults to 0.
 
     Returns:
diff --git a/oneflow/python/test/modules/test_diag.py b/oneflow/python/test/modules/test_diag.py
new file mode 100644
index 000000000..c9288b40a
--- /dev/null
+++ b/oneflow/python/test/modules/test_diag.py
@@ -0,0 +1,162 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import unittest
+from collections import OrderedDict
+
+import numpy as np
+
+import oneflow.experimental as flow
+from test_util import GenArgList
+
+
+def _test_diag_forward(test_case, shape, diagonal, device):
+    input = flow.Tensor(np.random.randn(*shape), device=flow.device(device))
+    of_out = flow.diag(input, diagonal)
+    np_out = np.diag(input.numpy(), diagonal)
+    test_case.assertTrue(
+        np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
+    )
+
+    test_case.assertTrue(
+        np.allclose(
+            input.diag(diagonal=diagonal).numpy(), np_out, 1e-5, 1e-5, equal_nan=True
+        )
+    )
+
+
+def _test_diag_one_dim_backward(test_case, diagonal, device):
+    input = flow.Tensor(
+        np.random.randn(3), device=flow.device(device), requires_grad=True
+    )
+    of_out = flow.diag(input, diagonal).sum()
+    of_out.backward()
+    np_grad = np.ones(shape=3)
+    test_case.assertTrue(
+        np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
+    )
+
+    input = flow.Tensor(
+        np.random.randn(3), device=flow.device(device), requires_grad=True
+    )
+    of_out = input.diag(diagonal=diagonal).sum()
+    of_out.backward()
+    np_grad = np.ones(shape=3)
+    test_case.assertTrue(
+        np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
+    )
+
+
+def _test_diag_other_dim_backward(test_case, diagonal, device):
+    input = flow.Tensor(
+        np.random.randn(3, 3), device=flow.device(device), requires_grad=True
+    )
+    of_out = flow.diag(input, diagonal).sum()
+    of_out.backward()
+    if diagonal > 0:
+        np_grad = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0],])
+    elif diagonal < 0:
+        np_grad = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0],])
+    else:
+        np_grad = np.identity(3)
+    test_case.assertTrue(
+        np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
+    )
+
+    input = flow.Tensor(
+        np.random.randn(3, 3), device=flow.device(device), requires_grad=True
+    )
+    of_out = input.diag(diagonal=diagonal).sum()
+    of_out.backward()
+    if diagonal > 0:
+        np_grad = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0],])
+    elif diagonal < 0:
+        np_grad = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0],])
+    else:
+        np_grad = np.identity(3)
+    test_case.assertTrue(
+        np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
+    )
+
+
+def _test_diag_other_dim_non_square_backward(test_case, diagonal, device):
+    input = flow.Tensor(
+        np.random.randn(3, 4), device=flow.device(device), requires_grad=True
+    )
+    of_out = flow.diag(input, diagonal).sum()
+    of_out.backward()
+    if diagonal > 0:
+        np_tmp = np.zeros([3, 1])
+        np_grad = np.identity(3)
+        np_grad = np.hstack((np_tmp, np_grad))
+    elif diagonal < 0:
+        np_grad = np.array([[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0],])
+    else:
+        np_tmp = np.zeros([3, 1])
+        np_grad = np.identity(3)
+        np_grad = np.hstack((np_grad, np_tmp))
+    test_case.assertTrue(
+        np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
+    )
+
+    input = flow.Tensor(
+        np.random.randn(3, 4), device=flow.device(device), requires_grad=True
+    )
+    of_out = input.diag(diagonal=diagonal).sum()
+    of_out.backward()
+    if diagonal > 0:
+        np_tmp = np.zeros([3, 1])
+        np_grad = np.identity(3)
+        np_grad = np.hstack((np_tmp, np_grad))
+    elif diagonal < 0:
+        np_grad = np.array([[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0],])
+    else:
+        np_tmp = np.zeros([3, 1])
+        np_grad = np.identity(3)
+        np_grad = np.hstack((np_grad, np_tmp))
+    test_case.assertTrue(
+        np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
+    )
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestDiag(flow.unittest.TestCase):
+    def test_diag_forward(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["shape"] = [(3,), (3, 3), (3, 4)]
+        arg_dict["diagonal"] = [1, 0, -1]
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            _test_diag_forward(test_case, *arg[0:])
+
+    def test_diag_backward(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["test_fun"] = [
+            _test_diag_one_dim_backward,
+            _test_diag_other_dim_backward,
+            _test_diag_other_dim_non_square_backward,
+        ]
+        arg_dict["diagonal"] = [1, 0, -1]
+        arg_dict["device"] = ["cpu", "cuda"]
+        for arg in GenArgList(arg_dict):
+            arg[0](test_case, *arg[1:])
+
+
+if __name__ == "__main__":
+    unittest.main()
-- 
GitLab