Skip to content
Snippets Groups Projects
Commit 21d0c322 authored by huangxinjing's avatar huangxinjing
Browse files

load_ckpt_skip

parent b77e8eb5
No related branches found
No related tags found
No related merge requests found
...@@ -104,7 +104,7 @@ As the format of the downstream tasks can be various, the `preprocess.py` provid ...@@ -104,7 +104,7 @@ As the format of the downstream tasks can be various, the `preprocess.py` provid
Suppose the text data is under the `./data` and **each text file ends with 'txt'**, we can run the following command to generate the mindrecord files with seq_length=1025. Suppose the text data is under the `./data` and **each text file ends with 'txt'**, we can run the following command to generate the mindrecord files with seq_length=1025.
```bash ```bash
python -m src.preprocess --input_glob data/*.txt --tokenizer gpt --eot 50256 --data_column_name input_ids --seq_length 1025 python -m src.preprocess --input_glob 'data/*.txt' --tokenizer gpt --eot 50256 --data_column_name input_ids --seq_length 1025
``` ```
The script will chunk the each line with 1025 tokens. For the chunk with no more 1025 tokens, the chunk will be ignored. The script will chunk the each line with 1025 tokens. For the chunk with no more 1025 tokens, the chunk will be ignored.
......
...@@ -215,9 +215,10 @@ def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch): ...@@ -215,9 +215,10 @@ def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
f"{ckpt_name}*.ckpt") f"{ckpt_name}*.ckpt")
ckpt_files = glob.glob(ckpt_pattern) ckpt_files = glob.glob(ckpt_pattern)
if not ckpt_files: if not ckpt_files:
raise ValueError(f"There is no ckpt file in {args_param.load_ckpt_path}, " print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
f"pre_trained is unsupported, current ckpt_files found is {ckpt_files} " f"current ckpt_files found is {ckpt_files} "
f"with pattern {ckpt_pattern}") f"with pattern {ckpt_pattern}, so skip the loading.")
return
ckpt_files.sort(key=os.path.getmtime, reverse=True) ckpt_files.sort(key=os.path.getmtime, reverse=True)
time_stamp = datetime.datetime.now() time_stamp = datetime.datetime.now()
print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading", print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment