diff --git a/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py b/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py index 85707de766470ee0c92326f757e4c4e8e9067c58..227bf48a3526a8b3bdd5abc69be6b8245e8659fb 100644 --- a/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py +++ b/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py @@ -95,15 +95,15 @@ def train_and_eval(config): if config.full_batch: context.set_auto_parallel_context(full_batch=True) ds.config.set_seed(1) - ds_train = create_dataset(data_path, train_mode=True, epochs=1, + ds_train = create_dataset(data_path, train_mode=True, batch_size=batch_size*get_group_size(), data_type=dataset_type) - ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + ds_eval = create_dataset(data_path, train_mode=False, batch_size=batch_size*get_group_size(), data_type=dataset_type) else: - ds_train = create_dataset(data_path, train_mode=True, epochs=1, + ds_train = create_dataset(data_path, train_mode=True, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) - ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + ds_eval = create_dataset(data_path, train_mode=False, batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size(), data_type=dataset_type) print("ds_train.size: {}".format(ds_train.get_dataset_size()))