From 498e00b6e3da13ef149486458821cc99f6c896b0 Mon Sep 17 00:00:00 2001
From: Yao Chi <later@usopp.net>
Date: Sun, 11 Jul 2021 22:00:36 -0500
Subject: [PATCH] Dev add docstring (#5449)

* startup: use python, not C extension

* rst changes test

* refine

* add F.rst

* add bernoulli add docstr example

* add doctest for functional method

* del docstr utils immediately in init.py

* refine test case

* use sin, cos instead of add, mul

* refine

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
---
 docs/source/F.rst                             |  8 ++
 docs/source/index.rst                         |  1 +
 oneflow/init.py                               |  7 ++
 oneflow/python/framework/docstr/__init__.py   | 17 ++++
 oneflow/python/framework/docstr/math_ops.py   | 80 +++++++++++++++++++
 oneflow/python/framework/docstr/random.py     | 58 ++++++++++++++
 oneflow/python/framework/docstr/utils.py      | 26 ++++++
 .../test/modules/test_functional_docstr.py    | 70 ++++++++++++++++
 8 files changed, 267 insertions(+)
 create mode 100644 docs/source/F.rst
 create mode 100644 oneflow/python/framework/docstr/__init__.py
 create mode 100644 oneflow/python/framework/docstr/math_ops.py
 create mode 100644 oneflow/python/framework/docstr/random.py
 create mode 100644 oneflow/python/framework/docstr/utils.py
 create mode 100644 oneflow/python/test/modules/test_functional_docstr.py

diff --git a/docs/source/F.rst b/docs/source/F.rst
new file mode 100644
index 000000000..b271c03f2
--- /dev/null
+++ b/docs/source/F.rst
@@ -0,0 +1,8 @@
+oneflow.F
+===================================
+Functional functions
+----------------------------------
+.. currentmodule:: oneflow.F
+.. autofunction:: oneflow.F.bernoulli
+.. autofunction:: oneflow.F.cos
+.. autofunction:: oneflow.F.sin
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 6014d139b..f365979ea 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -26,6 +26,7 @@ OneFlow API Reference
     tensorrt 
     deprecated  
     experimental 
+    F
     scope 
     sysconfig    
     random 
diff --git a/oneflow/init.py b/oneflow/init.py
index 321b2942e..ffc0d3800 100644
--- a/oneflow/init.py
+++ b/oneflow/init.py
@@ -128,3 +128,10 @@ if not oneflow._oneflow_internal.IsMultiClient():
 
 del absolute_import
 del oneflow
+
+import oneflow.python.framework.docstr as docstr
+from oneflow.python.framework.docstr.utils import register_docstr
+
+register_docstr()
+del register_docstr
+del docstr
diff --git a/oneflow/python/framework/docstr/__init__.py b/oneflow/python/framework/docstr/__init__.py
new file mode 100644
index 000000000..2b00d5567
--- /dev/null
+++ b/oneflow/python/framework/docstr/__init__.py
@@ -0,0 +1,17 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+from .math_ops import *
+from .random import *
diff --git a/oneflow/python/framework/docstr/math_ops.py b/oneflow/python/framework/docstr/math_ops.py
new file mode 100644
index 000000000..42fc5f47a
--- /dev/null
+++ b/oneflow/python/framework/docstr/math_ops.py
@@ -0,0 +1,80 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow
+from oneflow.python.framework.docstr.utils import add_docstr
+
+add_docstr(
+    oneflow.F.sin,
+    r"""
+    sin(x: Tensor) -> Tensor
+
+    Returns a new tensor with the sine of the elements of :attr:`input`.
+
+    .. math::
+
+        \text{y}_{i} = \sin(\text{x}_{i})
+
+    Args:
+        x (Tensor): the input tensor.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import oneflow.experimental as flow
+        >>> import numpy as np
+        >>> flow.enable_eager_execution()
+        >>> x1 = flow.Tensor(np.array([-0.5461,  0.1347, -2.7266, -0.2746]).astype(np.float32))
+        >>> y1 = flow.F.sin(x1)
+        >>> y1
+        tensor([-0.5194,  0.1343, -0.4032, -0.2712], dtype=oneflow.float32)
+        >>> x2 = flow.Tensor(np.array([-1.4, 2.6, 3.7]).astype(np.float32),device=flow.device('cuda'))
+        >>> y2 = flow.F.sin(x2)
+        >>> y2
+        tensor([-0.9854,  0.5155, -0.5298], device='cuda:0', dtype=oneflow.float32)
+
+
+""",
+)
+
+add_docstr(
+    oneflow.F.cos,
+    r"""
+    cos(x: Tensor) -> Tensor
+
+    Returns a new tensor with the cosine  of the elements of :attr:`input`.
+    
+    .. math::
+        \text{y}_{i} = \cos(\text{x}_{i})
+
+    Args:
+        x (Tensor): the input tensor.
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import oneflow.experimental as flow
+        >>> import numpy as np
+        >>> flow.enable_eager_execution()
+        >>> x = np.array([1.4309,  1.2706, -0.8562,  0.9796])
+        >>> x = flow.Tensor(x, dtype=flow.float32)
+        >>> y = flow.F.cos(x)
+        >>> y
+        tensor([0.1394, 0.2957, 0.6553, 0.5574], dtype=oneflow.float32)
+
+""",
+)
diff --git a/oneflow/python/framework/docstr/random.py b/oneflow/python/framework/docstr/random.py
new file mode 100644
index 000000000..17bfca5f3
--- /dev/null
+++ b/oneflow/python/framework/docstr/random.py
@@ -0,0 +1,58 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow
+from oneflow.python.framework.docstr.utils import add_docstr
+
+add_docstr(
+    oneflow.F.bernoulli,
+    r"""
+    bernoulli(input, *, generator=None, out=None)
+    
+    This operator returns a Tensor with binaray random numbers (0 / 1) from a Bernoulli distribution.
+
+    Args:
+        input (Tensor): the input tensor of probability values for the Bernoulli distribution
+        generator: (Generator, optional) a pseudorandom number generator for sampling
+        out (Tensor, optional): the output tensor.
+
+    Shape:
+        - Input: :math:`(*)`. Input can be of any shape
+        - Output: :math:`(*)`. Output is of the same shape as input
+
+    For example:
+
+    .. code-block:: python
+
+        >>> import numpy as np
+        >>> import oneflow.experimental as flow
+        >>> flow.enable_eager_execution()
+
+        >>> arr = np.array(
+        ...    [
+        ...        [1.0, 1.0, 1.0],
+        ...        [1.0, 1.0, 1.0],
+        ...        [1.0, 1.0, 1.0],
+        ...    ]
+        ... )
+        >>> x = flow.Tensor(arr)
+        >>> y = flow.F.bernoulli(x)
+        >>> y
+        tensor([[1., 1., 1.],
+                [1., 1., 1.],
+                [1., 1., 1.]], dtype=oneflow.float32)
+
+    """,
+)
diff --git a/oneflow/python/framework/docstr/utils.py b/oneflow/python/framework/docstr/utils.py
new file mode 100644
index 000000000..7e5d335c8
--- /dev/null
+++ b/oneflow/python/framework/docstr/utils.py
@@ -0,0 +1,26 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+_function_docstr = {}
+
+
+def add_docstr(fun, docstr: str):
+    _function_docstr[fun] = docstr
+
+
+def register_docstr():
+    for fun, docstr in _function_docstr.items():
+        setattr(fun, "__doc__", docstr)
diff --git a/oneflow/python/test/modules/test_functional_docstr.py b/oneflow/python/test/modules/test_functional_docstr.py
new file mode 100644
index 000000000..2fc9023b4
--- /dev/null
+++ b/oneflow/python/test/modules/test_functional_docstr.py
@@ -0,0 +1,70 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+import inspect
+from collections import OrderedDict
+
+import oneflow as flow
+from oneflow.python.framework.functional import Function
+from test_util import GenArgList
+
+
+def _is_oneflow_functional(object):
+    return isinstance(object, Function)
+
+
+def _run_functional_doctest(
+    test_case,
+    globs=None,
+    verbose=None,
+    optionflags=0,
+    raise_on_error=True,
+    module=flow.F,
+):
+    import doctest
+
+    parser = doctest.DocTestParser()
+    if raise_on_error:
+        runner = doctest.DebugRunner(verbose=verbose, optionflags=optionflags)
+    else:
+        runner = doctest.DocTestRunner(verbose=verbose, optionflags=optionflags)
+
+    r = inspect.getmembers(flow.F, _is_oneflow_functional)
+    for name, fun in r:
+        if (
+            fun.__doc__ is not None
+        ):  # TODO(yaochi) None value of __doc__ will not be allowed
+            print("test on docstr of: ", ".".join([module.__name__, name]))
+            test = parser.get_doctest(fun.__doc__, {}, __name__, __file__, 0)
+            runner.run(test)
+
+
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestFunctionalDocstrModule(flow.unittest.TestCase):
+    def test_functional_docstr(test_case):
+        arg_dict = OrderedDict()
+        arg_dict["module"] = [flow.F, flow.experimental.F]
+        for arg in GenArgList(arg_dict):
+            _run_functional_doctest(
+                test_case, raise_on_error=True, verbose=None, module=arg[0]
+            )
+
+
+if __name__ == "__main__":
+    unittest.main()
-- 
GitLab