Skip to content
Snippets Groups Projects
Unverified Commit a56728f7 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3513 fix bert bug

Merge pull request !3513 from 周莉莉/bert
parents 3549a44d b422e71f
No related branches found
No related tags found
No related merge requests found
......@@ -269,7 +269,8 @@ For ner or classification task, schema file contains ["input_ids", "input_mask",
For squad task, training: schema file contains ["start_positions", "end_positions", "input_ids", "input_mask", "segment_ids"], evaluation: schema file contains ["input_ids", "input_mask", "segment_ids"].
`numRows` is the only option which could be set by user, other values must be set according to the dataset.
`numRows` is the only option in schema file which could be set by user when dataset_format is tfrecord, other values must be set according to the dataset.
`num_samlpes` is the only option in yaml file which could be set by user when dataset_format is mindrecord, other values must be set according to the dataset.
For example, the schema file of cn-wiki-128 dataset for pretraining shows as follows:
{
......
......@@ -272,7 +272,8 @@ For ner or classification task, schema file contains ["input_ids", "input_mask",
For squad task, training: schema file contains ["start_positions", "end_positions", "input_ids", "input_mask", "segment_ids"], evaluation: schema file contains ["input_ids", "input_mask", "segment_ids"].
`numRows` is the only option which could be set by user, other values must be set according to the dataset.
`numRows` is the only option in schema file which could be set by user when dataset_format is tfrecord, other values must be set according to the dataset.
`num_samlpes` is the only option in yaml file which could be set by user when dataset_format is mindrecord, other values must be set according to the dataset.
For example, the schema file of cn-wiki-128 dataset for pretraining shows as follows:
{
......
......@@ -32,6 +32,7 @@ save_checkpoint_num: 1
data_dir: ''
schema_dir: ''
dataset_format: "mindrecord"
num_samples: None # is the option which could be set by user to specify steps
# ==============================================================================
# pretrain related
......
......@@ -32,6 +32,7 @@ save_checkpoint_num: 1
data_dir: ''
schema_dir: ''
dataset_format: "mindrecord"
num_samples: None # is the option which could be set by user to specify steps
# ==============================================================================
# pretrain related
......
......@@ -32,6 +32,7 @@ save_checkpoint_num: 5
data_dir: ''
schema_dir: ''
dataset_format: "mindrecord"
num_samples: None # is the option which could be set by user to specify steps
# ==============================================================================
# pretrain related
......
......@@ -231,7 +231,7 @@ def run_pretrain():
logger.info("save checkpoint steps: {}".format(cfg.save_checkpoint_steps))
ds = create_bert_dataset(device_num, rank, cfg.do_shuffle, cfg.data_dir, cfg.schema_dir, cfg.batch_size,
cfg.bucket_list, cfg.dataset_format)
cfg.bucket_list, cfg.dataset_format, cfg.num_samples)
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
new_repeat_count = cfg.epoch_size * ds.get_dataset_size() // cfg.data_sink_steps
......@@ -262,7 +262,7 @@ def run_pretrain():
if cfg.train_with_eval == 'true':
net_eval = BertPretrainEval(bert_net_cfg, network=net_with_loss.bert)
eval_ds = create_eval_dataset(cfg.batch_size, device_num, rank, cfg.eval_data_dir, cfg.schema_dir,
cfg.dataset_format)
cfg.dataset_format, cfg.num_samples)
model = Model(net_with_grads, eval_network=net_eval, metrics={'bert_acc': BertMetric(cfg.batch_size)})
eval_callback = EvalCallBack(model, eval_ds, device_num * cfg.batch_size, cfg.eval_samples)
callback.append(eval_callback)
......
......@@ -116,7 +116,7 @@ class BucketDatasetGenerator:
def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None, batch_size=32,
bucket_list=None, dataset_format="mindrecord"):
bucket_list=None, dataset_format="mindrecord", num_samples=None):
"""create train dataset"""
# apply repeat operations
files = os.listdir(data_dir)
......@@ -126,11 +126,17 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
(dataset_format == "mindrecord" and "mindrecord" in file_name and "mindrecord.db" not in file_name):
data_files.append(os.path.join(data_dir, file_name))
if dataset_format == "mindrecord":
data_set = ds.MindDataset(data_files,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=ds.Shuffle.FILES if do_shuffle == "true" else False,
num_shards=device_num, shard_id=rank)
if num_samples is not None:
data_set = ds.MindDataset(data_files,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=False, num_shards=device_num, shard_id=rank, num_samples=num_samples)
else:
data_set = ds.MindDataset(data_files,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=ds.Shuffle.FILES if do_shuffle == "true" else False,
num_shards=device_num, shard_id=rank)
elif dataset_format == "tfrecord":
data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
......@@ -261,7 +267,7 @@ def create_squad_dataset(batch_size=1, data_file_path=None, schema_file_path=Non
def create_eval_dataset(batchsize=32, device_num=1, rank=0, data_dir=None, schema_dir=None,
dataset_format="mindrecord"):
dataset_format="mindrecord", num_samples=None):
"""create evaluation dataset"""
data_files = []
if os.path.isdir(data_dir):
......@@ -275,7 +281,8 @@ def create_eval_dataset(batchsize=32, device_num=1, rank=0, data_dir=None, schem
if dataset_format == "mindrecord":
data_set = ds.MindDataset(data_files,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
num_samples=num_samples)
elif dataset_format == "tfrecord":
data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
......
......@@ -210,7 +210,7 @@ Inference result is saved in current path, you can find result like this in acc.
| Resource | CentOs 8.2; Ascend 910; CPU 2.60GHz, 192cores; Memory 755G |
| MindSpore Version | 1.3.0 |
| Dataset | DIV2K |
| Training Parameters | ddbpn:epoch=2000, batch_size = 16; dbpngan:epoch=1000,batch_size=4|
| Training Parameters | ddbpn:epoch=2000, batch_size = 16; dbpngan:epoch=1100,batch_size=4|
| Optimizer | Adam |
| Loss Function | BCELoss MSELoss VGGLoss |
| outputs | super-resolution pictures |
......
......@@ -52,7 +52,7 @@ def get_args(is_gan=False):
# additional parameters
parser.add_argument('--sens', type=float, default=1024.0)
if is_gan:
parser.add_argument('--nEpochs', type=int, default=1000, help='number of epochs to train for')
parser.add_argument('--nEpochs', type=int, default=1100, help='number of epochs to train for')
parser.add_argument('--batchSize', type=int, default=4, choices=[4, 8, 16], help='training batch size')
parser.add_argument('--patch_size', type=int, default=60, choices=[40, 60], help='Size of cropped HR image')
parser.add_argument('--pretrained_iter', type=int, default=100, help='number of epochs to train for')
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment