Skip to content
Snippets Groups Projects
Commit 3fa4d810 authored by anzhengqi's avatar anzhengqi
Browse files

modify yolov3_resnet18 dataset scripts

parent 15b2586e
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,6 @@ from __future__ import division
import os
import numpy as np
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
from PIL import Image
import mindspore.dataset as de
from mindspore.mindrecord import FileWriter
......@@ -32,13 +31,9 @@ def preprocess_fn(image, box, is_training):
"""Preprocess function for dataset."""
config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326]
anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2)
do_hsv = False
max_boxes = 20
num_classes = ConfigYOLOV3ResNet18.num_classes
def _rand(a=0., b=1.):
return np.random.rand() * (b - a) + a
def _preprocess_true_boxes(true_boxes, anchors, in_shape=None):
"""Get true boxes."""
num_layers = anchors.shape[0] // 3
......@@ -145,14 +140,14 @@ def preprocess_fn(image, box, is_training):
if not is_training:
return _infer_data(image, image_size, box)
flip = _rand() < .5
flip = np.random.rand() < 0.5
# correct boxes
box_data = np.zeros((max_boxes, 5))
while True:
# Prevent the situation that all boxes are eliminated
new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \
_rand(1 - jitter, 1 + jitter)
scale = _rand(0.25, 2)
new_ar = float(w) / float(h) * np.random.uniform(1 - jitter, 1 + jitter) / \
np.random.uniform(1 - jitter, 1 + jitter)
scale = np.random.uniform(0.25, 2)
if new_ar < 1:
nh = int(scale * h)
......@@ -161,8 +156,8 @@ def preprocess_fn(image, box, is_training):
nw = int(scale * w)
nh = int(nw / new_ar)
dx = int(_rand(0, w - nw))
dy = int(_rand(0, h - nh))
dx = int(np.random.uniform(0, w - nw))
dy = int(np.random.uniform(0, h - nh))
if len(box) >= 1:
t_box = box.copy()
......@@ -195,8 +190,7 @@ def preprocess_fn(image, box, is_training):
image = image.transpose(Image.FLIP_LEFT_RIGHT)
# convert image to gray or not
gray = _rand() < .25
if gray:
if np.random.rand() < 0.25:
image = image.convert('L').convert('RGB')
# when the channels of image is 1
......@@ -206,21 +200,7 @@ def preprocess_fn(image, box, is_training):
image = np.concatenate([image, image, image], axis=-1)
# distort image
hue = _rand(-hue, hue)
sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat)
val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val)
image_data = image / 255.
if do_hsv:
x = rgb_to_hsv(image_data)
x[..., 0] += hue
x[..., 0][x[..., 0] > 1] -= 1
x[..., 0][x[..., 0] < 0] += 1
x[..., 1] *= sat
x[..., 2] *= val
x[x > 1] = 1
x[x < 0] = 0
image_data = hsv_to_rgb(x) # numpy array, 0 to 1
image_data = image_data.astype(np.float32)
image_data = image.astype(np.float32) / 255.
# preprocess bounding boxes
bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
......@@ -294,13 +274,14 @@ def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix,
writer.commit()
def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=1, device_num=1, rank=0,
def create_yolo_dataset(mindrecord_dir, batch_size=32, device_num=1, rank=0,
is_training=True, num_parallel_workers=8):
"""Create YOLOv3 dataset with MindDataset."""
de.config.set_prefetch_size(64)
ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
num_parallel_workers=num_parallel_workers, shuffle=is_training)
num_parallel_workers=2, shuffle=is_training)
decode = C.Decode()
ds = ds.map(operations=decode, input_columns=["image"])
ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
if is_training:
......@@ -309,9 +290,8 @@ def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=1, device_num=
output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
column_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=1)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_num)
else:
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "annotation"],
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment