Skip to content
Snippets Groups Projects
Commit ce61e5c4 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1243 modify pangu model for master

Merge pull request !1243 from lilei/modify_pangu_model
parents f3810dd8 8745a07c
No related branches found
No related tags found
No related merge requests found
......@@ -394,6 +394,9 @@ def add_training_params(opt):
opt.add_argument("--data_column_name",
type=str, default="input_ids",
help="Column name of datasets")
opt.add_argument("--micro_batch_interleaved",
type=int, default=2,
help="Parallel split num of batch size. default 2")
......
......@@ -32,7 +32,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.common.dtype as mstype
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, MicroBatchInterleaved
from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net
......@@ -133,13 +133,14 @@ def run_train(args_opt):
model_parallel_num = args_opt.op_level_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
batch_size = args_opt.per_batch_size * data_parallel_num
micro_batch_interleaved = args_opt.micro_batch_interleaved
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num,
pipeline_stage=args_opt.stage_num,
micro_batch_num=args_opt.micro_size,
optimizer_shard=bool(args_opt.optimizer_shard),
vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True,
gradient_aggregation_group=args_opt.gradient_aggregation_group)
config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads,
config = PanguAlphaConfig(batch_size=batch_size // micro_batch_interleaved, num_heads=args_opt.num_heads,
hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers,
ffn_hidden_size=args_opt.embedding_size * 4, eod_token=bool(args_opt.eod_reset),
......@@ -153,7 +154,8 @@ def run_train(args_opt):
# Define network
pangu_alpha = PanguAlphaModel(config=config)
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss)
pangu_alpha_with_loss_net = MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss),
micro_batch_interleaved)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
print("=====args_opt is: ", args_opt, flush=True)
# Warm-up and cosine decay learning rate
......@@ -172,7 +174,7 @@ def run_train(args_opt):
loss_scale_value = math.pow(2, 32)
epoch_num = args_opt.epoch_size
# Dataset loading mindrecord files
ds = create_dataset(config.batch_size, data_path=cache_url, data_start_index=0,
ds = create_dataset(config.batch_size * micro_batch_interleaved, data_path=cache_url, data_start_index=0,
eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id,
device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num)
actual_epoch_num = int(epoch_num * ds.get_dataset_size() / args_opt.sink_size)
......@@ -182,7 +184,7 @@ def run_train(args_opt):
pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True,
config=config)
if args_opt.train_and_eval_mode:
ds_eval = create_dataset(config.batch_size, data_path=eval_cache_url,
ds_eval = create_dataset(config.batch_size * micro_batch_interleaved, data_path=eval_cache_url,
data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch),
eod_id=args_opt.eod_id, device_num=device_num, rank=rank,
column_name=args_opt.data_column_name, epoch=epoch_num,
......@@ -353,6 +355,22 @@ def restore_exception_checkpoint(args_param, sink_size, dataset, model, network,
return True
def set_pipeline_parallel_context(args_opt):
r"""Set pipeline parallel context."""
D.init()
device_num = D.get_group_size()
rank_id = D.get_rank()
print("rank_id is {}, device_num is {}".format(rank_id, device_num))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
full_batch=bool(args_opt.full_batch), loss_repeated_mean=True,
device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard),
pipeline_stages=args_opt.stage_num)
set_algo_parameters(elementwise_op_strategy_follow=True)
_set_multi_subgraphs()
return rank_id, device_num
def run_train_pipeline(args_opt):
r"""The main training process in pipeline."""
# Set hccl connect time
......@@ -360,22 +378,10 @@ def run_train_pipeline(args_opt):
context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(variable_memory_max_size="30GB")
rank_id = int(os.getenv("RANK_ID"))
device_num = 1
if args_opt.distribute == "true":
D.init()
device_num = D.get_group_size()
rank_id = D.get_rank()
print("rank_id is {}, device_num is {}".format(rank_id, device_num))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
full_batch=bool(args_opt.full_batch), loss_repeated_mean=True,
device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard),
pipeline_stages=args_opt.stage_num)
set_algo_parameters(elementwise_op_strategy_follow=True)
_set_multi_subgraphs()
else:
rank_id = int(os.getenv("RANK_ID"))
device_num = 1
rank_id, device_num = set_pipeline_parallel_context(args_opt)
# copy data from the cloud to the /cache/Data
cache_url = '/cache/Data/'
eval_cache_url = '/cache/EvalData/'
......@@ -392,7 +398,7 @@ def run_train_pipeline(args_opt):
raise ValueError("The dp must large than 1 when applying optimizer shard.")
per_batch_size = args_opt.per_batch_size
batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
micro_batch_interleaved = args_opt.micro_batch_interleaved
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
model_parallel=model_parallel_num,
pipeline_stage=args_opt.stage_num,
......@@ -400,7 +406,7 @@ def run_train_pipeline(args_opt):
optimizer_shard=bool(args_opt.optimizer_shard),
vocab_emb_dp=bool(args_opt.word_emb_dp),
recompute=True)
config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num,
config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num // micro_batch_interleaved,
num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size,
seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size,
num_layers=args_opt.num_layers, ffn_hidden_size=args_opt.embedding_size * 4,
......@@ -411,7 +417,8 @@ def run_train_pipeline(args_opt):
print("===config is: ", config, flush=True)
pangu_alpha = PanguAlphaModel(config=config)
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
pangu_alpha_with_loss_net = PipelineCell(PanGUAlphaWithLoss(config, pangu_alpha, loss),
pangu_alpha_with_loss_net = PipelineCell(MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss),
micro_batch_interleaved),
config.parallel_config.micro_batch_num)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
print("=====args_opt is: ", args_opt, flush=True)
......@@ -427,7 +434,8 @@ def run_train_pipeline(args_opt):
else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
ds = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=cache_url,
ds = create_dataset(config.batch_size * parallel_config.micro_batch_num * micro_batch_interleaved,
data_path=cache_url,
device_num=stage_device_num,
rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
full_batch=context.get_auto_parallel_context("full_batch"),
......@@ -443,7 +451,8 @@ def run_train_pipeline(args_opt):
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
if args_opt.train_and_eval_mode:
ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num, data_path=eval_cache_url,
ds_eval = create_dataset(config.batch_size * parallel_config.micro_batch_num * micro_batch_interleaved,
data_path=eval_cache_url,
device_num=stage_device_num, rank=rank_id % stage_device_num, eod_reset=True,
data_start_index=0, full_batch=bool(args_opt.full_batch),
column_name=args_opt.data_column_name,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment