Skip to content
Snippets Groups Projects
Commit 0bd98a87 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1103 [模型训练] 支持master分支pangu_alpha 临终遗言样例

Merge pull request !1103 from Atlas_hrp/master
parents 87109b0a 1bc2bda0
No related branches found
No related tags found
No related merge requests found
......@@ -17,9 +17,11 @@ PanguAlpha train script
"""
import datetime
import json
import glob
import os
import math
from mindspore import context
from mindspore.train.model import Model
import mindspore.communication.management as D
......@@ -191,7 +193,11 @@ def run_train(args_opt):
else:
model = Model(pangu_alpha_with_grads)
if args_opt.pre_trained:
restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num)
flag = restore_exception_checkpoint(args_opt, args_opt.sink_size, ds, model,
pangu_alpha_with_grads, epoch=actual_epoch_num)
if not flag:
restore_checkpoint(args_opt, args_opt.sink_size, ds, model,
pangu_alpha_with_grads, epoch=actual_epoch_num)
callback = [TimeMonitor(args_opt.sink_size), LossCallBack(args_opt.sink_size, rank, args_opt.has_trained_epoches,
args_opt.has_trained_steps)]
......@@ -236,6 +242,117 @@ def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
load_param_into_net(network, param_dict)
def get_exception_checkpoints(args_param):
r"""
Load checkpoint process.
"""
print("======start exception checkpoint", flush=True)
restore_ranks = os.getenv("RESTORE_RANKS")
if not restore_ranks:
return None
restore_rank_list = list(map(int, restore_ranks.split(",")))
ckpt_file_list = []
ckpt_name = args_param.ckpt_name_prefix
for ckpt_rank in restore_rank_list:
ckpt_pattern = os.path.join(args_param.save_checkpoint_path,
f"rank_{ckpt_rank}",
f"{ckpt_name}*_breakpoint.ckpt")
ckpt_files = glob.glob(ckpt_pattern)
if not ckpt_files:
print(
f"There is no ckpt file in {args_param.save_checkpoint_path}, "
f"current ckpt_files found is {ckpt_files} "
f"with pattern {ckpt_pattern}, so skip the loading.")
return None
ckpt_files.sort(key=os.path.getmtime, reverse=True)
ckpt_file_list.append(ckpt_files[0])
print(f"checkpoint file {ckpt_file_list}")
return ckpt_file_list
def check_exception_checkpoints(ckpt_file_list):
"""
Check exception checkpoints size.
Args:
ckpt_file_list: exception checkpoints
Returns: result of exception checkpoints size check.
"""
ckpt_size_list = []
for ckpt_file in ckpt_file_list:
ckpt_size_list.append(os.path.getsize(ckpt_file))
if len(set(ckpt_size_list)) > 1:
return False
return True
def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, epoch):
"""
Restore exception checkpoint to training model.
Args:
args_param: model training parameters
sink_size: model training sink size
dataset: dataset used for training
model: model
network: pangu_alpha network
epoch: training epoch
Returns: load exception checkpont success or not.
"""
if os.getenv("RESTORE_RANKS") == "-1":
return False
ckpt_file_list = get_exception_checkpoints(args_param)
restore_flag = False
if ckpt_file_list:
restore_flag = check_exception_checkpoints(ckpt_file_list)
if not restore_flag:
return False
ckpt_name = args_param.ckpt_name_prefix
restore_ranks_map = os.getenv("RESTORE_RANKS_MAP")
if not restore_ranks_map:
return False
try:
print("whether run into load process")
restore_ranks_map_json = json.loads(restore_ranks_map)
map_rank_id = D.get_rank()
for key in restore_ranks_map_json.keys():
if str(D.get_rank()) in key:
map_rank_id = restore_ranks_map_json.get(key)
print(f"loading map rank id {map_rank_id}")
ckpt_pattern = os.path.join(args_param.save_checkpoint_path,
f"rank_{map_rank_id}",
f"{ckpt_name}*breakpoint.ckpt")
ckpt_files = glob.glob(ckpt_pattern)
ckpt_files.sort(key=os.path.getmtime, reverse=True)
print(f" checkpoint files {ckpt_files[0]}")
param_dict = load_checkpoint(ckpt_files[0])
print(f" checkpoint param dict epoch num {param_dict.get('epoch_num')}")
if param_dict.get("epoch_num") and param_dict.get("step_num"):
args_param.has_trained_epoches = int(
param_dict["epoch_num"].data.asnumpy())
args_param.has_trained_steps = int(
param_dict["step_num"].data.asnumpy())
# Load checkpoint files
model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
load_param_into_net(network, param_dict)
except TypeError:
return False
else:
return True
def run_train_pipeline(args_opt):
r"""The main training process in pipeline."""
# Set hccl connect time
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment