From 21d0c32283e4718cd89e16fc12613ec5c20784f5 Mon Sep 17 00:00:00 2001
From: huangxinjing <huangxinjing@huawei.com>
Date: Thu, 14 Oct 2021 09:30:34 +0800
Subject: [PATCH] load_ckpt_skip

---
 official/nlp/pangu_alpha/README.md | 2 +-
 official/nlp/pangu_alpha/train.py  | 7 ++++---
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/official/nlp/pangu_alpha/README.md b/official/nlp/pangu_alpha/README.md
index fd27204df..593f6ed08 100644
--- a/official/nlp/pangu_alpha/README.md
+++ b/official/nlp/pangu_alpha/README.md
@@ -104,7 +104,7 @@ As the format of the downstream tasks can be various, the `preprocess.py` provid
 Suppose the text data is under the `./data` and **each text file ends with 'txt'**, we can run the following command to generate the mindrecord files with seq_length=1025.
 
 ```bash
-python -m src.preprocess --input_glob  data/*.txt --tokenizer gpt --eot 50256 --data_column_name input_ids --seq_length 1025
+python -m src.preprocess --input_glob  'data/*.txt' --tokenizer gpt --eot 50256 --data_column_name input_ids --seq_length 1025
 ```
 
 The script will chunk the each line with 1025 tokens. For the chunk with no more 1025 tokens, the chunk will be ignored.
diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py
index a70ad082a..ec957db17 100644
--- a/official/nlp/pangu_alpha/train.py
+++ b/official/nlp/pangu_alpha/train.py
@@ -215,9 +215,10 @@ def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
                                 f"{ckpt_name}*.ckpt")
     ckpt_files = glob.glob(ckpt_pattern)
     if not ckpt_files:
-        raise ValueError(f"There is no ckpt file in {args_param.load_ckpt_path}, "
-                         f"pre_trained is unsupported, current ckpt_files found is {ckpt_files} "
-                         f"with pattern {ckpt_pattern}")
+        print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
+              f"current ckpt_files found is {ckpt_files} "
+              f"with pattern {ckpt_pattern}, so skip the loading.")
+        return
     ckpt_files.sort(key=os.path.getmtime, reverse=True)
     time_stamp = datetime.datetime.now()
     print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading",
-- 
GitLab