diff --git a/official/nlp/gnmt_v2/README.md b/official/nlp/gnmt_v2/README.md
index 44971b55898f7379a9e67f7c5998acfce29f3a93..caf0b4d8a1406be7ed228132fd68af3487580070 100644
--- a/official/nlp/gnmt_v2/README.md
+++ b/official/nlp/gnmt_v2/README.md
@@ -13,10 +13,12 @@
- [Configuration File](#configuration-file)
- [Training Process](#training-process)
- [Inference Process](#inference-process)
+ - [Inference Process](#inference-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
+ - [ONNX Export and Evaluation](#onnx-export-and-evaluation)
- [Random Situation Description](#random-situation-description)
- [Others](#others)
- [ModelZoo HomePage](#modelzoo-homepage)
@@ -188,7 +190,7 @@ The GNMT network script and code result are as follows:
│ ├──local_adapter.py // Local adapter
│ ├──moxing_adapter.py // Moxing adapter for ModelArts
├── src
- │ ├──__init__.py // User interface.
+ │ ├──__init__.py // User interface.
│ ├──dataset
│ ├──__init__.py // User interface.
│ ├──base.py // Base class of data loader.
@@ -222,7 +224,8 @@ The GNMT network script and code result are as follows:
│ ├──optimizer.py // Optimizer.
├── scripts
│ ├──run_distributed_train_ascend.sh // Shell script for distributed train on ascend.
- │ ├──run_distributed_train_gpu.sh // Shell script for distributed train on GPU.
+ │ ├──run_distributed_train_gpu.sh // Shell script for distributed train on GPU.
+ │ ├──run_onnx_eval_gpu.sh // Shell script for ONNX eval on GPU.
│ ├──run_standalone_eval_ascend.sh // Shell script for standalone eval on ascend.
│ ├──run_standalone_eval_gpu.sh // Shell script for standalone eval on GPU.
│ ├──run_standalone_train_ascend.sh // Shell script for standalone eval on ascend.
@@ -233,6 +236,7 @@ The GNMT network script and code result are as follows:
├── default_test_config_gpu.yaml // Configurations for eval on GPU.
├── create_dataset.py // Dataset preparation.
├── eval.py // Infer API entry.
+ ├── eval_onnx.py // ONNX infer API entry.
├── export.py // Export checkpoint file into air models.
├── mindspore_hub_conf.py // Hub config.
├── pip-requirements.txt // Requirements of third party package for modelarts.
@@ -374,6 +378,25 @@ For more configuration details, please refer the script `./default_config.yaml`
The `TEST_DATASET` is the address of inference dataset, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process.
The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers.
+## ONNX Export and Evaluation
+
+- Export your model to ONNX:
+
+ ```bash
+ python export.py --config_path default_test_config_gpu.yaml --existed_ckpt /path/to/checkpoint.ckpt --file_name /path/to/exported.onnx --file_format ONNX
+ ```
+
+- Run ONNX evaluation:
+
+ ```bash
+ python eval_onnx.py --config_path default_test_config_gpu.yaml --test_dataset /path/to/newstest2014.en.mindrecord --file_name /path/to/exported.onnx --vocab /path/to/vocab.bpe.32000 --bpe_codes /path/to/bpe.32000 --test_tgt /path/to/newstest2014.de
+
+ # or
+
+ cd scripts
+ bash run_onnx_eval_gpu.sh /path/to/newstest2014.en.mindrecord /path/to/exported.onnx /path/to/vocab.bpe.32000 /path/to/bpe.32000 /path/to/newstest2014.de
+ ```
+
# [Model Description](#contents)
## Performance
diff --git a/official/nlp/gnmt_v2/eval_onnx.py b/official/nlp/gnmt_v2/eval_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b6383ec6606702a7736f5f97535f76c3b43f98
--- /dev/null
+++ b/official/nlp/gnmt_v2/eval_onnx.py
@@ -0,0 +1,116 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""ONNX evaluation"""
+import pickle
+import time
+
+import numpy as np
+import onnxruntime as ort
+from src.dataset import load_dataset
+from src.gnmt_model.bleu_calculate import bleu_calculate
+from src.dataset.tokenizer import Tokenizer
+from src.utils.get_config import get_config
+
+from model_utils.config import config as default_config
+
+
+def create_session(checkpoint_path, target_device):
+ """Create ONNX runtime session"""
+ if target_device == 'GPU':
+ providers = ['CUDAExecutionProvider']
+ elif target_device in ('CPU', 'Ascend'):
+ providers = ['CPUExecutionProvider']
+ else:
+ raise ValueError(f"Unsupported target device '{target_device}'. Expected one of: 'CPU', 'GPU', 'Ascend'")
+ session = ort.InferenceSession(checkpoint_path, providers=providers)
+ input_names = [x.name for x in session.get_inputs()]
+ return session, input_names
+
+
+def infer(config):
+ """Run inference"""
+ session, [ids_name, mask_name] = create_session(config.file_name, config.device_target)
+ eval_dataset = load_dataset(data_files=config.test_dataset,
+ batch_size=config.batch_size,
+ sink_mode=config.dataset_sink_mode,
+ drop_remainder=False,
+ is_translate=True,
+ shuffle=False)
+
+ predictions = []
+ source_sentences = []
+
+ batch_index = 1
+ pad_idx = 0
+ sos_idx = 2
+ eos_idx = 3
+ source_ids_pad = np.tile(np.array([[sos_idx, eos_idx] + [pad_idx] * (config.seq_length - 2)], np.int32),
+ [config.batch_size, 1])
+ source_mask_pad = np.tile(np.array([[1, 1] + [0] * (config.seq_length - 2)], np.int32),
+ [config.batch_size, 1])
+ for batch in eval_dataset.create_dict_iterator(output_numpy=True):
+ source_sentences.append(batch["source_eos_ids"])
+ source_ids = batch["source_eos_ids"]
+ source_mask = batch["source_eos_mask"]
+
+ active_num = source_ids.shape[0]
+ if active_num < config.batch_size:
+ source_ids = np.concatenate((source_ids, source_ids_pad[active_num:, :]))
+ source_mask = np.concatenate((source_mask, source_mask_pad[active_num:, :]))
+
+ start_time = time.time()
+ [predicted_ids] = session.run(None, {ids_name: source_ids, mask_name: source_mask})
+
+ print(f" | BatchIndex = {batch_index}, Batch size: {config.batch_size}, active_num={active_num}, "
+ f"Time cost: {time.time() - start_time}.")
+ if active_num < config.batch_size:
+ predicted_ids = predicted_ids[:active_num, :]
+ batch_index = batch_index + 1
+ predictions.append(predicted_ids)
+
+ output = []
+ for inputs, batch_out in zip(source_sentences, predictions):
+ for i, _ in enumerate(batch_out):
+ if batch_out.ndim == 3:
+ batch_out = batch_out[:, 0]
+
+ example = {
+ "source": inputs[i].tolist(),
+ "prediction": batch_out[i].tolist()
+ }
+ output.append(example)
+
+ return output
+
+
+def run_onnx_eval():
+ """ONNX eval"""
+ config = get_config(default_config)
+ result = infer(config)
+
+ with open(config.output, "wb") as f:
+ pickle.dump(result, f, 1)
+
+ result_npy_addr = config.output
+ vocab = config.vocab
+ bpe_codes = config.bpe_codes
+ test_tgt = config.test_tgt
+ tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')
+ scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
+ print(f"BLEU scores is :{scores}")
+
+
+if __name__ == '__main__':
+ run_onnx_eval()
diff --git a/official/nlp/gnmt_v2/export.py b/official/nlp/gnmt_v2/export.py
index 7880fe2002cb46b838308d92338696652be3e766..bc28710b4fbc3a356d648118ebffb269c45acd4c 100644
--- a/official/nlp/gnmt_v2/export.py
+++ b/official/nlp/gnmt_v2/export.py
@@ -1,4 +1,4 @@
-# Copyright 2020 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -88,11 +88,11 @@ def modelarts_pre_process():
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_export():
'''run export.'''
- context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend",
- reserve_class_name_in_scope=False)
-
config = get_config(default_config)
+ context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=config.device_target,
+ reserve_class_name_in_scope=False)
+
tfm_model = GNMT(config=config,
is_training=False,
use_one_hot_embeddings=False)
diff --git a/official/nlp/gnmt_v2/requirements.txt b/official/nlp/gnmt_v2/requirements.txt
index 28f9e9f07218aaf684451cf03f3002c2919dbf60..f6b33f8c7a46e4078fa785d25cc3d7a61bfbd281 100644
--- a/official/nlp/gnmt_v2/requirements.txt
+++ b/official/nlp/gnmt_v2/requirements.txt
@@ -3,3 +3,4 @@ pyyaml
subword-nmt==0.3.7
sacrebleu==1.4.14
sacremoses==0.0.35
+onnxruntime-gpu
diff --git a/official/nlp/gnmt_v2/scripts/run_onnx_eval_gpu.sh b/official/nlp/gnmt_v2/scripts/run_onnx_eval_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..db8436703452a8ad9c2a80d85fb40baa7dd5b43d
--- /dev/null
+++ b/official/nlp/gnmt_v2/scripts/run_onnx_eval_gpu.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "bash run_onnx_eval_gpu.sh TEST_DATASET ONNX_CKPT_PATH \
+ VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET"
+echo "for example:"
+echo "bash run_onnx_eval_gpu.sh \
+ /home/workspace/dataset_menu/newstest2014.en.mindrecord \
+ /home/workspace/gnmt_v2/exported.onnx \
+ /home/workspace/wmt16_de_en/vocab.bpe.32000 \
+ /home/workspace/wmt16_de_en/bpe.32000 \
+ /home/workspace/wmt16_de_en/newstest2014.de"
+echo "It is better to use absolute path."
+echo "=============================================================================================================="
+
+TEST_DATASET=${1:?Missing test dataset}
+ONNX_CKPT_PATH=${2:?Missing ONNX checkpoint path}
+VOCAB_ADDR=${3:?Missing vocabulary path}
+BPE_CODE_ADDR=${4:?Missing BPE codes path}
+TEST_TARGET=${5:?Missing test target}
+
+current_exec_path=$(pwd)
+echo ${current_exec_path}
+
+
+export GLOG_v=2
+
+if [ -d "eval" ];
+then
+ rm -rf ./eval
+fi
+mkdir ./eval
+cp ../*.py ./eval
+cp ../*.yaml ./eval
+cp -r ../src ./eval
+cp -r ../model_utils ./eval
+cd ./eval || exit
+echo "start for evaluation"
+env > env.log
+
+config_path="${current_exec_path}/eval/default_test_config_gpu.yaml"
+echo "config path is : ${config_path}"
+
+python eval_onnx.py \
+ --config_path=$config_path \
+ --test_dataset=$TEST_DATASET \
+ --file_name=$ONNX_CKPT_PATH \
+ --vocab=$VOCAB_ADDR \
+ --bpe_codes=$BPE_CODE_ADDR \
+ --test_tgt=$TEST_TARGET >onnx_eval.log 2>&1 &
+cd ..
diff --git a/official/nlp/gnmt_v2/src/gnmt_model/dynamic_rnn.py b/official/nlp/gnmt_v2/src/gnmt_model/dynamic_rnn.py
index 7f4c91ab3486f85e24ccf1a7d15d1aadf166b38b..8edafd4286f48d031c361b22551ecdf92817a2a3 100644
--- a/official/nlp/gnmt_v2/src/gnmt_model/dynamic_rnn.py
+++ b/official/nlp/gnmt_v2/src/gnmt_model/dynamic_rnn.py
@@ -1,4 +1,4 @@
-# Copyright 2020-2021 Huawei Technologies Co., Ltd
+# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.