Skip to content
Snippets Groups Projects
Unverified Commit 1960ae64 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

add coco reader module (#5391)


* add coco reader module

* auto format by CI

Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
parent 8ae9de23
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,8 @@ from oneflow.python.nn.modules.utils import (
)
from oneflow.python.nn.common_types import _size_1_t, _size_2_t, _size_3_t, _size_any_t
from typing import Optional, List, Tuple, Sequence, Union
import random
import sys
import traceback
......@@ -472,3 +474,48 @@ class ImageDecode(Module):
def forward(self, input):
return self._op(input)[0]
@oneflow_export("nn.COCOReader")
@experimental_api
class COCOReader(Module):
def __init__(
self,
annotation_file: str,
image_dir: str,
batch_size: int,
shuffle: bool = True,
random_seed: Optional[int] = None,
group_by_aspect_ratio: bool = True,
remove_images_without_annotations: bool = True,
stride_partition: bool = True,
):
super().__init__()
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
self._op = (
flow.builtin_op("COCOReader")
.Output("image")
.Output("image_id")
.Output("image_size")
.Output("gt_bbox")
.Output("gt_label")
.Output("gt_segm")
.Output("gt_segm_index")
.Attr("session_id", flow.current_scope().session_id)
.Attr("annotation_file", annotation_file)
.Attr("image_dir", image_dir)
.Attr("batch_size", batch_size)
.Attr("shuffle_after_epoch", shuffle)
.Attr("random_seed", random_seed)
.Attr("group_by_ratio", group_by_aspect_ratio)
.Attr(
"remove_images_without_annotations", remove_images_without_annotations
)
.Attr("stride_partition", stride_partition)
.Build()
)
def forward(self):
res = self._op()
return res
......@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import math
import os
import cv2
import numpy as np
......@@ -96,5 +98,223 @@ class TestOFRecordModule(flow.unittest.TestCase):
test_case.assertTrue(np.array_equal(image_np, gt_np))
coco_dict = dict()
def _coco(anno_file):
global coco_dict
if anno_file not in coco_dict:
from pycocotools.coco import COCO
coco_dict[anno_file] = COCO(anno_file)
return coco_dict[anno_file]
def _get_coco_image_samples(anno_file, image_dir, image_ids):
coco = _coco(anno_file)
category_id_to_contiguous_id_map = _get_category_id_to_contiguous_id_map(coco)
image, image_size = _read_images_with_cv(coco, image_dir, image_ids)
bbox = _read_bbox(coco, image_ids)
label = _read_label(coco, image_ids, category_id_to_contiguous_id_map)
img_segm_poly_list = _read_segm_poly(coco, image_ids)
poly, poly_index = _segm_poly_list_to_tensor(img_segm_poly_list)
samples = []
for im, ims, b, l, p, pi in zip(image, image_size, bbox, label, poly, poly_index):
samples.append(
dict(image=im, image_size=ims, bbox=b, label=l, poly=p, poly_index=pi)
)
return samples
def _get_category_id_to_contiguous_id_map(coco):
return {v: i + 1 for i, v in enumerate(coco.getCatIds())}
def _read_images_with_cv(coco, image_dir, image_ids):
image_files = [
os.path.join(image_dir, coco.imgs[img_id]["file_name"]) for img_id in image_ids
]
image_size = [
(coco.imgs[img_id]["height"], coco.imgs[img_id]["width"])
for img_id in image_ids
]
return (
[cv2.imread(image_file).astype(np.single) for image_file in image_files],
image_size,
)
def _bbox_convert_from_xywh_to_xyxy(bbox, image_h, image_w):
x, y, w, h = bbox
x1, y1 = x, y
x2 = x1 + max(w - 1, 0)
y2 = y1 + max(h - 1, 0)
# clip to image
x1 = min(max(x1, 0), image_w - 1)
y1 = min(max(y1, 0), image_h - 1)
x2 = min(max(x2, 0), image_w - 1)
y2 = min(max(y2, 0), image_h - 1)
if x1 >= x2 or y1 >= y2:
return None
return [x1, y1, x2, y2]
def _read_bbox(coco, image_ids):
img_bbox_list = []
for img_id in image_ids:
anno_ids = coco.getAnnIds(imgIds=[img_id])
assert len(anno_ids) > 0, "image with id {} has no anno".format(img_id)
image_h = coco.imgs[img_id]["height"]
image_w = coco.imgs[img_id]["width"]
bbox_list = []
for anno_id in anno_ids:
anno = coco.anns[anno_id]
if anno["iscrowd"] != 0:
continue
bbox = anno["bbox"]
assert isinstance(bbox, list)
bbox_ = _bbox_convert_from_xywh_to_xyxy(bbox, image_h, image_w)
if bbox_ is not None:
bbox_list.append(bbox_)
bbox_array = np.array(bbox_list, dtype=np.single)
img_bbox_list.append(bbox_array)
return img_bbox_list
def _read_label(coco, image_ids, category_id_to_contiguous_id_map):
img_label_list = []
for img_id in image_ids:
anno_ids = coco.getAnnIds(imgIds=[img_id])
assert len(anno_ids) > 0, "image with id {} has no anno".format(img_id)
label_list = []
for anno_id in anno_ids:
anno = coco.anns[anno_id]
if anno["iscrowd"] != 0:
continue
cate_id = anno["category_id"]
isinstance(cate_id, int)
label_list.append(category_id_to_contiguous_id_map[cate_id])
label_array = np.array(label_list, dtype=np.int32)
img_label_list.append(label_array)
return img_label_list
def _read_segm_poly(coco, image_ids):
img_segm_poly_list = []
for img_id in image_ids:
anno_ids = coco.getAnnIds(imgIds=[img_id])
assert len(anno_ids) > 0, "img {} has no anno".format(img_id)
segm_poly_list = []
for anno_id in anno_ids:
anno = coco.anns[anno_id]
if anno["iscrowd"] != 0:
continue
segm = anno["segmentation"]
assert isinstance(segm, list)
assert len(segm) > 0, str(len(segm))
assert all([len(poly) > 0 for poly in segm]), str(
[len(poly) for poly in segm]
)
segm_poly_list.append(segm)
img_segm_poly_list.append(segm_poly_list)
return img_segm_poly_list
def _segm_poly_list_to_tensor(img_segm_poly_list):
poly_array_list = []
poly_index_array_list = []
for img_idx, segm_poly_list in enumerate(img_segm_poly_list):
img_poly_elem_list = []
img_poly_index_list = []
for obj_idx, poly_list in enumerate(segm_poly_list):
for poly_idx, poly in enumerate(poly_list):
img_poly_elem_list.extend(poly)
for pt_idx, pt in enumerate(poly):
if pt_idx % 2 == 0:
img_poly_index_list.append([pt_idx / 2, poly_idx, obj_idx])
img_poly_array = np.array(img_poly_elem_list, dtype=np.single).reshape(-1, 2)
assert img_poly_array.size > 0, segm_poly_list
poly_array_list.append(img_poly_array)
img_poly_index_array = np.array(img_poly_index_list, dtype=np.int32)
assert img_poly_index_array.size > 0, segm_poly_list
poly_index_array_list.append(img_poly_index_array)
return poly_array_list, poly_index_array_list
@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestCocoReader(flow.unittest.TestCase):
def test_coco_reader(test_case):
anno_file = "/dataset/mscoco_2017/annotations/instances_val2017.json"
image_dir = "/dataset/mscoco_2017/val2017"
num_iterations = 100
coco_reader = flow.nn.COCOReader(
annotation_file=anno_file,
image_dir=image_dir,
batch_size=2,
shuffle=True,
stride_partition=True,
)
image_decoder = flow.nn.image.decode(dtype=flow.float)
for i in range(num_iterations):
(
image,
image_id,
image_size,
gt_bbox,
gt_label,
gt_segm,
gt_segm_index,
) = coco_reader()
decoded_image = image_decoder(image)
image_list = decoded_image.numpy()
image_id = image_id.numpy()
image_size = image_size.numpy()
bbox_list = gt_bbox.numpy()
label_list = gt_label.numpy()
segm_list = gt_segm.numpy()
segm_index_list = gt_segm_index.numpy()
samples = _get_coco_image_samples(anno_file, image_dir, image_id)
for i, sample in enumerate(samples):
test_case.assertTrue(np.array_equal(image_list[i], sample["image"]))
test_case.assertTrue(
np.array_equal(image_size[i], sample["image_size"])
)
test_case.assertTrue(np.allclose(bbox_list[i], sample["bbox"]))
cur_label = label_list[i]
if len(cur_label.shape) == 0:
# when cur_label is scalar
cur_label = np.array([cur_label])
test_case.assertTrue(np.array_equal(cur_label, sample["label"]))
test_case.assertTrue(np.allclose(segm_list[i], sample["poly"]))
test_case.assertTrue(
np.array_equal(segm_index_list[i], sample["poly_index"])
)
if __name__ == "__main__":
unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment