Skip to content
Snippets Groups Projects
Unverified Commit 3cfe08b3 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3584 fix bert bug

Merge pull request !3584 from 周莉莉/ISSUE
parents f1d9d687 c3e86ec2
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,7 @@ save_checkpoint_num: 1 ...@@ -32,7 +32,7 @@ save_checkpoint_num: 1
data_dir: '' data_dir: ''
schema_dir: '' schema_dir: ''
dataset_format: "mindrecord" dataset_format: "mindrecord"
num_samples: None # is the option which could be set by user to specify steps num_samples: None # is the option which could be set by user to specify steps when bert_network is base
# ============================================================================== # ==============================================================================
# pretrain related # pretrain related
......
...@@ -32,7 +32,7 @@ save_checkpoint_num: 1 ...@@ -32,7 +32,7 @@ save_checkpoint_num: 1
data_dir: '' data_dir: ''
schema_dir: '' schema_dir: ''
dataset_format: "mindrecord" dataset_format: "mindrecord"
num_samples: None # is the option which could be set by user to specify steps num_samples: None # is the option which could be set by user to specify steps when bert_network is base
# ============================================================================== # ==============================================================================
# pretrain related # pretrain related
......
...@@ -32,7 +32,7 @@ save_checkpoint_num: 5 ...@@ -32,7 +32,7 @@ save_checkpoint_num: 5
data_dir: '' data_dir: ''
schema_dir: '' schema_dir: ''
dataset_format: "mindrecord" dataset_format: "mindrecord"
num_samples: None # is the option which could be set by user to specify steps num_samples: None # is the option which could be set by user to specify steps when bert_network is base
# ============================================================================== # ==============================================================================
# pretrain related # pretrain related
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the script as: " echo "Please run the script as: "
echo "bash scripts/run_distributed_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR" echo "bash scripts/run_distributed_pretrain_for_gpu.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash scripts/run_distributed_pretrain.sh 8 40 /path/zh-wiki/ [/path/Schema.json](optional)" echo "for example: bash scripts/run_distributed_pretrain_for_gpu.sh 8 40 /path/zh-wiki/ [/path/Schema.json](optional)"
echo "It is better to use absolute path." echo "It is better to use absolute path."
echo "==============================================================================================================" echo "=============================================================================================================="
......
...@@ -126,7 +126,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, ...@@ -126,7 +126,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
(dataset_format == "mindrecord" and "mindrecord" in file_name and "mindrecord.db" not in file_name): (dataset_format == "mindrecord" and "mindrecord" in file_name and "mindrecord.db" not in file_name):
data_files.append(os.path.join(data_dir, file_name)) data_files.append(os.path.join(data_dir, file_name))
if dataset_format == "mindrecord": if dataset_format == "mindrecord":
if num_samples is not None: if str(num_samples).lower() != "none":
data_set = ds.MindDataset(data_files, data_set = ds.MindDataset(data_files,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
...@@ -279,10 +279,15 @@ def create_eval_dataset(batchsize=32, device_num=1, rank=0, data_dir=None, schem ...@@ -279,10 +279,15 @@ def create_eval_dataset(batchsize=32, device_num=1, rank=0, data_dir=None, schem
else: else:
data_files.append(data_dir) data_files.append(data_dir)
if dataset_format == "mindrecord": if dataset_format == "mindrecord":
data_set = ds.MindDataset(data_files, if str(num_samples).lower() != "none":
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", data_set = ds.MindDataset(data_files,
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
num_samples=num_samples) "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
num_samples=num_samples)
else:
data_set = ds.MindDataset(data_files,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
elif dataset_format == "tfrecord": elif dataset_format == "tfrecord":
data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
...@@ -312,10 +317,16 @@ def create_eval_dataset(batchsize=32, device_num=1, rank=0, data_dir=None, schem ...@@ -312,10 +317,16 @@ def create_eval_dataset(batchsize=32, device_num=1, rank=0, data_dir=None, schem
eval_ds.use_sampler(sampler) eval_ds.use_sampler(sampler)
else: else:
if dataset_format == "mindrecord": if dataset_format == "mindrecord":
eval_ds = ds.MindDataset(data_files, if str(num_samples).lower() != "none":
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", eval_ds = ds.MindDataset(data_files,
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
num_shards=device_num, shard_id=rank) "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
num_shards=device_num, shard_id=rank, num_samples=num_samples)
else:
eval_ds = ds.MindDataset(data_files,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
num_shards=device_num, shard_id=rank)
elif dataset_format == "tfrecord": elif dataset_format == "tfrecord":
eval_ds = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, eval_ds = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
......
...@@ -22,7 +22,8 @@ def tf_2_mr(item): ...@@ -22,7 +22,8 @@ def tf_2_mr(item):
item_path = item item_path = item
if not os.path.exists(args.output_mindrecord_dir): if not os.path.exists(args.output_mindrecord_dir):
os.makedirs(args.output_mindrecord_dir, exist_ok=True) os.makedirs(args.output_mindrecord_dir, exist_ok=True)
mindrecord_path = args.output_mindrecord_dir + item[item.rfind('/') + 1:item.rfind('.')] + '.mindrecord' mindrecord_path = os.path.join(args.output_mindrecord_dir,
item[item.rfind('/') + 1:item.rfind('.')] + '.mindrecord')
print("Start convert {} to {}.".format(item_path, mindrecord_path)) print("Start convert {} to {}.".format(item_path, mindrecord_path))
writer = FileWriter(file_name=mindrecord_path, shard_num=1, overwrite=True) writer = FileWriter(file_name=mindrecord_path, shard_num=1, overwrite=True)
nlp_schema = {"input_ids": {"type": "int64", "shape": [-1]}, nlp_schema = {"input_ids": {"type": "int64", "shape": [-1]},
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment