diff --git a/official/cv/resnet/src/dataset.py b/official/cv/resnet/src/dataset.py
index 40f1b143606cc85361d3b1c4b6e2f39d942df297..77809beb7135f291b2b735bda1dcc1426b32e26c 100644
--- a/official/cv/resnet/src/dataset.py
+++ b/official/cv/resnet/src/dataset.py
@@ -127,8 +127,12 @@ def create_dataset2(dataset_path, do_train, batch_size=32, train_image_size=224,
     trans_norm = [ds.vision.c_transforms.Normalize(mean=mean, std=std), ds.vision.c_transforms.HWC2CHW()]
 
     type_cast_op = ds.transforms.c_transforms.TypeCast(ms.int32)
-
-    data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=get_num_parallel_workers(12))
+    if device_num == 1:
+        trans_work_num = 24
+    else:
+        trans_work_num = 12
+    data_set = data_set.map(operations=trans, input_columns="image",
+                            num_parallel_workers=get_num_parallel_workers(trans_work_num))
     data_set = data_set.map(operations=trans_norm, input_columns="image",
                             num_parallel_workers=get_num_parallel_workers(12))
     # only enable cache for eval