Skip to content
Snippets Groups Projects
Commit 345f14a4 authored by lilei's avatar lilei
Browse files

modify pangu model

parent d186202c
No related branches found
No related tags found
No related merge requests found
......@@ -394,9 +394,6 @@ def add_training_params(opt):
opt.add_argument("--data_column_name",
type=str, default="input_ids",
help="Column name of datasets")
opt.add_argument("--micro_batch_interleaved",
type=int, default=2,
help="Parallel split num of batch size. default 2")
......
......@@ -30,7 +30,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.common.dtype as mstype
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, MicroBatchInterleaved
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
from mindspore.parallel.nn import TransformerOpParallelConfig, CrossEntropyLoss
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint, load_param_into_net
......@@ -103,6 +103,7 @@ def set_parallel_context(args_opt):
enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt')
set_algo_parameters(elementwise_op_strategy_follow=True)
_set_multi_subgraphs()
return rank, device_num
def run_train(args_opt):
......@@ -115,7 +116,7 @@ def run_train(args_opt):
rank = 0
device_num = 1
if args_opt.distribute == "true":
set_parallel_context(args_opt)
rank, device_num = set_parallel_context(args_opt)
context.set_context(save_graphs=False, save_graphs_path="./graphs_of_device_id_" + str(rank))
# copy data from the cloud to the /cache/Data
cache_url = '/cache/Data/'
......@@ -130,14 +131,13 @@ def run_train(args_opt):
model_parallel_num = args_opt.op_level_model_parallel_num
data_parallel_num = int(device_num / model_parallel_num)
batch_size = args_opt.per_batch_size * data_parallel_num
micro_batch_interleaved = args_opt.micro_batch_interleaved
parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num, model_parallel=model_parallel_num,
pipeline_stage=args_opt.stage_num,
micro_batch_num=args_opt.micro_size,
optimizer_shard=bool(args_opt.optimizer_shard),
vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True,
gradient_aggregation_group=args_opt.gradient_aggregation_group)
config = PanguAlphaConfig(batch_size=batch_size // micro_batch_interleaved, num_heads=args_opt.num_heads,
config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads,
hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size, num_layers=args_opt.num_layers,
ffn_hidden_size=args_opt.embedding_size * 4, eod_token=bool(args_opt.eod_reset),
......@@ -151,8 +151,7 @@ def run_train(args_opt):
# Define network
pangu_alpha = PanguAlphaModel(config=config)
loss = CrossEntropyLoss(config.parallel_config.dp_mp_config)
pangu_alpha_with_loss_net = MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss),
micro_batch_interleaved)
pangu_alpha_with_loss_net = PanGUAlphaWithLoss(config, pangu_alpha, loss)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net)
print("=====args_opt is: ", args_opt, flush=True)
# Warm-up and cosine decay learning rate
......@@ -171,7 +170,7 @@ def run_train(args_opt):
loss_scale_value = math.pow(2, 32)
epoch_num = args_opt.epoch_size
# Dataset loading mindrecord files
ds = create_dataset(config.batch_size * micro_batch_interleaved, data_path=cache_url, data_start_index=0,
ds = create_dataset(config.batch_size, data_path=cache_url, data_start_index=0,
eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id,
device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num)
actual_epoch_num = int(epoch_num * ds.get_dataset_size() / args_opt.sink_size)
......@@ -181,7 +180,7 @@ def run_train(args_opt):
pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True,
config=config)
if args_opt.train_and_eval_mode:
ds_eval = create_dataset(config.batch_size * micro_batch_interleaved, data_path=eval_cache_url,
ds_eval = create_dataset(config.batch_size, data_path=eval_cache_url,
data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch),
eod_id=args_opt.eod_id, device_num=device_num, rank=rank,
column_name=args_opt.data_column_name, epoch=epoch_num,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment