Skip to content
Snippets Groups Projects
Unverified Commit 95e1f27c authored by Lyon's avatar Lyon Committed by GitHub
Browse files

Dev merge modules part2 (#4776)


* add logsoftmax module

* add sigmoid module

* add maskedfill module

* format

* add unsquuze module

* add eq module

* add arange module

* fix arange module

* add softmax module

* refine

* update logsoftmax module

* update eq module

* update eq module

* refine

* refine according to comments

* refine softmax test case

* fix unsqueeze module

* refine according to comments

* refine

* format

* format

Co-authored-by: default avatarYinggang Wang <wyg19970408@gmail.com>
Co-authored-by: default avatarXiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 3bae5a15
No related branches found
No related tags found
No related merge requests found
......@@ -18,17 +18,25 @@ import oneflow._oneflow_internal
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
from typing import Optional
@oneflow_export("nn.Sigmoid")
class Sigmoid(Module):
def __init__(self):
super().__init__()
self._op = flow.builtin_op("sigmoid").Input("in").Output("out").Build()
def _softmax_need_transpose(x, axis):
assert type(axis) is int
dim_num = len(x.shape)
assert dim_num >= 2
if axis < 0:
axis += dim_num
assert axis >= 0
assert axis < dim_num
def forward(self, x):
res = self._op(x)[0]
return res
need_transpose = False
permute = list(range(dim_num))
if axis != dim_num - 1:
need_transpose = True
permute[axis] = permute[-1]
permute[-1] = axis
return need_transpose, permute
@oneflow_export("nn.ReLU")
......@@ -199,7 +207,7 @@ def gelu_op(x):
Args:
x (oneflow.Tensor): Input Tensor
Returns:
oneflow.Tensor: A Tensor.
......@@ -216,8 +224,221 @@ def gelu_op(x):
gelu = flow.nn.GELU()
out = gelu(input)
# out [-0.15426877, 0., 0.34573123]
"""
return GELU()(x)
@oneflow_export("nn.Sigmoid")
class Sigmoid(Module):
r"""Applies the element-wise function:
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
x = flow.Tensor(
np.array(
[
[0.81733328, 0.43621480, 0.10351428],
[-1.15555191, -0.67776406, 0.27372134],
]
)
)
m = flow.nn.Sigmoid() # or y = flow.sigmoid(x)
y = m(x)
# [[0.69366997, 0.60735673, 0.52585548],
# [0.23947647, 0.33676055, 0.56800622]]
"""
def __init__(self):
super().__init__()
self._op = flow.builtin_op("sigmoid").Input("in").Output("out").Build()
def forward(self, x):
return self._op(x)[0]
@oneflow_export("sigmoid")
@register_tensor_op("sigmoid")
def sigmoid_op(x):
r"""Applies the element-wise function:
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
x = flow.Tensor(
np.array(
[
[0.81733328, 0.43621480, 0.10351428],
[-1.15555191, -0.67776406, 0.27372134],
]
)
)
y = x.sigmoid()
# [[0.69366997, 0.60735673, 0.52585548],
# [0.23947647, 0.33676055, 0.56800622]]
"""
return Sigmoid()(x)
@oneflow_export("nn.Softmax")
@oneflow_export("softmax")
class Softmax(Module):
def __init__(self, dim: Optional[int] = None):
super().__init__()
self.axis = -1 if dim is None else dim
self._op = flow.builtin_op("softmax").Input("in").Output("out").Build()
def forward(self, x):
need_transpose, permute = _softmax_need_transpose(x, self.axis)
if need_transpose:
x = x.transpose(perm=permute)
res = self._op(x)[0]
if need_transpose:
res = res.transpose(perm=permute)
return res
@oneflow_export("softmax")
def softmax_op(tensor, dim=None):
r"""Applies the Softmax function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range [0,1] and sum to 1.
Softmax is defined as:
.. math::
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
When the input Tensor is a sparse tensor then the unspecifed
values are treated as ``-inf``.
Shape:
- Input: :math:`(*)` where `*` means, any number of additional
dimensions
- Output: :math:`(*)`, same shape as the input
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [0, 1]
Args:
dim (int): A dimension along which Softmax will be computed (so every slice
along dim will sum to 1).
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
m = flow.nn.Softmax(dim = 2)
x = flow.Tensor(
np.array(
[[[[-0.46716809, 0.40112534, 0.61984003],
[-1.31244969, -0.42528763, 1.47953856]]],
[[[ 1.02978742, -0.49383053, 1.88214159],
[ 1.35351622, -1.46251285, -1.40751374]]]]
)
)
y = m(x)
# [[[[0.6995764 0.6955959 0.29740235]
# [0.3004236 0.30440408 0.7025977 ]]]
# [[[0.4197673 0.7248568 0.96407217]
# [0.58023274 0.27514324 0.03592779]]]]
"""
return Softmax(dim)(tensor)
@oneflow_export("nn.LogSoftmax")
class LogSoftmax(Module):
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
input Tensor.
The LogSoftmax formulation can be simplified as:
.. math::
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
Args:
dim (int): A dimension along which LogSoftmax will be computed.
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
m = flow.nn.LogSoftmax(dim=1)
x = flow.Tensor(
np.array(
[[ 0.4296, -1.1957, 2.5463],
[ 1.2552, -1.5747, 0.6923]]
)
)
y = m(x)
# [[-2.251349 -3.8766491 -0.13464898]
# [-0.48770458 -3.3176045 -1.0506046 ]]
"""
def __init__(
self, dim: Optional[int] = 1,
):
super().__init__()
self.dim = dim
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, "dim"):
self.dim = None
def forward(self, x):
need_transpose, permute = _softmax_need_transpose(x, self.dim)
if need_transpose:
x = x.transpose(perm=permute)
x = flow.softmax(x)
res = flow.log(x)
if need_transpose:
res = res.transpose(perm=permute)
return res
def extra_repr(self):
return "dim={dim}".format(dim=self.dim)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
class Arange(Module):
def __init__(self, start, end, step=1) -> None:
super().__init__()
self.start = 0 if start is None else start
self.end = 1 if end is None else end
self.step = step
self.dtype = flow.int64 # "Only support dtype: `flow.int64` for now!"
assert self.end > self.start, "end should be larger than start"
assert self.step <= self.end - self.start, "step is ilegal"
assert type(self.start) == int, "Params `start`'s type should be int"
assert type(self.end) == int, "Params `end`'s type should be int"
assert type(self.step) == int, "Params `step`'s type should be int"
# TODO: zhaoluyang Put dtype attr in forward() after bug fixed
self._op_arange = (
flow.builtin_op("range").Output("out").Attr("dtype", self.dtype).Build()
)
def forward(self):
return self._op_arange(start=self.start, delta=self.step, limit=self.end)[0]
@oneflow_export("arange")
def arange_op(start=1, end=1, step=1):
r"""
Returns a 1-D tensor of size :math:`\left\lfloor \frac{\text{end} - \text{start}}{\text{step}} \right\rfloor + 1`
with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is
the gap between two values in the tensor.
.. math::
\text{out}_{i+1} = \text{out}_i + \text{step}.
Args:
start (float): the starting value for the set of points. Default: ``0``.
end (float): the ending value for the set of points
step (float): the gap between each pair of adjacent points. Default: ``1``.
Keyword args:
dtype: If `dtype` is not given, the `dtype` is inferred to be the default dtype.
For example:
.. code-block:: python
import oneflow as flow
y = flow.arange(0, 5)
# [0, 1, 2, 3, 4]
"""
return Arange(start, end, step)()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
class Eq(Module):
def __init__(self) -> None:
super().__init__()
self.eq_op = (
flow.builtin_op("broadcast_equal").Input("x").Input("y").Output("z").Build()
)
def forward(self, input, other):
if isinstance(other, flow.Tensor):
for i in range(len(input.size())):
assert (
input.shape[i] >= other.shape[i]
), "The second tensor's shape should broadcastable with the first argument."
elif isinstance(other, int) or isinstance(other, float):
raise NotImplementedError(
"Unsupport data type, int or float data type are not support yet!"
)
else:
raise NotImplementedError(
"Unsupport data type, The second argument can be a tensor whose shape is broadcastable with the first argument."
)
return self.eq_op(input, other)[0]
@oneflow_export("eq", "equal")
@register_tensor_op("eq")
def eq_op(input, other):
r"""
Computes element-wise equality.
The second argument can be a number or a tensor whose shape is broadcastable with the first argument.
Args:
input (Tensor): the tensor to compare
other (Tensor): the tensor to compare
Returns:
A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
input = flow.Tensor(np.array([2, 3, 4, 5]), dtype=flow.float32)
other = flow.Tensor(np.array([2, 3, 4, 1]), dtype=flow.float32)
y = flow.eq(input, other)
# [1 1 1 0]
"""
return Eq()(input, other)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
class MaskedFill(Module):
def __init__(self, value) -> None:
super().__init__()
self.value = value
self._where_op = (
flow.builtin_op("where")
.Input("condition")
.Input("x")
.Input("y")
.Output("out")
.Build()
)
def forward(self, input, mask):
in_shape = tuple(input.shape)
value_like_x = flow.Tensor(*in_shape)
value_like_x.fill_(self.value)
return self._where_op(mask, value_like_x, input)[0]
@oneflow_export("tmp.masked_fill")
@register_tensor_op("masked_fill")
def masked_fill_op(tensor, mask, value):
r"""
Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is True.
The shape of :attr:`mask` must be broadcastable with the shape of the underlying tensor.
Args:
mask (BoolTensor) – the boolean mask
value (float) – the value to fill in with
For example:
.. code-block:: python
import oneflow as flow
import numpy as np
in_arr = np.array(
[[[-0.13169311, 0.97277078, 1.23305363, 1.56752789],
[-1.51954275, 1.87629473, -0.53301206, 0.53006478],
[-1.38244183, -2.63448052, 1.30845795, -0.67144869]],
[[ 0.41502161, 0.14452418, 0.38968 , -1.76905653],
[ 0.34675095, -0.7050969 , -0.7647731 , -0.73233418],
[-1.90089858, 0.01262963, 0.74693893, 0.57132389]]]
)
fill_value = 8.7654321 # random value e.g. -1e9 3.1415
input = flow.Tensor(in_arr, dtype=flow.float32)
mask = flow.Tensor((in_arr > 0).astype(np.int8), dtype=flow.int)
output = input.masked_fill(mask, fill_value)
# [[[-0.13169311 8.765432 8.765432 8.765432 ]
# [-1.5195427 8.765432 -0.53301203 8.765432 ]
# [-1.3824419 -2.6344805 8.765432 -0.6714487 ]]
# [[ 8.765432 8.765432 8.765432 -1.7690566 ]
# [ 8.765432 -0.7050969 -0.7647731 -0.7323342 ]
# [-1.9008986 8.765432 8.765432 8.765432 ]]]
"""
return MaskedFill(value)(tensor, mask)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export
from oneflow.python.framework.tensor import register_tensor_op
class Unsqueeze(Module):
def __init__(self, dim: int = 0) -> None:
super().__init__()
self.dim = dim
self._op = flow.builtin_op("expand_dims").Input("in").Output("out").Build()
def forward(self, input):
assert (
-(1 + input.ndimension()) <= self.dim <= input.ndimension()
), "dim should within the range [-input.ndimension() - 1, input.ndimension() + 1)"
if self.dim < 0:
self.dim = 1 + input.ndimension() + self.dim
return self._op(input, axis=self.dim)[0]
@oneflow_export("unsqueeze")
@register_tensor_op("unsqueeze")
def unsqueeze_op(input, dim):
r"""Returns a new tensor with a dimension of size one inserted at the
specified position.
The returned tensor shares the same underlying data with this tensor.
A :attr:`dim` value within the range `[-input.ndimension() - 1, input.ndimension() + 1)`
can be used. Negative :attr:`dim` will correspond to :meth:`unsqueeze`
applied at :attr:`dim` = ``dim + input.ndimension() + 1``.
Args:
input (Tensor) – the input tensor.
dim (int): the index at which to insert the singleton dimension
For example:
.. code-block:: python
import numpy as np
import oneflow as flow
x = flow.Tensor(np.random.rand(2, 3, 4))
y = x.unsqueeze(2)
"""
return Unsqueeze(dim)(input)
......@@ -119,5 +119,121 @@ class TestGeLU(flow.unittest.TestCase):
test_case.assertTrue(np.allclose(y.numpy(), z, rtol=1e-4, atol=1e-4))
def numpy_sigmoid(x):
return 1.0 / (1 + np.exp(-x))
def numpy_softmax(x, axis):
x = x - x.max(axis=axis, keepdims=True)
y = np.exp(x)
return y / y.sum(axis=axis, keepdims=True)
def numpy_logsoftmax(x, dim):
e_x = np.exp(x - np.max(x, axis=dim, keepdims=True))
return np.log(e_x / e_x.sum(axis=dim, keepdims=True))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestSigmoidModule(flow.unittest.TestCase):
def test_sigmoid(test_case):
m = flow.nn.Sigmoid()
input_arr = np.random.randn(2, 3, 4, 5)
x = flow.Tensor(input_arr)
y = m(x)
y2 = flow.sigmoid(x)
y3 = x.sigmoid()
output = numpy_sigmoid(input_arr)
test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
test_case.assertTrue(np.allclose(y2.numpy(), output, rtol=1e-05))
test_case.assertTrue(np.allclose(y3.numpy(), output, rtol=1e-05))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestSoftmaxModule(flow.unittest.TestCase):
def test_softmax(test_case):
axis = 0
m = flow.nn.Softmax(dim=axis)
arr = np.random.randn(2, 3, 4, 5)
x = flow.Tensor(arr)
y = m(x)
output = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
def test_softmax_dim_1(test_case):
axis = 1
m = flow.nn.Softmax(dim=axis)
arr = np.random.randn(9, 7, 8, 16)
x = flow.Tensor(arr)
y = m(x)
output = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
def test_softmax_dim_2(test_case):
axis = 2
m = flow.nn.Softmax(dim=axis)
arr = np.random.randn(2, 5, 6, 3)
x = flow.Tensor(arr)
y = m(x)
output = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
def test_softmax_dim_3(test_case):
axis = 3
m = flow.nn.Softmax(dim=axis)
arr = np.random.randn(1, 3, 4, 7)
x = flow.Tensor(arr)
y = m(x)
output = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
axis2 = -1
m2 = flow.nn.Softmax(dim=axis)
y2 = m(x)
output2 = numpy_softmax(arr, axis)
test_case.assertTrue(np.allclose(y2.numpy(), output2, rtol=1e-05))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestLogSoftmaxModule(flow.unittest.TestCase):
def test_logsoftmax(test_case):
dim = 1
m = flow.nn.LogSoftmax(dim)
input_arr = np.random.randn(4, 7)
x = flow.Tensor(input_arr)
y = m(x)
output = numpy_logsoftmax(input_arr, dim)
test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
def test_logsoftmax_dim_2(test_case):
dim = 2
m = flow.nn.LogSoftmax(dim)
input_arr = np.random.randn(3, 4, 5)
x = flow.Tensor(input_arr)
y = m(x)
output = numpy_logsoftmax(input_arr, dim)
test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
def test_logsoftmax_dim_3(test_case):
dim = 3
m = flow.nn.LogSoftmax(dim)
input_arr = np.random.randn(8, 9, 7, 3)
x = flow.Tensor(input_arr)
y = m(x)
output = numpy_logsoftmax(input_arr, dim)
test_case.assertTrue(np.allclose(y.numpy(), output, rtol=1e-05))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestArange(flow.unittest.TestCase):
def test_arange(test_case):
np_out = np.arange(5)
of_out = flow.arange(0, end=5)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
np_out2 = np.arange(0, 20, 2)
of_out2 = flow.arange(0, 20, step=2)
test_case.assertTrue(np.allclose(of_out2.numpy(), np_out2))
def test_arange_v2(test_case):
np_out = np.arange(20)
of_out = flow.arange(start=0, end=20)
test_case.assertTrue(np.allclose(of_out.numpy(), np_out))
np_out2 = np.arange(0, 100, 3)
of_out2 = flow.arange(start=0, end=100, step=3)
test_case.assertTrue(np.allclose(of_out2.numpy(), np_out2))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestEq(flow.unittest.TestCase):
def test_eq(test_case):
arr1 = np.array([2, 3, 4, 5,])
arr2 = np.array([2, 3, 4, 1])
input = flow.Tensor(arr1, dtype=flow.float32)
other = flow.Tensor(arr2, dtype=flow.float32)
of_out = flow.eq(input, other)
of_out2 = flow.equal(input, other)
np_out = np.equal(arr1, arr2)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
test_case.assertTrue(np.array_equal(of_out2.numpy(), np_out))
def test_eq_tensor_function(test_case):
arr1 = np.random.randint(1, 10, size=(2, 3, 4, 5))
arr2 = np.random.randint(1, 10, size=(2, 3, 4, 5))
input = flow.Tensor(arr1, dtype=flow.float32)
other = flow.Tensor(arr2, dtype=flow.float32)
of_out = input.eq(other)
np_out = np.equal(arr1, arr2)
test_case.assertTrue(np.array_equal(of_out.numpy(), np_out))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestMaskedFill(flow.unittest.TestCase):
def test_masked_fill(test_case):
input_arr = np.array(
[
[
[-0.13169311, 0.97277078, 1.23305363, 1.56752789],
[-1.51954275, 1.87629473, -0.53301206, 0.53006478],
[-1.38244183, -2.63448052, 1.30845795, -0.67144869],
],
[
[0.41502161, 0.14452418, 0.38968, -1.76905653],
[0.34675095, -0.7050969, -0.7647731, -0.73233418],
[-1.90089858, 0.01262963, 0.74693893, 0.57132389],
],
]
)
output = np.array(
[
[
[-0.1316931, 8.7654321, 8.7654321, 8.7654321],
[-1.5195428, 8.7654321, -0.5330121, 8.7654321],
[-1.3824418, -2.6344805, 8.7654321, -0.6714487],
],
[
[8.7654321, 8.7654321, 8.7654321, -1.7690565],
[8.7654321, -0.7050969, -0.7647731, -0.7323342],
[-1.9008986, 8.7654321, 8.7654321, 8.7654321],
],
]
)
fill_value = 8.7654321 # random value e.g. -1e9 3.14
input = flow.Tensor(input_arr, dtype=flow.float32)
mask = flow.Tensor((input_arr > 0).astype(np.int8), dtype=flow.int)
of_out = input.masked_fill(mask, value=fill_value)
test_case.assertTrue(np.allclose(of_out.numpy(), output))
input2 = flow.Tensor(input_arr, dtype=flow.float32)
mask2 = flow.Tensor((input_arr > 0).astype(np.int8), dtype=flow.int)
of_out2 = flow.tmp.masked_fill(input2, mask, value=fill_value)
test_case.assertTrue(np.allclose(of_out2.numpy(), output))
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import oneflow as flow
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestUnsqueeze(flow.unittest.TestCase):
def test_unsqueeze(test_case):
np_arr = np.random.rand(2, 6, 9, 3)
x = flow.Tensor(np_arr)
y = flow.unsqueeze(x, dim=1)
output = np.expand_dims(np_arr, axis=1)
test_case.assertTrue(np.allclose(output, y.numpy(), rtol=1e-05))
def test_unsqueeze_tensor_function(test_case):
np_arr = np.random.rand(2, 3, 4)
x = flow.Tensor(np_arr)
y = x.unsqueeze(dim=2)
output = np.expand_dims(np_arr, axis=2)
test_case.assertTrue(np.allclose(output, y.numpy(), rtol=1e-05))
def test_unsqueeze_different_dim(test_case):
np_arr = np.random.rand(4, 5, 6, 7)
x = flow.Tensor(np_arr)
for axis in range(-5, 5):
y = flow.unsqueeze(x, dim=axis)
output = np.expand_dims(np_arr, axis=axis)
test_case.assertTrue(np.allclose(output, y.numpy(), rtol=1e-05))
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment