diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py
index e119ff9066a48e8b20f1efde32d4959c7f7488bc..847ea9b657a501f0bfc5a340be51773eac49d248 100644
--- a/official/nlp/pangu_alpha/train.py
+++ b/official/nlp/pangu_alpha/train.py
@@ -17,9 +17,11 @@ PanguAlpha train script
 """
 
 import datetime
+import json
 import glob
 import os
 import math
+
 from mindspore import context
 from mindspore.train.model import Model
 import mindspore.communication.management as D
@@ -191,7 +193,11 @@ def run_train(args_opt):
     else:
         model = Model(pangu_alpha_with_grads)
     if args_opt.pre_trained:
-        restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num)
+        flag = restore_exception_checkpoint(args_opt, args_opt.sink_size, ds, model,
+                                            pangu_alpha_with_grads, epoch=actual_epoch_num)
+        if not flag:
+            restore_checkpoint(args_opt, args_opt.sink_size, ds, model,
+                               pangu_alpha_with_grads, epoch=actual_epoch_num)
 
     callback = [TimeMonitor(args_opt.sink_size), LossCallBack(args_opt.sink_size, rank, args_opt.has_trained_epoches,
                                                               args_opt.has_trained_steps)]
@@ -236,6 +242,117 @@ def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
     load_param_into_net(network, param_dict)
 
 
+def get_exception_checkpoints(args_param):
+    r"""
+    Load checkpoint process.
+    """
+    print("======start exception checkpoint", flush=True)
+    restore_ranks = os.getenv("RESTORE_RANKS")
+    if not restore_ranks:
+        return None
+
+    restore_rank_list = list(map(int, restore_ranks.split(",")))
+    ckpt_file_list = []
+    ckpt_name = args_param.ckpt_name_prefix
+    for ckpt_rank in restore_rank_list:
+        ckpt_pattern = os.path.join(args_param.save_checkpoint_path,
+                                    f"rank_{ckpt_rank}",
+                                    f"{ckpt_name}*_breakpoint.ckpt")
+        ckpt_files = glob.glob(ckpt_pattern)
+        if not ckpt_files:
+            print(
+                f"There is no ckpt file in {args_param.save_checkpoint_path}, "
+                f"current ckpt_files found is {ckpt_files} "
+                f"with pattern {ckpt_pattern}, so skip the loading.")
+            return None
+        ckpt_files.sort(key=os.path.getmtime, reverse=True)
+        ckpt_file_list.append(ckpt_files[0])
+    print(f"checkpoint file {ckpt_file_list}")
+    return ckpt_file_list
+
+
+def check_exception_checkpoints(ckpt_file_list):
+    """
+    Check exception checkpoints size.
+    Args:
+        ckpt_file_list: exception checkpoints
+
+    Returns: result of exception checkpoints size check.
+
+    """
+    ckpt_size_list = []
+    for ckpt_file in ckpt_file_list:
+        ckpt_size_list.append(os.path.getsize(ckpt_file))
+
+    if len(set(ckpt_size_list)) > 1:
+        return False
+
+    return True
+
+
+def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, epoch):
+    """
+    Restore exception checkpoint to training model.
+    Args:
+        args_param: model training parameters
+        sink_size: model training sink size
+        dataset: dataset used for training
+        model: model
+        network: pangu_alpha network
+        epoch: training epoch
+
+    Returns: load exception checkpont success or not.
+
+    """
+    if os.getenv("RESTORE_RANKS") == "-1":
+        return False
+
+    ckpt_file_list = get_exception_checkpoints(args_param)
+
+    restore_flag = False
+    if ckpt_file_list:
+        restore_flag = check_exception_checkpoints(ckpt_file_list)
+
+    if not restore_flag:
+        return False
+
+    ckpt_name = args_param.ckpt_name_prefix
+    restore_ranks_map = os.getenv("RESTORE_RANKS_MAP")
+    if not restore_ranks_map:
+        return False
+
+    try:
+        print("whether run into load process")
+        restore_ranks_map_json = json.loads(restore_ranks_map)
+        map_rank_id = D.get_rank()
+        for key in restore_ranks_map_json.keys():
+            if str(D.get_rank()) in key:
+                map_rank_id = restore_ranks_map_json.get(key)
+
+        print(f"loading map rank id {map_rank_id}")
+        ckpt_pattern = os.path.join(args_param.save_checkpoint_path,
+                                    f"rank_{map_rank_id}",
+                                    f"{ckpt_name}*breakpoint.ckpt")
+        ckpt_files = glob.glob(ckpt_pattern)
+        ckpt_files.sort(key=os.path.getmtime, reverse=True)
+        print(f" checkpoint files {ckpt_files[0]}")
+        param_dict = load_checkpoint(ckpt_files[0])
+        print(f" checkpoint param dict epoch num {param_dict.get('epoch_num')}")
+        if param_dict.get("epoch_num") and param_dict.get("step_num"):
+            args_param.has_trained_epoches = int(
+                param_dict["epoch_num"].data.asnumpy())
+            args_param.has_trained_steps = int(
+                param_dict["step_num"].data.asnumpy())
+
+        # Load checkpoint files
+        model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
+        load_param_into_net(network, param_dict)
+    except TypeError:
+        return False
+    else:
+        return True
+
+
 def run_train_pipeline(args_opt):
     r"""The main training process in pipeline."""
     # Set hccl connect time