From 21d0c32283e4718cd89e16fc12613ec5c20784f5 Mon Sep 17 00:00:00 2001 From: huangxinjing <huangxinjing@huawei.com> Date: Thu, 14 Oct 2021 09:30:34 +0800 Subject: [PATCH] load_ckpt_skip --- official/nlp/pangu_alpha/README.md | 2 +- official/nlp/pangu_alpha/train.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/official/nlp/pangu_alpha/README.md b/official/nlp/pangu_alpha/README.md index fd27204df..593f6ed08 100644 --- a/official/nlp/pangu_alpha/README.md +++ b/official/nlp/pangu_alpha/README.md @@ -104,7 +104,7 @@ As the format of the downstream tasks can be various, the `preprocess.py` provid Suppose the text data is under the `./data` and **each text file ends with 'txt'**, we can run the following command to generate the mindrecord files with seq_length=1025. ```bash -python -m src.preprocess --input_glob data/*.txt --tokenizer gpt --eot 50256 --data_column_name input_ids --seq_length 1025 +python -m src.preprocess --input_glob 'data/*.txt' --tokenizer gpt --eot 50256 --data_column_name input_ids --seq_length 1025 ``` The script will chunk the each line with 1025 tokens. For the chunk with no more 1025 tokens, the chunk will be ignored. diff --git a/official/nlp/pangu_alpha/train.py b/official/nlp/pangu_alpha/train.py index a70ad082a..ec957db17 100644 --- a/official/nlp/pangu_alpha/train.py +++ b/official/nlp/pangu_alpha/train.py @@ -215,9 +215,10 @@ def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch): f"{ckpt_name}*.ckpt") ckpt_files = glob.glob(ckpt_pattern) if not ckpt_files: - raise ValueError(f"There is no ckpt file in {args_param.load_ckpt_path}, " - f"pre_trained is unsupported, current ckpt_files found is {ckpt_files} " - f"with pattern {ckpt_pattern}") + print(f"There is no ckpt file in {args_param.save_checkpoint_path}, " + f"current ckpt_files found is {ckpt_files} " + f"with pattern {ckpt_pattern}, so skip the loading.") + return ckpt_files.sort(key=os.path.getmtime, reverse=True) time_stamp = datetime.datetime.now() print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading", -- GitLab