diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst index 9cafbe69515f922be5a3bf15b6c65b9d38ab8c28..6706456a025e6c5b45e3fab8ac5a08e0de8c8480 100644 --- a/docs/source/experimental.rst +++ b/docs/source/experimental.rst @@ -237,6 +237,7 @@ Experimental features .. autofunction:: oneflow.experimental.Tensor.diag .. autofunction:: oneflow.experimental.nn.GroupNorm .. autofunction:: oneflow.experimental.nn.ZeroPad2d +.. autofunction:: oneflow.experimental.nn.image.flip .. autofunction:: oneflow.experimental.tensor_buffer_to_tensor .. autofunction:: oneflow.experimental.tensor_to_tensor_buffer .. autofunction:: oneflow.experimental.Tensor.type_as diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index eff0d32cb47e067b4196e44b5fc1f56bf0c82292..9dc3814ba69ce852d8174c7664d917980ea43e86 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -128,6 +128,10 @@ signature: "Tensor ReciprocalNoNan(Tensor x)" bind_python: True +- name: "image_flip" + signature: "Tensor ImageFlip(Tensor x, Tensor flip_code)" + bind_python: True + - name: "sin" signature: "Tensor Sin(Tensor x)" bind_python: True diff --git a/oneflow/core/functional/impl/dataset_functor.cpp b/oneflow/core/functional/impl/dataset_functor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..09d38466684543bae08b7e40eb430c8dd6eecfe1 --- /dev/null +++ b/oneflow/core/functional/impl/dataset_functor.cpp @@ -0,0 +1,51 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/scalar.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class ImageFlipFuntor { + public: + ImageFlipFuntor() { + op_ = CHECK_JUST( + one::OpBuilder("image_flip").Input("in").Input("flip_code").Output("out").Build()); + } + Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, + const std::shared_ptr<one::Tensor>& flip_code) const { + return OpInterpUtil::Dispatch<Tensor>(*op_, {x, flip_code}); + } + + private: + std::shared_ptr<OpExpr> op_; +}; + +} // namespace impl + +ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::ImageFlipFuntor>("ImageFlip"); }; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/python/nn/modules/dataset.py b/oneflow/python/nn/modules/dataset.py index f37c2a9b0cf5a0686a4d7fe20b89c19b10dbcb49..2f18db4721c0688ce53530b6c1ff22a23263e08e 100644 --- a/oneflow/python/nn/modules/dataset.py +++ b/oneflow/python/nn/modules/dataset.py @@ -458,6 +458,69 @@ def get_ofrecord_handle( )() +@oneflow_export("nn.image.flip") +@experimental_api +class ImageFlip(Module): + r"""This operator flips the images. + + The flip code corresponds to the different flip mode: + + 0 (0x00): Non Flip + + 1 (0x01): Horizontal Flip + + 16 (0x10): Vertical Flip + + 17 (0x11): Both Horizontal and Vertical Flip + + Args: + images: The input images. + flip_code: The flip code. + + Returns: + The result image. + + For example: + + .. code-block:: python + + >>> import numpy as np + >>> import oneflow.experimental as flow + >>> import oneflow.experimental.nn as nn + >>> flow.enable_eager_execution() + + >>> arr = np.array([ + ... [[[1, 2, 3], [3, 2, 1]], + ... [[2, 3, 4], [4, 3, 2]]], + ... [[[3, 4, 5], [5, 4, 3]], + ... [[4, 5, 6], [6, 5, 4]]]]) + >>> image_tensors = flow.Tensor(arr, device=flow.device("cpu")) + >>> image_tensor_buffer = flow.tensor_to_tensor_buffer(image_tensors, instance_dims=3) + >>> output = nn.image.flip(1)(image_tensor_buffer).numpy() + >>> output[0] + array([[[3., 2., 1.], + [1., 2., 3.]], + <BLANKLINE> + [[4., 3., 2.], + [2., 3., 4.]]], dtype=float32) + >>> output[1] + array([[[5., 4., 3.], + [3., 4., 5.]], + <BLANKLINE> + [[6., 5., 4.], + [4., 5., 6.]]], dtype=float32) + """ + + def __init__(self, flip_code): + super().__init__() + self.flip_code = flip_code + + def forward(self, images): + flip_codes = flow.Tensor([self.flip_code] * images.shape[0], dtype=flow.int8) + + return flow.F.image_flip(images, flip_codes) + + @oneflow_export("nn.image.decode") @experimental_api class ImageDecode(Module): @@ -557,3 +620,9 @@ class ImageBatchAlign(Module): def forward(self, input): return self._op(input)[0] + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/oneflow/python/test/modules/test_image_flip.py b/oneflow/python/test/modules/test_image_flip.py new file mode 100644 index 0000000000000000000000000000000000000000..008efac15faee2c8dfdaa1f35a3b5d48fbf43262 --- /dev/null +++ b/oneflow/python/test/modules/test_image_flip.py @@ -0,0 +1,84 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import cv2 +import numpy as np +import oneflow.experimental as flow + + +def _of_image_flip(images, image_static_shape, flip_code): + image_tensors = flow.Tensor(images, dtype=flow.float, device=flow.device("cpu")) + image_tensor_buffer = flow.tensor_to_tensor_buffer(image_tensors, instance_dims=3) + flip_images = flow.nn.image.flip(flip_code)(image_tensor_buffer) + return flip_images.numpy() + + +def _read_images_by_cv(image_files): + images = [cv2.imread(image_file).astype(np.single) for image_file in image_files] + return [np.expand_dims(image, axis=0) for image in images] + + +def _get_images_static_shape(images): + image_shapes = [image.shape for image in images] + image_static_shape = np.amax(image_shapes, axis=0) + assert isinstance( + image_static_shape, np.ndarray + ), "image_shapes: {}, image_static_shape: {}".format( + str(image_shapes), str(image_static_shape) + ) + image_static_shape = image_static_shape.tolist() + assert image_static_shape[0] == 1, str(image_static_shape) + image_static_shape[0] = len(image_shapes) + return image_static_shape + + +def _compare_image_flip_with_cv(test_case, image_files): + images = _read_images_by_cv(image_files) + assert all([len(image.shape) == 4 for image in images]) + + image_static_shape = _get_images_static_shape(images) + image_paddings = np.zeros(tuple(image_static_shape)) + for idx, image in enumerate(images): + image_paddings[ + idx, : image.shape[1], : image.shape[2], : image.shape[3] + ] = image + + flip_images = _of_image_flip(image_paddings, image_static_shape, 1) + + for image, flip_image in zip(image_paddings, flip_images): + exp_flip_image = cv2.flip(image.squeeze(), 1) + + test_case.assertTrue(np.allclose(exp_flip_image, flip_image)) + + +@flow.unittest.skip_unless_1n1d() +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestImageFlip(flow.unittest.TestCase): + def test_image_flip(test_case): + _compare_image_flip_with_cv( + test_case, + [ + "/dataset/mscoco_2017/val2017/000000000139.jpg", + "/dataset/mscoco_2017/val2017/000000000632.jpg", + ], + ) + + +if __name__ == "__main__": + unittest.main()