diff --git a/official/nlp/pangu_alpha/src/callbacks.py b/official/nlp/pangu_alpha/src/callbacks.py index 24e7ca008ffbd7054e85b38a05add1c891065f68..448ef881170aef31365ed1072a80765fcdbf3e86 100644 --- a/official/nlp/pangu_alpha/src/callbacks.py +++ b/official/nlp/pangu_alpha/src/callbacks.py @@ -16,9 +16,9 @@ Callbacks """ - import time import math +import numpy as np from mindspore.train.callback import Callback from mindspore import context from mindspore.context import ParallelMode @@ -30,14 +30,16 @@ class LossCallBack(Callback): If the loss in NAN or INF terminating training. """ - def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0, micro_size=1): + def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0, micro_size=1, + is_last_stage=True): super(LossCallBack, self).__init__() self._dataset_size = dataset_size self.local_rank = local_rank self.has_trained_epoch = has_trained_epoch self.has_trained_step = has_trained_step self.micro_size = micro_size - print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True) + self.is_last_stage = is_last_stage + print("Load the trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True) def step_end(self, run_context): """ @@ -50,8 +52,11 @@ class LossCallBack(Callback): if percent == 0: epoch_num -= 1 date = time.asctime(time.localtime(time.time())) - loss_value = cb_params.net_outputs[0].asnumpy() / self.micro_size - print("time: {} local_rank: {}, epoch: {}, step: {}, output is {}, overflow is {}, scale is {}". + loss_value = 'no loss for this stage' + if self.is_last_stage: + loss_value = cb_params.net_outputs[0].asnumpy() / self.micro_size + loss_value = np.mean(loss_value) + print("time: {} local_rank: {}, epoch: {}, step: {}, loss is {}, overflow is {}, loss scale is {}". format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch), cb_params.cur_step_num + int(self.has_trained_step), loss_value, cb_params.net_outputs[1].asnumpy(), cb_params.net_outputs[2].asnumpy())) diff --git a/official/nlp/pangu_alpha/src/pangu_alpha.py b/official/nlp/pangu_alpha/src/pangu_alpha.py index b4f8442aa9453843e2a0779dc2533ecee2ae2488..a172be739a2f1669048524595b56f17be68b4fe1 100644 --- a/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -37,7 +37,7 @@ class EmbeddingLayer(nn.Cell): self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size, embedding_size=config.hidden_size, param_init=initializer("normal", [config.vocab_size, config.hidden_size], - dtype=config.param_init_type), + dtype=mstype.float32), parallel_config=config.parallel_config.embedding_dp_mp_config) copied_parallel_config = copy.deepcopy(config.parallel_config) copied_parallel_config.vocab_emb_dp = True @@ -45,7 +45,7 @@ class EmbeddingLayer(nn.Cell): embedding_size=config.hidden_size, param_init=initializer("normal", [config.seq_length, config.hidden_size], - dtype=config.param_init_type), + dtype=mstype.float32), parallel_config=copied_parallel_config.embedding_dp_mp_config) self.add = P.Add().shard( ((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1))) @@ -262,7 +262,7 @@ class PanguAlpha_Model(Cell): embedding_size=config.hidden_size, param_init=initializer("normal", [config.seq_length, config.hidden_size], - dtype=config.param_init_type), + dtype=mstype.float32), parallel_config=copied_parallel_config.embedding_dp_mp_config) self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1 if config.parallel_config.pipeline_stage > 1: diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py index a66613252463e6533e085fab41567cc0ab04a875..d36a52554ba86a91f89423065cd4ee3160300311 100644 --- a/official/nlp/pangu_alpha/train.py +++ b/official/nlp/pangu_alpha/train.py @@ -112,8 +112,7 @@ def run_train(args_opt): r"""The main training process.""" os.environ['HCCL_CONNECT_TIMEOUT'] = "6000" # Set execution mode - context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) - context.set_context(variable_memory_max_size="30GB") + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, variable_memory_max_size="30GB") # Set parallel context rank = 0 device_num = 1 @@ -177,16 +176,16 @@ def run_train(args_opt): param_init_type=config.param_init_type) else: optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95) - # Initial scaling sens - loss_scale_value = math.pow(2, 32) epoch_num = args_opt.epoch_size # Dataset loading mindrecord files ds = create_dataset(config.batch_size * micro_batch_interleaved, data_path=cache_url, data_start_index=0, eod_reset=config.eod_reset, full_batch=bool(args_opt.full_batch), eod_id=args_opt.eod_id, device_num=device_num, rank=rank, column_name=args_opt.data_column_name, epoch=epoch_num) - actual_epoch_num = int(epoch_num * ds.get_dataset_size() / args_opt.sink_size) - callback = [TimeMonitor(args_opt.sink_size), LossCallBack(args_opt.sink_size, rank, 0, 0)] - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) + step_per_epoch = ds.get_dataset_size() + actual_epoch_num = int(epoch_num * step_per_epoch / args_opt.sink_size) + loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved) + callback = [TimeMonitor(args_opt.sink_size), loss_callback] + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=math.pow(2, 32), scale_factor=2, scale_window=1000) pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell( pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True, config=config) @@ -207,9 +206,8 @@ def run_train(args_opt): if not flag: restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num) - - callback = [TimeMonitor(args_opt.sink_size), LossCallBack(args_opt.sink_size, rank, args_opt.has_trained_epoches, - args_opt.has_trained_steps)] + loss_callback.has_trained_epoch = args_opt.has_trained_epoch + loss_callback.has_trained_step = args_opt.has_trained_step add_checkpoint_callback_policy(args_opt, callback, rank) if args_opt.incremental_training: strategy = model.infer_train_layout(train_dataset=ds, sink_size=args_opt.sink_size) @@ -219,7 +217,7 @@ def run_train(args_opt): range(0, 512)] print(f"Loading from path {ckpt_file_list[0]}", flush=True) load_distributed_checkpoint(model.train_network, ckpt_file_list, strategy) - print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True) + print("Dataset size: {}, actual_epoch_num: {}".format(step_per_epoch, actual_epoch_num), flush=True) model.train(actual_epoch_num, ds, callbacks=callback, sink_size=args_opt.sink_size, dataset_sink_mode=True) @@ -417,8 +415,9 @@ def run_train_pipeline(args_opt): model_parallel_num = args_opt.op_level_model_parallel_num stage_device_num = int(device_num / args_opt.stage_num) data_parallel_num = int(stage_device_num / model_parallel_num) + is_last_stage = (rank_id // stage_device_num) == args_opt.stage_num -1 if data_parallel_num <= 1 and args_opt.optimizer_shard == 1: - raise ValueError("The dp must large than 1 when applying optimizer shard.") + raise ValueError("The data parallel number must be larger than 1 when applying optimizer shard.") per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num * args_opt.micro_size micro_batch_interleaved = args_opt.micro_batch_interleaved @@ -435,14 +434,14 @@ def run_train_pipeline(args_opt): param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16, enable_offload=bool(args_opt.opt_offload), parallel_config=parallel_config) - print("===config is: ", config, flush=True) + print("[Configure] is: ", config, flush=True) pangu_alpha = PanguAlphaModel(config=config) loss = CrossEntropyLoss(config.parallel_config.dp_mp_config) pangu_alpha_with_loss_net = PipelineCell(MicroBatchInterleaved(PanGUAlphaWithLoss(config, pangu_alpha, loss), micro_batch_interleaved), config.parallel_config.micro_batch_num) pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss_net) - print("=====args_opt is: ", args_opt, flush=True) + print("[args_opt] is: ", args_opt, flush=True) lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr, warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps) params = pangu_alpha.infer_param_pipeline_stage() @@ -464,8 +463,9 @@ def run_train_pipeline(args_opt): step_per_epoch = ds.get_dataset_size() callback_size = args_opt.sink_size actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) - callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, - micro_size=parallel_config.micro_batch_num)] + loss_callback = LossCallBack(step_per_epoch, rank_id, is_last_stage=is_last_stage, + micro_size=parallel_config.micro_batch_num * micro_batch_interleaved) + callback = [TimeMonitor(callback_size), loss_callback] loss_scale_value = math.pow(2, 32) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell( @@ -492,8 +492,8 @@ def run_train_pipeline(args_opt): if not flag: restore_checkpoint(args_opt, callback_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num) - callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, args_opt.has_trained_epoches, - args_opt.has_trained_steps)] + loss_callback.has_trained_epoch = args_opt.has_trained_epoch + loss_callback.has_trained_step = args_opt.has_trained_step add_checkpoint_callback_policy(args_opt, callback, rank_id) model.train(actual_epoch_num, ds, callbacks=callback,