diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py index e119ff9066a48e8b20f1efde32d4959c7f7488bc..847ea9b657a501f0bfc5a340be51773eac49d248 100644 --- a/official/nlp/pangu_alpha/train.py +++ b/official/nlp/pangu_alpha/train.py @@ -17,9 +17,11 @@ PanguAlpha train script """ import datetime +import json import glob import os import math + from mindspore import context from mindspore.train.model import Model import mindspore.communication.management as D @@ -191,7 +193,11 @@ def run_train(args_opt): else: model = Model(pangu_alpha_with_grads) if args_opt.pre_trained: - restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num) + flag = restore_exception_checkpoint(args_opt, args_opt.sink_size, ds, model, + pangu_alpha_with_grads, epoch=actual_epoch_num) + if not flag: + restore_checkpoint(args_opt, args_opt.sink_size, ds, model, + pangu_alpha_with_grads, epoch=actual_epoch_num) callback = [TimeMonitor(args_opt.sink_size), LossCallBack(args_opt.sink_size, rank, args_opt.has_trained_epoches, args_opt.has_trained_steps)] @@ -236,6 +242,117 @@ def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch): load_param_into_net(network, param_dict) +def get_exception_checkpoints(args_param): + r""" + Load checkpoint process. + """ + print("======start exception checkpoint", flush=True) + restore_ranks = os.getenv("RESTORE_RANKS") + if not restore_ranks: + return None + + restore_rank_list = list(map(int, restore_ranks.split(","))) + ckpt_file_list = [] + ckpt_name = args_param.ckpt_name_prefix + for ckpt_rank in restore_rank_list: + ckpt_pattern = os.path.join(args_param.save_checkpoint_path, + f"rank_{ckpt_rank}", + f"{ckpt_name}*_breakpoint.ckpt") + ckpt_files = glob.glob(ckpt_pattern) + if not ckpt_files: + print( + f"There is no ckpt file in {args_param.save_checkpoint_path}, " + f"current ckpt_files found is {ckpt_files} " + f"with pattern {ckpt_pattern}, so skip the loading.") + return None + ckpt_files.sort(key=os.path.getmtime, reverse=True) + ckpt_file_list.append(ckpt_files[0]) + print(f"checkpoint file {ckpt_file_list}") + return ckpt_file_list + + +def check_exception_checkpoints(ckpt_file_list): + """ + Check exception checkpoints size. + Args: + ckpt_file_list: exception checkpoints + + Returns: result of exception checkpoints size check. + + """ + ckpt_size_list = [] + for ckpt_file in ckpt_file_list: + ckpt_size_list.append(os.path.getsize(ckpt_file)) + + if len(set(ckpt_size_list)) > 1: + return False + + return True + + +def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, epoch): + """ + Restore exception checkpoint to training model. + Args: + args_param: model training parameters + sink_size: model training sink size + dataset: dataset used for training + model: model + network: pangu_alpha network + epoch: training epoch + + Returns: load exception checkpont success or not. + + """ + if os.getenv("RESTORE_RANKS") == "-1": + return False + + ckpt_file_list = get_exception_checkpoints(args_param) + + restore_flag = False + if ckpt_file_list: + restore_flag = check_exception_checkpoints(ckpt_file_list) + + if not restore_flag: + return False + + ckpt_name = args_param.ckpt_name_prefix + restore_ranks_map = os.getenv("RESTORE_RANKS_MAP") + if not restore_ranks_map: + return False + + try: + print("whether run into load process") + restore_ranks_map_json = json.loads(restore_ranks_map) + map_rank_id = D.get_rank() + for key in restore_ranks_map_json.keys(): + if str(D.get_rank()) in key: + map_rank_id = restore_ranks_map_json.get(key) + + print(f"loading map rank id {map_rank_id}") + ckpt_pattern = os.path.join(args_param.save_checkpoint_path, + f"rank_{map_rank_id}", + f"{ckpt_name}*breakpoint.ckpt") + ckpt_files = glob.glob(ckpt_pattern) + ckpt_files.sort(key=os.path.getmtime, reverse=True) + print(f" checkpoint files {ckpt_files[0]}") + param_dict = load_checkpoint(ckpt_files[0]) + print(f" checkpoint param dict epoch num {param_dict.get('epoch_num')}") + if param_dict.get("epoch_num") and param_dict.get("step_num"): + args_param.has_trained_epoches = int( + param_dict["epoch_num"].data.asnumpy()) + args_param.has_trained_steps = int( + param_dict["step_num"].data.asnumpy()) + + # Load checkpoint files + model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch) + load_param_into_net(network, param_dict) + except TypeError: + return False + else: + return True + + def run_train_pipeline(args_opt): r"""The main training process in pipeline.""" # Set hccl connect time