diff --git a/official/nlp/pangu_alpha/src/utils.py b/official/nlp/pangu_alpha/src/utils.py index 57caa2b2b5c4c818fbf40d14eb8db7d32669ffd9..a07ed7d2f87707ed00583a265e06d6681dd849c0 100644 --- a/official/nlp/pangu_alpha/src/utils.py +++ b/official/nlp/pangu_alpha/src/utils.py @@ -394,9 +394,6 @@ def add_training_params(opt): opt.add_argument("--data_column_name", type=str, default="input_ids", help="Column name of datasets") - opt.add_argument("--micro_batch_interleaved", - type=int, default=2, - help="Parallel split num of batch size. default 2") diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py index 7ada03c2db5f0674090d52a5cfe6b0c5afdd3d75..a70ad082a1bedc673a61273ccd67ce5ebb85ea8d 100644 --- a/official/nlp/pangu_alpha/train.py +++ b/official/nlp/pangu_alpha/train.py @@ -30,7 +30,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell import mindspore.common.dtype as mstype from mindspore.parallel import set_algo_parameters from mindspore.parallel._cost_model_context import _set_multi_subgraphs -from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, MicroBatchInterleaved +from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net @@ -103,6 +103,7 @@ def set_parallel_context(args_opt): enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt') set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() + return rank, device_num def run_train(args_opt): @@ -115,7 +116,7 @@ def run_train(args_opt): rank = 0 device_num = 1 if args_opt.distribute == "true": - set_parallel_context(args_opt) + rank, device_num = set_parallel_context(args_opt) context.set_context(save_graphs=False, save_graphs_path="./graphs_of_device_id_" + str(rank)) # copy data from the cloud to the /cache/Data cache_url = '/cache/Data/' @@ -130,14 +131,13 @@ def run_train(args_opt): model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) batch_size = args_opt.per_batch_size * data_parallel_num - micro_batch_interleaved = args_opt.micro_batch_interleaved parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num, pipeline_stage=args_opt.stage_num, micro_batch_num=args_opt.micro_size, optimizer_shard=bool(args_opt.optimizer_shard), vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True, gradient_aggregation_group=args_opt.gradient_aggregation_group) - config = PanguAlphaConfig(batch_size=batch_size // micro_batch_interleaved, num_heads=args_opt.num_heads, + config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers, ffn_hidden_size=args_opt.embedding_size * 4, eod_token=bool(args_opt.eod_reset), @@ -151,8 +151,7 @@ def run_train(args_opt): # Define network pangu_alpha = PanguAlphaModel(config=config) loss = CrossEntropyLoss(config.parallel_config.dp_mp_config) - pangu_alpha_with_loss_net = MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss), - micro_batch_interleaved) + pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss) pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net) print("=====args_opt is: ", args_opt, flush=True) # Warm-up and cosine decay learning rate @@ -171,7 +170,7 @@ def run_train(args_opt): loss_scale_value = math.pow(2, 32) epoch_num = args_opt.epoch_size # Dataset loading mindrecord files - ds = create_dataset(config.batch_size * micro_batch_interleaved, data_path=cache_url, data_start_index=0, + ds = create_dataset(config.batch_size, data_path=cache_url, data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id, device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num) actual_epoch_num = int(epoch_num * ds.get_dataset_size() / args_opt.sink_size) @@ -181,7 +180,7 @@ def run_train(args_opt): pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True, config=config) if args_opt.train_and_eval_mode: - ds_eval = create_dataset(config.batch_size * micro_batch_interleaved, data_path=eval_cache_url, + ds_eval = create_dataset(config.batch_size, data_path=eval_cache_url, data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id, device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num,