diff --git a/oneflow/python/nn/modules/dropout.py b/oneflow/python/nn/modules/dropout.py
new file mode 100644
index 0000000000000000000000000000000000000000..7be0815a3ac1d85411b8da4703df1f2fdcc16d61
--- /dev/null
+++ b/oneflow/python/nn/modules/dropout.py
@@ -0,0 +1,121 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import sys
+import random
+import oneflow as flow
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export
+import oneflow.python.framework.id_util as id_util
+
+
+class _DropoutNd(Module):
+ __constants__ = ["p", "inplace"]
+ p: float
+ inplace: bool
+
+ def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
+ super(_DropoutNd, self).__init__()
+ assert inplace is False, "Not support inplace=True yet!"
+ if p < 0 or p > 1:
+ raise ValueError(
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
+ )
+ self.p = p
+ self.inplace = inplace
+
+ def extra_repr(self) -> str:
+ return "p={}, inplace={}".format(self.p, self.inplace)
+
+
+@oneflow_export("nn.Dropout")
+class Dropout(_DropoutNd):
+ r"""During training, randomly zeroes some of the elements of the input
+ tensor with probability :attr:`p` using samples from a Bernoulli
+ distribution. Each channel will be zeroed out independently on every forward
+ call.
+
+ This has proven to be an effective technique for regularization and
+ preventing the co-adaptation of neurons as described in the paper
+ `Improving neural networks by preventing co-adaptation of feature
+ detectors`_ .
+
+ Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
+ training. This means that during evaluation the module simply computes an
+ identity function.
+
+ Args:
+ p: probability of an element to be zeroed. Default: 0.5
+ inplace: If set to ``True``, will do this operation in-place. Default: ``False``
+
+ Shape:
+ - Input: :math:`(*)`. Input can be of any shape
+ - Output: :math:`(*)`. Output is of the same shape as input
+
+ For example:
+
+ .. code-block:: python
+
+ import numpy as np
+ import oneflow as flow
+
+ m = flow.nn.Dropout(p=0.5)
+ arr = np.array(
+ [
+ [-0.7797, 0.2264, 0.2458, 0.4163],
+ [0.4299, 0.3626, -0.4892, 0.4141],
+ [-1.4115, 1.2183, -0.5503, 0.6520],
+ ]
+ )
+ x = flow.Tensor(arr)
+ y = m(x)
+ # likely output:
+ # [[-0. 0. 0.4916 0.8326]
+ # [ 0.8598 0. -0. 0.8282]
+ # [-2.823 2.4366 -0. 1.304 ]]
+
+ """
+
+ def __init__(self, p: float = 0.5, inplace: bool = False):
+ _DropoutNd.__init__(self, p, inplace)
+
+ if self.p == 1.0:
+ scale = 1
+ else:
+ scale = float(1.0 / (1.0 - self.p))
+
+ seed = random.randint(-sys.maxsize, sys.maxsize)
+ self._op = (
+ flow.builtin_op("dropout")
+ .Input("in")
+ .Input("mask")
+ .Output("out")
+ .Attr("scale", scale)
+ .Build()
+ )
+ self._mask_op = (
+ flow.builtin_op("random_mask_like")
+ .Input("like")
+ .Output("out")
+ .Attr("rate", self.p)
+ .Attr("seed", seed)
+ .Build()
+ )
+
+ def forward(self, x):
+ if self.p == 0.0:
+ return x
+ mask = self._mask_op(x)[0]
+ return self._op(x, mask)[0]
diff --git a/oneflow/python/nn/modules/linear.py b/oneflow/python/nn/modules/linear.py
index a779cb2e7cc969e20b22b2d76b00e89ea8cd8ddd..4877b6ff56173ef1c716fd6bfcab85d3c61ab70f 100644
--- a/oneflow/python/nn/modules/linear.py
+++ b/oneflow/python/nn/modules/linear.py
@@ -26,6 +26,10 @@ import math
class Identity(Module):
"""A placeholder identity operator that is argument-insensitive.
+ Args:
+ args: any argument (unused)
+ kwargs: any keyword argument (unused)
+
For example:
.. code-block:: python
@@ -42,7 +46,7 @@ class Identity(Module):
"""
- def __init__(self):
+ def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, input: Tensor) -> Tensor:
diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py
index 34cbdee759a445c511dc97a3771b2d641bf057ac..68c027d796f92daf0a182c917f9b38cfb7b8a949 100644
--- a/oneflow/python/nn/modules/math_ops.py
+++ b/oneflow/python/nn/modules/math_ops.py
@@ -631,3 +631,161 @@ def log_op(tensor):
"""
return Log()(tensor)
+
+
+class Subtract(Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, x, y):
+ if isinstance(x, (int, float)):
+ return ScalarAdd(x)(-1 * y)
+ elif isinstance(y, (int, float)):
+ return ScalarAdd(-1 * y)(x)
+ elif x.shape == y.shape:
+ # TODO: add element-wise op
+ return BroadcastSub()(x, y)
+ elif x.shape == (1,):
+ return ScalarSubByTensor()(y, x)
+ elif y.shape == (1,):
+ return ScalarSubByTensor()(x, y)
+ else:
+ return BroadcastSub()(x, y)
+
+
+class Sqrt(Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.sqrt_op = flow.builtin_op("sqrt").Input("x").Output("y").Build()
+
+ def forward(self, input):
+ return self.sqrt_op(input)[0]
+
+
+@oneflow_export("sqrt")
+@register_tensor_op("sqrt")
+def sqrt_op(input):
+ r"""Returns a new tensor with the square-root of the elements of :attr:`input`.
+
+ .. math::
+ \text{out}_{i} = \sqrt{\text{input}_{i}}
+
+ Args:
+ input (Tensor) – the input tensor.
+
+ For example:
+
+ .. code-block:: python
+
+ import oneflow as flow
+ import numpy as np
+
+ arr = np.random.randn(3, 2, 5, 7)
+ input = flow.Tensor(arr)
+ output = flow.sqrt(input)
+ # output equal to np.sqrt(arr)
+ """
+ return Sqrt()(input)
+
+
+class Square(Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.square_op = flow.builtin_op("square").Input("x").Output("y").Build()
+
+ def forward(self, input):
+ return self.square_op(input)[0]
+
+
+@oneflow_export("square")
+@register_tensor_op("square")
+def square_op(input):
+ r"""Returns a new tensor with the square of the elements of :attr:`input`.
+
+ .. math::
+ \text{out}_{i} = \sqrt{\text{input}_{i}}
+
+ Args:
+ input (Tensor) – the input tensor.
+
+ For example:
+
+ .. code-block:: python
+
+ import oneflow as flow
+ import numpy as np
+
+ arr = np.random.randn(3, 2, 5, 7)
+ input = flow.Tensor(arr)
+ output = flow.square(input)
+ # output equal to np.square(arr)
+ """
+ return Square()(input)
+
+
+class Std(Module):
+ def __init__(self, dim=None, unbiased=True, keepdim=False) -> None:
+ super().__init__()
+ assert unbiased == True, "Only support 'unbiased=True' for now!"
+ self.unbiased = unbiased
+ self.keepdim = keepdim
+ self.dim = dim
+ self.reduce_count = 1
+ self.square_op = Square()
+ self.sqrt_op = Sqrt()
+ self.subtract_op = Subtract()
+
+ def forward(self, x):
+ self.axis = _check_axis(self.dim, x.shape)
+ if isinstance(self.axis, list) and len(self.axis) == 0:
+ return flow.tmp.zeros(size=x.shape)
+ else:
+ if len(self.axis) == 0:
+ self.reduce_count = x.nelemenet()
+ else:
+ for i in self.axis:
+ self.reduce_count *= x.shape[i]
+
+ sum = Sum(self.axis, self.keepdim)(self.square_op(x)) / self.reduce_count
+ square = self.square_op(Sum(self.axis, self.keepdim)(x) / self.reduce_count)
+ subtract = self.subtract_op(sum, square)
+ res = self.sqrt_op(subtract)
+ return res
+
+
+@oneflow_export("tmp.std")
+@register_tensor_op("std")
+def std_op(tensor, dim, unbiased=True, keepdim=False):
+ r"""
+ Returns the standard-deviation of each row of the :attr:`input` tensor in the
+ dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
+ reduce over all of them.
+
+ If keepdim is True, the output tensor is of the same size as input except in
+ the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed,
+ resulting in the output tensor having 1 (or len(dim)) fewer dimension(s).
+
+ If :attr:`unbiased` is ``False``, then the standard-deviation will be calculated
+ via the biased estimator. Otherwise, Bessel's correction will be used.
+
+ Args:
+ input (Tensor) – the input tensor.
+ dim (int or tuple of python:ints) – the dimension or dimensions to reduce.
+ unbiased (bool) – whether to use the unbiased estimation or not
+ keepdim (bool) – whether the output tensor has `dim` retained or not.
+
+ For example:
+
+ .. code-block:: python
+
+ import oneflow as flow
+ import numpy as np
+
+ arr = np.random.randn(2, 3, 4, 5)
+ input = flow.Tensor(arr)
+ output = flow.std(input, dim=2)
+
+ # equal to numpy np.std(arr, axis=2)
+
+ """
+ return Std(dim, unbiased, keepdim)(tensor)
diff --git a/oneflow/python/nn/modules/normalization.py b/oneflow/python/nn/modules/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca91c51e2f93ee923c853bc3ff2d8bdcbc73ed2e
--- /dev/null
+++ b/oneflow/python/nn/modules/normalization.py
@@ -0,0 +1,165 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import oneflow as flow
+from oneflow.python.nn import init
+from oneflow.python.nn.module import Module
+from oneflow.python.oneflow_export import oneflow_export
+from oneflow.python.framework.tensor import Tensor
+from typing import Tuple, Union
+
+_shape_t = Union[int, Tuple[int], flow.Size]
+
+
+@oneflow_export("nn.LayerNorm")
+class LayerNorm(Module):
+ r"""Applies Layer Normalization over a mini-batch of inputs as described in
+ the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
+
+ .. math::
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+ The mean and standard-deviation are calculated separately over the last
+ certain number dimensions which have to be of the shape specified by
+ :attr:`normalized_shape`.
+ :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
+ :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
+ The standard-deviation is calculated via the biased estimator, equivalent to
+ `torch.var(input, unbiased=False)`.
+
+ .. note::
+ Unlike Batch Normalization and Instance Normalization, which applies
+ scalar scale and bias for each entire channel/plane with the
+ :attr:`affine` option, Layer Normalization applies per-element scale and
+ bias with :attr:`elementwise_affine`.
+ This layer uses statistics computed from input data in both training and
+ evaluation modes.
+
+ Args:
+ normalized_shape (int or list or flow.Size): input shape from an expected input
+ of size
+
+ .. math::
+ [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
+ \times \ldots \times \text{normalized\_shape}[-1]]
+ If a single integer is used, it is treated as a singleton list, and this module will
+ normalize over the last dimension which is expected to be of that specific size.
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
+ elementwise_affine: a boolean value that when set to ``True``, this module
+ has learnable per-element affine parameters initialized to ones (for weights)
+ and zeros (for biases). Default: ``True``.
+ Shape:
+ - Input: :math:`(N, *)`
+ - Output: :math:`(N, *)` (same shape as input)
+
+ For example:
+
+ .. code-block:: python
+
+ import numpy as np
+ import oneflow as flow
+
+ input_arr = np.array(
+ [
+ [
+ [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]],
+ [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]],
+ ],
+ [
+ [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]],
+ [[1.07541728, 0.11008703], [0.26361224, -0.48663723]],
+ ],
+ ],
+ dtype=np.float32,
+ )
+
+ x = flow.Tensor(input_arr)
+ m = flow.nn.LayerNorm(2)
+ y = m(x)
+
+ # [[[[ 0.99997395 -0.99997395]
+ # [-0.999947 0.999947 ]]
+
+ # [[-0.99995947 0.9999595 ]
+ # [ 0.99998796 -0.99998796]]]
+
+ # [[[-0.9998348 0.99983454]
+ # [ 0.9999913 -0.9999913 ]]
+
+ # [[ 0.99997866 -0.99997854]
+ # [ 0.9999645 -0.9999645 ]]]]
+
+ """
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
+ normalized_shape: Tuple[int, ...]
+ eps: float
+ elementwise_affine: bool
+
+ def __init__(
+ self,
+ normalized_shape: _shape_t,
+ eps: float = 1e-5,
+ elementwise_affine: bool = True,
+ ) -> None:
+ super(LayerNorm, self).__init__()
+ if isinstance(normalized_shape, int):
+ # mypy error: incompatible types in assignment
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
+
+ self.epsilon = eps
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = flow.nn.Parameter(flow.Tensor(*self.normalized_shape))
+ self.bias = flow.nn.Parameter(flow.Tensor(*self.normalized_shape))
+ else:
+ self.register_parameter("weight", None)
+ self.register_parameter("bias", None)
+ self.reset_parameters()
+ # An integer specifies which axis to normalize at first, defaults to 1.
+ self.begin_norm_axis = 1
+ # An integer specifies which axis params at, defaults to -1 in 'NCHW' format,-2 in 'NHWC' format
+ self.begin_params_axis = -1
+
+ self._op = (
+ flow.builtin_op("layer_norm")
+ .Input("x")
+ .Output("y")
+ .Output("mean")
+ .Output("inv_variance")
+ .Attr("center", False)
+ .Attr("scale", False)
+ .Attr("begin_params_axis", self.begin_params_axis)
+ .Attr("epsilon", self.epsilon)
+ .Build()
+ )
+
+ def reset_parameters(self) -> None:
+ if self.elementwise_affine:
+ init.ones_(self.weight)
+ init.zeros_(self.bias)
+
+ def forward(self, x):
+ assert len(x.shape) > len(
+ self.normalized_shape
+ ), "Input tensor dim must greater than normalized dim!"
+ self.begin_norm_axis = len(x.shape) - len(self.normalized_shape)
+ res = self._op(x, begin_norm_axis=self.begin_norm_axis)[0]
+ return res
+
+ def extra_repr(self) -> str:
+ return (
+ "{normalized_shape}, eps={eps}, "
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
+ )
diff --git a/oneflow/python/test/modules/test_dropout.py b/oneflow/python/test/modules/test_dropout.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c3e73b088fcb1161b54956d86055c18c78ebb98
--- /dev/null
+++ b/oneflow/python/test/modules/test_dropout.py
@@ -0,0 +1,70 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+
+import numpy as np
+import oneflow as flow
+
+
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestDropout(flow.unittest.TestCase):
+ def test_dropout(test_case):
+ input_arr = np.array(
+ [
+ [-0.7797, 0.2264, 0.2458, 0.4163],
+ [0.4299, 0.3626, -0.4892, 0.4141],
+ [-1.4115, 1.2183, -0.5503, 0.6520],
+ ]
+ )
+ m = flow.nn.Dropout(p=0)
+ x = flow.Tensor(input_arr)
+ y = m(x)
+ test_case.assertTrue(np.allclose(y.numpy(), input_arr))
+
+ def test_dropout_special_case(test_case):
+ input_arr = np.array(
+ [
+ [-0.7797, 0.2264, 0.2458, 0.4163],
+ [0.4299, 0.3626, -0.4892, 0.4141],
+ [-1.4115, 1.2183, -0.5503, 0.6520],
+ ]
+ )
+ output = np.array(
+ [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],]
+ )
+ m = flow.nn.Dropout(p=1.0)
+ x = flow.Tensor(input_arr)
+ y = m(x)
+ test_case.assertTrue(np.allclose(y.numpy(), output))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/oneflow/python/test/modules/test_linear.py b/oneflow/python/test/modules/test_linear.py
index 2c61366cfaa0b19a652fc2b8179bb3c4cc3c5f9e..1c2ca69b077722b978711bd6848fcf235d81bbe9 100644
--- a/oneflow/python/test/modules/test_linear.py
+++ b/oneflow/python/test/modules/test_linear.py
@@ -81,5 +81,17 @@ class TestLinear(flow.unittest.TestCase):
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestIdentity(flow.unittest.TestCase):
+ def test_identity(test_case):
+ m = flow.nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
+ x = flow.Tensor(np.random.rand(2, 3, 4, 5))
+ y = m(x)
+ test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/oneflow/python/test/modules/test_math_ops.py b/oneflow/python/test/modules/test_math_ops.py
index 891dba2f93350425077b2444f1dd9ba38cf6cf2d..28b8765548655c0bc2a7e3de4b81cd79eb36ae37 100644
--- a/oneflow/python/test/modules/test_math_ops.py
+++ b/oneflow/python/test/modules/test_math_ops.py
@@ -22,20 +22,25 @@ import numpy as np
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
-class TestMathModule(flow.unittest.TestCase):
+class TestSin(flow.unittest.TestCase):
def test_sin(test_case):
- input = flow.Tensor(np.random.randn(2, 6, 5, 3), dtype=flow.float32)
+ input = flow.Tensor(np.random.randn(2, 6, 5, 3))
of_out = flow.sin(input)
np_out = np.sin(input.numpy())
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
- test_case.assertTrue(np.allclose(input.sin().numpy(), np_out, 1e-5, 1e-5))
- arr = np.array([-0.5461, 0.1347, -2.7266, -0.2746])
- input2 = flow.Tensor(arr, dtype=flow.float32)
- np_out2 = np.array([-0.51935846, 0.13429303, -0.40318328, -0.27116194])
- of_out2 = flow.sin(input2)
- test_case.assertTrue(np.allclose(of_out2.numpy(), np_out2, 1e-5, 1e-5))
+ def test_sin_tensor_function(test_case):
+ input = flow.Tensor(np.random.randn(8, 11, 9, 7))
+ of_out = input.sin()
+ np_out = np.sin(input.numpy())
+ test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestCos(flow.unittest.TestCase):
def test_cos(test_case):
input = flow.Tensor(np.random.randn(1, 3, 6), dtype=flow.float32)
of_out = flow.cos(input)
@@ -43,12 +48,19 @@ class TestMathModule(flow.unittest.TestCase):
test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
test_case.assertTrue(np.allclose(input.cos().numpy(), np_out, 1e-5, 1e-5))
- arr = np.array([1.4309, 1.2706, -0.8562, 0.9796])
- input2 = flow.Tensor(arr, dtype=flow.float32)
- np_out2 = np.array([0.13944048, 0.29570782, 0.6553126, 0.5573547])
- of_out2 = flow.cos(input2)
- test_case.assertTrue(np.allclose(of_out2.numpy(), np_out2))
+ def test_cos_tensor_function(test_case):
+ arr = np.random.randn(4, 5, 6, 7)
+ input = flow.Tensor(arr, dtype=flow.float32)
+ np_out = np.cos(arr)
+ of_out = input.cos()
+ test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestLog(flow.unittest.TestCase):
def test_log(test_case):
input = flow.Tensor(np.random.randn(2, 3, 4, 5), dtype=flow.float32)
of_out = flow.log(input)
@@ -58,12 +70,81 @@ class TestMathModule(flow.unittest.TestCase):
)
test_case.assertTrue(np.allclose(input.log().numpy(), np_out, equal_nan=True))
+ def test_log_nan_value(test_case):
arr = np.array([-0.7168, -0.5471, -0.8933, -1.4428, -0.1190])
- input2 = flow.Tensor(arr, dtype=flow.float32)
+ input = flow.Tensor(arr, dtype=flow.float32)
np_out = np.full((5,), np.nan)
- of_out2 = flow.log(input2)
+ of_out = flow.log(input)
test_case.assertTrue(
- np.allclose(of_out2.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
+ np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
+ )
+
+
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestStd(flow.unittest.TestCase):
+ def test_std(test_case):
+ np_arr = np.random.randn(2, 3, 4, 5)
+ input = flow.Tensor(np_arr)
+ of_out = flow.tmp.std(input, dim=2)
+ np_out = np.std(np_arr, axis=2)
+ test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5,))
+
+ def test_std_tensor_function(test_case):
+ np_arr = np.random.randn(9, 8, 7, 6)
+ input = flow.Tensor(np_arr)
+ of_out = input.std(dim=1, keepdim=False)
+ np_out = np.std(np_arr, axis=1)
+ test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5))
+
+
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestSqrt(flow.unittest.TestCase):
+ def test_sqrt(test_case):
+ input_arr = np.random.randn(3, 2, 5, 7)
+ np_out = np.sqrt(input_arr)
+ x = flow.Tensor(input_arr)
+ of_out = flow.sqrt(input=x)
+ test_case.assertTrue(
+ np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
+ )
+
+ def test_sqrt_tensor_function(test_case):
+ input_arr = np.random.randn(1, 6, 3, 8)
+ np_out = np.sqrt(input_arr)
+ x = flow.Tensor(input_arr)
+ of_out = x.sqrt()
+ test_case.assertTrue(
+ np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
+ )
+
+
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestSquare(flow.unittest.TestCase):
+ def test_square(test_case):
+ input_arr = np.random.randn(9, 4, 5, 6)
+ np_out = np.square(input_arr)
+ x = flow.Tensor(input_arr)
+ of_out = flow.square(x)
+ test_case.assertTrue(
+ np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
+ )
+
+ def test_square_tensor_function(test_case):
+ input_arr = np.random.randn(2, 7, 7, 3)
+ np_out = np.square(input_arr)
+ x = flow.Tensor(input_arr)
+ of_out = x.square()
+ test_case.assertTrue(
+ np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5, equal_nan=True)
)
diff --git a/oneflow/python/test/modules/test_normalization.py b/oneflow/python/test/modules/test_normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..0905e08ef3e56f9ab28bb132da48a3bed02d082b
--- /dev/null
+++ b/oneflow/python/test/modules/test_normalization.py
@@ -0,0 +1,126 @@
+"""
+Copyright 2020 The OneFlow Authors. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import unittest
+
+import numpy as np
+import oneflow as flow
+
+
+@unittest.skipIf(
+ not flow.unittest.env.eager_execution_enabled(),
+ ".numpy() doesn't work in lazy mode",
+)
+class TestLayerNorm(flow.unittest.TestCase):
+ def test_layernorm(test_case):
+ input_arr = np.array(
+ [
+ [
+ [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]],
+ [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]],
+ ],
+ [
+ [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]],
+ [[1.07541728, 0.11008703], [0.26361224, -0.48663723]],
+ ],
+ ],
+ dtype=np.float32,
+ )
+ output = np.array(
+ [
+ [
+ [[-0.0544118, -1.0509688], [-0.2696846, 0.4295622]],
+ [[-1.2834904, -0.4838651], [2.0891891, 0.6236691]],
+ ],
+ [
+ [[-0.8555527, -0.3554582], [0.4930190, -1.6948260]],
+ [[1.8035311, 0.4155158], [0.6362644, -0.4424936]],
+ ],
+ ],
+ dtype=np.float32,
+ )
+
+ x = flow.Tensor(input_arr)
+ m = flow.nn.LayerNorm(x.size()[1:])
+ y = m(x)
+ test_case.assertTrue(np.allclose(y.numpy(), output))
+
+ def test_layernorm_v2(test_case):
+ input_arr = np.array(
+ [
+ [
+ [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]],
+ [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]],
+ ],
+ [
+ [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]],
+ [[1.07541728, 0.11008703], [0.26361224, -0.48663723]],
+ ],
+ ],
+ dtype=np.float32,
+ )
+ output = np.array(
+ [
+ [
+ [[0.3406544, -1.5249983], [-0.0623574, 1.2467014]],
+ [[-1.2004623, -0.5688803], [1.4634399, 0.3059027]],
+ ],
+ [
+ [[-0.3180245, 0.3122248], [1.3815271, -1.3757277]],
+ [[1.4972910, -0.2341234], [0.0412391, -1.3044068]],
+ ],
+ ],
+ dtype=np.float32,
+ )
+ x = flow.Tensor(input_arr)
+ m = flow.nn.LayerNorm([2, 2], eps=1e-5)
+ y = m(x)
+ test_case.assertTrue(np.allclose(y.numpy(), output))
+
+ def test_layernorm_v3(test_case):
+ input_arr = np.array(
+ [
+ [
+ [[-0.16046895, -1.03667831], [-0.34974465, 0.26505867]],
+ [[-1.24111986, -0.53806001], [1.72426331, 0.43572459]],
+ ],
+ [
+ [[-0.77390957, -0.42610624], [0.16398858, -1.35760343]],
+ [[1.07541728, 0.11008703], [0.26361224, -0.48663723]],
+ ],
+ ],
+ dtype=np.float32,
+ )
+ output = np.array(
+ [
+ [
+ [[0.9999740, -0.9999740], [-0.9999470, 0.9999470]],
+ [[-0.9999595, 0.9999595], [0.9999880, -0.9999880]],
+ ],
+ [
+ [[-0.9998344, 0.9998341], [0.9999914, -0.9999914]],
+ [[0.9999787, -0.9999787], [0.9999645, -0.9999645]],
+ ],
+ ],
+ dtype=np.float32,
+ )
+ x = flow.Tensor(input_arr)
+ m = flow.nn.LayerNorm(2, elementwise_affine=True)
+ y = m(x)
+ test_case.assertTrue(np.allclose(y.numpy(), output))
+
+
+if __name__ == "__main__":
+ unittest.main()