From c3e86ec27f016e62bf2b06d9be0bfe40a9176d5e Mon Sep 17 00:00:00 2001
From: z00512154 <zhoulili20@huawei.com>
Date: Fri, 2 Sep 2022 17:14:47 +0800
Subject: [PATCH] fix bert bug

---
 official/nlp/bert/pretrain_config.yaml        |  2 +-
 .../bert/pretrain_config_Ascend_Boost.yaml    |  2 +-
 .../nlp/bert/pretrain_config_Ascend_Thor.yaml |  2 +-
 .../run_distributed_pretrain_for_gpu.sh       |  4 +--
 official/nlp/bert/src/dataset.py              | 29 +++++++++++++------
 .../tools/parallel_tfrecord_to_mindrecord.py  |  3 +-
 6 files changed, 27 insertions(+), 15 deletions(-)

diff --git a/official/nlp/bert/pretrain_config.yaml b/official/nlp/bert/pretrain_config.yaml
index 55603ee15..d4469561b 100644
--- a/official/nlp/bert/pretrain_config.yaml
+++ b/official/nlp/bert/pretrain_config.yaml
@@ -32,7 +32,7 @@ save_checkpoint_num: 1
 data_dir: ''
 schema_dir: ''
 dataset_format: "mindrecord"
-num_samples: None   # is the option which could be set by user to specify steps
+num_samples: None   # is the option which could be set by user to specify steps when bert_network is base
 
 # ==============================================================================
 # pretrain related
diff --git a/official/nlp/bert/pretrain_config_Ascend_Boost.yaml b/official/nlp/bert/pretrain_config_Ascend_Boost.yaml
index 77be3be95..0a9680235 100644
--- a/official/nlp/bert/pretrain_config_Ascend_Boost.yaml
+++ b/official/nlp/bert/pretrain_config_Ascend_Boost.yaml
@@ -32,7 +32,7 @@ save_checkpoint_num: 1
 data_dir: ''
 schema_dir: ''
 dataset_format: "mindrecord"
-num_samples: None   # is the option which could be set by user to specify steps
+num_samples: None   # is the option which could be set by user to specify steps when bert_network is base
 
 # ==============================================================================
 # pretrain related
diff --git a/official/nlp/bert/pretrain_config_Ascend_Thor.yaml b/official/nlp/bert/pretrain_config_Ascend_Thor.yaml
index 666ab1831..31ac77f8a 100644
--- a/official/nlp/bert/pretrain_config_Ascend_Thor.yaml
+++ b/official/nlp/bert/pretrain_config_Ascend_Thor.yaml
@@ -32,7 +32,7 @@ save_checkpoint_num: 5
 data_dir: ''
 schema_dir: ''
 dataset_format: "mindrecord"
-num_samples: None   # is the option which could be set by user to specify steps
+num_samples: None   # is the option which could be set by user to specify steps when bert_network is base
 
 # ==============================================================================
 # pretrain related
diff --git a/official/nlp/bert/scripts/run_distributed_pretrain_for_gpu.sh b/official/nlp/bert/scripts/run_distributed_pretrain_for_gpu.sh
index 770dab311..a03d18387 100644
--- a/official/nlp/bert/scripts/run_distributed_pretrain_for_gpu.sh
+++ b/official/nlp/bert/scripts/run_distributed_pretrain_for_gpu.sh
@@ -16,8 +16,8 @@
 
 echo "=============================================================================================================="
 echo "Please run the script as: "
-echo "bash scripts/run_distributed_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR"
-echo "for example: bash scripts/run_distributed_pretrain.sh 8 40 /path/zh-wiki/ [/path/Schema.json](optional)"
+echo "bash scripts/run_distributed_pretrain_for_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR"
+echo "for example: bash scripts/run_distributed_pretrain_for_gpu.sh 8 40 /path/zh-wiki/ [/path/Schema.json](optional)"
 echo "It is better to use absolute path."
 echo "=============================================================================================================="
 
diff --git a/official/nlp/bert/src/dataset.py b/official/nlp/bert/src/dataset.py
index f1277646b..01f9659e6 100644
--- a/official/nlp/bert/src/dataset.py
+++ b/official/nlp/bert/src/dataset.py
@@ -126,7 +126,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
                 (dataset_format == "mindrecord" and "mindrecord" in file_name and "mindrecord.db" not in file_name):
             data_files.append(os.path.join(data_dir, file_name))
     if dataset_format == "mindrecord":
-        if num_samples is not None:
+        if str(num_samples).lower() != "none":
             data_set = ds.MindDataset(data_files,
                                       columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
                                                     "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
@@ -279,10 +279,15 @@ def create_eval_dataset(batchsize=32, device_num=1, rank=0, data_dir=None, schem
     else:
         data_files.append(data_dir)
     if dataset_format == "mindrecord":
-        data_set = ds.MindDataset(data_files,
-                                  columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
-                                                "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
-                                  num_samples=num_samples)
+        if str(num_samples).lower() != "none":
+            data_set = ds.MindDataset(data_files,
+                                      columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
+                                                    "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
+                                      num_samples=num_samples)
+        else:
+            data_set = ds.MindDataset(data_files,
+                                      columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
+                                                    "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
     elif dataset_format == "tfrecord":
         data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
                                       columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
@@ -312,10 +317,16 @@ def create_eval_dataset(batchsize=32, device_num=1, rank=0, data_dir=None, schem
         eval_ds.use_sampler(sampler)
     else:
         if dataset_format == "mindrecord":
-            eval_ds = ds.MindDataset(data_files,
-                                     columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
-                                                   "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
-                                     num_shards=device_num, shard_id=rank)
+            if str(num_samples).lower() != "none":
+                eval_ds = ds.MindDataset(data_files,
+                                         columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
+                                                       "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
+                                         num_shards=device_num, shard_id=rank, num_samples=num_samples)
+            else:
+                eval_ds = ds.MindDataset(data_files,
+                                         columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
+                                                       "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
+                                         num_shards=device_num, shard_id=rank)
         elif dataset_format == "tfrecord":
             eval_ds = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
                                          columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
diff --git a/official/nlp/bert/src/tools/parallel_tfrecord_to_mindrecord.py b/official/nlp/bert/src/tools/parallel_tfrecord_to_mindrecord.py
index c65d5f8e5..20f6e2085 100644
--- a/official/nlp/bert/src/tools/parallel_tfrecord_to_mindrecord.py
+++ b/official/nlp/bert/src/tools/parallel_tfrecord_to_mindrecord.py
@@ -22,7 +22,8 @@ def tf_2_mr(item):
     item_path = item
     if not os.path.exists(args.output_mindrecord_dir):
         os.makedirs(args.output_mindrecord_dir, exist_ok=True)
-    mindrecord_path = args.output_mindrecord_dir + item[item.rfind('/') + 1:item.rfind('.')] + '.mindrecord'
+    mindrecord_path = os.path.join(args.output_mindrecord_dir,
+                                   item[item.rfind('/') + 1:item.rfind('.')] + '.mindrecord')
     print("Start convert {} to {}.".format(item_path, mindrecord_path))
     writer = FileWriter(file_name=mindrecord_path, shard_num=1, overwrite=True)
     nlp_schema = {"input_ids": {"type": "int64", "shape": [-1]},
-- 
GitLab