diff --git a/official/cv/retinaface_resnet50/src/config.py b/official/cv/retinaface_resnet50/src/config.py
index 02553fc8efd5681822911e484ebcce24874e34bc..473124ccf11eed0cb9fe28861478af77c42c54f7 100644
--- a/official/cv/retinaface_resnet50/src/config.py
+++ b/official/cv/retinaface_resnet50/src/config.py
@@ -20,7 +20,7 @@ cfg_res50 = {
     'class_weight': 1.0,
     'landm_weight': 1.0,
     'batch_size': 8,
-    'num_workers': 8,
+    'num_workers': 4,
     'num_anchor': 29126,
     'ngpu': 4,
     'image_size': 840,
diff --git a/official/cv/retinaface_resnet50/src/dataset.py b/official/cv/retinaface_resnet50/src/dataset.py
index 18a3126b411d393387a93e1fc05eb5f78f8ed92e..7a108e5883a6ef12792a4e1267d4809464f4627f 100644
--- a/official/cv/retinaface_resnet50/src/dataset.py
+++ b/official/cv/retinaface_resnet50/src/dataset.py
@@ -135,17 +135,34 @@ def create_dataset(data_dir, cfg, batch_size=32, repeat_num=1, shuffle=True, mul
     aug = preproc(cfg['image_size'])
     encode = bbox_encode(cfg)
 
-    def union_data(image, annot):
+    def read_data_from_dataset(image, annot):
         i, a = read_dataset(image, annot)
-        i, a = aug(i, a)
-        out = encode(i, a)
+        return i, a
 
+    def augmentation(image, annot):
+        i, a = aug(image, annot)
+        return i, a
+
+    def encode_data(image, annot):
+        out = encode(image, annot)
         return out
 
+    de_dataset = de_dataset.map(input_columns=["image", "annotation"],
+                                output_columns=["image", "annotation"],
+                                column_order=["image", "annotation"],
+                                operations=read_data_from_dataset,
+                                python_multiprocessing=multiprocessing,
+                                num_parallel_workers=num_worker)
+    de_dataset = de_dataset.map(input_columns=["image", "annotation"],
+                                output_columns=["image", "annotation"],
+                                column_order=["image", "annotation"],
+                                operations=augmentation,
+                                python_multiprocessing=multiprocessing,
+                                num_parallel_workers=num_worker)
     de_dataset = de_dataset.map(input_columns=["image", "annotation"],
                                 output_columns=["image", "truths", "conf", "landm"],
                                 column_order=["image", "truths", "conf", "landm"],
-                                operations=union_data,
+                                operations=encode_data,
                                 python_multiprocessing=multiprocessing,
                                 num_parallel_workers=num_worker)