diff --git a/official/nlp/pangu_alpha/src/utils.py b/official/nlp/pangu_alpha/src/utils.py index a07ed7d2f87707ed00583a265e06d6681dd849c0..57caa2b2b5c4c818fbf40d14eb8db7d32669ffd9 100644 --- a/official/nlp/pangu_alpha/src/utils.py +++ b/official/nlp/pangu_alpha/src/utils.py @@ -394,6 +394,9 @@ def add_training_params(opt): opt.add_argument("--data_column_name", type=str, default="input_ids", help="Column name of datasets") + opt.add_argument("--micro_batch_interleaved", + type=int, default=2, + help="Parallel split num of batch size. default 2") diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py index 847ea9b657a501f0bfc5a340be51773eac49d248..db51564d4cfc247bcd18f2457c0778d99b00f77f 100644 --- a/official/nlp/pangu_alpha/train.py +++ b/official/nlp/pangu_alpha/train.py @@ -32,7 +32,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell import mindspore.common.dtype as mstype from mindspore.parallel import set_algo_parameters from mindspore.parallel._cost_model_context import _set_multi_subgraphs -from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell +from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, MicroBatchInterleaved from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net @@ -133,13 +133,14 @@ def run_train(args_opt): model_parallel_num = args_opt.op_level_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) batch_size = args_opt.per_batch_size * data_parallel_num + micro_batch_interleaved = args_opt.micro_batch_interleaved parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num, pipeline_stage=args_opt.stage_num, micro_batch_num=args_opt.micro_size, optimizer_shard=bool(args_opt.optimizer_shard), vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True, gradient_aggregation_group=args_opt.gradient_aggregation_group) - config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads, + config = PanguAlphaConfig(batch_size=batch_size // micro_batch_interleaved, num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers, ffn_hidden_size=args_opt.embedding_size * 4, eod_token=bool(args_opt.eod_reset), @@ -153,7 +154,8 @@ def run_train(args_opt): # Define network pangu_alpha = PanguAlphaModel(config=config) loss = CrossEntropyLoss(config.parallel_config.dp_mp_config) - pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss) + pangu_alpha_with_loss_net = MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss), + micro_batch_interleaved) pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net) print("=====args_opt is: ", args_opt, flush=True) # Warm-up and cosine decay learning rate @@ -172,7 +174,7 @@ def run_train(args_opt): loss_scale_value = math.pow(2, 32) epoch_num = args_opt.epoch_size # Dataset loading mindrecord files - ds = create_dataset(config.batch_size, data_path=cache_url, data_start_index=0, + ds = create_dataset(config.batch_size * micro_batch_interleaved, data_path=cache_url, data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id, device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num) actual_epoch_num = int(epoch_num * ds.get_dataset_size() / args_opt.sink_size) @@ -182,7 +184,7 @@ def run_train(args_opt): pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True, config=config) if args_opt.train_and_eval_mode: - ds_eval = create_dataset(config.batch_size, data_path=eval_cache_url, + ds_eval = create_dataset(config.batch_size * micro_batch_interleaved, data_path=eval_cache_url, data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id, device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num, @@ -353,6 +355,22 @@ def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, return True +def set_pipeline_parallel_context(args_opt): + r"""Set pipeline parallel context.""" + D.init() + device_num = D.get_group_size() + rank_id = D.get_rank() + print("rank_id is {}, device_num is {}".format(rank_id, device_num)) + context.reset_auto_parallel_context() + context.set_auto_parallel_context( + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, + full_batch=bool(args_opt.full_batch), loss_repeated_mean=True, + device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard), + pipeline_stages=args_opt.stage_num) + set_algo_parameters(elementwise_op_strategy_follow=True) + _set_multi_subgraphs() + return rank_id, device_num + def run_train_pipeline(args_opt): r"""The main training process in pipeline.""" # Set hccl connect time @@ -360,22 +378,10 @@ def run_train_pipeline(args_opt): context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(variable_memory_max_size="30GB") + rank_id = int(os.getenv("RANK_ID")) + device_num = 1 if args_opt.distribute == "true": - D.init() - device_num = D.get_group_size() - rank_id = D.get_rank() - print("rank_id is {}, device_num is {}".format(rank_id, device_num)) - context.reset_auto_parallel_context() - context.set_auto_parallel_context( - parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, - full_batch=bool(args_opt.full_batch), loss_repeated_mean=True, - device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard), - pipeline_stages=args_opt.stage_num) - set_algo_parameters(elementwise_op_strategy_follow=True) - _set_multi_subgraphs() - else: - rank_id = int(os.getenv("RANK_ID")) - device_num = 1 + rank_id, device_num = set_pipeline_parallel_context(args_opt) # copy data from the cloud to the /cache/Data cache_url = '/cache/Data/' eval_cache_url = '/cache/EvalData/' @@ -392,7 +398,7 @@ def run_train_pipeline(args_opt): raise ValueError("The dp must large than 1 when applying optimizer shard.") per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num * args_opt.micro_size - + micro_batch_interleaved = args_opt.micro_batch_interleaved parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num, pipeline_stage=args_opt.stage_num, @@ -400,7 +406,7 @@ def run_train_pipeline(args_opt): optimizer_shard=bool(args_opt.optimizer_shard), vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True) - config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num, + config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num // micro_batch_interleaved, num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers, ffn_hidden_size=args_opt.embedding_size * 4, @@ -411,7 +417,8 @@ def run_train_pipeline(args_opt): print("===config is: ", config, flush=True) pangu_alpha = PanguAlphaModel(config=config) loss = CrossEntropyLoss(config.parallel_config.dp_mp_config) - pangu_alpha_with_loss_net = PipelineCell(PanGUAlphaWithLoss(config, pangu_alpha, loss), + pangu_alpha_with_loss_net = PipelineCell(MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss), + micro_batch_interleaved), config.parallel_config.micro_batch_num) pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net) print("=====args_opt is: ", args_opt, flush=True) @@ -427,7 +434,8 @@ def run_train_pipeline(args_opt): else: optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) - ds = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=cache_url, + ds = create_dataset(config.batch_size * parallel_config.micro_batch_num * micro_batch_interleaved, + data_path=cache_url, device_num=stage_device_num, rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0, full_batch=context.get_auto_parallel_context("full_batch"), @@ -443,7 +451,8 @@ def run_train_pipeline(args_opt): pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell( pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell) if args_opt.train_and_eval_mode: - ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=eval_cache_url, + ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num * micro_batch_interleaved, + data_path=eval_cache_url, device_num=stage_device_num, rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0, full_batch=bool(args_opt.full_batch), column_name=args_opt.data_column_name,