diff --git a/oneflow/python/framework/tensor.py b/oneflow/python/framework/tensor.py index c6ee560096b6b139bc70a75237e1c9bdd95de73c..ce02de4b7b147428bdd11c7db04fd9f60c854278 100644 --- a/oneflow/python/framework/tensor.py +++ b/oneflow/python/framework/tensor.py @@ -301,6 +301,33 @@ class Tensor: def __deepcopy__(self, memo): TODO() + def __mul__(self, other): + return self.mul(other) + + def __rmul__(self, other): + return self.mul(other) + + def __add__(self, other): + return self.add(other) + + def __radd__(self, other): + return self.add(other) + + def __sub__(self, other): + return self.sub(other) + + def __rsub__(self, other): + return flow.sub(other, self) + + def __truediv__(self, other): + return self.div(other) + + def __rtruediv__(self, other): + return flow.div(other, self) + + def __neg__(self): + return flow.mul(-1, self) + def _determine_if_needed(self, determining_initializer=None): if not self.is_determined: self.determine(determining_initializer) diff --git a/oneflow/python/nn/modules/loss.py b/oneflow/python/nn/modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e471f0068a7009cf3824718c7ab05da0cd919dcc --- /dev/null +++ b/oneflow/python/nn/modules/loss.py @@ -0,0 +1,88 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from typing import Optional + +import oneflow as flow +from oneflow.python.oneflow_export import oneflow_export +from oneflow.python.nn.module import Module + + +@oneflow_export("nn.CrossEntropyLoss") +class CrossEntropyLoss(Module): + r""" + Args: + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + + For example: + .. code-block:: python + + import oneflow as flow + input = flow.Tensor( + [[-0.1664078, -1.7256707, -0.14690138], + [-0.21474946, 0.53737473, 0.99684894], + [-1.135804, -0.50371903, 0.7645404]], dtype=flow.float32) + target = flow.Tensor(np.array([0, 1, 2]), dtype=flow.int32) + out = flow.nn.CrossEntropyLoss(reduction="none")(input, target) + # out: [0.80199665 1.1166505 0.35826027] + out_sum = flow.nn.CrossEntropyLoss(reduction="sum")(input, target) + # out_sum: [2.2769074] + out_mean = flow.nn.CrossEntropyLoss(reduction="mean")(input, target) + # out_mean: [0.7589692] + + """ + + def __init__( + self, + weight=None, + ignore_index: Optional[int] = None, + reduction: Optional[str] = "mean", + ) -> None: + super().__init__() + if weight is not None: + raise ValueError("Argument weight is not supported yet") + if ignore_index is not None: + raise ValueError("Argument ignore_index is not supported yet") + assert reduction in [ + "sum", + "none", + "mean", + None, + ], "only 'sum', 'mean' and None supported by now" + + self.reduction = reduction + self._op = ( + flow.builtin_op("sparse_softmax_cross_entropy") + .Input("prediction") + .Input("label") + .Output("prob") + .Output("out") + .Build() + ) + + def forward(self, input, target): + prob, out = self._op(input, target, depth=input.shape[len(input.shape) - 1]) + if self.reduction == "mean": + return flow.mean(out) + elif self.reduction == "sum": + return flow.sum(out) + else: + return out diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py index 7384ea6900b02c831ee6b84cbf162b12d0458051..34cbdee759a445c511dc97a3771b2d641bf057ac 100644 --- a/oneflow/python/nn/modules/math_ops.py +++ b/oneflow/python/nn/modules/math_ops.py @@ -13,12 +13,534 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ + +import collections +from typing import Optional, Sequence, Union + import oneflow as flow -from oneflow.python.nn.module import Module from oneflow.python.oneflow_export import oneflow_export +from oneflow.python.nn.module import Module from oneflow.python.framework.tensor import register_tensor_op +def _check_axis(axis, shape): + # TODO(yaochi): refine this function when all related ops in `python/ops/math_ops.py` migrated + if axis is None: + axis = list(range(len(shape))) + + if isinstance(axis, int): + axis = [axis] + + assert isinstance(axis, (list, tuple)), "Invalid axis {}".format(axis) + for x in axis: + if x < 0: + x += len(shape) + assert x >= 0 and x < len(shape), "Invalid axis {}, len(shape): {}".format( + axis, len(shape) + ) + + return axis + + +class Sum(Module): + def __init__( + self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False + ) -> None: + super().__init__() + + self.axis = axis + self.keepdims = keepdims + self._op = ( + flow.builtin_op("reduce_sum") + .Input("input_tensor") + .Output("output_tensor") + .Attr("keepdims", keepdims) + .Build() + ) + + def forward(self, input): + axis_checked = _check_axis(self.axis, input.shape) + if len(axis_checked) == 0: + return input + return self._op(input, axis=axis_checked)[0] + + +@oneflow_export("sum") +@register_tensor_op("sum") +def _sum(input, dim=None, keepdims=False): + r"""Computes the sum of row of elements in a tensor in the given axis, if the axis is None, sum of all elements will be caculated. + For example: + + .. code-block:: python + + import oneflow as flow + input = flow.Tensor(np.random.randn(4, 5, 6), dtype=flow.float32) + of_out = flow.sum(input, dim=(2,1)) + + """ + + return Sum(dim, keepdims)(input) + + +class ScalarMul(Module): + def __init__(self, operand) -> None: + super().__init__() + self._op = flow.builtin_op("scalar_mul").Input("in").Output("out") + if isinstance(operand, int): + self._op = ( + self._op.Attr("has_int_operand", True) + .Attr("has_float_operand", False) + .Attr("int_operand", operand) + .Attr("float_operand", 0.0) + .Build() + ) + elif isinstance(operand, float): + self._op = ( + self._op.Attr("has_int_operand", False) + .Attr("has_float_operand", True) + .Attr("int_operand", 0) + .Attr("float_operand", operand) + .Build() + ) + else: + raise ValueError("operand type can only be int or float") + + def forward(self, x): + return self._op(x)[0] + + +class ScalarMulByTensor(Module): + def __init__(self) -> None: + super().__init__() + self._op = ( + flow.builtin_op("scalar_mul_by_tensor") + .Input("x") + .Input("scalar") + .Output("y") + .Build() + ) + + def forward(self, x, y): + return self._op(x, y)[0] + + +class ElementwiseMul(Module): + def __init__(self) -> None: + super().__init__() + self._op = ( + flow.builtin_op("multiply").Input("x").Input("y").Output("out").Build() + ) + + def forward(self, x, y): + return self._op(x, y)[0] + + +class BroadcastMul(Module): + def __init__(self) -> None: + super().__init__() + self._op = ( + flow.builtin_op("broadcast_mul").Input("x").Input("y").Output("z").Build() + ) + + def forward(self, x, y): + return self._op(x, y)[0] + + +@oneflow_export("mul") +@register_tensor_op("mul") +def _mul(x, y): + r"""Computes the multiplication of x by y for each element, scalar and broadcast promotation are supported. + The formula is: + .. math:: + out = x \times y + For example: + + .. code-block:: python + + import oneflow as flow + + # element-wise multiply + x = flow.Tensor(np.random.randn(2,3)) + y = flow.Tensor(np.random.randn(2,3)) + out = flow.mul(x,y).numpy() + print(out.shape) # (2,3) + + # scalar mutiply + x = 5 + y = flow.Tensor(np.random.randn(2,3)) + out = flow.mul(x,y).numpy() + print(out.shape) # (2,3) + + # broadcast mutiply + x = flow.Tensor(np.random.randn(1,1)) + y = flow.Tensor(np.random.randn(2,3)) + out = flow.mul(x,y).numpy() + print(out.shape) # (2,3) + + """ + + if isinstance(x, (int, float)): + return ScalarMul(x)(y) + elif isinstance(y, (int, float)): + return ScalarMul(y)(x) + elif x.shape == y.shape: + return ElementwiseMul()(x, y) + elif x.shape == (1,): + return ScalarMulByTensor()(y, x) + elif y.shape == (1,): + return ScalarMulByTensor()(x, y) + else: + return BroadcastMul()(x, y) + + +class Mean(Module): + def __init__( + self, + axis: Optional[Union[collections.Sized, int]] = None, + keepdims: bool = False, + ) -> None: + super().__init__() + self.keepdims = keepdims + self.axis = axis + # TODO: add if input.is_dynamic branch like flow.math.reduce_mean + if axis is None: + self.axes = [] + else: + self.axes = list(axis) if isinstance(axis, collections.Sized) else [axis] + + def forward(self, input_tensor): + reduce_sum = flow.sum(input_tensor, dim=self.axis, keepdims=self.keepdims) + reduce_count = 1 + if len(self.axes) == 0: + for dim in input_tensor.shape: + reduce_count *= dim + else: + for i in self.axes: + reduce_count *= input_tensor.shape[i] + return flow.mul(reduce_sum, 1.0 / reduce_count) + + +@oneflow_export("mean") +@register_tensor_op("mean") +def _mean(input_tensor, dim=None, keepdim=False): + r"""Computes the mean of row of elements in a tensor in the given axis, + if the axis is None, mean of all elements will be caculated. + + For example: + + .. code-block:: python + + import oneflow as flow + + input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) + out = flow.mean(input) + # out: [3.5] + print(out.numpy()) + + input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) + out = flow.mean(input, axis=0) + # out: [2.5 3.5 4.5] + print(out.numpy()) + + input = flow.Tensor([[1, 2, 3], [4, 5, 6]]) + out = flow.mean(input, axis=1) + # out: [ 2. 5.] + print(out.numpy()) + + """ + + return Mean(axis=dim, keepdims=keepdim)(input_tensor) + + +class ScalarSubByTensor(Module): + def __init__(self) -> None: + super().__init__() + self._op = ( + flow.builtin_op("scalar_sub_by_tensor") + .Input("x") + .Input("scalar") + .Output("y") + .Build() + ) + + def forward(self, x, y): + return self._op(x, y)[0] + + +class BroadcastSub(Module): + def __init__(self) -> None: + super().__init__() + self._op = ( + flow.builtin_op("broadcast_sub").Input("x").Input("y").Output("z").Build() + ) + + def forward(self, x, y): + return self._op(x, y)[0] + + +class ScalarAdd(Module): + def __init__(self, operand) -> None: + super().__init__() + self._op = flow.builtin_op("scalar_add").Input("in").Output("out") + + if isinstance(operand, int): + self._op = ( + self._op.Attr("has_int_operand", True) + .Attr("has_float_operand", False) + .Attr("int_operand", operand) + .Attr("float_operand", 0.0) + .Build() + ) + elif isinstance(operand, float): + self._op = ( + self._op.Attr("has_int_operand", False) + .Attr("has_float_operand", True) + .Attr("int_operand", 0) + .Attr("float_operand", operand) + .Build() + ) + else: + raise ValueError("operand type can only be int or float") + + def forward(self, x): + return self._op(x)[0] + + +@oneflow_export("sub") +@register_tensor_op("sub") +def _sub(x, y): + r"""Computes the subtraction of x by y for each element, scalar and broadcast promotation are supported. + The formula is: + .. math:: + out = x - y + For example: + + .. code-block:: python + + import oneflow as flow + + # element-wise subtract + x = flow.Tensor(np.random.randn(2,3)) + y = flow.Tensor(np.random.randn(2,3)) + out = flow.sub(x,y).numpy() + print(out.shape) # (2,3) + + # scalar subtract + x = 5 + y = flow.Tensor(np.random.randn(2,3)) + out = flow.sub(x,y).numpy() + print(out.shape) # (2,3) + + # broadcast subtract + x = flow.Tensor(np.random.randn(1,1)) + y = flow.Tensor(np.random.randn(2,3)) + out = flow.sub(x,y).numpy() + print(out.shape) # (2,3) + + """ + + if isinstance(x, (int, float)): + return ScalarAdd(x)(ScalarMul(-1)(y)) + elif isinstance(y, (int, float)): + return ScalarAdd(-1 * y)(x) + elif x.shape == y.shape: + # TODO: add element-wise op + return BroadcastSub()(x, y) + elif y.shape == (1,): + return ScalarSubByTensor()(x, y) + else: + return BroadcastSub()(x, y) + + +class BroadcastDiv(Module): + def __init__(self) -> None: + super().__init__() + self._op = ( + flow.builtin_op("broadcast_div").Input("x").Input("y").Output("z").Build() + ) + + def forward(self, x, y): + return self._op(x, y)[0] + + +class ScalarDivByTensor(Module): + def __init__(self) -> None: + super().__init__() + self._op = ( + flow.builtin_op("scalar_div_by_tensor") + .Input("x") + .Input("scalar") + .Output("y") + .Build() + ) + + def forward(self, x, scalar): + return self._op(x, scalar)[0] + + +@oneflow_export("div") +@register_tensor_op("div") +def _div(x, y): + r"""Computes the division of x by y for each element, scalar and broadcast promotation are supported. + The formula is: + .. math:: + out = \frac{X}{Y} + Args: + x (Union[int, float, flow.Tensor]): X. + y (Union[int, float, flow.Tensor]): Y. + For example: + .. code-block:: python + + import oneflow as flow + + # element-wise divide + x = flow.Tensor(np.random.randn(2,3)) + y = flow.Tensor(np.random.randn(2,3)) + out = flow.div(x,y).numpy() + print(out.shape) # (2,3) + + # scalar divide + x = 5 + y = flow.Tensor(np.random.randn(2,3)) + out = flow.div(x,y).numpy() + print(out.shape) # (2,3) + + # broadcast divide + x = flow.Tensor(np.random.randn(1,1)) + y = flow.Tensor(np.random.randn(2,3)) + out = flow.div(x,y).numpy() + print(out.shape) # (2,3) + + """ + + if isinstance(x, (int, float)): + return ScalarMul(x)(flow.reciprocal(y)) + elif isinstance(y, (int, float)): + if y == 0 or y == 0.0: + y = 0.0 + else: + y = 1.0 / (float(y)) + return ScalarMul(y)(x) + elif x.shape == y.shape: + return BroadcastDiv()(x, y) + elif y.shape == (1,): + return ScalarDivByTensor()(x, y) + else: + return BroadcastDiv()(x, y) + + +class Reciprocal(Module): + def __init__(self) -> None: + super().__init__() + self._op = flow.builtin_op("reciprocal_no_nan").Input("x").Output("y").Build() + + def forward(self, x): + return self._op(x)[0] + + +@oneflow_export("reciprocal") +@register_tensor_op("reciprocal") +def _reciprocal(x): + r"""Computes the safe reciprocal of x. If x is zero, the reciprocal will + be also set to zero. + + For example: + + .. code-block:: python + + import oneflow as flow + x = flow.Tensor(np.array([[1, 2, 3], [4, 5, 6]])) + out = flow.reciprocal()(x) + # out [[1. 0.5 0.33333334] + [0.25 0.2 0.16666667]] + + """ + + return Reciprocal()(x) + + +class ScalarAddByTensor(Module): + def __init__(self) -> None: + super().__init__() + self._op = ( + flow.builtin_op("scalar_add_by_tensor") + .Input("x") + .Input("scalar") + .Output("y") + .Build() + ) + + def forward(self, x, y): + return self._op(x, y)[0] + + +class ElementwiseAdd(Module): + def __init__(self) -> None: + super().__init__() + self._op = flow.builtin_op("add_n").Input("in", 2).Output("out").Build() + + def forward(self, x, y): + return self._op(x, y)[0] + + +class BroadcastAdd(Module): + def __init__(self) -> None: + super().__init__() + self._op = ( + flow.builtin_op("broadcast_add").Input("x").Input("y").Output("z").Build() + ) + + def forward(self, x, y): + return self._op(x, y)[0] + + +@oneflow_export("add") +@register_tensor_op("add") +def _add(x, y): + r"""Computes the addition of x by y for each element, scalar and broadcast promotation are supported. + The formula is: + .. math:: + out = x + y + For example: + + .. code-block:: python + + import oneflow as flow + + # element-wise add + x = flow.Tensor(np.random.randn(2,3)) + y = flow.Tensor(np.random.randn(2,3)) + out = flow.add(x,y).numpy() + print(out.shape) # (2,3) + + # scalar add + x = 5 + y = flow.Tensor(np.random.randn(2,3)) + out = flow.add(x,y).numpy() + print(out.shape) # (2,3) + + # broadcast add + x = flow.Tensor(np.random.randn(1,1)) + y = flow.Tensor(np.random.randn(2,3)) + out = flow.add(x,y).numpy() + print(out.shape) # (2,3) + + """ + + if isinstance(x, (int, float)): + return ScalarAdd(x)(y) + elif isinstance(y, (int, float)): + return ScalarAdd(y)(x) + elif x.shape == y.shape: + return ElementwiseAdd()(x, y) + elif x.shape == (1,): + return ScalarAddByTensor()(y, x) + elif y.shape == (1,): + return ScalarAddByTensor()(x, y) + else: + return BroadcastAdd()(x, y) + + class Sin(Module): def __init__(self) -> None: super().__init__() @@ -33,25 +555,18 @@ class Sin(Module): def sin_op(tensor): r""" Returns a new tensor with the sine of the elements of :attr:`input`. - .. math:: \text{out}_{i} = \sin(\text{input}_{i}) - Args: input (Tensor) 鈥� the input tensor. - For example: - .. code-block:: python - import oneflow as flow import numpy as np - arr = np.array([-0.5461, 0.1347, -2.7266, -0.2746]) input = flow.Tensor(arr, dtype=flow.float32) output = flow.sin(input) # [-0.51935846 0.13429303 -0.40318328 -0.27116194] - """ return Sin()(tensor) @@ -70,20 +585,14 @@ class Cos(Module): def cos_op(tensor): r""" Returns a new tensor with the cosine of the elements of :attr:`input`. - .. math:: \text{out}_{i} = \cos(\text{input}_{i}) - Args: input (Tensor) 鈥� the input tensor. - For example: - .. code-block:: python - import oneflow as flow import numpy as np - arr = np.array([1.4309, 1.2706, -0.8562, 0.9796]) input = flow.Tensor(arr, dtype=flow.float32) output = flow.cos(input) @@ -107,20 +616,14 @@ class Log(Module): def log_op(tensor): r""" Returns a new tensor with the natural logarithm of the elements of :attr:`input`. - .. math:: y_{i} = \log_{e} (x_{i}) - Args: input (Tensor) 鈥� the input tensor. - For example: - .. code-block:: python - import oneflow as flow import numpy as np - arr = np.random.randn(2, 3, 4, 5) input = flow.Tensor(arr, dtype=flow.float32) output = flow.log(input) diff --git a/oneflow/python/ops/math_ops.py b/oneflow/python/ops/math_ops.py index 5a1f7088ed5a63b4831d4945d0b7af13050bf71b..29c8792d06c540f83cd7c3c0b95832458e8bc036 100644 --- a/oneflow/python/ops/math_ops.py +++ b/oneflow/python/ops/math_ops.py @@ -198,8 +198,6 @@ def subtract( elif x.shape == y.shape: # TODO: add element-wise op return broadcast_sub(x, y, name) - elif x.shape == (1,): - return scalar_sub_by_tensor(y, x, name) elif y.shape == (1,): return scalar_sub_by_tensor(x, y, name) else: @@ -315,8 +313,6 @@ def divide( elif x.shape == y.shape: # TODO: add element-wise op return broadcast_div(x, y, name) - elif x.shape == (1,): - return scalar_div_by_tensor(y, x, name) elif y.shape == (1,): return scalar_div_by_tensor(x, y, name) else: diff --git a/oneflow/python/test/modules/test_add.py b/oneflow/python/test/modules/test_add.py new file mode 100644 index 0000000000000000000000000000000000000000..7b0eb52fbabca0818f4ea8c16787661f12815d06 --- /dev/null +++ b/oneflow/python/test/modules/test_add.py @@ -0,0 +1,59 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import unittest +import numpy as np + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestAddModule(flow.unittest.TestCase): + def test_add(test_case): + x = flow.Tensor(np.random.randn(2, 3)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.add(x, y) + np_out = np.add(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = 5 + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.add(x, y) + np_out = np.add(x, y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + y = 5 + of_out = flow.add(x, y) + np_out = np.add(x.numpy(), y) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + y = flow.Tensor(np.array([5])) + of_out = flow.add(x, y) + np_out = np.add(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(1, 1)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.add(x, y) + np_out = np.add(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_avgpool2d_module.py b/oneflow/python/test/modules/test_avgpool2d.py similarity index 100% rename from oneflow/python/test/modules/test_avgpool2d_module.py rename to oneflow/python/test/modules/test_avgpool2d.py diff --git a/oneflow/python/test/modules/test_crossentropyloss.py b/oneflow/python/test/modules/test_crossentropyloss.py new file mode 100644 index 0000000000000000000000000000000000000000..75f23358be6d63b5830da926383407fb1615a1a7 --- /dev/null +++ b/oneflow/python/test/modules/test_crossentropyloss.py @@ -0,0 +1,61 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import unittest +import numpy as np + +g_test_samples = [ + { + "input": np.array( + [ + [-0.6980871, 0.4765042, -1.969919, 0.28965086, -0.53548324], + [-0.26332688, 0.27541, 0.30080616, 0.09914763, 0.53522176], + [0.7332028, 0.38375184, -0.2831992, -0.9833142, 0.387824], + ] + ), + "target": np.array([3, 3, 4], dtype=np.int32), + "out": np.array([1.1380, 1.7332, 1.4287], dtype=np.float32), + "out_sum": np.array([4.2999], dtype=np.float32), + "out_mean": np.array([1.4333], dtype=np.float32), + } +] + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestCrossEntropyLossModule(flow.unittest.TestCase): + def test_CrossEntropyLoss(test_case): + global g_test_samples + for sample in g_test_samples: + loss = flow.nn.CrossEntropyLoss(reduction=None) + input = flow.Tensor(sample["input"], dtype=flow.float32) + target = flow.Tensor(sample["target"], dtype=flow.int32) + of_out = loss(input, target) + assert np.allclose(of_out.numpy(), sample["out"], 1e-4, 1e-4) + + loss_sum = flow.nn.CrossEntropyLoss(reduction="sum") + of_out_sum = loss_sum(input, target) + assert np.allclose(of_out_sum.numpy(), sample["out_sum"], 1e-4, 1e-4) + + loss_mean = flow.nn.CrossEntropyLoss(reduction="mean") + of_out_mean = loss_mean(input, target) + assert np.allclose(of_out_mean.numpy(), sample["out_mean"], 1e-4, 1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_div.py b/oneflow/python/test/modules/test_div.py new file mode 100644 index 0000000000000000000000000000000000000000..1e6f530f9cb26b2dba3a7e330f0a6dc03d57450a --- /dev/null +++ b/oneflow/python/test/modules/test_div.py @@ -0,0 +1,65 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import unittest +import numpy as np + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestDiv(flow.unittest.TestCase): + def test_div(test_case): + x = flow.Tensor(np.random.randn(2, 3)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.div(x, y) + np_out = np.divide(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = 5 + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.div(x, y) + np_out = np.divide(x, y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + y = 5 + of_out = flow.div(x, y) + np_out = np.divide(x.numpy(), y) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + y = flow.Tensor(np.random.randn(1, 1)) + of_out = flow.div(x, y) + np_out = np.divide(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.array([5])) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.div(x, y) + np_out = np.divide(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + y = flow.Tensor(np.array([5])) + of_out = flow.div(x, y) + np_out = np.divide(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_mean.py b/oneflow/python/test/modules/test_mean.py new file mode 100644 index 0000000000000000000000000000000000000000..542fb97bc7682fb8a204fa73455cf23465bbc417 --- /dev/null +++ b/oneflow/python/test/modules/test_mean.py @@ -0,0 +1,39 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import unittest +import numpy as np + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestMeanModule(flow.unittest.TestCase): + def test_mean(test_case): + input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32) + of_out = flow.mean(input, dim=1) + np_out = np.mean(input.numpy(), axis=1) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32) + of_out = flow.mean(input, dim=0) + np_out = np.mean(input.numpy(), axis=0) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_mul.py b/oneflow/python/test/modules/test_mul.py new file mode 100644 index 0000000000000000000000000000000000000000..e7bba0d0843ecaa93e9f4d2404cf4b3dd696080f --- /dev/null +++ b/oneflow/python/test/modules/test_mul.py @@ -0,0 +1,53 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import unittest +import numpy as np + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestMulModule(flow.unittest.TestCase): + def test_mul(test_case): + x = flow.Tensor(np.random.randn(2, 3)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.mul(x, y) + np_out = np.multiply(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = 5 + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.mul(x, y) + np_out = np.multiply(x, y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + y = 5 + of_out = flow.mul(x, y) + np_out = np.multiply(x.numpy(), y) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(1, 1)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.mul(x, y) + np_out = np.multiply(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_reciprocal.py b/oneflow/python/test/modules/test_reciprocal.py new file mode 100644 index 0000000000000000000000000000000000000000..ad64f316974ecc295c386b1c6c4c10340ba30d66 --- /dev/null +++ b/oneflow/python/test/modules/test_reciprocal.py @@ -0,0 +1,34 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import unittest +import numpy as np + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestReciprocalModule(flow.unittest.TestCase): + def test_reciprocal(test_case): + x = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.reciprocal(x) + np_out = np.reciprocal(x.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_sub.py b/oneflow/python/test/modules/test_sub.py new file mode 100644 index 0000000000000000000000000000000000000000..15f720eb84bc5da4a47b11d39ae1423c33f4a8c8 --- /dev/null +++ b/oneflow/python/test/modules/test_sub.py @@ -0,0 +1,65 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import unittest +import numpy as np + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestSubModule(flow.unittest.TestCase): + def test_sub(test_case): + x = flow.Tensor(np.random.randn(2, 3)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.sub(x, y) + np_out = np.subtract(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = 5 + y = flow.Tensor(np.random.randn(2, 3)) + of_out = flow.sub(x, y) + np_out = np.subtract(x, y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + y = 5 + of_out = flow.sub(x, y) + np_out = np.subtract(x.numpy(), y) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + y = flow.Tensor(np.random.randn(1, 1)) + of_out = flow.sub(x, y) + np_out = np.subtract(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.array([5])) + y = flow.Tensor(np.random.randn(1, 1)) + of_out = flow.sub(x, y) + np_out = np.subtract(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(1, 1)) + y = flow.Tensor(np.array([5])) + of_out = flow.sub(x, y) + np_out = np.subtract(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_sum.py b/oneflow/python/test/modules/test_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..3d49b384ebece9d2d9bb2ca3acda62c00073c617 --- /dev/null +++ b/oneflow/python/test/modules/test_sum.py @@ -0,0 +1,51 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import unittest +import numpy as np + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestSumModule(flow.unittest.TestCase): + def test_sum(test_case): + input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32) + of_out = flow.sum(input, dim=0) + np_out = np.sum(input.numpy(), axis=0) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32) + of_out = flow.sum(input, dim=0) + np_out = np.sum(input.numpy(), axis=0) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32) + of_out = flow.sum(input, dim=1) + of_out2 = input.sum(dim=1) + np_out = np.sum(input.numpy(), axis=1) + test_case.assertTrue(np.allclose(of_out2.numpy(), of_out.numpy(), 1e-4, 1e-4)) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + input = flow.Tensor(np.random.randn(4, 5, 6), dtype=flow.float32) + of_out = flow.sum(input, dim=(2, 1)) + np_out = np.sum(input.numpy(), axis=(2, 1)) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/tensor/test_tensor.py b/oneflow/python/test/tensor/test_tensor.py index 7f1137af30658de6cbeba201d24e60a53aaf8de4..1974e0567e98bf2db9973f662ad282478f5734a8 100644 --- a/oneflow/python/test/tensor/test_tensor.py +++ b/oneflow/python/test/tensor/test_tensor.py @@ -94,6 +94,97 @@ class TestTensor(flow.unittest.TestCase): test_case.assertTrue(np.array_equal(y.numpy(), 5 * np.ones(y.shape))) test_case.assertTrue(np.array_equal(z.numpy(), 5 * np.ones(z.shape))) + def test_div(test_case): + x = flow.Tensor(np.random.randn(1, 1)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = x / y + np_out = np.divide(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + of_out = x / 3 + np_out = np.divide(x.numpy(), 3) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + of_out = 3 / x + np_out = np.divide(3, x.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(1)) + of_out = 3 / x + np_out = np.divide(3, x.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + def test_mul(test_case): + x = flow.Tensor(np.random.randn(1, 1)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = x * y + np_out = np.multiply(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + of_out = x * 3 + np_out = np.multiply(x.numpy(), 3) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + of_out = 3 * x + np_out = np.multiply(3, x.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + def test_add_tensor_method(test_case): + x = flow.Tensor(np.random.randn(1, 1)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = x + y + np_out = np.add(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + of_out = x + 3 + np_out = np.add(x.numpy(), 3) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + of_out = 3 + x + np_out = np.add(3, x.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + def test_sub_tensor_method(test_case): + x = flow.Tensor(np.random.randn(1, 1)) + y = flow.Tensor(np.random.randn(2, 3)) + of_out = x - y + np_out = np.subtract(x.numpy(), y.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + of_out = x - 3 + np_out = np.subtract(x.numpy(), 3) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + x = flow.Tensor(np.random.randn(2, 3)) + of_out = 3 - x + np_out = np.subtract(3, x.numpy()) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + def test_sum(test_case): + input = flow.Tensor(np.random.randn(4, 5, 6), dtype=flow.float32) + of_out = input.sum(dim=(2, 1)) + np_out = np.sum(input.numpy(), axis=(2, 1)) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + def test_mean(test_case): + input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32) + of_out = input.mean(dim=0) + np_out = np.mean(input.numpy(), axis=0) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + + def test_neg(test_case): + input = flow.Tensor(np.random.randn(2, 3), dtype=flow.float32) + of_out = -input + np_out = -input.numpy() + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-4, 1e-4)) + if __name__ == "__main__": unittest.main()