diff --git a/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py b/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py
index 85707de766470ee0c92326f757e4c4e8e9067c58..227bf48a3526a8b3bdd5abc69be6b8245e8659fb 100644
--- a/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py
+++ b/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py
@@ -95,15 +95,15 @@ def train_and_eval(config):
     if config.full_batch:
         context.set_auto_parallel_context(full_batch=True)
         ds.config.set_seed(1)
-        ds_train = create_dataset(data_path, train_mode=True, epochs=1,
+        ds_train = create_dataset(data_path, train_mode=True,
                                   batch_size=batch_size*get_group_size(), data_type=dataset_type)
-        ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
+        ds_eval = create_dataset(data_path, train_mode=False,
                                  batch_size=batch_size*get_group_size(), data_type=dataset_type)
     else:
-        ds_train = create_dataset(data_path, train_mode=True, epochs=1,
+        ds_train = create_dataset(data_path, train_mode=True,
                                   batch_size=batch_size, rank_id=get_rank(),
                                   rank_size=get_group_size(), data_type=dataset_type)
-        ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
+        ds_eval = create_dataset(data_path, train_mode=False,
                                  batch_size=batch_size, rank_id=get_rank(),
                                  rank_size=get_group_size(), data_type=dataset_type)
     print("ds_train.size: {}".format(ds_train.get_dataset_size()))