From c8b46e6eb66132a30648fe99f047899fd3ef3fe3 Mon Sep 17 00:00:00 2001 From: jonyguo <guozhijian@huawei.com> Date: Mon, 20 Jun 2022 10:28:05 +0800 Subject: [PATCH] fix: retinaface_resnet50 probability failure --- official/cv/retinaface_resnet50/src/config.py | 2 +- .../cv/retinaface_resnet50/src/dataset.py | 25 ++++++++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/official/cv/retinaface_resnet50/src/config.py b/official/cv/retinaface_resnet50/src/config.py index 02553fc8e..473124ccf 100644 --- a/official/cv/retinaface_resnet50/src/config.py +++ b/official/cv/retinaface_resnet50/src/config.py @@ -20,7 +20,7 @@ cfg_res50 = { 'class_weight': 1.0, 'landm_weight': 1.0, 'batch_size': 8, - 'num_workers': 8, + 'num_workers': 4, 'num_anchor': 29126, 'ngpu': 4, 'image_size': 840, diff --git a/official/cv/retinaface_resnet50/src/dataset.py b/official/cv/retinaface_resnet50/src/dataset.py index 18a3126b4..7a108e588 100644 --- a/official/cv/retinaface_resnet50/src/dataset.py +++ b/official/cv/retinaface_resnet50/src/dataset.py @@ -135,17 +135,34 @@ def create_dataset(data_dir, cfg, batch_size=32, repeat_num=1, shuffle=True, mul aug = preproc(cfg['image_size']) encode = bbox_encode(cfg) - def union_data(image, annot): + def read_data_from_dataset(image, annot): i, a = read_dataset(image, annot) - i, a = aug(i, a) - out = encode(i, a) + return i, a + def augmentation(image, annot): + i, a = aug(image, annot) + return i, a + + def encode_data(image, annot): + out = encode(image, annot) return out + de_dataset = de_dataset.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + column_order=["image", "annotation"], + operations=read_data_from_dataset, + python_multiprocessing=multiprocessing, + num_parallel_workers=num_worker) + de_dataset = de_dataset.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + column_order=["image", "annotation"], + operations=augmentation, + python_multiprocessing=multiprocessing, + num_parallel_workers=num_worker) de_dataset = de_dataset.map(input_columns=["image", "annotation"], output_columns=["image", "truths", "conf", "landm"], column_order=["image", "truths", "conf", "landm"], - operations=union_data, + operations=encode_data, python_multiprocessing=multiprocessing, num_parallel_workers=num_worker) -- GitLab