diff --git a/official/cv/retinaface_resnet50/src/config.py b/official/cv/retinaface_resnet50/src/config.py index 02553fc8efd5681822911e484ebcce24874e34bc..473124ccf11eed0cb9fe28861478af77c42c54f7 100644 --- a/official/cv/retinaface_resnet50/src/config.py +++ b/official/cv/retinaface_resnet50/src/config.py @@ -20,7 +20,7 @@ cfg_res50 = { 'class_weight': 1.0, 'landm_weight': 1.0, 'batch_size': 8, - 'num_workers': 8, + 'num_workers': 4, 'num_anchor': 29126, 'ngpu': 4, 'image_size': 840, diff --git a/official/cv/retinaface_resnet50/src/dataset.py b/official/cv/retinaface_resnet50/src/dataset.py index 18a3126b411d393387a93e1fc05eb5f78f8ed92e..7a108e5883a6ef12792a4e1267d4809464f4627f 100644 --- a/official/cv/retinaface_resnet50/src/dataset.py +++ b/official/cv/retinaface_resnet50/src/dataset.py @@ -135,17 +135,34 @@ def create_dataset(data_dir, cfg, batch_size=32, repeat_num=1, shuffle=True, mul aug = preproc(cfg['image_size']) encode = bbox_encode(cfg) - def union_data(image, annot): + def read_data_from_dataset(image, annot): i, a = read_dataset(image, annot) - i, a = aug(i, a) - out = encode(i, a) + return i, a + def augmentation(image, annot): + i, a = aug(image, annot) + return i, a + + def encode_data(image, annot): + out = encode(image, annot) return out + de_dataset = de_dataset.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + column_order=["image", "annotation"], + operations=read_data_from_dataset, + python_multiprocessing=multiprocessing, + num_parallel_workers=num_worker) + de_dataset = de_dataset.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + column_order=["image", "annotation"], + operations=augmentation, + python_multiprocessing=multiprocessing, + num_parallel_workers=num_worker) de_dataset = de_dataset.map(input_columns=["image", "annotation"], output_columns=["image", "truths", "conf", "landm"], column_order=["image", "truths", "conf", "landm"], - operations=union_data, + operations=encode_data, python_multiprocessing=multiprocessing, num_parallel_workers=num_worker)