Skip to content
Snippets Groups Projects
Commit 0397c5ac authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!74 Check model_version and device_target in CRNN and raise error for unmatch

Merge pull request !74 from chenhaozhe/fix-crnn-version-dismatch
parents 01e5a758 9222be85
Branches
Tags
No related merge requests found
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
- [Distributed Training](#distributed-training) - [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process) - [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation) - [Evaluation](#evaluation)
- [Evaluation while training](#evaluation-while-training)
- [Inference Process](#inference-process) - [Inference Process](#inference-process)
- [Export MindIR](#export-mindir) - [Export MindIR](#export-mindir)
- [Infer on Ascend310](#infer-on-ascend310) - [Infer on Ascend310](#infer-on-ascend310)
...@@ -43,7 +44,7 @@ CRNN use a vgg16 structure for feature extraction, the appending with two-layer ...@@ -43,7 +44,7 @@ CRNN use a vgg16 structure for feature extraction, the appending with two-layer
We provide 2 versions of network using different ways to transfer the hidden size to class numbers. You could choose different version by modifying the `model_version` in config yaml. We provide 2 versions of network using different ways to transfer the hidden size to class numbers. You could choose different version by modifying the `model_version` in config yaml.
- V1 using an full connection after the RNN parts. - V1 using an full connection after the RNN parts.
- V2 change the output feature size of the last RNN, to output a feature with the same size of class numbers. - V2 change the output feature size of the last RNN, to output a feature with the same size of class numbers. V2 also switch to the builtin `LSTM` cell instead of operator `DynamicRNN`, which make the model being supported both on GPU and Ascend.
## [Dataset](#content) ## [Dataset](#content)
......
...@@ -122,9 +122,9 @@ def get_config(): ...@@ -122,9 +122,9 @@ def get_config():
help='Config file path') help='Config file path')
path_args, _ = parser.parse_known_args() path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path) default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default) final_config = merge(args, default)
pprint(final_config)
return Config(final_config) return Config(final_config)
config = get_config() config = get_config()
...@@ -55,6 +55,9 @@ def train(): ...@@ -55,6 +55,9 @@ def train():
device_id = get_device_id() device_id = get_device_id()
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
if config.model_version == 'V1' and config.device_target != 'Ascend':
raise ValueError("model version V1 is only supported on Ascend, pls check the config.")
# lr_scale = 1 # lr_scale = 1
if config.run_distribute: if config.run_distribute:
if config.device_target == 'Ascend': if config.device_target == 'Ascend':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment