Skip to content
Snippets Groups Projects
Unverified Commit 6caa52f3 authored by tangnana925's avatar tangnana925 Committed by GitHub
Browse files

init of op diag (#5298)


* init of op diag

* init of op diag

* modify  op diag

* merge testcase

* delete no use tensor diag

* amend diag api docs

* resolve conficts

* resolve confict

* resolve confict

* resolve confict

* auto format by CI

* add tensor.diag

* auto format by CI

* amend API description

* add non_square test

* amend diag API description

* amend diag API description

* compact test code

* compact code

* auto format by CI

* amend API description

* amend API description

* motify functional API

* motify test_diag

* amend diag functional API

* amend diag docstring

* auto format by CI

* amend code standards

* auto format by CI

Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent f7f80b4a
No related branches found
No related tags found
No related merge requests found
......@@ -233,6 +233,8 @@ Experimental features
.. autofunction:: oneflow.experimental.meshgrid
.. autofunction:: oneflow.experimental.topk
.. autofunction:: oneflow.experimental.Tensor.topk
.. autofunction:: oneflow.experimental.diag
.. autofunction:: oneflow.experimental.Tensor.diag
.. autofunction:: oneflow.experimental.nn.GroupNorm
.. autofunction:: oneflow.experimental.nn.ZeroPad2d
.. autofunction:: oneflow.experimental.tensor_buffer_to_tensor
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DiagInterpState : public OpExprInterpState {
bool requires_grad;
int32_t diagonal;
};
class Diag : public OpExprGradFunction<DiagInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DiagInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const DiagInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> Diag::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> Diag::Capture(DiagInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->diagonal = JUST(composed_attrs.GetAttr<int32_t>("diagonal"));
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> Diag::Apply(const DiagInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(2);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::DiagGrad(out_grads.at(0), x, ctx->diagonal));
}
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("diag", Diag);
} // namespace one
} // namespace oneflow
......@@ -594,6 +594,15 @@
signature: "Tensor PadGrad(Tensor dy, *, Int64List pad, String mode=\"constant\", Scalar value=0)"
bind_python: False
- name: "diag"
signature: "Tensor Diag(Tensor x, *, Int32 diagonal=0)"
bind_python: True
- name: "diag_grad"
signature: "Tensor DiagGrad(Tensor dy, Tensor in, *, Int32 diagonal=0)"
bind_python: False
- name: "tensor_getitem"
signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)"
bind_python: True
......@@ -458,6 +458,35 @@ class TriuFunctor {
std::shared_ptr<OpExpr> op_;
};
class DiagFunctor {
public:
DiagFunctor() { op_ = CHECK_JUST(one::OpBuilder("diag").Input("in").Output("out").Build()); }
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const int32_t& diagonal) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("diagonal", diagonal));
return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class DiagGradFunctor {
public:
DiagGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("diag_grad").Input("dy").Input("in").Output("dx").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& x, const int32_t& diagonal) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("diagonal", diagonal));
return OpInterpUtil::Dispatch<Tensor>(*op_, {dy, x}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class TensorGetItemFunctor {
public:
TensorGetItemFunctor() {}
......@@ -545,6 +574,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::UpsampleFunctor>("Upsample");
m.add_functor<impl::UnsortedSegmentSumLikeFunctor>("UnsortedSegmentSumLike");
m.add_functor<impl::TriuFunctor>("Triu");
m.add_functor<impl::DiagFunctor>("Diag");
m.add_functor<impl::DiagGradFunctor>("DiagGrad");
m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
};
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.oneflow_export import oneflow_export, experimental_api
from oneflow.python.nn.module import Module
from oneflow.python.framework.tensor import register_tensor_op
class Diag(Module):
def __init__(self, diagonal=0):
super().__init__()
self.diagonal = diagonal
def forward(self, input):
return flow.F.diag(input, self.diagonal)
@oneflow_export("diag")
@experimental_api
def diag_op(input, diagonal=0):
r"""
If input is a vector (1-D tensor), then returns a 2-D square tensor with the elements of input as the diagonal.
If input is a matrix (2-D tensor), then returns a 1-D tensor with diagonal elements of input.
Args:
input (Tensor): the input tensor.
diagonal (Optional[int], 0): The diagonal to consider.
If diagonal = 0, it is the main diagonal. If diagonal > 0, it is above the main diagonal. If diagonal < 0, it is below the main diagonal. Defaults to 0.
Returns:
oneflow.Tensor: the output Tensor.
For example:
.. code-block:: python
>>> import oneflow.experimental as flow
>>> import numpy as np
>>> flow.enable_eager_execution()
>>> arr = np.array(
... [
... [1.0, 2.0, 3.0],
... [4.0, 5.0, 6.0],
... [7.0, 8.0, 9.0],
... ]
... )
>>> input = flow.Tensor(arr, dtype=flow.float32)
>>> flow.diag(input)
tensor([1., 5., 9.], dtype=oneflow.float32)
"""
return Diag(diagonal)(input)
@register_tensor_op("diag")
@experimental_api
def diag_op_tensor(input, diagonal=0):
r"""
diag() -> Tensor
See :func:`oneflow.experimental.diag`
"""
return Diag(diagonal)(input)
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
......@@ -16,13 +16,14 @@ limitations under the License.
import oneflow as flow
import oneflow.python.framework.id_util as id_util
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.oneflow_export import oneflow_export, stable_api
import oneflow.python.framework.remote_blob as remote_blob_util
from typing import Optional
import oneflow._oneflow_internal
@oneflow_export("diag")
@stable_api
def diag(
input: oneflow._oneflow_internal.BlobDesc,
diagonal: Optional[int] = 0,
......@@ -32,8 +33,9 @@ def diag(
If input is a vector, then returns a square matrix with the elements of input as the diagonal.
If input is a matrix, then returns a vector with the diagonal elements of input.
Args:
input (remote_blob_util.BlobDef): The input Blob.
input (remote_blob_util.BlobDef): The input Blob
diagonal (Optional[int], 0): The diagonal to consider. If diagonal = 0, it is the main diagonal. If diagonal > 0, it is above the main diagonal. If diagonal < 0, it is below the main diagonal. Defaults to 0.
Returns:
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList
def _test_diag_forward(test_case, shape, diagonal, device):
input = flow.Tensor(np.random.randn(*shape), device=flow.device(device))
of_out = flow.diag(input, diagonal)
np_out = np.diag(input.numpy(), diagonal)
test_case.assertTrue(
np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
)
test_case.assertTrue(
np.allclose(
input.diag(diagonal=diagonal).numpy(), np_out, 1e-5, 1e-5, equal_nan=True
)
)
def _test_diag_one_dim_backward(test_case, diagonal, device):
input = flow.Tensor(
np.random.randn(3), device=flow.device(device), requires_grad=True
)
of_out = flow.diag(input, diagonal).sum()
of_out.backward()
np_grad = np.ones(shape=3)
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
)
input = flow.Tensor(
np.random.randn(3), device=flow.device(device), requires_grad=True
)
of_out = input.diag(diagonal=diagonal).sum()
of_out.backward()
np_grad = np.ones(shape=3)
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
)
def _test_diag_other_dim_backward(test_case, diagonal, device):
input = flow.Tensor(
np.random.randn(3, 3), device=flow.device(device), requires_grad=True
)
of_out = flow.diag(input, diagonal).sum()
of_out.backward()
if diagonal > 0:
np_grad = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0],])
elif diagonal < 0:
np_grad = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0],])
else:
np_grad = np.identity(3)
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
)
input = flow.Tensor(
np.random.randn(3, 3), device=flow.device(device), requires_grad=True
)
of_out = input.diag(diagonal=diagonal).sum()
of_out.backward()
if diagonal > 0:
np_grad = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 0],])
elif diagonal < 0:
np_grad = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0],])
else:
np_grad = np.identity(3)
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
)
def _test_diag_other_dim_non_square_backward(test_case, diagonal, device):
input = flow.Tensor(
np.random.randn(3, 4), device=flow.device(device), requires_grad=True
)
of_out = flow.diag(input, diagonal).sum()
of_out.backward()
if diagonal > 0:
np_tmp = np.zeros([3, 1])
np_grad = np.identity(3)
np_grad = np.hstack((np_tmp, np_grad))
elif diagonal < 0:
np_grad = np.array([[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0],])
else:
np_tmp = np.zeros([3, 1])
np_grad = np.identity(3)
np_grad = np.hstack((np_grad, np_tmp))
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
)
input = flow.Tensor(
np.random.randn(3, 4), device=flow.device(device), requires_grad=True
)
of_out = input.diag(diagonal=diagonal).sum()
of_out.backward()
if diagonal > 0:
np_tmp = np.zeros([3, 1])
np_grad = np.identity(3)
np_grad = np.hstack((np_tmp, np_grad))
elif diagonal < 0:
np_grad = np.array([[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0],])
else:
np_tmp = np.zeros([3, 1])
np_grad = np.identity(3)
np_grad = np.hstack((np_grad, np_tmp))
test_case.assertTrue(
np.allclose(input.grad.numpy(), np_grad, 1e-5, 1e-5, equal_nan=True)
)
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestDiag(flow.unittest.TestCase):
def test_diag_forward(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(3,), (3, 3), (3, 4)]
arg_dict["diagonal"] = [1, 0, -1]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
_test_diag_forward(test_case, *arg[0:])
def test_diag_backward(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_diag_one_dim_backward,
_test_diag_other_dim_backward,
_test_diag_other_dim_non_square_backward,
]
arg_dict["diagonal"] = [1, 0, -1]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment