diff --git a/official/nlp/ernie/README_CN.md b/official/nlp/ernie/README_CN.md index a21a440fb357348feb62aa9b85082f79543e64b2..1abbe73f11c3202b0dc5506a5b5bbd61f137d0f2 100644 --- a/official/nlp/ernie/README_CN.md +++ b/official/nlp/ernie/README_CN.md @@ -10,11 +10,11 @@ - [环境要求](#环境要求) - [快速入门](#快速入门) - [脚本说明](#脚本说明) - - [脚本和样例代码](#脚本和样例代码) + - [脚本和代码](#脚本和代码) - [选项及参数](#选项及参数) - [选项](#选项) - [参数](#参数) - - [训练过程](#训练过程) + - [预训练过程](#预训练过程) - [用法](#用法) - [下载数据集并预处理](#下载数据集并预处理) - [Ascend处理器上运行](#ascend处理器上运行) @@ -24,13 +24,26 @@ - [Ascend处理器上运行单卡微调](#ascend处理器上运行单卡微调) - [Ascend处理器上单机多卡微调](#ascend处理器上单机多卡微调) - [Ascend处理器上运行微调后的模型评估](#ascend处理器上运行微调后的模型评估) + - [导出onnx模型](#导出onnx模型) + - [onnx模型导出](#onnx模型导出) + - [onnx模型评估](#onnx模型评估) + - [基于chnsenticorp数据集进行onnx评估](#基于chnsenticorp数据集进行onnx评估) + - [基于xnli数据集进行onnx评估](#基于xnli数据集进行onnx评估) + - [基于dbqa数据集进行onnx评估](#基于dbqa数据集进行onnx评估) - [导出mindir模型](#导出mindir模型) - [推理过程](#推理过程) - [用法](#用法-2) - [结果](#结果) - - [精度与性能](#精度与性能) - - [推理性能](#推理性能) + - [模型描述](#模型描述) + - [精度与性能](#精度与性能) + - [推理性能](#推理性能) + - [命名实体识别任务](#命名实体识别任务) + - [情感分析任务](#情感分析任务) + - [自然语言接口](#自然语言接口) + - [问答](#问答) + - [阅读理解](#阅读理解) - [ModelZoo主页](#modelzoo主页) +- [FAQ](#faq) <!-- /TOC --> @@ -77,9 +90,9 @@ bash scripts/download_datasets.sh pretrain # 将数据集转为MindRecord # 预训练数据集 -bash scripts/convert_pretrain_dataset.sh /path/zh_wiki/ /path/zh_wiki/mindrecord/ +bash scripts/convert_pretrain_datasets.sh /path/zh_wiki/ /path/zh_wiki/mindrecord/ # 微调数据集 -bash scripts/convert_finetune_dataset.sh /path/msra_ner/ /path/msra_ner/mindrecord/ msra_ner +bash scripts/convert_finetune_datasets.sh /path/msra_ner/ /path/msra_ner/mindrecord/ msra_ner # 单机运行预训练示例 bash scripts/run_standalone_pretrain_ascend.sh 0 1 /path/cn-wiki-128 @@ -112,6 +125,7 @@ bash scripts/run_distribute_finetune.sh rank_table.json xnli ├─export.sh # 导出模型中间表示脚本,如MindIR ├─migrate_pretrained_models.sh # 在x86设备上将Paddle预训练权重参数转为MindSpore权重参数脚本 ├─run_distribute_finetune.sh # Ascend设备上多卡运行微调任务脚本 + ├─run_eval_onnx.sh # 对导出的onnx模型评估脚本 ├─run_finetune_eval.sh # Ascend设备上测试微调结果脚本 ├─run_infer_310.sh # Ascend 310设备推理脚本 ├─run_standalone_finetune.sh # Ascend设备上单卡运行微调任务脚本 @@ -138,7 +152,8 @@ bash scripts/run_distribute_finetune.sh rank_table.json xnli ├─run_ernie_classifier.py # 分类器任务的微调和评估网络 ├─run_ernie_mrc.py # 阅读理解任务的微调和评估网络 ├─run_ernie_ner.py # NER任务的微调和评估网络 - └─run_ernie_pretrain.py # 预训练网络 + ├─run_ernie_pretrain.py # 预训练网络 + └─run_eval_onnx.py # 评估onnx模型 ``` ## 选项及参数 @@ -215,15 +230,19 @@ bash scripts/download_datasets.sh pretrain 然后进行数据预处理,对文本进行分词,并随机mask词语: ```bash -bash scripts/convert_pretrain_dataset.sh /path/zh_wiki/ /path/zh_wiki/mindrecord/ +bash scripts/convert_pretrain_datasets.sh /path/zh_wiki/ /path/zh_wiki/mindrecord/ ``` > **注意:** +> > 1. 维基百科文本抽取依赖`wikiextractor`,数据预处理依赖结巴分词和`OpenCC`繁简体转换,可以通过以下命令安装依赖: ->```bash -> pip install -r requirements.txt ->``` +> +> ```bash +> pip install -r requirements.txt +> ``` +> > 2. 若需要使用私有词典进行分词,可修改`scr/pretrain_reader.py`中`get_word_segs`方法。 + #### Ascend处理器上运行 ```bash @@ -284,6 +303,7 @@ bash scripts/run_distribute_finetune.sh [RANK_TABLE_FILE] [TASK_TYPE] 以上命令后台运行,您可以在{task_type}_train_log.txt中查看训练日志。 > **注意:** +> > 1. `rank_table.json`可以通过`/etc/hccn.conf`获取加速卡IP进行配置。 > 2. `drcd, cmrc`数据集评估需要使用`nltk`,请通过以下命令安装并下载依赖库,然后运行微调脚本。 @@ -324,6 +344,93 @@ F1 0.920507 {"exact_match": 84.13970798740338, "f1": 90.52935807300771} ``` +## 导出onnx模型 + +### finetune数据集转化为mindspore格式 + +```shell +bash scripts/convert_finetune_datasets.sh [DATASET_PATH] [OUTPUT_PATH] [TASK_TYPE] +# bash scripts/convert_finetune_datasets.sh data/chnsenticorp/ data/chnsenticorp/ chnsenticorp +# bash scripts/convert_finetune_datasets.sh data/xnli/ data/xnli/ xnli +# bash scripts/convert_finetune_datasets.sh data/nlpcc-dbqa/ data/nlpcc-dbqa/ dbqa +``` + +[DATASET_PATH]:数据集路径; [OUTPUT_PATH]:转换后输出路径; [TASK_TYPE]:任务类型 + +### onnx模型导出 + +```bash +python export.py --ckpt_file [CKPT_PATH] --file_format onnx --file_name [FILE_NAME] --device_target GPU --task_type [TASK_TYPE] --number_labels [NUMBER_LABELS] +# chnsenticorp数据集 +# python export.py --ckpt_file ./save_models/chnsenticorp-0-3_300.ckpt --file_format ONNX --file_name ernie_finetune --device_target GPU --task_type chnsenticorp --number_labels 3 + +# xnli数据集 +# python export.py --ckpt_file ./save_models/ernie_ascend_v170_xnli_official_nlp_acc77.94.ckpt --file_format ONNX --device_target GPU --task_type xnli --number_labels 3 + +# dbqa数据集 +# python export.py --ckpt_file ./save_models/ernie_ascend_v170_bdqa_official_nlp_F1score84.96.ckpt --file_format ONNX --device_target GPU --task_type dbqa --number_labels 2 +``` + +[CKPT_PATH]:为ckpt路径; + +[FILE_NAME]:为导出onnx文件名称; + +[TASK_TYPE]:任务数据集名称,包括`[chnsenticorp, xnli, dbqa]` + +[NUMBER_LABELS]:分类任务label数量, chnsenticorp、xnli为3, dbqa为2 + +### onnx模型评估 + +```bash +bash scripts/run_eval_onnx.sh [TASK_TYPE] +# bash scripts/run_eval_onnx.sh chnsenticorp +# TASK_TYPE including [chnsenticorp, xnli, dbqa] +``` + +以上命令后台运行,您可以在`./ms_log/[TASK_TYPE]_onnx_log.txt`查看运行结果。 + +onnx模型路径为 `ONNX_PATH=${CUR_DIR}/ernie_finetune.onnx` + +eval数据集路径为 `EVAL_DATA_PATH=${CUR_DIR}/data/${TASK_TYPE}/${TASK_TYPE}_test.mindrecord` + 如:`home/models/official/nlp/ernie/data/chnsenticorp/chnsenticorp_test.mindrecord` + +具体请看 scripts/run_eval_onnx.sh 脚本 + +#### 基于chnsenticorp数据集进行onnx评估 + +```bash +bash scripts/run_eval_onnx.sh chnsenticorp + +cat ms_log/chnsenticorp_onnx_log.txt +============================================================== +acc_num 1150 , total_num 1200, accuracy 0.958333 +============================================================== +``` + +#### 基于xnli数据集进行onnx评估 + +```bash +bash scripts/run_eval_onnx.sh xnli + +cat ms_log/xnli_onnx_log.txt +============================================================== +acc_num 3902 , total_num 5010, accuracy 0.778842 +============================================================== +``` + +#### 基于dbqa数据集进行onnx评估 + +```bash +bash scripts/run_eval_onnx.sh dbqa + +cat ms_log/dbqa_onnx_log.txt +============================================================== +Precision 0.804701 +Recall 0.897794 +F1 0.848703 +============================================================== +``` + ## 导出mindir模型 ```bash diff --git a/official/nlp/ernie/requirements.txt b/official/nlp/ernie/requirements.txt index 6e682e5e1249910f762ebd4e521b6ddd0d80234c..095d2da4907c3c1c0d147806165bfa39ea75d734 100644 --- a/official/nlp/ernie/requirements.txt +++ b/official/nlp/ernie/requirements.txt @@ -2,4 +2,5 @@ paddlepaddle nltk wikiextractor jieba -opencc-python-reimplemented \ No newline at end of file +opencc-python-reimplemented +onnxruntime-gpu \ No newline at end of file diff --git a/official/nlp/ernie/run_eval_onnx.py b/official/nlp/ernie/run_eval_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..f33eb9db127ae58d9b0a8ef912175297153481dc --- /dev/null +++ b/official/nlp/ernie/run_eval_onnx.py @@ -0,0 +1,133 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +ERNIE task classifier onnx script. +''' + +import os +import argparse +from src.dataset import create_finetune_dataset +from src.assessment_method import Accuracy, F1 +from mindspore import Tensor, dtype +import onnxruntime as ort + + +def eval_result_print(assessment_method="accuracy", callback=None): + """ print eval result """ + if assessment_method == "accuracy": + print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, + callback.acc_num / callback.total_num)) + elif assessment_method == "f1": + print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) + print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) + print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1]") + +def create_session(checkpoint_path, target_device): + """Create ONNX runtime session""" + if target_device == 'GPU': + providers = ['CUDAExecutionProvider'] + elif target_device == 'CPU': + providers = ['CPUExecutionProvider'] + else: + raise ValueError( + f'Unsupported target device {target_device}, ' + f'Expected one of: "CPU", "GPU"' + ) + session = ort.InferenceSession(checkpoint_path, providers=providers) + input_name_0 = session.get_inputs()[0].name + input_name_1 = session.get_inputs()[1].name + input_name_2 = session.get_inputs()[2].name + output_name_0 = session.get_outputs()[0].name + return session, input_name_0, input_name_1, input_name_2, output_name_0 + + +def do_eval_onnx(dataset=None, onnx_file_name="", num_class=3, assessment_method="accuracy", target_device="GPU"): + """ do eval for onnx model""" + if assessment_method == "accuracy": + callback = Accuracy() + elif assessment_method == "f1": + callback = F1(num_class) + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1]") + + columns_list = ["input_ids", "input_mask", "token_type_id", "label_ids"] + + if not os.path.exists(onnx_file_name): + raise ValueError("ONNX file not exists, please check onnx file has been saved and whether the " + "export_file_name is correct.") + + session, input_name_0, input_name_1, input_name_2, output_name_0 = create_session(onnx_file_name, target_device) + + + for data in dataset.create_dict_iterator(num_epochs=1): + input_data = [] + for i in columns_list: + input_data.append(data[i]) + input_ids, input_mask, token_type_id, label_ids = input_data + + x0 = input_ids.asnumpy() + x1 = input_mask.asnumpy() + x2 = token_type_id.asnumpy() + + result = session.run([output_name_0], {input_name_0: x0, input_name_1: x1, input_name_2: x2}) + logits = Tensor(result[0], dtype.float32) + callback.update(logits, label_ids) + + print("==============================================================") + eval_result_print(assessment_method, callback) + print("==============================================================") + + +def run_classifier_onnx(): + """run classifier task for onnx model""" + args_opt = parse_args() + if args_opt.eval_data_file_path == "": + raise ValueError("'eval_data_file_path' must be set when do onnx evaluation task") + assessment_method = args_opt.assessment_method.lower() + ds = create_finetune_dataset(batch_size=args_opt.eval_batch_size, + repeat_count=1, + data_file_path=args_opt.eval_data_file_path, + do_shuffle=(args_opt.eval_data_shuffle.lower() == "true")) + do_eval_onnx(ds, args_opt.onnx_file, args_opt.number_labels, assessment_method, args_opt.device_target) + + +def parse_args(): + """set and check parameters.""" + parser = argparse.ArgumentParser(description="run classifier") + parser.add_argument("--task_type", type=str, default="chnsenticorp", choices=["chnsenticorp", "xnli", "dbqa"], + help="Task type, default is chnsenticorp") + parser.add_argument("--assessment_method", type=str, default="accuracy", choices=["accuracy", "f1"], + help="Assessment method") + parser.add_argument("--device_target", type=str, default="GPU", choices=["Ascend", "GPU"], + help="Device type, default is Ascend") + parser.add_argument('--onnx_file', type=str, + default="./ernie_finetune.onnx", + help='Onnx file path') + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--number_labels", type=int, default=3, help="The number of class, default is 3.") + parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"], + help="Enable eval data shuffle, default is false") + parser.add_argument("--eval_batch_size", type=int, default=1, help="Eval batch size, default is 1") + parser.add_argument("--eval_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + args_opt = parser.parse_args() + + return args_opt + + +if __name__ == "__main__": + run_classifier_onnx() diff --git a/official/nlp/ernie/scripts/run_eval_onnx.sh b/official/nlp/ernie/scripts/run_eval_onnx.sh new file mode 100644 index 0000000000000000000000000000000000000000..5cfdde2ff765662fd344e36eb2a3ff7a7dfaf6a2 --- /dev/null +++ b/official/nlp/ernie/scripts/run_eval_onnx.sh @@ -0,0 +1,70 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# -ne 1 ] +then + echo "==============================================================================================================" + echo "Please run the script as: " + echo "bash run_eval_onnx.sh [TASK_TYPE]" + echo "for example: bash scripts/run_eval_onnx.sh chnsenticorp" + echo "TASK_TYPE including [chnsenticorp, xnli, dbqa]" + echo "==============================================================================================================" +exit 1 +fi + +TASK_TYPE=$1 +DEVICE_ID=0 +CUR_DIR=`pwd` +mkdir -p ms_log +ONNX_PATH=${CUR_DIR}/ernie_finetune.onnx +DATA_PATH=${CUR_DIR}/data +GLOG_log_dir=${CUR_DIR}/ms_log + +case $TASK_TYPE in + "chnsenticorp") + PY_NAME=run_eval_onnx + NUM_LABELS=3 + EVAL_BATCH_SIZE=1 + EVAL_DATA_PATH="${DATA_PATH}/chnsenticorp/chnsenticorp_test.mindrecord" + ASSESSMENT_METHOD="accuracy" + ;; + + "xnli") + PY_NAME=run_eval_onnx + NUM_LABELS=3 + EVAL_BATCH_SIZE=1 + EVAL_DATA_PATH="${DATA_PATH}/xnli/xnli_test.mindrecord" + ASSESSMENT_METHOD="accuracy" + ;; + + "dbqa") + PY_NAME=run_eval_onnx + NUM_LABELS=2 + EVAL_BATCH_SIZE=1 + EVAL_DATA_PATH="${DATA_PATH}/nlpcc-dbqa/dbqa_test.mindrecord" + ASSESSMENT_METHOD="f1" + ;; +esac + +python ${CUR_DIR}/${PY_NAME}.py \ + --task_type=${TASK_TYPE} \ + --device_target="GPU" \ + --device_id=${DEVICE_ID} \ + --number_labels=${NUM_LABELS} \ + --eval_data_shuffle="false" \ + --eval_batch_size=${EVAL_BATCH_SIZE} \ + --eval_data_file_path=${EVAL_DATA_PATH} \ + --onnx_file=${ONNX_PATH} \ + --assessment_method=${ASSESSMENT_METHOD} > ${GLOG_log_dir}/${TASK_TYPE}_onnx_log.txt 2>&1 & diff --git a/official/nlp/ernie/src/dataset.py b/official/nlp/ernie/src/dataset.py index d672304a316e6ab75dfc1a10c2e8ab69ef8ec27d..5aa74a06353c84c17dca69f501ba3ddb7569bd04 100644 --- a/official/nlp/ernie/src/dataset.py +++ b/official/nlp/ernie/src/dataset.py @@ -55,6 +55,7 @@ def create_finetune_dataset(batch_size=1, do_shuffle=True): """create finetune or evaluation dataset""" type_cast_op = C.TypeCast(mstype.int32) + data_set = ds.MindDataset(data_file_path, columns_list=["input_ids", "input_mask", "token_type_id", "label_ids"], shuffle=do_shuffle, @@ -79,6 +80,7 @@ def create_mrc_dataset(batch_size=1, drop_reminder=False): """create finetune or evaluation dataset""" type_cast_op = C.TypeCast(mstype.int32) + if is_training: data_set = ds.MindDataset(data_file_path, columns_list=["input_ids", "input_mask", "token_type_id", diff --git a/official/nlp/ernie/src/finetune_task_reader.py b/official/nlp/ernie/src/finetune_task_reader.py index 868397d7065b6f5effe53d8171cb1c1eb237dc46..fdef43a87f5abdb1d08227097c0269d85f725416 100644 --- a/official/nlp/ernie/src/finetune_task_reader.py +++ b/official/nlp/ernie/src/finetune_task_reader.py @@ -23,7 +23,7 @@ from collections import namedtuple import numpy as np from mindspore.mindrecord import FileWriter from mindspore.log import logging -from src.tokenizer import FullTokenizer, convert_to_unicode, tokenize_chinese_chars +from .tokenizer import FullTokenizer, convert_to_unicode, tokenize_chinese_chars def csv_reader(fd, delimiter='\t'): """ diff --git a/official/nlp/ernie/src/pretrain_reader.py b/official/nlp/ernie/src/pretrain_reader.py index b42898f7fa313cef6c93a2402b4456e6002b6d11..e200b66cebde6646068ad971677f3ccc1dbcc96e 100644 --- a/official/nlp/ernie/src/pretrain_reader.py +++ b/official/nlp/ernie/src/pretrain_reader.py @@ -24,8 +24,8 @@ import jieba import numpy as np from opencc import OpenCC from mindspore.mindrecord import FileWriter -from src.tokenizer import convert_to_unicode, CharTokenizer -from src.utils import get_file_list +from .tokenizer import convert_to_unicode, CharTokenizer +from .utils import get_file_list class ErnieDataReader: """Ernie data reader"""