From 1960ae644e3e002833edb8b631caae3093bf1446 Mon Sep 17 00:00:00 2001
From: Shijie <821898965@qq.com>
Date: Tue, 6 Jul 2021 13:01:56 +0800
Subject: [PATCH] add coco reader module (#5391)

* add coco reader module

* auto format by CI

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
---
 oneflow/python/nn/modules/dataset.py        |  47 +++++
 oneflow/python/test/modules/test_dataset.py | 220 ++++++++++++++++++++
 2 files changed, 267 insertions(+)

diff --git a/oneflow/python/nn/modules/dataset.py b/oneflow/python/nn/modules/dataset.py
index 155586087..82fe5480f 100644
--- a/oneflow/python/nn/modules/dataset.py
+++ b/oneflow/python/nn/modules/dataset.py
@@ -25,6 +25,8 @@ from oneflow.python.nn.modules.utils import (
 )
 from oneflow.python.nn.common_types import _size_1_t, _size_2_t, _size_3_t, _size_any_t
 from typing import Optional, List, Tuple, Sequence, Union
+import random
+import sys
 import traceback
 
 
@@ -472,3 +474,48 @@ class ImageDecode(Module):
 
     def forward(self, input):
         return self._op(input)[0]
+
+
+@oneflow_export("nn.COCOReader")
+@experimental_api
+class COCOReader(Module):
+    def __init__(
+        self,
+        annotation_file: str,
+        image_dir: str,
+        batch_size: int,
+        shuffle: bool = True,
+        random_seed: Optional[int] = None,
+        group_by_aspect_ratio: bool = True,
+        remove_images_without_annotations: bool = True,
+        stride_partition: bool = True,
+    ):
+        super().__init__()
+        if random_seed is None:
+            random_seed = random.randrange(sys.maxsize)
+        self._op = (
+            flow.builtin_op("COCOReader")
+            .Output("image")
+            .Output("image_id")
+            .Output("image_size")
+            .Output("gt_bbox")
+            .Output("gt_label")
+            .Output("gt_segm")
+            .Output("gt_segm_index")
+            .Attr("session_id", flow.current_scope().session_id)
+            .Attr("annotation_file", annotation_file)
+            .Attr("image_dir", image_dir)
+            .Attr("batch_size", batch_size)
+            .Attr("shuffle_after_epoch", shuffle)
+            .Attr("random_seed", random_seed)
+            .Attr("group_by_ratio", group_by_aspect_ratio)
+            .Attr(
+                "remove_images_without_annotations", remove_images_without_annotations
+            )
+            .Attr("stride_partition", stride_partition)
+            .Build()
+        )
+
+    def forward(self):
+        res = self._op()
+        return res
diff --git a/oneflow/python/test/modules/test_dataset.py b/oneflow/python/test/modules/test_dataset.py
index 579fe81db..f35e706dd 100644
--- a/oneflow/python/test/modules/test_dataset.py
+++ b/oneflow/python/test/modules/test_dataset.py
@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 """
 import unittest
+import math
+import os
 
 import cv2
 import numpy as np
@@ -96,5 +98,223 @@ class TestOFRecordModule(flow.unittest.TestCase):
         test_case.assertTrue(np.array_equal(image_np, gt_np))
 
 
+coco_dict = dict()
+
+
+def _coco(anno_file):
+    global coco_dict
+
+    if anno_file not in coco_dict:
+        from pycocotools.coco import COCO
+
+        coco_dict[anno_file] = COCO(anno_file)
+
+    return coco_dict[anno_file]
+
+
+def _get_coco_image_samples(anno_file, image_dir, image_ids):
+    coco = _coco(anno_file)
+    category_id_to_contiguous_id_map = _get_category_id_to_contiguous_id_map(coco)
+    image, image_size = _read_images_with_cv(coco, image_dir, image_ids)
+    bbox = _read_bbox(coco, image_ids)
+    label = _read_label(coco, image_ids, category_id_to_contiguous_id_map)
+    img_segm_poly_list = _read_segm_poly(coco, image_ids)
+    poly, poly_index = _segm_poly_list_to_tensor(img_segm_poly_list)
+    samples = []
+    for im, ims, b, l, p, pi in zip(image, image_size, bbox, label, poly, poly_index):
+        samples.append(
+            dict(image=im, image_size=ims, bbox=b, label=l, poly=p, poly_index=pi)
+        )
+    return samples
+
+
+def _get_category_id_to_contiguous_id_map(coco):
+    return {v: i + 1 for i, v in enumerate(coco.getCatIds())}
+
+
+def _read_images_with_cv(coco, image_dir, image_ids):
+    image_files = [
+        os.path.join(image_dir, coco.imgs[img_id]["file_name"]) for img_id in image_ids
+    ]
+    image_size = [
+        (coco.imgs[img_id]["height"], coco.imgs[img_id]["width"])
+        for img_id in image_ids
+    ]
+    return (
+        [cv2.imread(image_file).astype(np.single) for image_file in image_files],
+        image_size,
+    )
+
+
+def _bbox_convert_from_xywh_to_xyxy(bbox, image_h, image_w):
+    x, y, w, h = bbox
+    x1, y1 = x, y
+    x2 = x1 + max(w - 1, 0)
+    y2 = y1 + max(h - 1, 0)
+
+    # clip to image
+    x1 = min(max(x1, 0), image_w - 1)
+    y1 = min(max(y1, 0), image_h - 1)
+    x2 = min(max(x2, 0), image_w - 1)
+    y2 = min(max(y2, 0), image_h - 1)
+
+    if x1 >= x2 or y1 >= y2:
+        return None
+
+    return [x1, y1, x2, y2]
+
+
+def _read_bbox(coco, image_ids):
+    img_bbox_list = []
+    for img_id in image_ids:
+        anno_ids = coco.getAnnIds(imgIds=[img_id])
+        assert len(anno_ids) > 0, "image with id {} has no anno".format(img_id)
+        image_h = coco.imgs[img_id]["height"]
+        image_w = coco.imgs[img_id]["width"]
+
+        bbox_list = []
+        for anno_id in anno_ids:
+            anno = coco.anns[anno_id]
+            if anno["iscrowd"] != 0:
+                continue
+
+            bbox = anno["bbox"]
+            assert isinstance(bbox, list)
+            bbox_ = _bbox_convert_from_xywh_to_xyxy(bbox, image_h, image_w)
+            if bbox_ is not None:
+                bbox_list.append(bbox_)
+
+        bbox_array = np.array(bbox_list, dtype=np.single)
+        img_bbox_list.append(bbox_array)
+
+    return img_bbox_list
+
+
+def _read_label(coco, image_ids, category_id_to_contiguous_id_map):
+    img_label_list = []
+    for img_id in image_ids:
+        anno_ids = coco.getAnnIds(imgIds=[img_id])
+        assert len(anno_ids) > 0, "image with id {} has no anno".format(img_id)
+
+        label_list = []
+        for anno_id in anno_ids:
+            anno = coco.anns[anno_id]
+            if anno["iscrowd"] != 0:
+                continue
+            cate_id = anno["category_id"]
+            isinstance(cate_id, int)
+            label_list.append(category_id_to_contiguous_id_map[cate_id])
+        label_array = np.array(label_list, dtype=np.int32)
+        img_label_list.append(label_array)
+    return img_label_list
+
+
+def _read_segm_poly(coco, image_ids):
+    img_segm_poly_list = []
+    for img_id in image_ids:
+        anno_ids = coco.getAnnIds(imgIds=[img_id])
+        assert len(anno_ids) > 0, "img {} has no anno".format(img_id)
+
+        segm_poly_list = []
+        for anno_id in anno_ids:
+            anno = coco.anns[anno_id]
+            if anno["iscrowd"] != 0:
+                continue
+            segm = anno["segmentation"]
+            assert isinstance(segm, list)
+            assert len(segm) > 0, str(len(segm))
+            assert all([len(poly) > 0 for poly in segm]), str(
+                [len(poly) for poly in segm]
+            )
+            segm_poly_list.append(segm)
+
+        img_segm_poly_list.append(segm_poly_list)
+
+    return img_segm_poly_list
+
+
+def _segm_poly_list_to_tensor(img_segm_poly_list):
+    poly_array_list = []
+    poly_index_array_list = []
+    for img_idx, segm_poly_list in enumerate(img_segm_poly_list):
+        img_poly_elem_list = []
+        img_poly_index_list = []
+
+        for obj_idx, poly_list in enumerate(segm_poly_list):
+            for poly_idx, poly in enumerate(poly_list):
+                img_poly_elem_list.extend(poly)
+                for pt_idx, pt in enumerate(poly):
+                    if pt_idx % 2 == 0:
+                        img_poly_index_list.append([pt_idx / 2, poly_idx, obj_idx])
+
+        img_poly_array = np.array(img_poly_elem_list, dtype=np.single).reshape(-1, 2)
+        assert img_poly_array.size > 0, segm_poly_list
+        poly_array_list.append(img_poly_array)
+
+        img_poly_index_array = np.array(img_poly_index_list, dtype=np.int32)
+        assert img_poly_index_array.size > 0, segm_poly_list
+        poly_index_array_list.append(img_poly_index_array)
+
+    return poly_array_list, poly_index_array_list
+
+
+@flow.unittest.skip_unless_1n1d()
+@unittest.skipIf(
+    not flow.unittest.env.eager_execution_enabled(),
+    ".numpy() doesn't work in lazy mode",
+)
+class TestCocoReader(flow.unittest.TestCase):
+    def test_coco_reader(test_case):
+        anno_file = "/dataset/mscoco_2017/annotations/instances_val2017.json"
+        image_dir = "/dataset/mscoco_2017/val2017"
+        num_iterations = 100
+
+        coco_reader = flow.nn.COCOReader(
+            annotation_file=anno_file,
+            image_dir=image_dir,
+            batch_size=2,
+            shuffle=True,
+            stride_partition=True,
+        )
+        image_decoder = flow.nn.image.decode(dtype=flow.float)
+
+        for i in range(num_iterations):
+            (
+                image,
+                image_id,
+                image_size,
+                gt_bbox,
+                gt_label,
+                gt_segm,
+                gt_segm_index,
+            ) = coco_reader()
+
+            decoded_image = image_decoder(image)
+            image_list = decoded_image.numpy()
+            image_id = image_id.numpy()
+            image_size = image_size.numpy()
+            bbox_list = gt_bbox.numpy()
+            label_list = gt_label.numpy()
+            segm_list = gt_segm.numpy()
+            segm_index_list = gt_segm_index.numpy()
+
+            samples = _get_coco_image_samples(anno_file, image_dir, image_id)
+            for i, sample in enumerate(samples):
+                test_case.assertTrue(np.array_equal(image_list[i], sample["image"]))
+                test_case.assertTrue(
+                    np.array_equal(image_size[i], sample["image_size"])
+                )
+                test_case.assertTrue(np.allclose(bbox_list[i], sample["bbox"]))
+                cur_label = label_list[i]
+                if len(cur_label.shape) == 0:
+                    # when cur_label is scalar
+                    cur_label = np.array([cur_label])
+                test_case.assertTrue(np.array_equal(cur_label, sample["label"]))
+                test_case.assertTrue(np.allclose(segm_list[i], sample["poly"]))
+                test_case.assertTrue(
+                    np.array_equal(segm_index_list[i], sample["poly_index"])
+                )
+
+
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab