Skip to content
Snippets Groups Projects
Commit 05c36fce authored by jonyguo's avatar jonyguo
Browse files

reconstruct batch param for md

parent f467d207
No related branches found
No related tags found
No related merge requests found
...@@ -550,7 +550,8 @@ def create_maskrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id ...@@ -550,7 +550,8 @@ def create_maskrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id
column_order=["image", "image_shape", "box", "label", "valid_num", "mask"], column_order=["image", "image_shape", "box", "label", "valid_num", "mask"],
python_multiprocessing=False, python_multiprocessing=False,
num_parallel_workers=num_parallel_workers) num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True, pad_info={"mask": ([config.max_instance_count, None, None], 0)}) ds = ds.padded_batch(batch_size, drop_remainder=True,
pad_info={"mask": ([config.max_instance_count, None, None], 0)})
else: else:
ds = ds.map(operations=compose_map_func, ds = ds.map(operations=compose_map_func,
......
...@@ -565,7 +565,8 @@ def create_maskrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id ...@@ -565,7 +565,8 @@ def create_maskrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id
column_order=["image", "image_shape", "box", "label", "valid_num", "mask"], column_order=["image", "image_shape", "box", "label", "valid_num", "mask"],
python_multiprocessing=False, python_multiprocessing=False,
num_parallel_workers=num_parallel_workers) num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True, pad_info={"mask": ([config.max_instance_count, None, None], 0)}) ds = ds.padded_batch(batch_size, drop_remainder=True,
pad_info={"mask": ([config.max_instance_count, None, None], 0)})
else: else:
ds = ds.map(operations=compose_map_func, ds = ds.map(operations=compose_map_func,
......
...@@ -89,17 +89,17 @@ def _load_dataset(dataset_path, batch_size, rank_size=None, rank_id=None, shuffl ...@@ -89,17 +89,17 @@ def _load_dataset(dataset_path, batch_size, rank_size=None, rank_id=None, shuffl
per_batch_map=collate, per_batch_map=collate,
input_columns=["truth", "input_ids"], input_columns=["truth", "input_ids"],
output_columns=["input_ids", "attention_mask", "position_ids", "truth"], output_columns=["input_ids", "attention_mask", "position_ids", "truth"],
column_order=["input_ids", "attention_mask", "position_ids", "loss_mask", "labels"],
num_parallel_workers=4, num_parallel_workers=4,
drop_remainder=drop_remainder) drop_remainder=drop_remainder)
data = data.project(["input_ids", "attention_mask", "position_ids", "loss_mask", "labels"])
else: else:
data = data.batch(batch_size, data = data.batch(batch_size,
per_batch_map=collate, per_batch_map=collate,
input_columns=["truth", "input_ids"], input_columns=["truth", "input_ids"],
output_columns=["input_ids", "attention_mask", "position_ids", "truth"], output_columns=["input_ids", "attention_mask", "position_ids", "truth"],
column_order=["input_ids", "attention_mask", "position_ids", "loss_mask", "labels", "truth"],
num_parallel_workers=4, num_parallel_workers=4,
drop_remainder=drop_remainder) drop_remainder=drop_remainder)
data = data.project(["input_ids", "attention_mask", "position_ids", "loss_mask", "labels", "truth"])
return data return data
......
...@@ -131,9 +131,10 @@ def train_net(distribute, imagenet, epochs): ...@@ -131,9 +131,10 @@ def train_net(distribute, imagenet, epochs):
resize_fuc = bicubic() resize_fuc = bicubic()
train_de_dataset = train_de_dataset.project(columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"]) train_de_dataset = train_de_dataset.project(columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"])
train_de_dataset = train_de_dataset.batch(args.batch_size, train_de_dataset = train_de_dataset.batch(args.batch_size,
drop_remainder=True,
per_batch_map=resize_fuc.forward,
input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"], input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"],
output_columns=["LR", "HR", "idx", "filename"], output_columns=["LR", "HR", "idx", "filename"])
drop_remainder=True, per_batch_map=resize_fuc.forward)
train_loader = train_de_dataset.create_dict_iterator(output_numpy=True) train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
net_work = IPT(args) net_work = IPT(args)
......
...@@ -68,9 +68,10 @@ def train_net(distribute, imagenet): ...@@ -68,9 +68,10 @@ def train_net(distribute, imagenet):
resize_fuc = bicubic() resize_fuc = bicubic()
train_de_dataset = train_de_dataset.batch( train_de_dataset = train_de_dataset.batch(
args.batch_size, args.batch_size,
drop_remainder=True,
per_batch_map=resize_fuc.forward,
input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"], input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
output_columns=["LR", "HR", "idx", "filename"], drop_remainder=True, output_columns=["LR", "HR", "idx", "filename"])
per_batch_map=resize_fuc.forward)
else: else:
train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True) train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
......
...@@ -483,7 +483,8 @@ def create_yolact_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id=0 ...@@ -483,7 +483,8 @@ def create_yolact_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id=0
python_multiprocessing=False, python_multiprocessing=False,
num_parallel_workers=8) num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True, pad_info={"mask": ([cfg['max_instance_count'], None, None], 0)}) ds = ds.padded_batch(batch_size, drop_remainder=True,
pad_info={"mask": ([cfg['max_instance_count'], None, None], 0)})
else: else:
ds = ds.map(operations=compose_map_func, ds = ds.map(operations=compose_map_func,
......
...@@ -170,9 +170,9 @@ def run_train(): ...@@ -170,9 +170,9 @@ def run_train():
random_input_dropout = RandomInputDropout() random_input_dropout = RandomInputDropout()
train_ds = train_ds.batch(batch_size=args.batch_size, train_ds = train_ds.batch(batch_size=args.batch_size,
drop_remainder=True,
per_batch_map=random_input_dropout, per_batch_map=random_input_dropout,
input_columns=["data", "label"], input_columns=["data", "label"],
drop_remainder=True,
num_parallel_workers=num_workers, num_parallel_workers=num_workers,
python_multiprocessing=True) python_multiprocessing=True)
......
...@@ -519,7 +519,8 @@ def create_textfusenet_dataset(mindrecord_file, batch_size=2, device_num=1, rank ...@@ -519,7 +519,8 @@ def create_textfusenet_dataset(mindrecord_file, batch_size=2, device_num=1, rank
column_order=["image", "image_shape", "box", "label", "valid_num", "mask"], column_order=["image", "image_shape", "box", "label", "valid_num", "mask"],
python_multiprocessing=False, python_multiprocessing=False,
num_parallel_workers=num_parallel_workers) num_parallel_workers=num_parallel_workers)
ds = ds.batch(batch_size, drop_remainder=True, pad_info={"mask": ([config.max_instance_count, None, None], 0)}) ds = ds.padded_batch(batch_size, drop_remainder=True,
pad_info={"mask": ([config.max_instance_count, None, None], 0)})
else: else:
ds = ds.map(operations=compose_map_func, ds = ds.map(operations=compose_map_func,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment