diff --git a/official/cv/resnet/src/dataset.py b/official/cv/resnet/src/dataset.py index 40f1b143606cc85361d3b1c4b6e2f39d942df297..77809beb7135f291b2b735bda1dcc1426b32e26c 100644 --- a/official/cv/resnet/src/dataset.py +++ b/official/cv/resnet/src/dataset.py @@ -127,8 +127,12 @@ def create_dataset2(dataset_path, do_train, batch_size=32, train_image_size=224, trans_norm = [ds.vision.c_transforms.Normalize(mean=mean, std=std), ds.vision.c_transforms.HWC2CHW()] type_cast_op = ds.transforms.c_transforms.TypeCast(ms.int32) - - data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=get_num_parallel_workers(12)) + if device_num == 1: + trans_work_num = 24 + else: + trans_work_num = 12 + data_set = data_set.map(operations=trans, input_columns="image", + num_parallel_workers=get_num_parallel_workers(trans_work_num)) data_set = data_set.map(operations=trans_norm, input_columns="image", num_parallel_workers=get_num_parallel_workers(12)) # only enable cache for eval