From 6bc0596be99c332a8b8a6983fe1b4e698c231550 Mon Sep 17 00:00:00 2001 From: Shijie <821898965@qq.com> Date: Sat, 17 Jul 2021 08:20:53 +0800 Subject: [PATCH] Add scale size for resize (#5509) * add scale and new_size * add and refine test_case * fix testcase Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/python/nn/modules/dataset.py | 22 +- .../python/test/modules/image_test_util.py | 174 +++++++++++ oneflow/python/test/modules/test_dataset.py | 2 +- .../python/test/modules/test_image_resize.py | 282 ++++++++++++++++++ .../test/modules/test_resnet50_with_bn.py | 2 +- .../test/modules/test_resnet50_without_bn.py | 2 +- 6 files changed, 472 insertions(+), 12 deletions(-) create mode 100644 oneflow/python/test/modules/image_test_util.py create mode 100644 oneflow/python/test/modules/test_image_resize.py diff --git a/oneflow/python/nn/modules/dataset.py b/oneflow/python/nn/modules/dataset.py index 9b3fa9a4b..8b35ed603 100644 --- a/oneflow/python/nn/modules/dataset.py +++ b/oneflow/python/nn/modules/dataset.py @@ -382,13 +382,6 @@ class ImageResize(Module): .Attr("interpolation_type", interpolation_type) .Build() ) - # TODO(Liang Depeng) - # scale = flow.tensor_buffer_to_tensor( - # scale, dtype=flow.float32, instance_shape=(2,) - # ) - # new_size = flow.tensor_buffer_to_tensor( - # new_size, dtype=flow.int32, instance_shape=(2,) - # ) else: if ( not isinstance(target_size, (list, tuple)) @@ -417,8 +410,19 @@ class ImageResize(Module): ) def forward(self, input): - res = self._op(input)[0] - return res + res = self._op(input) + res_image = res[0] + if len(res) == 3: + new_size = flow.experimental.tensor_buffer_to_tensor( + res[1], dtype=flow.int32, instance_shape=(2,) + ) + scale = flow.experimental.tensor_buffer_to_tensor( + res[2], dtype=flow.float32, instance_shape=(2,) + ) + else: + new_size = None + scale = res[1] + return res_image, scale, new_size @oneflow_export("tmp.RawDecoder") diff --git a/oneflow/python/test/modules/image_test_util.py b/oneflow/python/test/modules/image_test_util.py new file mode 100644 index 000000000..897f219b3 --- /dev/null +++ b/oneflow/python/test/modules/image_test_util.py @@ -0,0 +1,174 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import numpy as np +import cv2 +import oneflow as flow +import PIL +import random +import os + +global_coco_dict = dict() +default_coco_anno_file = "/dataset/mscoco_2017/annotations/instances_val2017.json" +default_coco_image_dir = "/dataset/mscoco_2017/val2017" + + +def get_coco(anno_file): + global global_coco_dict + + if anno_file not in global_coco_dict: + from pycocotools.coco import COCO + + global_coco_dict[anno_file] = COCO(anno_file) + + return global_coco_dict[anno_file] + + +def random_sample_images_from_coco( + anno_file=default_coco_anno_file, image_dir=default_coco_image_dir, batch_size=2 +): + image_files = [] + image_ids = [] + batch_group_id = -1 + + coco = get_coco(anno_file) + img_ids = coco.getImgIds() + + while len(image_files) < batch_size: + rand_img_id = random.choice(img_ids) + img_h = coco.imgs[rand_img_id]["height"] + img_w = coco.imgs[rand_img_id]["width"] + group_id = int(img_h / img_w) + + if batch_group_id == -1: + batch_group_id = group_id + + if group_id != batch_group_id: + continue + + image_files.append(os.path.join(image_dir, coco.imgs[rand_img_id]["file_name"])) + image_ids.append(rand_img_id) + + assert len(image_files) == len(image_ids) + return image_files, image_ids + + +def read_images_by_cv(image_files, dtype, channels=3): + np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(dtype) + images = [cv2.imread(image_file).astype(np_dtype) for image_file in image_files] + assert all(isinstance(image, np.ndarray) for image in images) + assert all(image.ndim == 3 for image in images) + assert all(image.shape[2] == channels for image in images) + return images + + +def read_images_by_pil(image_files, dtype, channels=3): + image_objs = [PIL.Image.open(image_file) for image_file in image_files] + images = [] + np_dtype = flow.convert_oneflow_dtype_to_numpy_dtype(dtype) + + for im in image_objs: + bands = im.getbands() + band = "".join(bands) + if band == "RGB": + # convert to BGR + images.append(np.asarray(im).astype(np_dtype)[:, :, ::-1]) + elif band == "L": + gs_image = np.asarray(im).astype(np_dtype) + gs_image_shape = gs_image.shape + assert len(gs_image_shape) == 2 + gs_image = gs_image.reshape(gs_image_shape + (1,)) + gs_image = np.broadcast_to(gs_image, shape=gs_image_shape + (3,)) + images.append(gs_image) + elif band == "BGR": + images.append(np.asarray(im).astype(np_dtype)) + else: + raise NotImplementedError + + assert all(isinstance(image, np.ndarray) for image in images) + assert all(image.ndim == 3 for image in images) + assert all(image.shape[2] == channels for image in images) + + return images + + +def infer_images_static_shape(images, channels=3): + image_shapes = [image.shape for image in images] + assert all(image.ndim == 3 for image in images) + assert all(image.shape[2] == channels for image in images) + image_shapes = np.asarray(image_shapes) + + max_h = np.max(image_shapes[:, 0]).item() + max_w = np.max(image_shapes[:, 1]).item() + image_static_shape = (len(images), max_h, max_w, channels) + + group_ids = [] # 0: h < w, 1: h >= w + aspect_ratio_list = [] # shorter / longer + for image_shape in image_shapes: + h, w = image_shape[0:2] + if h < w: + group_id = 0 + aspect_ratio = h / w + else: + group_id = 1 + aspect_ratio = w / h + group_ids.append(group_id) + aspect_ratio_list.append(aspect_ratio) + assert all(group_id == group_ids[0] for group_id in group_ids) + + return image_static_shape, aspect_ratio_list + + +def compute_keep_aspect_ratio_resized_size( + target_size, min_size, max_size, aspect_ratio, resize_side +): + if resize_side == "shorter": + min_res_size = target_size + max_res_size = int(round(min_res_size / aspect_ratio)) + if max_size is not None and max_res_size > max_size: + max_res_size = max_size + min_res_size = int(round(max_res_size * aspect_ratio)) + elif resize_side == "longer": + max_res_size = target_size + min_res_size = int(round(max_res_size * aspect_ratio)) + if min_size is not None and min_res_size < min_size: + min_res_size = min_size + max_res_size = int(round(min_res_size / aspect_ratio)) + else: + raise NotImplementedError + + return (min_res_size, max_res_size) + + +def infer_keep_aspect_ratio_resized_images_static_shape( + target_size, + min_size, + max_size, + aspect_ratio_list, + resize_side="shorter", + channels=3, +): + resized_size_list = [] + for aspect_ratio in aspect_ratio_list: + resized_size_list.append( + compute_keep_aspect_ratio_resized_size( + target_size, min_size, max_size, aspect_ratio, resize_side + ) + ) + + res_min_size, res_max_size = max( + resized_size_list, key=lambda size: size[0] * size[1] + ) + return (res_min_size, res_max_size, channels) diff --git a/oneflow/python/test/modules/test_dataset.py b/oneflow/python/test/modules/test_dataset.py index f35e706dd..541cc1c65 100644 --- a/oneflow/python/test/modules/test_dataset.py +++ b/oneflow/python/test/modules/test_dataset.py @@ -73,7 +73,7 @@ class TestOFRecordModule(flow.unittest.TestCase): gt_np = cv2.imread("/dataset/imagenette/ofrecord/gt_tensor_buffer_image.png") test_case.assertTrue(np.array_equal(image_raw_buffer_nd, gt_np)) - image = resize(image_raw_buffer) + image = resize(image_raw_buffer)[0] resized_image_raw_buffer_nd = image.numpy()[0] gt_np = cv2.imread( diff --git a/oneflow/python/test/modules/test_image_resize.py b/oneflow/python/test/modules/test_image_resize.py new file mode 100644 index 000000000..cb9d66e4c --- /dev/null +++ b/oneflow/python/test/modules/test_image_resize.py @@ -0,0 +1,282 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import cv2 +import numpy as np +import oneflow.experimental as flow +import oneflow.experimental.nn as nn +import image_test_util + + +def _of_image_resize( + image_list, + dtype=flow.float32, + origin_dtype=flow.float32, + channels=3, + keep_aspect_ratio=False, + target_size=None, + min_size=None, + max_size=None, + resize_side="shorter", + interpolation_type="bilinear", +): + assert isinstance(image_list, (list, tuple)) + assert all(isinstance(image, np.ndarray) for image in image_list) + assert all(image.ndim == 3 for image in image_list) + assert all(image.shape[2] == channels for image in image_list) + + res_image_list = [] + res_size_list = [] + res_scale_list = [] + image_resize_module = nn.image.Resize( + target_size=target_size, + min_size=min_size, + max_size=max_size, + keep_aspect_ratio=keep_aspect_ratio, + resize_side=resize_side, + dtype=dtype, + interpolation_type=interpolation_type, + channels=channels, + ) + for image in image_list: + tensor_dtype = dtype if keep_aspect_ratio else origin_dtype + input = flow.Tensor( + np.expand_dims(image, axis=0), dtype=tensor_dtype, device=flow.device("cpu") + ) + image_buffer = flow.tensor_to_tensor_buffer(input, instance_dims=3) + res_image, scale, new_size = image_resize_module(image_buffer) + res_image = res_image.numpy() + scale = scale.numpy() + if not keep_aspect_ratio: + new_size = np.asarray([(target_size, target_size)]) + else: + new_size = new_size.numpy() + res_image_list.append(res_image[0]) + res_size_list.append(new_size[0]) + res_scale_list.append(scale[0]) + return (res_image_list, res_scale_list, res_size_list) + + +def _get_resize_size_and_scale( + w, + h, + target_size, + min_size=None, + max_size=None, + keep_aspect_ratio=True, + resize_side="shorter", +): + if keep_aspect_ratio: + assert isinstance(target_size, int) + aspect_ratio = float(min((w, h))) / float(max((w, h))) + ( + min_res_size, + max_res_size, + ) = image_test_util.compute_keep_aspect_ratio_resized_size( + target_size, min_size, max_size, aspect_ratio, resize_side + ) + if w < h: + res_w = min_res_size + res_h = max_res_size + else: + res_w = max_res_size + res_h = min_res_size + + else: + assert isinstance(target_size, (list, tuple)) + assert len(target_size) == 2 + assert all(isinstance(size, int) for size in target_size) + res_w, res_h = target_size + + scale_w = res_w / w + scale_h = res_h / h + return (res_w, res_h), (scale_w, scale_h) + + +def _cv_image_resize( + image_list, + target_size, + keep_aspect_ratio=True, + min_size=None, + max_size=None, + resize_side="shorter", + interpolation=cv2.INTER_LINEAR, + dtype=np.float32, +): + res_image_list = [] + res_size_list = [] + res_scale_list = [] + + for image in image_list: + h, w = image.shape[:2] + new_size, scale = _get_resize_size_and_scale( + w, h, target_size, min_size, max_size, keep_aspect_ratio, resize_side + ) + res_image_list.append( + cv2.resize(image.squeeze(), new_size, interpolation=interpolation).astype( + dtype + ) + ) + res_size_list.append(new_size) + res_scale_list.append(scale) + + return res_image_list, res_scale_list, res_size_list + + +def _test_image_resize_with_cv( + test_case, + image_files, + target_size, + min_size=None, + max_size=None, + keep_aspect_ratio=True, + resize_side="shorter", + dtype=flow.float32, + origin_dtype=None, +): + if origin_dtype is None: + origin_dtype = dtype + + image_list = image_test_util.read_images_by_cv(image_files, origin_dtype) + + of_res_images, of_scales, of_new_sizes = _of_image_resize( + image_list=image_list, + dtype=dtype, + origin_dtype=origin_dtype, + keep_aspect_ratio=keep_aspect_ratio, + target_size=target_size, + min_size=min_size, + max_size=max_size, + resize_side=resize_side, + ) + + cv_res_images, cv_scales, cv_new_sizes = _cv_image_resize( + image_list=image_list, + target_size=target_size, + keep_aspect_ratio=keep_aspect_ratio, + min_size=min_size, + max_size=max_size, + resize_side=resize_side, + dtype=flow.convert_oneflow_dtype_to_numpy_dtype(dtype), + ) + + for ( + of_res_image, + cv_res_image, + of_scale, + cv_scale, + of_new_size, + cv_new_size, + ) in zip( + of_res_images, cv_res_images, of_scales, cv_scales, of_new_sizes, cv_new_sizes, + ): + test_case.assertTrue(np.allclose(of_res_image, cv_res_image)) + test_case.assertTrue(np.allclose(of_scale, cv_scale)) + test_case.assertTrue(np.allclose(of_new_size, cv_new_size)) + + +@flow.unittest.skip_unless_1n1d() +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestImageResize(flow.unittest.TestCase): + def test_image_resize_to_fixed_size(test_case): + image_files, _ = image_test_util.random_sample_images_from_coco() + _test_image_resize_with_cv( + test_case, image_files, target_size=(224, 224), keep_aspect_ratio=False, + ) + + def test_image_resize_shorter_to_target_size(test_case): + image_files, _ = image_test_util.random_sample_images_from_coco() + _test_image_resize_with_cv( + test_case, + image_files, + target_size=800, + keep_aspect_ratio=True, + resize_side="shorter", + ) + + def test_image_resize_longer_to_target_size(test_case): + image_files, _ = image_test_util.random_sample_images_from_coco() + _test_image_resize_with_cv( + test_case, + image_files, + target_size=1000, + keep_aspect_ratio=True, + resize_side="longer", + ) + + def test_image_resize_shorter_to_target_size_with_max_size(test_case): + image_files, _ = image_test_util.random_sample_images_from_coco() + _test_image_resize_with_cv( + test_case, + image_files, + target_size=800, + max_size=1333, + keep_aspect_ratio=True, + resize_side="shorter", + ) + + def test_image_resize_longer_to_target_size_with_min_size(test_case): + image_files, _ = image_test_util.random_sample_images_from_coco() + _test_image_resize_with_cv( + test_case, + image_files, + target_size=1000, + min_size=600, + keep_aspect_ratio=True, + resize_side="longer", + ) + + def test_image_resize_to_fixed_size_with_dtype_uint8(test_case): + image_files, _ = image_test_util.random_sample_images_from_coco() + _test_image_resize_with_cv( + test_case, + image_files, + target_size=(1000, 1000), + keep_aspect_ratio=False, + dtype=flow.uint8, + ) + + def test_image_reisze_shorter_to_target_size_with_max_size_with_dtype_uint8( + test_case, + ): + image_files, _ = image_test_util.random_sample_images_from_coco() + _test_image_resize_with_cv( + test_case, + image_files, + target_size=1000, + max_size=1600, + keep_aspect_ratio=True, + resize_side="shorter", + dtype=flow.uint8, + ) + + def test_image_resize_uint8_to_float(test_case): + image_files, _ = image_test_util.random_sample_images_from_coco() + _test_image_resize_with_cv( + test_case, + image_files, + target_size=(1000, 1000), + keep_aspect_ratio=False, + dtype=flow.float32, + origin_dtype=flow.uint8, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/modules/test_resnet50_with_bn.py b/oneflow/python/test/modules/test_resnet50_with_bn.py index 50f6ee3b2..41daebabd 100644 --- a/oneflow/python/test/modules/test_resnet50_with_bn.py +++ b/oneflow/python/test/modules/test_resnet50_with_bn.py @@ -82,7 +82,7 @@ class TestResNet50(flow.unittest.TestCase): val_record = record_reader() label = record_label_decoder(val_record) image_raw_buffer = record_image_decoder(val_record) - image = resize(image_raw_buffer) + image = resize(image_raw_buffer)[0] image = crop_mirror_normal(image) image = image.to("cuda") label = label.to("cuda") diff --git a/oneflow/python/test/modules/test_resnet50_without_bn.py b/oneflow/python/test/modules/test_resnet50_without_bn.py index e2cecc59c..0e6c3a595 100644 --- a/oneflow/python/test/modules/test_resnet50_without_bn.py +++ b/oneflow/python/test/modules/test_resnet50_without_bn.py @@ -188,7 +188,7 @@ class TestResNet50(flow.unittest.TestCase): val_record = record_reader() label = record_label_decoder(val_record) image_raw_buffer = record_image_decoder(val_record) - image = resize(image_raw_buffer) + image = resize(image_raw_buffer)[0] image = crop_mirror_normal(image) image = image.to("cuda") label = label.to("cuda") -- GitLab