Skip to content
Snippets Groups Projects
Commit 129c05a1 authored by sun_zhongqian's avatar sun_zhongqian
Browse files

remove redundant code

parent 9989a82c
No related branches found
No related tags found
No related merge requests found
......@@ -95,15 +95,15 @@ def train_and_eval(config):
if config.full_batch:
context.set_auto_parallel_context(full_batch=True)
ds.config.set_seed(1)
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
ds_train = create_dataset(data_path, train_mode=True,
batch_size=batch_size*get_group_size(), data_type=dataset_type)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
ds_eval = create_dataset(data_path, train_mode=False,
batch_size=batch_size*get_group_size(), data_type=dataset_type)
else:
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
ds_train = create_dataset(data_path, train_mode=True,
batch_size=batch_size, rank_id=get_rank(),
rank_size=get_group_size(), data_type=dataset_type)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
ds_eval = create_dataset(data_path, train_mode=False,
batch_size=batch_size, rank_id=get_rank(),
rank_size=get_group_size(), data_type=dataset_type)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment