From c8b46e6eb66132a30648fe99f047899fd3ef3fe3 Mon Sep 17 00:00:00 2001
From: jonyguo <guozhijian@huawei.com>
Date: Mon, 20 Jun 2022 10:28:05 +0800
Subject: [PATCH] fix: retinaface_resnet50 probability failure

---
 official/cv/retinaface_resnet50/src/config.py |  2 +-
 .../cv/retinaface_resnet50/src/dataset.py     | 25 ++++++++++++++++---
 2 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/official/cv/retinaface_resnet50/src/config.py b/official/cv/retinaface_resnet50/src/config.py
index 02553fc8e..473124ccf 100644
--- a/official/cv/retinaface_resnet50/src/config.py
+++ b/official/cv/retinaface_resnet50/src/config.py
@@ -20,7 +20,7 @@ cfg_res50 = {
     'class_weight': 1.0,
     'landm_weight': 1.0,
     'batch_size': 8,
-    'num_workers': 8,
+    'num_workers': 4,
     'num_anchor': 29126,
     'ngpu': 4,
     'image_size': 840,
diff --git a/official/cv/retinaface_resnet50/src/dataset.py b/official/cv/retinaface_resnet50/src/dataset.py
index 18a3126b4..7a108e588 100644
--- a/official/cv/retinaface_resnet50/src/dataset.py
+++ b/official/cv/retinaface_resnet50/src/dataset.py
@@ -135,17 +135,34 @@ def create_dataset(data_dir, cfg, batch_size=32, repeat_num=1, shuffle=True, mul
     aug = preproc(cfg['image_size'])
     encode = bbox_encode(cfg)
 
-    def union_data(image, annot):
+    def read_data_from_dataset(image, annot):
         i, a = read_dataset(image, annot)
-        i, a = aug(i, a)
-        out = encode(i, a)
+        return i, a
 
+    def augmentation(image, annot):
+        i, a = aug(image, annot)
+        return i, a
+
+    def encode_data(image, annot):
+        out = encode(image, annot)
         return out
 
+    de_dataset = de_dataset.map(input_columns=["image", "annotation"],
+                                output_columns=["image", "annotation"],
+                                column_order=["image", "annotation"],
+                                operations=read_data_from_dataset,
+                                python_multiprocessing=multiprocessing,
+                                num_parallel_workers=num_worker)
+    de_dataset = de_dataset.map(input_columns=["image", "annotation"],
+                                output_columns=["image", "annotation"],
+                                column_order=["image", "annotation"],
+                                operations=augmentation,
+                                python_multiprocessing=multiprocessing,
+                                num_parallel_workers=num_worker)
     de_dataset = de_dataset.map(input_columns=["image", "annotation"],
                                 output_columns=["image", "truths", "conf", "landm"],
                                 column_order=["image", "truths", "conf", "landm"],
-                                operations=union_data,
+                                operations=encode_data,
                                 python_multiprocessing=multiprocessing,
                                 num_parallel_workers=num_worker)
 
-- 
GitLab