Skip to content
Snippets Groups Projects
Unverified Commit f7d805d2 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3739 fix transformer network default config yaml file selection

Merge pull request !3739 from xubangduo/fixI5P7FG
parents 36073fef 41e7ddb2
No related branches found
No related tags found
No related merge requests found
...@@ -98,12 +98,12 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P ...@@ -98,12 +98,12 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P
```bash ```bash
# Train 8p with Ascend # Train 8p with Ascend
# (1) Perform a or b. # (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file. # a. Set "enable_modelarts=True" on default_config_large.yaml file.
# Set "distribute=True" on default_config.yaml file. # Set "distribute=True" on default_config_large.yaml file.
# Set "dataset_path='/cache/data'" on default_config.yaml file. # Set "dataset_path='/cache/data'" on default_config_large.yaml file.
# Set "epoch_size: 52" on default_config.yaml file. # Set "epoch_size: 52" on default_config_large.yaml file.
# (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config.yaml file. # (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config_large.yaml file.
# Set other parameters on default_config.yaml file you need. # Set other parameters on default_config_large.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface. # b. Add "enable_modelarts=True" on the website UI interface.
# Add "distribute=True" on the website UI interface. # Add "distribute=True" on the website UI interface.
# Add "dataset_path=/cache/data" on the website UI interface. # Add "dataset_path=/cache/data" on the website UI interface.
...@@ -124,11 +124,11 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P ...@@ -124,11 +124,11 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P
# #
# Train 1p with Ascend # Train 1p with Ascend
# (1) Perform a or b. # (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file. # a. Set "enable_modelarts=True" on default_config_large.yaml file.
# Set "dataset_path='/cache/data'" on default_config.yaml file. # Set "dataset_path='/cache/data'" on default_config_large.yaml file.
# Set "epoch_size: 52" on default_config.yaml file. # Set "epoch_size: 52" on default_config_large.yaml file.
# (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config.yaml file. # (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config_large.yaml file.
# Set other parameters on default_config.yaml file you need. # Set other parameters on default_config_large.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface. # b. Add "enable_modelarts=True" on the website UI interface.
# Add "dataset_path='/cache/data'" on the website UI interface. # Add "dataset_path='/cache/data'" on the website UI interface.
# Add "epoch_size: 52" on the website UI interface. # Add "epoch_size: 52" on the website UI interface.
...@@ -148,11 +148,11 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P ...@@ -148,11 +148,11 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P
# #
# Eval 1p with Ascend # Eval 1p with Ascend
# (1) Perform a or b. # (1) Perform a or b.
# a. Set "enable_modelarts=True" on default_config.yaml file. # a. Set "enable_modelarts=True" on default_config_large.yaml file.
# Set "checkpoint_url='s3://dir_to_your_trained_model/'" on base_config.yaml file. # Set "checkpoint_url='s3://dir_to_your_trained_model/'" on base_config.yaml file.
# Set "checkpoint='./transformer/transformer_trained.ckpt'" on default_config.yaml file. # Set "checkpoint='./transformer/transformer_trained.ckpt'" on default_config_large.yaml file.
# Set "dataset_path='/cache/data'" on default_config.yaml file. # Set "dataset_path='/cache/data'" on default_config_large.yaml file.
# Set other parameters on default_config.yaml file you need. # Set other parameters on default_config_large.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface. # b. Add "enable_modelarts=True" on the website UI interface.
# Add "checkpoint_url='s3://dir_to_your_trained_model/'" on the website UI interface. # Add "checkpoint_url='s3://dir_to_your_trained_model/'" on the website UI interface.
# Add "checkpoint='./transformer/transformer_trained.ckpt'" on the website UI interface. # Add "checkpoint='./transformer/transformer_trained.ckpt'" on the website UI interface.
...@@ -284,7 +284,7 @@ options: ...@@ -284,7 +284,7 @@ options:
#### Running Options #### Running Options
```text ```text
default_config.yaml: default_config_large.yaml:
transformer_network version of Transformer model: base | large, default is large transformer_network version of Transformer model: base | large, default is large
init_loss_scale_value initial value of loss scale: N, default is 2^10 init_loss_scale_value initial value of loss scale: N, default is 2^10
scale_factor factor used to update loss scale: N, default is 2 scale_factor factor used to update loss scale: N, default is 2
...@@ -354,7 +354,7 @@ Parameters for learning rate: ...@@ -354,7 +354,7 @@ Parameters for learning rate:
## [Training Process](#contents) ## [Training Process](#contents)
- Set options in `default_config.yaml`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorials/en/master/advanced/dataset.html) for more information about dataset. - Set options in `default_config_large.yaml`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorials/en/master/advanced/dataset.html) for more information about dataset.
- Run `run_standalone_train.sh` for non-distributed training of Transformer model. - Run `run_standalone_train.sh` for non-distributed training of Transformer model.
...@@ -402,7 +402,7 @@ Parameters for learning rate: ...@@ -402,7 +402,7 @@ Parameters for learning rate:
- Export your model to ONNX: - Export your model to ONNX:
```bash ```bash
python export.py --device_target GPU --config default_config.yaml --model_file /path/to/transformer.ckpt --file_name /path/to/transformer.onnx --file_format ONNX python export.py --device_target GPU --config default_config_large.yaml --model_file /path/to/transformer.ckpt --file_name /path/to/transformer.onnx --file_format ONNX
``` ```
- Run ONNX evaluation: - Run ONNX evaluation:
...@@ -459,7 +459,7 @@ Inference result is saved in current path, 'output_file' will generate in path s ...@@ -459,7 +459,7 @@ Inference result is saved in current path, 'output_file' will generate in path s
#### Training Performance #### Training Performance
| Parameters | Ascend | GPU | | Parameters | Ascend | GPU |
| -------------------------- | -------------------------------------------| --------------------------------| | -------------------------- |------------------------------------------------------------------------------------------------| --------------------------------|
| Resource | Ascend 910; OS Euler2.8 | GPU(Tesla V100 SXM2) | | Resource | Ascend 910; OS Euler2.8 | GPU(Tesla V100 SXM2) |
| uploaded Date | 07/05/2021 (month/day/year) | 12/21/2021 (month/day/year) | | uploaded Date | 07/05/2021 (month/day/year) | 12/21/2021 (month/day/year) |
| MindSpore Version | 1.3.0 | 1.5.0 | | MindSpore Version | 1.3.0 | 1.5.0 |
...@@ -473,11 +473,12 @@ Inference result is saved in current path, 'output_file' will generate in path s ...@@ -473,11 +473,12 @@ Inference result is saved in current path, 'output_file' will generate in path s
| Params (M) | 213.7 | 213.7 | | Params (M) | 213.7 | 213.7 |
| Checkpoint for inference | 2.4G (.ckpt file) | 2.4G (.ckpt file) | | Checkpoint for inference | 2.4G (.ckpt file) | 2.4G (.ckpt file) |
| Scripts | [Transformer scripts](https://gitee.com/mindspore/models/tree/master/official/nlp/transformer) | | Scripts | [Transformer scripts](https://gitee.com/mindspore/models/tree/master/official/nlp/transformer) |
| Model Version | large |large|
#### Evaluation Performance #### Evaluation Performance
| Parameters | Ascend | GPU | | Parameters | Ascend | GPU |
| ------------------- | --------------------------- | ----------------------------| | ----------------- | --------------------------- | --------------------------- |
| Resource | Ascend 910; OS Euler2.8 | GPU(Tesla V100 SXM2) | | Resource | Ascend 910; OS Euler2.8 | GPU(Tesla V100 SXM2) |
| Uploaded Date | 07/05/2021 (month/day/year) | 12/21/2021 (month/day/year) | | Uploaded Date | 07/05/2021 (month/day/year) | 12/21/2021 (month/day/year) |
| MindSpore Version | 1.3.0 | 1.5.0 | | MindSpore Version | 1.3.0 | 1.5.0 |
...@@ -485,6 +486,7 @@ Inference result is saved in current path, 'output_file' will generate in path s ...@@ -485,6 +486,7 @@ Inference result is saved in current path, 'output_file' will generate in path s
| batch_size | 1 | 1 | | batch_size | 1 | 1 |
| outputs | BLEU score | BLEU score | | outputs | BLEU score | BLEU score |
| Accuracy | BLEU=28.7 | BLEU=24.4 | | Accuracy | BLEU=28.7 | BLEU=24.4 |
| Model Version | large | large |
## [Description of Random Situation](#contents) ## [Description of Random Situation](#contents)
...@@ -494,7 +496,7 @@ There are three random situations: ...@@ -494,7 +496,7 @@ There are three random situations:
- Initialization of some model weights. - Initialization of some model weights.
- Dropout operations. - Dropout operations.
Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in default_config.yaml. Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in default_config_large.yaml.
## [ModelZoo Homepage](#contents) ## [ModelZoo Homepage](#contents)
......
...@@ -99,12 +99,12 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P ...@@ -99,12 +99,12 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P
```python ```python
# 在 ModelArts 上使用8卡训练 # 在 ModelArts 上使用8卡训练
# (1) 执行a或者b # (1) 执行a或者b
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True" # a. 在 default_config_large.yaml 文件中设置 "enable_modelarts=True"
# 在 default_config.yaml 文件中设置 "distribute=True" # 在 default_config_large.yaml 文件中设置 "distribute=True"
# 在 default_config.yaml 文件中设置 "dataset_path='/cache/data'" # 在 default_config_large.yaml 文件中设置 "dataset_path='/cache/data'"
# 在 default_config.yaml 文件中设置 "epoch_size: 52" # 在 default_config_large.yaml 文件中设置 "epoch_size: 52"
# (可选)在 default_config.yaml 文件中设置 "checkpoint_url='s3://dir_to_your_pretrained/'" # (可选)在 default_config_large.yaml 文件中设置 "checkpoint_url='s3://dir_to_your_pretrained/'"
# 在 default_config.yaml 文件中设置 其他参数 # 在 default_config_large.yaml 文件中设置 其他参数
# b. 在网页上设置 "enable_modelarts=True" # b. 在网页上设置 "enable_modelarts=True"
# 在网页上设置 "distribute=True" # 在网页上设置 "distribute=True"
# 在网页上设置 "dataset_path=/cache/data" # 在网页上设置 "dataset_path=/cache/data"
...@@ -125,11 +125,11 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P ...@@ -125,11 +125,11 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P
# #
# 在 ModelArts 上使用单卡训练 # 在 ModelArts 上使用单卡训练
# (1) 执行a或者b # (1) 执行a或者b
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True" # a. 在 default_config_large.yaml 文件中设置 "enable_modelarts=True"
# 在 default_config.yaml 文件中设置 "dataset_path='/cache/data'" # 在 default_config_large.yaml 文件中设置 "dataset_path='/cache/data'"
# 在 default_config.yaml 文件中设置 "epoch_size: 52" # 在 default_config_large.yaml 文件中设置 "epoch_size: 52"
# (可选)在 default_config.yaml 文件中设置 "checkpoint_url='s3://dir_to_your_pretrained/'" # (可选)在 default_config_large.yaml 文件中设置 "checkpoint_url='s3://dir_to_your_pretrained/'"
# 在 default_config.yaml 文件中设置 其他参数 # 在 default_config_large.yaml 文件中设置 其他参数
# b. 在网页上设置 "enable_modelarts=True" # b. 在网页上设置 "enable_modelarts=True"
# 在网页上设置 "dataset_path='/cache/data'" # 在网页上设置 "dataset_path='/cache/data'"
# 在网页上设置 "epoch_size: 52" # 在网页上设置 "epoch_size: 52"
...@@ -149,11 +149,11 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P ...@@ -149,11 +149,11 @@ bash scripts/run_eval.sh GPU [DEVICE_ID] [MINDRECORD_DATA] [CKPT_PATH] [CONFIG_P
# #
# 在 ModelArts 上使用单卡验证 # 在 ModelArts 上使用单卡验证
# (1) 执行a或者b # (1) 执行a或者b
# a. 在 default_config.yaml 文件中设置 "enable_modelarts=True" # a. 在 default_config_large.yaml 文件中设置 "enable_modelarts=True"
# 在 default_config.yaml 文件中设置 "checkpoint_url='s3://dir_to_your_trained_model/'" # 在 default_config_large.yaml 文件中设置 "checkpoint_url='s3://dir_to_your_trained_model/'"
# 在 default_config.yaml 文件中设置 "checkpoint='./transformer/transformer_trained.ckpt'" # 在 default_config_large.yaml 文件中设置 "checkpoint='./transformer/transformer_trained.ckpt'"
# 在 default_config.yaml 文件中设置 "dataset_path='/cache/data'" # 在 default_config_large.yaml 文件中设置 "dataset_path='/cache/data'"
# 在 default_config.yaml 文件中设置 其他参数 # 在 default_config_large.yaml 文件中设置 其他参数
# b. 在网页上设置 "enable_modelarts=True" # b. 在网页上设置 "enable_modelarts=True"
# 在网页上设置 "checkpoint_url='s3://dir_to_your_trained_model/'" # 在网页上设置 "checkpoint_url='s3://dir_to_your_trained_model/'"
# 在网页上设置 "checkpoint='./transformer/transformer_trained.ckpt'" # 在网页上设置 "checkpoint='./transformer/transformer_trained.ckpt'"
...@@ -285,7 +285,7 @@ options: ...@@ -285,7 +285,7 @@ options:
#### 运行选项 #### 运行选项
```text ```text
default_config.yaml: default_config_large.yaml:
transformer_network version of Transformer model: base | large, default is large transformer_network version of Transformer model: base | large, default is large
init_loss_scale_value initial value of loss scale: N, default is 2^10 init_loss_scale_value initial value of loss scale: N, default is 2^10
scale_factor factor used to update loss scale: N, default is 2 scale_factor factor used to update loss scale: N, default is 2
...@@ -356,7 +356,7 @@ Parameters for learning rate: ...@@ -356,7 +356,7 @@ Parameters for learning rate:
### 训练过程 ### 训练过程
-`default_config.yaml`中设置选项,包括loss_scale、学习率和网络超参数。点击[这里](https://www.mindspore.cn/tutorials/zh-CN/master/advanced/dataset.html)查看更多数据集信息。 -`default_config_large.yaml`中设置选项,包括loss_scale、学习率和网络超参数。点击[这里](https://www.mindspore.cn/tutorials/zh-CN/master/advanced/dataset.html)查看更多数据集信息。
- 运行`run_standalone_train.sh`,进行Transformer模型的单卡训练。 - 运行`run_standalone_train.sh`,进行Transformer模型的单卡训练。
...@@ -433,7 +433,7 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [DEVICE_ID] [CONFIG_PATH] ...@@ -433,7 +433,7 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [DEVICE_ID] [CONFIG_PATH]
#### 训练性能 #### 训练性能
| 参数 | Ascend | GPU | | 参数 | Ascend | GPU |
| -------------------------- | -------------------------------- | --------------------------------| | -------------------------- |-------------------------------------------------------------------------------------------| --------------------------------|
| 资源 | Ascend 910;系统 Euler2.8 | GPU(Tesla V100 SXM2) | | 资源 | Ascend 910;系统 Euler2.8 | GPU(Tesla V100 SXM2) |
| 上传日期 | 2021-07-05 | 2021-12-21 | | 上传日期 | 2021-07-05 | 2021-12-21 |
| MindSpore版本 | 1.3.0 | 1.5.0 | | MindSpore版本 | 1.3.0 | 1.5.0 |
...@@ -446,7 +446,8 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [DEVICE_ID] [CONFIG_PATH] ...@@ -446,7 +446,8 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [DEVICE_ID] [CONFIG_PATH]
| 损失 | 2.8 | 2.9 | | 损失 | 2.8 | 2.9 |
| 参数 (M) | 213.7 | 213.7 | | 参数 (M) | 213.7 | 213.7 |
| 推理检查点 | 2.4G (.ckpt文件) | 2.4G | | 推理检查点 | 2.4G (.ckpt文件) | 2.4G |
| 脚本 | <https://gitee.com/mindspore/models/tree/master/official/nlp/transformer> | | 脚本 | [Transformer 脚本](https://gitee.com/mindspore/models/tree/master/official/nlp/transformer) |
| 模型版本 | large |large|
#### 评估性能 #### 评估性能
...@@ -459,6 +460,7 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [DEVICE_ID] [CONFIG_PATH] ...@@ -459,6 +460,7 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [DEVICE_ID] [CONFIG_PATH]
| batch_size | 1 | 1 | | batch_size | 1 | 1 |
| 输出 | BLEU score | BLEU score | | 输出 | BLEU score | BLEU score |
| 准确率 | BLEU=28.7 | BLEU=24.4 | | 准确率 | BLEU=28.7 | BLEU=24.4 |
| 模型版本 | large | large |
## 随机情况说明 ## 随机情况说明
...@@ -468,7 +470,7 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [DEVICE_ID] [CONFIG_PATH] ...@@ -468,7 +470,7 @@ bash run_infer_310.sh [MINDIR_PATH] [NEED_PREPROCESS] [DEVICE_ID] [CONFIG_PATH]
- 初始化部分模型权重 - 初始化部分模型权重
- 随机失活运行 - 随机失活运行
train.py已经设置了一些种子,避免数据集轮换和权重初始化的随机性。若需关闭随机失活,将default_config.yaml中相应的dropout_prob参数设置为0。 train.py已经设置了一些种子,避免数据集轮换和权重初始化的随机性。若需关闭随机失活,将default_config_large.yaml中相应的dropout_prob参数设置为0。
## ModelZoo主页 ## ModelZoo主页
......
...@@ -115,7 +115,7 @@ bash wmt16_en_de.sh ...@@ -115,7 +115,7 @@ bash wmt16_en_de.sh
paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all paste newstest2014.tok.bpe.32000.en newstest2014.tok.bpe.32000.de > test.all
``` ```
将default_config.yaml中bucket改为bucket: [128] 将default_config_large.yaml中bucket改为bucket: [128]
```text ```text
# create_data.py # create_data.py
...@@ -134,7 +134,7 @@ bucket: [128] ...@@ -134,7 +134,7 @@ bucket: [128]
python3 create_data.py --input_file ./infer/data/data/test.all --vocab_file ./infer/data/data/vocab.bpe.32000 --output_file ./infer/data/data/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True python3 create_data.py --input_file ./infer/data/data/test.all --vocab_file ./infer/data/data/vocab.bpe.32000 --output_file ./infer/data/data/newstest2014-l128-mindrecord --num_splits 1 --max_seq_length 128 --clip_to_max_len True
``` ```
更改default_config.yaml中参数: 更改default_config_large.yaml中参数:
```text ```text
#eval_config/cfg edict #eval_config/cfg edict
......
#!/bin/bash #!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
if [ $# != 5 ] ; then if [ $# != 5 ] ; then
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the script as: " echo "Please run the script as: "
echo "sh scripts/run_eval.sh DEVICE_TARGET DEVICE_ID MINDRECORD_DATA CKPT_PATH CONFIG_PATH" echo "bash scripts/run_eval.sh DEVICE_TARGET DEVICE_ID MINDRECORD_DATA CKPT_PATH CONFIG_PATH"
echo "for example: sh run_eval.sh Ascend 0 /your/path/evaluation.mindrecord /your/path/checkpoint_file ./default_config_large_gpu.yaml" echo "for example: bash run_eval.sh Ascend 0 /your/path/evaluation.mindrecord /your/path/checkpoint_file ./default_config_large_gpu.yaml"
echo "Note: set the checkpoint and dataset path in default_config.yaml" echo "Note: set the checkpoint and dataset path in default_config_large.yaml"
echo "==============================================================================================================" echo "=============================================================================================================="
exit 1; exit 1;
fi fi
......
#!/bin/bash #!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2021-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,7 +31,7 @@ get_real_path(){ ...@@ -31,7 +31,7 @@ get_real_path(){
ONNX_MODEL=$(get_real_path $1) ONNX_MODEL=$(get_real_path $1)
MINDRECORD_DATA=$(get_real_path $2) MINDRECORD_DATA=$(get_real_path $2)
CONFIG_PATH=${3:-"./default_config.yaml"} CONFIG_PATH=${3:-"./default_config_large.yaml"}
CONFIG_PATH=$(get_real_path $CONFIG_PATH) CONFIG_PATH=$(get_real_path $CONFIG_PATH)
DEVICE_TARGET=${4:-"GPU"} DEVICE_TARGET=${4:-"GPU"}
DEVICE_ID=${5:-0} DEVICE_ID=${5:-0}
......
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2021-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -39,7 +39,7 @@ class Config: ...@@ -39,7 +39,7 @@ class Config:
return self.__str__() return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config_large.yaml"):
""" """
Parse command line arguments to the configuration according to the default yaml. Parse command line arguments to the configuration according to the default yaml.
...@@ -115,7 +115,7 @@ def get_config(): ...@@ -115,7 +115,7 @@ def get_config():
""" """
parser = argparse.ArgumentParser(description="default name", add_help=False) parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"), parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config_large.yaml"),
help="Config file path") help="Config file path")
path_args, _ = parser.parse_known_args() path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path) default, helper, choices = parse_yaml(path_args.config_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment