diff --git a/official/cv/crnn/README.md b/official/cv/crnn/README.md index 922bd409ac52e96aceba732b0993540d228eaa73..8fc7069a6597bccb83ffa09ecf34a78fe8aa3c2c 100644 --- a/official/cv/crnn/README.md +++ b/official/cv/crnn/README.md @@ -18,6 +18,7 @@ - [Distributed Training](#distributed-training) - [Evaluation Process](#evaluation-process) - [Evaluation](#evaluation) + - [Evaluation while training](#evaluation-while-training) - [Inference Process](#inference-process) - [Export MindIR](#export-mindir) - [Infer on Ascend310](#infer-on-ascend310) @@ -43,7 +44,7 @@ CRNN use a vgg16 structure for feature extraction, the appending with two-layer We provide 2 versions of network using different ways to transfer the hidden size to class numbers. You could choose different version by modifying the `model_version` in config yaml. - V1 using an full connection after the RNN parts. -- V2 change the output feature size of the last RNN, to output a feature with the same size of class numbers. +- V2 change the output feature size of the last RNN, to output a feature with the same size of class numbers. V2 also switch to the builtin `LSTM` cell instead of operator `DynamicRNN`, which make the model being supported both on GPU and Ascend. ## [Dataset](#content) diff --git a/official/cv/crnn/src/model_utils/config.py b/official/cv/crnn/src/model_utils/config.py index efc856cf0cf389f539efa236c1bcb638cff853aa..259a4a98a7852933d4367d067648ebf8540ead10 100644 --- a/official/cv/crnn/src/model_utils/config.py +++ b/official/cv/crnn/src/model_utils/config.py @@ -122,9 +122,9 @@ def get_config(): help='Config file path') path_args, _ = parser.parse_known_args() default, helper, choices = parse_yaml(path_args.config_path) - pprint(default) args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) final_config = merge(args, default) + pprint(final_config) return Config(final_config) config = get_config() diff --git a/official/cv/crnn/train.py b/official/cv/crnn/train.py index c7afc2ebca1593eb29375f34bffd5a97e664cef2..fb0f33fb857efa3c2563327a43bfa625b3aba98e 100644 --- a/official/cv/crnn/train.py +++ b/official/cv/crnn/train.py @@ -55,6 +55,9 @@ def train(): device_id = get_device_id() context.set_context(device_id=device_id) + if config.model_version == 'V1' and config.device_target != 'Ascend': + raise ValueError("model version V1 is only supported on Ascend, pls check the config.") + # lr_scale = 1 if config.run_distribute: if config.device_target == 'Ascend':