Skip to content
Snippets Groups Projects
Unverified Commit 4fb4d9bf authored by zhaoting's avatar zhaoting Committed by Gitee
Browse files

!2051 【西安交通大学】【高校贡献】【Mindspore】【speech transformer】

Merge pull request !2051 from gulong12345/master
parents 5098f165 f581ac57
No related branches found
No related tags found
No related merge requests found
Showing
with 2022 additions and 0 deletions
# Contents
- [Contents](#contents)
- [Transformer Description](#transformer-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Script Parameters](#training-script-parameters)
- [Running Options](#running-options)
- [Network Parameters](#network-parameters)
- [Dataset Preparation](#dataset-preparation)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Inference Process](#inference-process)
- [Export MindIR](#export-mindir)
- [Infer on Ascend310](#infer-on-ascend310)
- [result](#result)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
## [Speech Transformer Description](#contents)
The standard transformer sequence2sequence (encoder, decoder) model architecture is used to solve the speech2text problem.
[Paper](https://ieeexplore.ieee.org/document/8682586): Yuanyuan Zhao, Jie Li, Xiaorui Wang, and Yan Li. "The SpeechTransformer for Large-scale Mandarin Chinese Speech Recognition." ICASSP 2019.
## [Model Architecture](#contents)
Specifically, Transformer contains six encoder modules and six decoder modules. Each encoder module consists of a self-attention layer and a feed forward layer, each decoder module consists of a self-attention layer, a encoder-decoder-attention layer and a feed forward layer.
## [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using AISHELL dataset.
You can download dataset by this [link](http://www.openslr.org/33/)
The DataSet directory is as follows:
```shell
└── aishell
├── conf
│ ├── fbank.conf
├── data
│ ├── lang_1char
│ ├── non_lang_syms.txt
│ ├── train_chars.txt
├── dump
│ ├── dev
│ ├── data.json
│ ├── feats.1.ark
│ │── feats.1.scp
│ │── ......
│ ├── feats.40.ark
│ │── feats.40.scp
│ ├── feats.scp
│ │── utt2num_frames
│ ├── test
│ ├── data.json
│ ├── feats.1.ark
│ │── feats.1.scp
│ │── ......
│ ├── feats.40.ark
│ │── feats.40.scp
│ ├── feats.scp
│ │── utt2num_frames
│ └── train
│ ├── data.json
│ ├── feats.1.ark
│ │── feats.1.scp
│ │── ......
│ │── feats.40.ark
│ │── feats.40.scp
│ ├── feats.scp
│ │── utt2num_frames
──
```
## [Environment Requirements](#contents)
- Hardware(Ascend)
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/en/master/index.html)
## [Quick Start](#contents)
After dataset preparation, you can start training and evaluation as follows:
(Note that you must specify dataset path and path to char dictionary in `default_config.yaml`)
```bash
# run training example
bash scripts/run_standalone_train_ascend.sh 0 100 1
# run distributed training example
bash scripts/run_distribute_train_ascend.sh 8 100 ./default_config.yaml
# run evaluation example
bash scripts/run_eval_ascend.sh 0 /your/path/data.json /your/path/checkpoint_file ./default_config.yaml
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)
```shell
.
└── speech_transformer
├── README.md
├── default_config.yaml
├── eval.py
├── evaluate_cer.py
├── export.py
├── requirements.txt
├── prepare_aishell_data
│ ├── README.md
│ └── convert_kaldi_bins_to_pickle.py
├── scripts
│ ├── run_distribute_train_ascend.sh
│ ├── run_eval_ascend.sh
│ └── run_standalone_train_ascend.sh
├── src
│ ├── beam_search.py
│ ├── dataset.py
│ ├── kaldi_io.py
│ ├── lr_schedule.py
│ ├── model_utils
│ │ ├── config.py
│ │ ├── device_adapter.py
│ │ ├── __init__.py
│ │ ├── local_adapter.py
│ │ └── moxing_adapter.py
│ ├── transformer_for_train.py
│ ├── transformer_model.py
│ └── weight_init.py
└── train.py
```
### [Script Parameters](#contents)
#### Training Script Parameters
```text
usage: train.py [--distribute DISTRIBUTE] [--epoch_size N] [--device_num N] [--device_id N]
[--enable_save_ckpt ENABLE_SAVE_CKPT]
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
[--save_checkpoint_path SAVE_CHECKPOINT_PATH]
[--data_json_path DATA_PATH]
options:
--distribute pre_training by several devices: "true"(training by more than 1 device) | "false", default is "false"
--epoch_size epoch size: N, default is 150
--device_num number of used devices: N, default is 1
--device_id device id: N, default is 0
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true"
--do_shuffle enable shuffle: "true" | "false", default is "true"
--checkpoint_path path to load checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 2500
--save_checkpoint_num number for saving checkpoint files: N, default is 30
--save_checkpoint_path path to save checkpoint files: PATH, default is "./checkpoint/"
--data_json_path path to dataset file: PATH, default is ""
```
#### Running Options
```text
default_config.yaml:
transformer_network version of Transformer model: base | large, default is base
init_loss_scale_value initial value of loss scale: N, default is 2^10
scale_factor factor used to update loss scale: N, default is 2
scale_window steps for once updatation of loss scale: N, default is 2000
optimizer optimizer used in the network: Adam, default is "Adam"
data_file data file: PATH
model_file checkpoint file to be loaded: PATH
output_file output file of evaluation: PATH
```
#### Network Parameters
```text
Parameters for dataset and network (Training/Evaluation):
batch_size batch size of input dataset: N, default is 32
seq_length max length of input sequence: N, default is 512
input_feature_size Input feature size: N, default is 320
vocab_size size of each embedding vector: N, default is 4233
hidden_size size of Transformer encoder layers: N, default is 1024
num_hidden_layers number of hidden layers: N, default is 6
num_attention_heads number of attention heads: N, default is 8
intermediate_size size of intermediate layer: N, default is 2048
hidden_act activation function used: ACTIVATION, default is "relu"
hidden_dropout_prob dropout probability for TransformerOutput: Q, default is 0.3
attention_probs_dropout_prob dropout probability for TransformerAttention: Q, default is 0.3
max_position_embeddings maximum length of sequences: N, default is 512
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
label_smoothing label smoothing setting: Q, default is 0.1
beam_width beam width setting: N, default is 5
max_decode_length max decode length in evaluation: N, default is 100
length_penalty_weight normalize scores of translations according to their length: Q, default is 1.0
compute_type compute type in Transformer: mstype.float16 | mstype.float32, default is mstype.float16
Parameters for learning rate:
learning_rate value of learning rate: Q
warmup_steps steps of the learning rate warm up: N
start_decay_step step of the learning rate to decay: N
min_lr minimal learning rate: Q
lr_param_k scale factor for learning rate
请注意:
当训练时data_json_path设为'your/path/to/egs/aishell/dump/train/deltafalse/data.json'
当测试时data_json_path设为'your/path/to/egs/aishell/dump/test/deltafalse/data.json'
```
## [Dataset Preparation](#contents)
Detailed instruction for dataset preparation is described in [prepare_aishell_data/README.md](prepare_aishell_data/README.md)
Dataset preprocessed using reference [implementation](https://github.com/kaituoxu/Speech-Transformer).
Dataset is preprocessed using `Kaldi` and converts kaldi binaries into Python pickle objects.
## [Training Process](#contents)
- Set options in `default_config.yaml`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn) for more information about dataset.
- Run `run_standalone_train_ascend.sh` for non-distributed training of Transformer model.
``` bash
bash scripts/run_standalone_train_ascend.sh [DEVICE_ID] [EPOCH_SIZE] [GRADIENT_ACCUMULATE_STEP]
for example: bash run_standalone_train_ascend.sh Ascend 7 150 1
```
- Run `run_distribute_train_ascend.sh` for distributed training of Transformer model.
``` bash
bash scripts/run_distribute_train_ascend.sh [TRAIN_PATH] [DEVICE_NUM] [EPOCH_SIZE] [CONFIG_PAT [RANK_TABLE_FILE]
for example: bash run_distribute_train_ascend.sh ../train.py 8 120 ../default_config.yaml ../rank_table_8pcs.json
```
**Attention**: data sink mode can not be used in transformer since the input data have different sequence lengths.
## [Evaluation Process](#contents)
- Set options in `default_config.yaml`. Make sure the 'data_file', 'model_file' and 'output_file' are set to your own path.
- Run `bash scripts/run_eval_ascend.sh` for evaluation of Transformer model.
```bash
bash scripts/run_eval_ascend.sh [DEVICE_TARGET] [DEVICE_ID] [DATA_JSON_PATH] [CKPT_PATH] [CONFIG_PATH]
for example: bash run_eval_ascend.sh Ascend 0 /your/path/data.json /your/path/checkpoint_file ./default_config.yaml"
```
- Calculate Character Error Rate
```bash
python evaluate_cer.py
```
## Inference Process
### [Export MindIR](#contents)
```shell
python export.py --model_file [MODEL_FILE] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
`MODEL_FILE` should be in "MINDIR",
MODEL_FILE:模型参数路径,
FILE_NAME:导出文件的名字,
FILE_FORMAT:导出文件的格式,默认为MINDIR
### [Infer on Ascend310](#contents)
```shell
bash run_infer_310.sh 'your/path/ckpt/Speech_91.mindir' 'your/path/dataset/egs/aishell/dump/test/deltafalse/data.json' 0 'your/path/dataset/egs/aishell/data/lang_1char/train_chars.txt'
```
MINDIR_PATH:mindir文件路径,
DATASET_PATH:数据集路径(your/path/dataset/egs/aishell/dump/test/deltafalse/data.json),
DEVICE_ID:设备ID 默认为0,
CHARS_DICT_PATH:数字与汉字对应的json文件路径(your/path/dataset/egs/aishell/data/lang_1char/train_chars.txt)
### result
Inference result is saved in 'acc.log'
## [Model Description](#contents)
### [Performance](#contents)
#### Training Performance
| Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- |
| Resource | 8x Ascend 910 |
| uploaded Date | 02/16/2022 (month/day/year) |
| MindSpore Version | 1.7.0rc1 |
| Dataset | AISHELL Train |
| Training Parameters | epoch=100, batch_size=32 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| Speed | 291ms/step (8pcs) |
| Total Training time | 6.2 hours (8pcs) |
| Loss | 1.24 |
| Params (M) | 49.6 |
| Checkpoint for inference | 557Mb (.ckpt file) |
| Scripts | [Speech Transformer scripts](scripts) |
#### Evaluation Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Resource | Ascend 910 |
| Uploaded Date | 02/16/2022 (month/day/year) |
| MindSpore Version | 1.7.0rc1 |
| Dataset | AISHELL Test |
| batch_size | 1 |
| outputs | Character Error Rate |
| Accuracy | CER=11.6 |
| epoch | cer |
| ------------------- | --------------------------- |
| 91 | 11.6 |
## [Description of Random Situation](#contents)
There are three random situations:
- Shuffle of the dataset.
- Initialization of some model weights.
- Dropout operations.
Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in default_config.yaml.
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/models).
\ No newline at end of file
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
find_package(gflags REQUIRED)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
if [ -f "Makefile" ]; then
make clean
fi
cmake .. \
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
std::vector<std::string> GetAllFiles(std::string dir_name);
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name);
#endif
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(input0_path, ".", "input0 path");
DEFINE_int32(device_id, 0, "device id");
DEFINE_string(precision_mode, "allow_fp32_to_fp16", "precision mode");
int load_model(Model *model, std::vector<MSTensor> *model_inputs, std::string mindir_path, int device_id) {
if (RealPath(mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(device_id);
context->MutableDeviceInfo().push_back(ascend310);
ascend310->SetOpSelectImplMode("high_precision");
ascend310->SetPrecisionMode("allow_fp32_to_fp16");
mindspore::Graph graph;
Serialization::Load(mindir_path, ModelType::kMindIR, &graph);
Status ret = model->Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
*model_inputs = model->GetInputs();
if (model_inputs->empty()) {
std::cout << "Invalid model, inputs is empty." << std::endl;
return 1;
} else {
std::cout << "valid model, inputs is not empty." << std::endl;
}
return 0;
}
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
namespace ms = mindspore;
namespace ds = mindspore::dataset;
auto context = std::make_shared<ms::Context>();
auto ascend310_info = std::make_shared<ms::Ascend310DeviceInfo>();
ascend310_info->SetDeviceID(FLAGS_device_id);
context->MutableDeviceInfo().push_back(ascend310_info);
ms::Graph graph;
ms::Status ret = ms::Serialization::Load(FLAGS_mindir_path, ms::ModelType::kMindIR, &graph);
ms::Model model;
std::cout << ret << std::endl;
Status rets = model.Build(ms::GraphCell(graph), context);
if (rets != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> model_inputs = model.GetInputs();
if (model_inputs.empty()) {
std::cout << "Invalid model, inputs is empty." << std::endl;
return 1;
}
std::string input_path1 = FLAGS_input0_path;
std::string input_path2 = input_path1;
input_path2.replace(21, 1, "1");
auto input0_files = GetAllFiles(input_path1);
auto input1_files = GetAllFiles(input_path2);
std::cout << "input_path1:" << input_path1 << std::endl;
std::cout << "input_path2:" << input_path2 << std::endl;
if (input0_files.empty() || input1_files.empty()) {
std::cout << "ERROR: input data empty." << std::endl;
return 1;
}
std::map<double, double> costTime_map;
size_t size = input0_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << input0_files[i] << std::endl;
auto input0 = ReadFileToTensor(input0_files[i]);
auto input1 = ReadFileToTensor(input1_files[i]);
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
input0.Data().get(), input0.DataSize());
inputs.emplace_back(model_inputs[1].Name(), model_inputs[1].DataType(), model_inputs[1].Shape(),
input1.Data().get(), input1.DataSize());
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
int rst = WriteResult(input0_files[i], outputs);
if (rst != 0) {
std::cout << "write result failed." << std::endl;
return rst;
}
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <algorithm>
#include <iostream>
#include "inc/utils.h"
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name) {
std::vector<std::vector<std::string>> ret;
DIR *dir = OpenDir(dir_name);
if (dir == nullptr) {
return {};
}
struct dirent *filename;
/* read all the files in the dir ~ */
std::vector<std::string> sub_dirs;
while ((filename = readdir(dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
// get rid of "." and ".."
if (d_name == "." || d_name == ".." || d_name.empty()) {
continue;
}
std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name);
struct stat s;
lstat(dir_path.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
continue;
}
sub_dirs.emplace_back(dir_path);
}
std::sort(sub_dirs.begin(), sub_dirs.end());
(void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret),
[](const std::string &d) { return GetAllFiles(d); });
return ret;
}
std::vector<std::string> GetAllFiles(std::string dir_name) {
struct dirent *filename;
DIR *dir = OpenDir(dir_name);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string d_name = std::string(filename->d_name);
if (d_name == "." || d_name == ".." || d_name.size() <= 3) {
continue;
}
res.emplace_back(std::string(dir_name) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
return res;
}
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
const int INVALID_POINTER = -1;
const int ERROR = -2;
std::cout << "outputs.size(): " << outputs.size() << std::endl;
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE *outputFile = fopen(outFileName.c_str(), "wb");
if (outputFile == nullptr) {
std::cout << "open result file " << outFileName << " failed" << std::endl;
return INVALID_POINTER;
}
size_t size = fwrite(netOutput.get(), sizeof(char), outputSize, outputFile);
if (size != outputSize) {
fclose(outputFile);
outputFile = nullptr;
std::cout << "write result file " << outFileName << " failed, write size[" << size <<
"] is smaller than output size[" << outputSize << "], maybe the disk is full." << std::endl;
return ERROR;
}
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: "/cache/egs/"
train_url: "/cache/CheckPoints/"
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/training"
load_path: "/cache/checkpoint_path"
checkpoint_path: ''
device_target: Ascend
enable_profiling: False
# ==============================================================================
# config/cfg edict
transformer_network: 'base'
init_loss_scale_value: 1024
scale_factor: 2
scale_window: 2000
optimizer: 'Adam'
lr_param_k: 0.2
warmup_steps: 4000
optimizer_adam_beta2: 0.997
# transformer_net_cfg
batch_size: 32
seq_length: 512
input_feature_size: 320
vocab_size: 4233
hidden_size: 512
num_hidden_layers: 6
num_attention_heads: 8
intermediate_size: 2048
hidden_act: "relu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 512
initializer_range: 0.02
label_smoothing: 0.1
dtype: mstype.float32
compute_type: mstype.float16
#eval_config/cfg edict
data_file: '/cache/data'
model_file: './transformer/transformer_trained.ckpt'
output_file: './output_eval.json'
data_json_path: '/path/to/egs/aishell/dump/train/deltafalse/data.json'
chars_dict_path: '/path/to/egs/aishell/data/lang_1char/train_chars.txt'
# transformer_net_cfg
batch_size_ev: 1
hidden_dropout_prob_ev: 0.0
attention_probs_dropout_prob_ev: 0.0
beam_width: 5
max_decode_length: 100
length_penalty_weight: 1.0
# ==============================================================================
# train.py / Argparse init.
distribute: "false"
epoch_size: 91
device_id: 0
device_num: 1
enable_lossscale: "true"
do_shuffle: "true"
enable_save_ckpt: "true"
save_checkpoint_steps: 3753
save_checkpoint_num: 30
save_checkpoint_path: "./"
accumulation_steps: 1
# export.py /eval_config - transformer export
file_name: "transformer"
file_format: 'MINDIR'
#'postprocess / from eval_config'
result_dir: "./result_Files"
#'preprocess / from eval_config'
result_path: "./preprocess_Result/"
# src/process_output.py "recore nbest with smoothed sentence-level bleu."
vocab_file: ""
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local, it is better to use absolute path'
output_path: 'Training output path for local'
ann_file: 'Ann file, default is val.json.'
device_target: "device where the code will be implemented, default is Ascend"
checkpoint_path: "Checkpoint file path"
data_file: '/your/path/evaluation.mindrecord'
model_file: '/your/path/checkpoint_file'
output_file: './output_eval.txt'
data_json_path: '/path/to/egs/aishell/pickled_dataset/train/data.json or /path/to/egs/aishell/pickled_dataset/test/data.json'
chars_dict_path: '/path/to/egs/aishell/data/lang_1char/train_chars.txt'
distribute: "Run distribute, default is false."
epoch_size: "Epoch size, default is 91(1p),279(8p)"
device_id: "Device id, default is 0."
device_num: "Use device nums, default is 1."
enable_lossscale: "Use lossscale or not, default is true."
do_shuffle: "Enable shuffle for dataset, default is true."
enable_save_ckpt: "Enable save checkpoint, default is true."
save_checkpoint_steps: "Save checkpoint steps, default is 3753."
save_checkpoint_num: "Save checkpoint numbers, default is 30."
save_checkpoint_path: "Save checkpoint file path"
accumulation_steps: "Gradient accumulation steps, default is 1."
file_name: "output file name."
file_format: 'file format'
result_dir: "./result_Files"
result_path: "./preprocess_Result/"
vocab_file: "vocab file path."
input_file: 'Input raw text file (or comma-separated list of files).'
num_splits: 'The MindRecord file will be split into the number of partition.'
clip_to_max_len: 'clip sequences to maximum sequence length.'
max_seq_length: 'Maximum sequence length.'
---
device_target: ["Ascend", "GPU", "CPU"]
file_format: ["AIR", "ONNX", "MINDIR"]
distribute: ['true', 'false']
enable_lossscale: ['true', 'false']
do_shuffle: ['true', 'false']
enable_save_ckpt: ['true', 'false']
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformer evaluation script."""
import json
import os
import numpy as np
from mindspore import context
from mindspore import dtype as mstype
from mindspore import nn
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint
from mindspore.train.serialization import load_param_into_net
from tqdm import tqdm
from src.dataset import MsAudioDataset
from src.dataset import create_transformer_dataset
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
from src.model_utils.moxing_adapter import moxing_wrapper
from src.transformer_model import TransformerModel
config.dtype = mstype.float32
config.compute_type = mstype.float16
config.batch_size = config.batch_size_ev
config.hidden_dropout_prob = config.hidden_dropout_prob_ev
config.attention_probs_dropout_prob = config.attention_probs_dropout_prob_ev
class TransformerInferCell(nn.Cell):
"""
Encapsulation class of transformer network infer.
"""
def __init__(self, network):
super(TransformerInferCell, self).__init__(auto_prefix=False)
self.network = network
def construct(self, source_ids, source_mask):
predicted_ids = self.network(source_ids, source_mask)
return predicted_ids
def load_weights(model_path):
"""
Load checkpoint as parameter dict, support both npz file and mindspore checkpoint file.
"""
if model_path.endswith(".npz"):
ms_ckpt = np.load(model_path)
is_npz = True
else:
ms_ckpt = load_checkpoint(model_path)
is_npz = False
weights = {}
for msname in ms_ckpt:
infer_name = msname
if "tfm_decoder" in msname:
infer_name = "tfm_decoder.decoder." + infer_name
if is_npz:
weights[infer_name] = ms_ckpt[msname]
else:
weights[infer_name] = ms_ckpt[msname].data.asnumpy()
weights["tfm_decoder.decoder.tfm_embedding_lookup.embedding_table"] = \
weights["tfm_embedding_lookup.embedding_table"]
parameter_dict = {}
for name in weights:
parameter_dict[name] = Parameter(Tensor(weights[name]), name=name)
return parameter_dict
def modelarts_pre_process():
"""modelarts pre process"""
config.output_file = os.path.join(config.output_path, config.output_file)
config.data_file = os.path.join(config.data_file, config.data_file_name)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_transformer_eval():
"""
Transformer evaluation.
"""
context.set_context(
mode=context.GRAPH_MODE,
device_target=config.device_target,
reserve_class_name_in_scope=False,
device_id=get_device_id(),
)
dataset = create_transformer_dataset(
epoch_count=1,
rank_size=1,
rank_id=0,
do_shuffle='false',
data_json_path=config.data_json_path,
chars_dict_path=config.chars_dict_path,
batch_size=config.batch_size_ev,
)
char_list, _, _ = MsAudioDataset.process_dict(config.chars_dict_path)
tfm_model = TransformerModel(config=config, is_training=False, use_one_hot_embeddings=False)
parameter_dict = load_weights(config.model_file)
load_param_into_net(tfm_model, parameter_dict)
tfm_infer = TransformerInferCell(tfm_model)
model = Model(tfm_infer)
predictions = []
target_sents = []
for batch in tqdm(dataset.create_dict_iterator(output_numpy=True, num_epochs=1), total=dataset.get_dataset_size()):
target_sents.append(batch["target_eos_ids"])
source_feats = Tensor(batch["source_eos_features"], mstype.float32)
source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
predicted_ids = model.predict(source_feats, source_mask)
predictions.append(predicted_ids.asnumpy())
result_dict = dict()
sample_num = 0
for batch_out, batch_gt in zip(predictions, target_sents):
for i in range(config.batch_size):
if batch_out.ndim == 3:
batch_out = batch_out[:, 0]
predicted_tokens = [char_list[x] for x in batch_out[i].tolist()]
predict = " ".join(predicted_tokens)
gt_tokens = [char_list[x] for x in batch_gt[i].tolist() if x != -1]
gt = " ".join(gt_tokens)
result_dict[sample_num] = {
'output': predict,
'gt': gt,
}
sample_num += 1
with open(config.output_file, 'w') as file:
json.dump(result_dict, file, indent=2)
if __name__ == "__main__":
run_transformer_eval()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" evaluate CER for model results"""
import json
from pathlib import Path
import jiwer
from src.model_utils.config import config
def main():
"""evaluate CER"""
remove_non_words = jiwer.RemoveKaldiNonWords()
remove_space = jiwer.RemoveWhiteSpace()
preprocessing = jiwer.Compose([remove_non_words, remove_space])
with Path(config.output_file).open('r') as file:
output_data = json.load(file)
total_cer = 0
for sample in output_data.values():
res_text = preprocessing(sample['output'])
res_text = ' '.join(res_text)
gt_text = preprocessing(sample['gt'])
gt_text = ' '.join(gt_text)
cer = jiwer.wer(gt_text, res_text)
total_cer += cer
print('Resulting cer is ', (total_cer / len(output_data.values())) * 100)
if __name__ == '__main__':
main()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" export checkpoint file into models"""
import numpy as np
from mindspore import Tensor, context
from mindspore.train.serialization import load_param_into_net, export
from src.transformer_model import TransformerModel
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id
from eval import load_weights
config.batch_size = config.batch_size_ev
config.hidden_dropout_prob = config.hidden_dropout_prob_ev
config.attention_probs_dropout_prob = config.attention_probs_dropout_prob_ev
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
def modelarts_pre_process():
"""model arts pre process"""
@moxing_wrapper(pre_process=modelarts_pre_process)
def export_transformer():
""" export_transformer """
tfm_model = TransformerModel(config=config, is_training=False, use_one_hot_embeddings=False)
parameter_dict = load_weights(config.model_file)
load_param_into_net(tfm_model, parameter_dict)
source_ids = Tensor(np.ones((1, config.seq_length, config.input_feature_size)).astype(np.float32))
source_mask = Tensor(np.ones((1, config.seq_length)).astype(np.int32))
export(tfm_model, source_ids, source_mask, file_name=config.file_name, file_format=config.file_format)
if __name__ == '__main__':
export_transformer()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ctc evaluation"""
import os
import json
import argparse
from pathlib import Path
import numpy as np
from src.dataset import MsAudioDataset
import jiwer
parser = argparse.ArgumentParser(description="postprocess")
parser.add_argument('--chars_dict_path', type=str, default='', help='your/path/dataset/lang_1char/train_chars.txt')
args = parser.parse_args()
def run_eval():
'''eval_function'''
path1 = "./result_Files"
path2 = "./preprocess_Result/04_data"
chars_dict_path = args.chars_dict_path
char_list, _, _ = MsAudioDataset.process_dict(chars_dict_path)
file_name1 = os.listdir(path1)
file_name2 = os.listdir(path2)
size = len(file_name2)
label_dict = dict()
sample_num = 0
for i in range(size):
out_f_name = os.path.join(path1, file_name1[i])
out = np.fromfile(out_f_name, np.int32)
out_tokens = []
for x in out.tolist():
if x == 0:
break
out_tokens.append(char_list[x])
out_tokens.append(char_list[2])
out = " ".join(out_tokens)
str1 = file_name1[i]
file_name2_temp = str1[:len(str1)-6]+".bin"
label_f_name = os.path.join(path2, file_name2_temp)
labels = np.fromfile(label_f_name, np.int32)
gt_tokens = [char_list[x] for x in labels.tolist() if x != -1]
gt = " ".join(gt_tokens)
label_dict[sample_num] = {'output': out, 'gt': gt,}
sample_num += 1
with open('./preprocess_Result/labels_dict.json', 'w') as file:
json.dump(label_dict, file, indent=2)
remove_non_words = jiwer.RemoveKaldiNonWords()
remove_space = jiwer.RemoveWhiteSpace()
preprocessing = jiwer.Compose([remove_non_words, remove_space])
with Path('./preprocess_Result/labels_dict.json').open('r') as file:
output_data = json.load(file)
total_cer = 0
for sample in output_data.values():
res_text = preprocessing(sample['output'])
res_text = ' '.join(res_text)
gt_text = preprocessing(sample['gt'])
gt_text = ' '.join(gt_text)
cer = jiwer.wer(gt_text, res_text)
total_cer += cer
print('Resulting cer is ', (total_cer / len(output_data.values())) * 100)
if __name__ == "__main__":
run_eval()
# Dataset preparation
You can download aishell dataset by this [link](http://www.openslr.org/33/)
AISHELL数据集需要用Kaldi工具包进行预处理,你可以:
一、下载已经用kaldi处理好的aishell数据集,链接如下:
https://blog.csdn.net/m0_45973994/article/details/124891389
二、自己下载kaldi工具包,并处理,步骤如下:
1. cd your/path/to/prepare_aishell_data/
2. 下载aishell数据集[link](http://www.openslr.org/33/),并解压
3. git clone https://github.com/kaituoxu/Speech-Transformer
4. 下载kaldi,git clone https://github.com/kaldi-asr/kaldi.git
5. 把your/path/to/prepare_aishell_data/Speech-Transformer-master/egs/aishell/local文件夹下的parse_options.sh文件复制到your/path/to/prepare_aishell_data/kaldi/egs/aishell/s5/local
6. 用your/path/to/prepare_aishell_data/Speech-Transformer-master/egs/aishell/下的run.sh替换your/path/to/prepare_aishell_data/kaldi/egs/aishell/s5/下的run.sh
7. 进入kaldi下的tools进行make
7.1 执行 tools/extras/check_dependencies.sh
7.2 根据提示安装automake、autoconf、sox、gfortran、subversion等依赖。(apt-get install automake autoconf sox gfortran subversion)
7.3 执行tools/extras/install_mkl.sh
7.4 若7.3报错,则改用安装openblas替代,即执行tools/extras/install_openblas.sh
7.5 安装第三方工具openfast,命令:sudo make openfst
7.6 在tools下执行make -j 8
8. 进入kaldi下的src进行make
8.1 bash configure
8.2 make depend -j 8
8.3 make -j 8
9. 上述编译完成,开始执行run.sh
9.1 cd /prepare_aishell_data/kaldi/egs/aishell/s5/
9.2 把run.sh中的数据路径改为你自己的路径(your/path/to/prepare_aishell_data/data)
9.3 执行run.sh
9.4 stage 2步骤会出现错误,未发现bc命令。应安装bc,执行命令apt-get install bc
注:步骤7和8是为了生成fbank特征提取所需工具compute-fbank-feats、copy-feats等等。如果不进行make,源码中只有.cc文件,执行fbank特征提取(run.sh的第二步)时会找不到所需工具,会报错找不到compute-fbank-feats命令和copy-feats命令
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert Kaldi dataset"""
import argparse
import json
import pickle
from copy import deepcopy
from pathlib import Path
import kaldi_io
def convert_to_pickle(data_json_path, new_root_dir):
"""Convert kaldi dataset files"""
with Path(data_json_path).open('r', encoding="utf-8") as file:
dataset_dict = json.load(file)
new_dataset_dict = dict()
for sample_name, sample_info in dataset_dict['utts'].items():
new_sample_info = deepcopy(sample_info)
feature_path = sample_info['input'][0]['feat']
feature = kaldi_io.read_mat(feature_path)
new_feature_path = Path(new_root_dir) / (Path(feature_path).name.replace(':', '_') + '.pickle')
with new_feature_path.open('wb') as file:
pickle.dump(feature, file)
new_sample_info['input'][0]['feat'] = new_feature_path.as_posix()
new_dataset_dict[sample_name] = new_sample_info
with (Path(new_root_dir) / 'data.json').open('w') as file:
json.dump(new_dataset_dict, file, indent=2)
def main():
"""Main function"""
parser = argparse.ArgumentParser()
parser.add_argument('--processed-dataset-path')
args = parser.parse_args()
for dataset_split in ['train', 'dev', 'test']:
json_path = Path(args.processed_dataset_path) / 'dump' / dataset_split / 'deltafalse/data.json'
new_root_dir = Path(args.processed_dataset_path) / 'pickled_dataset' / dataset_split
new_root_dir.mkdir(exist_ok=True, parents=True)
convert_to_pickle(json_path, new_root_dir)
if __name__ == '__main__':
main()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GRU preprocess script."""
import os
import argparse
from src.dataset import create_transformer_dataset
from src.dataset import MsAudioDataset
parser = argparse.ArgumentParser('preprocess')
parser.add_argument('--data_path', type=str, default='', help='eval data dir')
parser.add_argument('--chars_dict_path', type=str, default='', help='your/path/dataset/lang_1char/train_chars.txt')
if __name__ == "__main__":
args = parser.parse_args()
mindrecord_file = args.data_path
if not os.path.exists(mindrecord_file):
print("dataset file {} not exists, please check!".format(mindrecord_file))
raise ValueError(mindrecord_file)
result_path = "./preprocess_Result"
chars_dict_path = args.chars_dict_path
char_list, _, _ = MsAudioDataset.process_dict(chars_dict_path)
test_batch_size = 1
dataset = create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true",
data_json_path=mindrecord_file, chars_dict_path=chars_dict_path, batch_size=1)
source_eos_features_path = os.path.join("./preprocess_Result", "00_data")
source_eos_mask_path = os.path.join("./preprocess_Result", "01_data")
target_sos_ids_path = os.path.join("./preprocess_Result", "02_data")
target_sos_mask_path = os.path.join(result_path, "03_data")
target_eos_ids_path = os.path.join(result_path, "04_data")
target_eos_mask_path = os.path.join(result_path, "05_data")
os.makedirs(source_eos_features_path)
os.makedirs(source_eos_mask_path)
os.makedirs(target_sos_ids_path)
os.makedirs(target_sos_mask_path)
os.makedirs(target_eos_ids_path)
os.makedirs(target_eos_mask_path)
for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
target_eos_ids = data["target_eos_ids"]
file_name = "speech_bs" + str(test_batch_size) + "_" + str(i) + ".bin"
data["source_eos_features"].tofile(os.path.join(source_eos_features_path, file_name))
data["source_eos_mask"].tofile(os.path.join(source_eos_mask_path, file_name))
data["target_sos_ids"].tofile(os.path.join(target_sos_ids_path, file_name))
data["target_sos_mask"].tofile(os.path.join(target_sos_mask_path, file_name))
data["target_eos_ids"].tofile(os.path.join(target_eos_ids_path, file_name))
data["target_eos_mask"].tofile(os.path.join(target_eos_mask_path, file_name))
print("=" * 20, "export bin files finished", "=" * 20)
easydict
pyyaml
jiwer
tqdm
decorator
kaldi_io
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ]; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train_ascend.sh [TRAIN_PATH] [DEVICE_NUM] [EPOCH_SIZE] [CONFIG_PATH] [RANK_TABLE_FILE]"
echo "for example: bash run_distribute_train_ascend.sh ../train.py 8 279 ../default_config.yaml ../rank_table_8pcs.json "
echo "It is better to use absolute path."
echo "=============================================================================================================="
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
TRAIN_PATH=$(get_real_path $1)
CONFIG_PATH=$(get_real_path $4)
RANK_TABLE_FILE=$(get_real_path $5)
echo $TRAIN_PATH
echo $CONFIG_PATH
echo $RANK_TABLE_FILE
if [ ! -f $TRAIN_PATH ]; then
echo "error: TRAIN_PATH=$TRAIN_PATH is not a file"
exit 1
fi
if [ ! -f $CONFIG_PATH ]; then
echo "error: CONFIG_PATH=$CONFIG_PATH is not a file"
exit 1
fi
if [ ! -f $RANK_TABLE_FILE ]; then
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
exit 1
fi
ulimit -u unlimited
export RANK_SIZE=$2
export EPOCH_SIZE=$3
export HCCL_CONNECT_TIMEOUT=6000
export RANK_TABLE_FILE=$RANK_TABLE_FILE
echo $RANK_SIZE
for ((i = 0; i <= $RANK_SIZE - 1; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp ../*.yaml ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py \
--config_path=$CONFIG_PATH \
--distribute="true" \
--device_target="Ascend" \
--epoch_size=$EPOCH_SIZE \
--device_num=$RANK_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--checkpoint_path="" \
--save_checkpoint_num=30 > log.txt 2>&1 &
cd ..
done
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval_ascend.sh DEVICE_TARGET DEVICE_ID DATA_JSON_PATH CKPT_PATH CONFIG_PATH"
echo "for example: bash run_eval_ascend.sh Ascend 0 /your/path/data.json /your/path/checkpoint_file ./default_config.yaml"
echo "Note: set the checkpoint and dataset path in default_config.yaml"
echo "=============================================================================================================="
exit 1;
fi
export DEVICE_TARGET=$1
export CONFIG_PATH=$5
DEVICE_ID=$2
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH3=$(get_real_path $3)
PATH4=$(get_real_path $4)
echo $PATH3
echo $PATH4
python eval.py \
--config_path=$CONFIG_PATH \
--device_target=$DEVICE_TARGET \
--device_id=$DEVICE_ID \
--data_json_path=$PATH3 \
--model_file=$PATH4
\ No newline at end of file
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 3 || $# -gt 4 ]]; then
echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID] [chars_dict_path]
DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
model=$(get_real_path $1)
chars_dict_path=$(get_real_path $4)
dataset_path=$(get_real_path $2)
device_id=0
if [ $# == 3 ]; then
device_id=$3
fi
echo "mindir name: "$model
echo "dataset path: "$dataset_path
echo "device id: "$device_id
echo "chars_dict_path: "$chars_dict_path
export ASCEND_HOME=/usr/local/Ascend/
if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp
else
export ASCEND_HOME=/usr/local/Ascend/latest/
export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH
export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH
export ASCEND_OPP_PATH=$ASCEND_HOME/opp
fi
function preprocess_data()
{
if [ -d preprocess_Result ]; then
rm -rf ./preprocess_Result
fi
mkdir preprocess_Result
python3.7 ../preprocess.py --data_path=$dataset_path --chars_dict_path=$chars_dict_path
}
function compile_app()
{
cd ../ascend310_infer/ || exit
bash build.sh &> build.log
}
function infer()
{
cd - || exit
if [ -d result_Files ]; then
rm -rf ./result_Files
fi
if [ -d time_Result ]; then
rm -rf ./time_Result
fi
mkdir result_Files
mkdir time_Result
../ascend310_infer/out/main --mindir_path=$model --input0_path=./preprocess_Result/00_data --device_id=$device_id &> infer.log
}
function cal_acc()
{
python3.7 ../postprocess.py --chars_dict_path=$chars_dict_path &> acc.log
}
preprocess_data
if [ $? -ne 0 ]; then
echo "preprocess dataset failed"
exit 1
fi
compile_app
if [ $? -ne 0 ]; then
echo "compile app code failed"
exit 1
fi
infer
if [ $? -ne 0 ]; then
echo " execute inference failed"
exit 1
fi
cal_acc
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1
fi
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_standalone_train_ascend.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE GRADIENT_ACCUMULATE_STEP"
echo "for example: bash run_standalone_train_ascend.sh GPU 0 91 1"
echo "It is better to use absolute path."
echo "=============================================================================================================="
exit 1;
fi
rm -rf run_standalone_train
mkdir run_standalone_train
cp -rf ./src/ train.py ./*.yaml ./run_standalone_train
cd run_standalone_train || exit
export DEVICE_TARGET=$1
export DEVICE_ID=$2
EPOCH_SIZE=$3
export GRADIENT_ACCUMULATE_STEP=$4
export CUDA_VISIBLE_DEVICES="$2"
python train.py \
--config_path="./default_config.yaml" \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_target=$DEVICE_TARGET \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--checkpoint_path="" \
--save_checkpoint_num=30 > log.txt 2>&1 &
cd ..
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformer beam search module."""
import numpy as np
from mindspore import dtype as mstype
from mindspore import nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
INF = 1. * 1e9
class LengthPenalty(nn.Cell):
"""
Normalize scores of translations according to their length.
Args:
weight (float): Weight of length penalty. Default: 1.0.
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
"""
def __init__(self,
weight=1.0,
compute_type=mstype.float32):
super(LengthPenalty, self).__init__()
self.weight = weight
self.add = P.Add()
self.pow = P.Pow()
self.div = P.RealDiv()
self.cast = P.Cast()
self.five = Tensor(5.0, mstype.float32)
self.six = Tensor(6.0, mstype.float32)
def construct(self, length_tensor):
"""apply length penalty"""
length_tensor = self.cast(length_tensor, mstype.float32)
output = self.add(length_tensor, self.five)
output = self.div(output, self.six)
output = self.pow(output, self.weight)
return output
class TileBeam(nn.Cell):
"""
TileBeam.
Args:
beam_width (int): beam width setting. Default: 4.
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
"""
def __init__(self,
beam_width,
compute_type=mstype.float32):
super(TileBeam, self).__init__()
self.beam_width = beam_width
self.expand = P.ExpandDims()
self.tile = P.Tile()
self.reshape = P.Reshape()
self.shape = P.Shape()
def construct(self, input_tensor):
"""
input_tensor: shape [batch, dim1, dim2]
output_tensor: shape [batch*beam, dim1, dim2]
"""
shape = self.shape(input_tensor)
input_tensor = self.expand(input_tensor, 1)
tile_shape = (1,) + (self.beam_width,)
for _ in range(len(shape)-1):
tile_shape = tile_shape + (1,)
output = self.tile(input_tensor, tile_shape)
out_shape = (shape[0]*self.beam_width,) + shape[1:]
output = self.reshape(output, out_shape)
return output
class Mod(nn.Cell):
"""
Mod function.
Args:
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
"""
def __init__(self,
compute_type=mstype.float32):
super(Mod, self).__init__()
self.compute_type = compute_type
self.floor_div = P.FloorDiv()
self.sub = P.Sub()
self.multiply = P.Mul()
def construct(self, input_x, input_y):
"""apply mod"""
x = self.floor_div(input_x, input_y)
x = self.multiply(x, input_y)
x = self.sub(input_x, x)
return x
class BeamSearchDecoder(nn.Cell):
"""
Beam search decoder.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence.
vocab_size (int): Size of vocabulary.
decoder (:class:`TransformerDecoderStep`): Decoder module.
beam_width (int): beam width setting. Default: 4.
length_penalty_weight (float): Weight of length penalty. Default: 1.0.
max_decode_length (int): max decode length. Default: 128.
sos_id (int): Id of sequence start token. Default: 1.
eos_id (int): Id of sequence end token. Default: 2.
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
"""
def __init__(self,
batch_size,
vocab_size,
decoder,
beam_width=4,
length_penalty_weight=1.0,
max_decode_length=128,
sos_id=1,
eos_id=2,
compute_type=mstype.float32):
super(BeamSearchDecoder, self).__init__(auto_prefix=False)
self.batch_size = batch_size
self.vocab_size = vocab_size
self.beam_width = beam_width
self.length_penalty_weight = length_penalty_weight
self.max_decode_length = max_decode_length
self.decoder = decoder
self.add = P.Add()
self.expand = P.ExpandDims()
self.reshape = P.Reshape()
self.shape_flat = (-1,)
self.shape = P.Shape()
self.zero_tensor = Tensor(np.zeros([batch_size, beam_width]), mstype.float32)
self.ninf_tensor = Tensor(np.full([batch_size, beam_width], -INF), mstype.float32)
self.select = P.Select()
self.flat_shape = (batch_size, beam_width * vocab_size)
self.topk = P.TopK(sorted=True)
self.floor_div = P.FloorDiv()
self.vocab_size_tensor = Tensor(self.vocab_size, mstype.int32)
self.real_div = P.RealDiv()
self.mod = Mod()
self.equal = P.Equal()
self.eos_ids = Tensor(np.full([batch_size, beam_width], eos_id), mstype.int32)
beam_ids = np.tile(np.arange(beam_width).reshape((1, beam_width)), [batch_size, 1])
self.beam_ids = Tensor(beam_ids, mstype.int32)
batch_ids = np.arange(batch_size*beam_width).reshape((batch_size, beam_width)) // beam_width
self.batch_ids = Tensor(batch_ids, mstype.int32)
self.concat = P.Concat(axis=-1)
self.gather_nd = P.GatherNd()
self.greater_equal = P.GreaterEqual()
self.sub = P.Sub()
self.cast = P.Cast()
self.zeroslike = P.ZerosLike()
# init inputs and states
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
init_scores = np.tile(np.array([[0.] + [-INF]*(beam_width-1)]), [batch_size, 1])
self.init_scores = Tensor(init_scores, mstype.float32)
self.init_finished = Tensor(np.zeros([batch_size, beam_width], dtype=np.bool))
self.init_length = Tensor(np.zeros([batch_size, beam_width], dtype=np.int32))
self.length_penalty = LengthPenalty(weight=length_penalty_weight)
self.one = Tensor(1, mstype.int32)
self.shape = P.Shape()
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_seq, state_finished, state_length):
"""
One step for decode
"""
seq_length = self.shape(enc_states)[1]
log_probs = self.decoder(cur_input_ids, enc_states, enc_attention_mask, seq_length)
log_probs = self.reshape(log_probs, (self.batch_size, self.beam_width, self.vocab_size))
# select topk indices
total_log_probs = self.add(log_probs, self.expand(state_log_probs, -1))
# mask finished beams
mask_tensor = self.select(state_finished, self.ninf_tensor, self.zero_tensor)
total_log_probs = self.add(total_log_probs, self.expand(mask_tensor, -1))
# reshape scores to [batch, beam*vocab]
flat_scores = self.reshape(total_log_probs, self.flat_shape)
# select topk
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width)
temp = topk_indices
beam_indices = self.zeroslike(topk_indices)
for _ in range(self.beam_width - 1):
temp = self.sub(temp, self.vocab_size_tensor)
res = self.cast(self.greater_equal(temp, 0), mstype.int32)
beam_indices = beam_indices + res
word_indices = topk_indices - beam_indices * self.vocab_size_tensor
#======================================================================
# mask finished indices
beam_indices = self.select(state_finished, self.beam_ids, beam_indices)
word_indices = self.select(state_finished, self.eos_ids, word_indices)
topk_scores = self.select(state_finished, state_log_probs, topk_scores)
###### put finished sequences to the end
# sort according to scores with -inf for finished beams
tmp_log_probs = self.select(
self.equal(word_indices, self.eos_ids),
self.ninf_tensor,
topk_scores)
_, tmp_indices = self.topk(tmp_log_probs, self.beam_width)
# update
tmp_gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(tmp_indices, -1)))
beam_indices = self.gather_nd(beam_indices, tmp_gather_indices)
word_indices = self.gather_nd(word_indices, tmp_gather_indices)
topk_scores = self.gather_nd(topk_scores, tmp_gather_indices)
###### generate new beam_search states
# gather indices for selecting alive beams
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(beam_indices, -1)))
# length add 1 if not finished in the previous step
length_add = self.add(state_length, self.one)
state_length = self.select(state_finished, state_length, length_add)
state_length = self.gather_nd(state_length, gather_indices)
# concat seq
seq = self.gather_nd(state_seq, gather_indices)
state_seq = self.concat((seq, self.expand(word_indices, -1)))
# new finished flag and log_probs
state_finished = self.equal(word_indices, self.eos_ids)
state_log_probs = topk_scores
###### generate new inputs and decoder states
cur_input_ids = self.reshape(state_seq, (self.batch_size*self.beam_width, -1))
return cur_input_ids, state_log_probs, state_seq, state_finished, state_length
def construct(self, enc_states, enc_attention_mask):
"""Get beam search result."""
cur_input_ids = self.start_ids
# beam search states
state_log_probs = self.init_scores
state_seq = self.init_seq
state_finished = self.init_finished
state_length = self.init_length
for _ in range(self.max_decode_length):
# run one step decoder to get outputs of the current step
# shape [batch*beam, 1, vocab]
cur_input_ids, state_log_probs, state_seq, state_finished, state_length = self.one_step(
cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, state_length)
# add length penalty scores
penalty_len = self.length_penalty(state_length)
# get penalty length
log_probs = self.real_div(state_log_probs, penalty_len)
# sort according to scores
_, top_beam_indices = self.topk(log_probs, self.beam_width)
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
# sort sequence
predicted_ids = self.gather_nd(state_seq, gather_indices)
# take the first one
predicted_ids = predicted_ids[::, 0:1:1, ::]
return predicted_ids
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment