From 05c36fce1be727da5156b5932c5cf39e75c9e23e Mon Sep 17 00:00:00 2001
From: jonyguo <guozhijian@huawei.com>
Date: Thu, 22 Sep 2022 11:42:06 +0800
Subject: [PATCH] reconstruct batch param for md

---
 official/cv/maskrcnn/src/dataset.py             | 3 ++-
 official/cv/maskrcnn_mobilenetv1/src/dataset.py | 3 ++-
 official/nlp/cpm/train.py                       | 4 ++--
 research/cv/IPT/train.py                        | 5 +++--
 research/cv/IPT/train_finetune.py               | 5 +++--
 research/cv/Yolact++/src/dataset.py             | 3 ++-
 research/cv/pointnet2/train.py                  | 2 +-
 research/cv/textfusenet/src/dataset.py          | 3 ++-
 8 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/official/cv/maskrcnn/src/dataset.py b/official/cv/maskrcnn/src/dataset.py
index b3b1b9c13..1529ce9a7 100644
--- a/official/cv/maskrcnn/src/dataset.py
+++ b/official/cv/maskrcnn/src/dataset.py
@@ -550,7 +550,8 @@ def create_maskrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id
                     column_order=["image", "image_shape", "box", "label", "valid_num", "mask"],
                     python_multiprocessing=False,
                     num_parallel_workers=num_parallel_workers)
-        ds = ds.batch(batch_size, drop_remainder=True, pad_info={"mask": ([config.max_instance_count, None, None], 0)})
+        ds = ds.padded_batch(batch_size, drop_remainder=True,
+                             pad_info={"mask": ([config.max_instance_count, None, None], 0)})
 
     else:
         ds = ds.map(operations=compose_map_func,
diff --git a/official/cv/maskrcnn_mobilenetv1/src/dataset.py b/official/cv/maskrcnn_mobilenetv1/src/dataset.py
index 0d05c8c10..eb065636f 100644
--- a/official/cv/maskrcnn_mobilenetv1/src/dataset.py
+++ b/official/cv/maskrcnn_mobilenetv1/src/dataset.py
@@ -565,7 +565,8 @@ def create_maskrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id
                     column_order=["image", "image_shape", "box", "label", "valid_num", "mask"],
                     python_multiprocessing=False,
                     num_parallel_workers=num_parallel_workers)
-        ds = ds.batch(batch_size, drop_remainder=True, pad_info={"mask": ([config.max_instance_count, None, None], 0)})
+        ds = ds.padded_batch(batch_size, drop_remainder=True,
+                             pad_info={"mask": ([config.max_instance_count, None, None], 0)})
 
     else:
         ds = ds.map(operations=compose_map_func,
diff --git a/official/nlp/cpm/train.py b/official/nlp/cpm/train.py
index ca9b4d249..881ac3c4c 100644
--- a/official/nlp/cpm/train.py
+++ b/official/nlp/cpm/train.py
@@ -89,17 +89,17 @@ def _load_dataset(dataset_path, batch_size, rank_size=None, rank_id=None, shuffl
                           per_batch_map=collate,
                           input_columns=["truth", "input_ids"],
                           output_columns=["input_ids", "attention_mask", "position_ids", "truth"],
-                          column_order=["input_ids", "attention_mask", "position_ids", "loss_mask", "labels"],
                           num_parallel_workers=4,
                           drop_remainder=drop_remainder)
+        data = data.project(["input_ids", "attention_mask", "position_ids", "loss_mask", "labels"])
     else:
         data = data.batch(batch_size,
                           per_batch_map=collate,
                           input_columns=["truth", "input_ids"],
                           output_columns=["input_ids", "attention_mask", "position_ids", "truth"],
-                          column_order=["input_ids", "attention_mask", "position_ids", "loss_mask", "labels", "truth"],
                           num_parallel_workers=4,
                           drop_remainder=drop_remainder)
+        data = data.project(["input_ids", "attention_mask", "position_ids", "loss_mask", "labels", "truth"])
 
     return data
 
diff --git a/research/cv/IPT/train.py b/research/cv/IPT/train.py
index db6b66bfd..c6424fe96 100644
--- a/research/cv/IPT/train.py
+++ b/research/cv/IPT/train.py
@@ -131,9 +131,10 @@ def train_net(distribute, imagenet, epochs):
     resize_fuc = bicubic()
     train_de_dataset = train_de_dataset.project(columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"])
     train_de_dataset = train_de_dataset.batch(args.batch_size,
+                                              drop_remainder=True,
+                                              per_batch_map=resize_fuc.forward,
                                               input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"],
-                                              output_columns=["LR", "HR", "idx", "filename"],
-                                              drop_remainder=True, per_batch_map=resize_fuc.forward)
+                                              output_columns=["LR", "HR", "idx", "filename"])
     train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
 
     net_work = IPT(args)
diff --git a/research/cv/IPT/train_finetune.py b/research/cv/IPT/train_finetune.py
index 486172611..11b184c0d 100644
--- a/research/cv/IPT/train_finetune.py
+++ b/research/cv/IPT/train_finetune.py
@@ -68,9 +68,10 @@ def train_net(distribute, imagenet):
         resize_fuc = bicubic()
         train_de_dataset = train_de_dataset.batch(
             args.batch_size,
+            drop_remainder=True,
+            per_batch_map=resize_fuc.forward,
             input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
-            output_columns=["LR", "HR", "idx", "filename"], drop_remainder=True,
-            per_batch_map=resize_fuc.forward)
+            output_columns=["LR", "HR", "idx", "filename"])
     else:
         train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
 
diff --git a/research/cv/Yolact++/src/dataset.py b/research/cv/Yolact++/src/dataset.py
index e7b277e3b..02c95b566 100644
--- a/research/cv/Yolact++/src/dataset.py
+++ b/research/cv/Yolact++/src/dataset.py
@@ -483,7 +483,8 @@ def create_yolact_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id=0
                     python_multiprocessing=False,
                     num_parallel_workers=8)
 
-        ds = ds.batch(batch_size, drop_remainder=True, pad_info={"mask": ([cfg['max_instance_count'], None, None], 0)})
+        ds = ds.padded_batch(batch_size, drop_remainder=True,
+                             pad_info={"mask": ([cfg['max_instance_count'], None, None], 0)})
 
     else:
         ds = ds.map(operations=compose_map_func,
diff --git a/research/cv/pointnet2/train.py b/research/cv/pointnet2/train.py
index f0a1d95a4..16558007e 100644
--- a/research/cv/pointnet2/train.py
+++ b/research/cv/pointnet2/train.py
@@ -170,9 +170,9 @@ def run_train():
     random_input_dropout = RandomInputDropout()
 
     train_ds = train_ds.batch(batch_size=args.batch_size,
+                              drop_remainder=True,
                               per_batch_map=random_input_dropout,
                               input_columns=["data", "label"],
-                              drop_remainder=True,
                               num_parallel_workers=num_workers,
                               python_multiprocessing=True)
 
diff --git a/research/cv/textfusenet/src/dataset.py b/research/cv/textfusenet/src/dataset.py
index a406408b3..70b9cd602 100755
--- a/research/cv/textfusenet/src/dataset.py
+++ b/research/cv/textfusenet/src/dataset.py
@@ -519,7 +519,8 @@ def create_textfusenet_dataset(mindrecord_file, batch_size=2, device_num=1, rank
                     column_order=["image", "image_shape", "box", "label", "valid_num", "mask"],
                     python_multiprocessing=False,
                     num_parallel_workers=num_parallel_workers)
-        ds = ds.batch(batch_size, drop_remainder=True, pad_info={"mask": ([config.max_instance_count, None, None], 0)})
+        ds = ds.padded_batch(batch_size, drop_remainder=True,
+                             pad_info={"mask": ([config.max_instance_count, None, None], 0)})
 
     else:
         ds = ds.map(operations=compose_map_func,
-- 
GitLab