Skip to content
Snippets Groups Projects
Commit ab7c2e0a authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!874 [武汉理工大学][高校贡献][Mindspore][HyperText] GPU+Ascend

Merge pull request !874 from 徐聪/master
parents cadd3781 05549ccb
No related branches found
No related tags found
No related merge requests found
Showing
with 1812 additions and 0 deletions
# 目录
- [HyperText概述](#HyperText概述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [TNEWS数据集](#TNEWS数据集)
- [IFLYTEK数据集](#IFLYTEK数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [TNEWS](#TNEWS)
- [IFLYTEK](#IFLYTEK)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [评估性能](#评估性能)
- [tnews上的HyperText](#tnews上的HyperText)
- [iflytek上的HyperText](#iflytek上的HyperText)
- [导出过程](#导出过程)
- [ModelZoo主页](#ModelZoo主页)
# HyperText概述
自然语言数据呈现出树状的层次结构,如WordNet中的hypernymhyponym关系。考虑到双曲空间自然适合于树状分层数据的建模,原作者提出了一个名为HyperText的新模型,通过赋予FastText以双曲几何形状来实现高效的文本分类。经验表明,HyperText在一系列文本分类任务中的表现优于FastText,而且参数大大减少。
论文:[HyperText: Endowing FastText with Hyperbolic Geometry](https://arxiv.org/abs/2010.16143)
# 模型架构
HyperText 基于双曲空间的庞加莱球模型。首先利用单词或 ngram 的庞加莱球嵌入捕捉自然语言句子中的潜在层次结构,然后使用Einstein中点作为池化方法来强调语义特定词(包含更多信息出现频率低于一般词),最后使用Mobius线性变换作为双曲线分类器。
![img](./model.png)
# 数据集
## TNEWS数据集
下载:[TNEWS](https://bj.bcebos.com/paddlehub-dataset/tnews.tar.gz)
解压至 data/tnews
## IFLYTEK数据集
下载:[IFLYTEK](https://bj.bcebos.com/paddlehub-dataset/iflytek.tar.gz)
解压至 data/iflytek
```text
│── data
│──iflytek_public # 处理后数据集
│──iflytek # 原始数据集
│──tnews_public # 处理后数据集
│──tnews # 原始数据集
```
# 环境要求
- 硬件(Ascend处理器)
- 准备Ascend或GPU处理器搭建硬件环境。
- 框架
- [MindSpore](https://gitee.com/mindspore/mindspore)
- 更多关于Mindspore的信息,请查看以下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore和下载数据集后,您可以按照如下步骤进行训练和评估:
- Ascend和GPU处理器环境运行
```shell
# 数据处理示例
cd ./scripts
bash data_process.sh [DATA_DIR] [OUT_DATA_DIR] [DATASET_TYPE]
# 运行训练示例
cd ./scripts
bash run_standalone_train.sh [DATASET_DIR] [DATASET_TYPE] [DEVICE]
# 运行评估示例
cd ./scripts
bash run_eval.sh [DATASET_DIR] [DATASET_TYPE] [MODEL_PATH] [DEVICE]
```
- device should be in["Ascend","GPU"].
# 脚本说明
## 脚本及样例代码
```text
│──HyperText
│── README.md # hypertext相关声明
│── scripts
│ │──run_standalone_train.sh # 训练的shell脚本
│ │──run_eval.sh # 评估的shell脚本
│ │──data_process.sh # 在Ascend上评估的shell脚本
│ │──run_gpu_distributed_train # 在GPU上分布式训练的shell脚本
│── output # 输出文件,包括保存的模型,训练日志,评估日志
│── src
│ │──config.py # 参数文件
│ │──dataset.py # 数据集文件
│ │──data_preprocessing.py # 数据集处理文件
│ │──hypertext.py # 模型文件
│ │──hypertext_train.py # 模型训练文件
│ │──math_utils.py # 工具模型
│ │──mobius_linear.py # mobius_linear文件
│ │──poincare.py # poincare算子文件
│ │──radam_optimizer.py # 优化器
│── train.py # 训练脚本
│── eval.py # 评估脚本
│── create_dataset.py # 数据处理脚本
│── export.py # 将checkpoint文件导出到air/mindir
```
## 脚本参数
在Config.py中可以同时配置参数
### TNEWS
```text
num_epochs 2 # epoch数量
batch_size 32 # batch数量
max_length 40 # 最大文本长度
learning_rate 0.011 # 学习率
embed 20 # embed维度
bucket 1500000 # 词和Ngram数量
wordNgrams 2 # wordNgram的数量
eval_step 100 # 验证步数
min_freq 1 # min_freq
lr_decay_rate 0.96 #decay率
```
### IFLYTEK
```text
num_epochs 2 # epoch数量
batch_size 32 # batch数量
max_length 1000 # 最大文本长度
learning_rate 0.013 # 学习率
embed 80 # embed维度
bucket 2000000 # 词和Ngram数量
wordNgrams 2 # wordNgram的数量
eval_step 50 # 验证步数
min_freq 1 # min_freq
lr_decay_rate 0.94 #decay率
```
更多配置细节请参考 src/config.py。
# 训练过程
## 训练
- 处理原始数据集:
```bash
tnews
python create_dataset.py --data_dir /data/tnews --out_data_dir /data/tnews_public --datasetType tnews
iflytek
python create_dataset.py --data_dir /data/iflytek --out_data_dir /data/iflytek_public --datasetType iflytek
```
- Ascend处理器环境运行
```bash
tnews
python train.py --datasetdir ./data/tnews_public --datasetType tnews --device Ascend
iflytek
python train.py --datasetdir ./data/iflytek_public --datasetType iflytek --device Ascend
```
- GPU处理器环境运行
```bash
tnews
python train.py --datasetdir ./data/tnews_public --datasetType tnews --device GPU
iflytek
python train.py --datasetdir ./data/iflytek_public --datasetType iflytek --device GPU
```
## 分布式训练
- GPU处理器环境运行
```bash
tnews
mpirun -n {device_num} python train.py --datasetdir ./data/tnews_public --datasetType tnews --device GPU --run_distribute True
iflytek
mpirun -n {device_num} python train.py --datasetdir ./data/iflytek_public --datasetType iflytek --device GPU --run_distribute True
```
其中:{device_num}代表并行的卡的数量,例如:4或者8
# 评估过程
## 评估
把训练生成的ckpt文件放入./output/文件夹下
- Ascend处理器环境运行
```bash
tnews
python eval.py --datasetdir ./data/tnews_public --modelPath ./output/hypertext_tnews.ckpt --datasetType tnews --device Ascend
iflytek
python eval.py --datasetdir ./data/iflytek_public --datasetType iflytek --modelPath ./output/hypertext_iflytek.ckpt --device Ascend
```
- GPU处理器环境运行
```bash
tnews
python eval.py --datasetdir ./data/tnews_public --modelPath ./output/hypertext_tnews.ckpt --datasetType tnews --device GPU
iflytek
python eval.py --datasetdir ./data/iflytek_public --datasetType iflytek --modelPath ./output/hypertext_iflytek.ckpt --device GPU
```
### 评估性能
#### tnews上的HyperText
| 参数 | Ascend 910 | GPU |
| -------------------| -------------------------------------- | -------------------------------------- |
| 模型版本 | HyperText | HyperText |
| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 | GPU(Tesla V100 SXM2);CPU:2.1GHz,24核;内存:128G |
| 上传日期 | 2021-11-18 ; | 2021-11-18 |
| MindSpore版本 | 1.3.0 | 1.5.0 |
| 数据集 | tnews | tnews |
| 训练参数 | epoch=2, batch_size = 32 | epoch=2, batch_size = 32 |
| 优化器 | radam | radam|
| 损失函数 | SoftmaxCrossEntropyWithLogits | SoftmaxCrossEntropyWithLogits |
| 输出 | 精度 | 精度 |
| 损失 | 0.9087 | 0.905 |
| 速度 | 1958.810毫秒/步(8卡) | 315.949毫秒/步(8卡) |
#### iflytek上的HyperText
| 参数 | Ascend 910 | GPU |
| -------------------------- | -------------------------------------- | -------------------------------------- |
| 模型版本 | HyperText | HyperText |
| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 Euler2.8 | GPU(Tesla V100 SXM2);CPU:2.1GHz,24核;内存:128G |
| 上传日期 | 2021-11-18 ; | 2021-11-18 |
| MindSpore版本 | 1.3.0 | 1.5.0 |
| 数据集 | iflytek | iflytek |
| 训练参数 | epoch=2, batch_size = 32 | epoch=2, batch_size = 32 |
| 优化器 | radam | radam|
| 损失函数 | SoftmaxCrossEntropyWithLogits | SoftmaxCrossEntropyWithLogits |
| 输出 | 精度 | 精度 |
| 损失 | 0.57 | 0.5776 |
| 速度 | 395.895毫秒/步(8卡) | 597.672毫秒/步(8卡) |
tnews多卡精度: 0.8833
iflytek多卡精度: 0.556
# 导出过程
可以使用如下命令导出mindir文件
- Ascend处理器环境运行
```shell
tnews
python export.py --modelPath ./output/hypertext_tnews.ckpt --datasetType tnews --device Ascend
iflytek
python export.py --modelPath ./output/hypertext_iflytek.ckpt --datasetType iflytek --device Ascend
```
- GPU处理器环境运行
```shell
tnews
python export.py --modelPath ./output/hypertext_tnews.ckpt --datasetType tnews --device GPU
iflytek
python export.py --modelPath ./output/hypertext_iflytek.ckpt --datasetType iflytek --device GPU
```
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/models)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
preprocess corpus and obtain mindrecord file.
"""
import argparse
import os
from src.data_preprocessing import changeIflytek, changeTnews
parser = argparse.ArgumentParser(description='preprocess corpus and obtain mindrecord.')
parser.add_argument('--data_dir', type=str, default='/data/tnews/', help='the directory of data.')
parser.add_argument('--out_data_dir', type=str, default='/data/tnews_public/',
help='the directory of file processing output ')
parser.add_argument('--datasetType', default='tnews', type=str, help='iflytek/tnews')
args = parser.parse_args()
def create_dir_not_exist(path):
if not os.path.exists(path):
os.mkdir(path)
create_dir_not_exist(args.out_data_dir)
if args.datasetType == 'iflytek':
changeIflytek(args.data_dir, args.out_data_dir)
if args.datasetType == 'tnews':
changeTnews(args.data_dir, args.out_data_dir)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval file"""
import argparse
from mindspore import load_checkpoint, load_param_into_net, context
from mindspore.ops import Squeeze, Argmax
from mindspore.common import dtype as mstype
from mindspore import numpy as mnp
from src.config import Config
from src.dataset import build_dataset, build_dataloader
from src.hypertext import HModel
parser = argparse.ArgumentParser(description='HyperText Text Classification')
parser.add_argument('--model', type=str, default='HyperText',
help='HyperText')
parser.add_argument('--modelPath', default='./output/hypertext_iflytek.ckpt', type=str, help='save model path')
parser.add_argument('--datasetdir', default='./data/iflytek_public', type=str,
help='dataset dir iflytek_public tnews_public')
parser.add_argument('--batch_size', default=32, type=int, help='batch_size')
parser.add_argument('--datasetType', default='iflytek', type=str, help='iflytek/tnews')
parser.add_argument('--device', default='GPU', type=str, help='device GPU Ascend')
args = parser.parse_args()
config = Config(args.datasetdir, None, args.device)
if args.datasetType == 'tnews':
config.useTnews()
else:
config.useIflyek()
if config.device == 'GPU':
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
elif config.device == 'Ascend':
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
vocab, train_data, dev_data, test_data = build_dataset(config, use_word=True, min_freq=int(config.min_freq))
test_iter = build_dataloader(test_data, config.batch_size, config.max_length)
config.n_vocab = len(vocab)
model_path = args.modelPath
hmodel = HModel(config).to_float(mstype.float16)
param_dict = load_checkpoint(model_path)
load_param_into_net(hmodel, param_dict)
squ = Squeeze(-1)
argmax = Argmax(output_type=mstype.int32)
cur, total = 0, 0
print('----------start test model-------------')
for d in test_iter.create_dict_iterator():
hmodel.set_train(False)
out = hmodel(d['ids'], d['ngrad_ids'])
predict = argmax(out)
acc = predict == squ(d['label'])
acc = mnp.array(acc, dtype=mnp.float16)
cur += (mnp.sum(acc, -1))
total += len(acc)
print('acc:', cur / total)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export checkpoint file into air models"""
import argparse
import numpy as np
from mindspore import Tensor, context
from mindspore.nn import Cell
from mindspore.ops import ArgMaxWithValue
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import export
from src.config import Config
from src.hypertext import HModel
parser = argparse.ArgumentParser(description="hypertext export")
parser.add_argument('--modelPath', default='./output/hypertext_iflytek.ckpt', type=str, help='save model path')
parser.add_argument('--datasetType', default='iflytek', type=str, help='iflytek/tnews')
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument("--file_name",
type=str,
default="hypertext",
help="output file name.")
parser.add_argument('--device', default='GPU', type=str, help='device GPU Ascend')
args = parser.parse_args()
config = Config(None, None, args.device)
if args.datasetType == 'tnews':
config.useTnews()
config.n_vocab = 147919 # vocab size
config.num_classes = 15 # label size
else:
config.useIflyek()
config.n_vocab = 118133 # vocab size
config.num_classes = 119 # label size
if config.device == 'GPU':
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
elif config.device == 'Ascend':
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
class HyperTextTextInferExportCell(Cell):
"""
HyperText network infer.
"""
def __init__(self, network):
"""init fun"""
super(HyperTextTextInferExportCell, self).__init__(auto_prefix=False)
self.network = network
self.argmax = ArgMaxWithValue(axis=1, keep_dims=True)
def construct(self, x1, x2):
"""construct hypertexttext infer cell"""
predicted_idx = self.network(x1, x2)
predicted_idx = self.argmax(predicted_idx)
return predicted_idx
def run_export():
hmodel = HModel(config)
param_dict = load_checkpoint(args.modelPath)
load_param_into_net(hmodel, param_dict)
file_name = args.file_name + '_' + args.datasetType
ht_infer = HyperTextTextInferExportCell(hmodel)
x1 = Tensor(np.ones((args.batch_size, config.max_length)).astype(np.int32))
x2 = Tensor(np.ones((args.batch_size, config.max_length)).astype(np.int32))
export(ht_infer, x1, x2, file_name=file_name, file_format='MINDIR')
if __name__ == '__main__':
run_export()
research/nlp/hypertext/model.png

307 KiB

tqdm
pkuseg
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
DATA_DIR=$1
OUT_DATA_DIR=$2
DATASET_TYPE=$3
dir="../output"
if [ ! -d "$dir" ];then
mkdir $dir
fi
python ../create_dataset.py --data_dir $DATA_DIR --out_data_dir $OUT_DATA_DIR --datasetType $DATASET_TYPE &> ../output/data_process_log &
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
DATASET_DIR=$1
DATASET_TYPE=$2
MODEL_PATH=$3
DEVICE=$4
dir="../output"
if [ ! -d "$dir" ];then
mkdir $dir
fi
python ../eval.py --datasetdir $DATASET_DIR --datasetType $DATASET_TYPE --modelPath $MODEL_PATH --device $DEVICE &> ../output/eval_log &
\ No newline at end of file
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
DATASET_DIR=$1
DATASET_TYPE=$2
DEVICE=$3
dir="../output"
if [ ! -d "$dir" ];then
mkdir $dir
fi
python ../train.py --datasetdir $DATASET_DIR --datasetType $DATASET_TYPE --device $DEVICE &> ../output/standalone_log &
\ No newline at end of file
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config file"""
import os
class Config:
"""hyperparameter configuration"""
def __init__(self, datasetdir, outputdir, device):
self.model_name = 'HyperText'
if datasetdir:
self.train_path = os.path.join(datasetdir, 'train.txt')
self.dev_path = os.path.join(datasetdir, 'dev.txt')
self.test_path = os.path.join(datasetdir, 'test.txt')
self.vocab_path = os.path.join(datasetdir, 'vocab.txt')
self.labels_path = os.path.join(datasetdir, 'labels.txt')
self.class_list = []
if outputdir:
self.save_path = os.path.join(outputdir, self.model_name + '.ckpt')
self.log_path = os.path.join(outputdir, self.model_name + '.log')
if not os.path.exists(outputdir):
os.makedirs(outputdir)
self.device = device
self.dropout = 0.5
self.outputdir = outputdir
self.num_classes = len(self.class_list) # label number
self.n_vocab = 0
self.num_epochs = 30
self.wordNgrams = 2
self.batch_size = 32
self.max_length = 1000
self.learning_rate = 1e-2
self.bucket = 20000 # word and ngram vocab size
self.lr_decay_rate = 0.96
def useTnews(self):
"""use tnew config"""
self.dropout = 0.0
self.num_classes = len(self.class_list) # label number
self.n_vocab = 0
self.wordNgrams = 2
self.datasetType = 'tnews'
self.max_length = 40
self.embed = 20
self.eval_step = 100
self.min_freq = 1
self.learning_rate = 0.011
self.bucket = 1500000 # word and ngram vocab size
self.lr_decay_rate = 0.96
def useIflyek(self):
"""use iflytek config"""
self.dropout = 0.0
self.datasetType = 'iflytek'
self.num_classes = len(self.class_list) # label number
self.n_vocab = 0
self.wordNgrams = 2
self.max_length = 1000
self.embed = 80
self.eval_step = 50
self.min_freq = 1
self.learning_rate = 0.013
self.bucket = 2000000 # word and ngram vocab size
self.lr_decay_rate = 0.94
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dataprocess"""
import os
import pkuseg
from tqdm import tqdm
seg = pkuseg.pkuseg()
current_path = os.path.abspath(os.path.dirname(os.getcwd()))
def changeListToText(content):
"""change list to text"""
wordList = seg.cut(content)
res = ""
for item in wordList:
res = res + ' ' + item
return res
def changeIflytek(in_data_dir='', out_data_dir=''):
"""change iflytek"""
changeList = ['dev.txt', 'train.txt', 'test.txt']
for name in changeList:
print(name)
data = []
with open(current_path + in_data_dir + "/" + name, 'r', encoding='utf-8') as f:
line = f.readline()
while line:
spData = line.split('_!_')
content = spData[1].strip('\n').replace('\t', '')
data.append({'content': content, 'label': spData[0]})
line = f.readline()
with open(current_path + out_data_dir + "/" + name, "w", encoding='utf-8') as f:
for d in tqdm(data):
content = changeListToText(d['content'])
f.write(content + '\t' + d['label'] + '\n')
def changeTnews(in_data_dir='', out_data_dir=''):
"""change tnews"""
changeDict = {'toutiao_category_dev.txt': 'dev.txt', 'toutiao_category_train.txt': 'train.txt',
'toutiao_category_test.txt': 'test.txt'}
for k in changeDict:
print(k)
print(changeDict[k])
data = []
with open(current_path + in_data_dir + "/" + k, 'r', encoding='utf-8') as f:
line = f.readline()
while line:
spData = line.split('_!_')
content = spData[3].strip('\n').replace('\t', '')
data.append({'content': content, 'label': spData[1]})
line = f.readline()
with open(current_path + out_data_dir + "/" + changeDict[k], "w", encoding='utf-8') as f:
for d in tqdm(data):
content = changeListToText(d['content'])
f.write(content + '\t' + d['label'] + '\n')
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""utils"""
import os
import time
import random
from datetime import timedelta
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.common.dtype as mstype
MAX_VOCAB_SIZE = 5000000
UNK, PAD = '<UNK>', '<PAD>'
def hash_str(gram_str):
"""hash fun"""
gram_bytes = bytes(gram_str, encoding='utf-8')
hash_size = 18446744073709551616
h = 2166136261
for gram in gram_bytes:
h = h ^ gram
h = (h * 1677619) % hash_size
return h
def addWordNgrams(hash_list, n, bucket):
"""add word grams"""
ngram_hash_list = []
len_hash_list = len(hash_list)
for index, hash_val in enumerate(hash_list):
bound = min(len_hash_list, index + n)
for i in range(index + 1, bound):
hash_val = hash_val * 116049371 + hash_list[i]
ngram_hash_list.append(hash_val % bucket)
return ngram_hash_list
def build_vocab(file_path, tokenizer, max_size, min_freq):
"""build vocab"""
vocab_dic = {}
label_set = set()
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
line_splits = line.split("\t")
if len(line_splits) != 2:
print(line)
content, label = line_splits
label_set.add(label.strip())
for word in tokenizer(content.strip()):
vocab_dic[word] = vocab_dic.get(word, 0) + 1
vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]
vocab_list = [[PAD, 111101], [UNK, 111100]] + vocab_list
vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}
base_datapath = os.path.dirname(file_path)
with open(os.path.join(base_datapath, "vocab.txt"), "w", encoding="utf-8") as f:
for w, c in vocab_list:
f.write(str(w) + " " + str(c) + "\n")
# 增加两个demo
f.write("4654654#%$#%$#" + " " + str(1) + "\n")
f.write("46#$%54#%$#%$#" + " " + str(1) + "\n")
with open(os.path.join(base_datapath, "labels.txt"), "w", encoding="utf-8") as fr:
labels_list = list(label_set)
labels_list.sort()
for l in labels_list:
fr.write(l + "\n")
return vocab_dic, list(label_set)
def _pad(data, pad_id, width=-1):
"""pad function"""
if width == -1:
width = max(len(d) for d in data)
rtn_data = [d + [pad_id] * (width - len(d)) for d in data]
return rtn_data
def load_vocab(vocab_path, max_size, min_freq):
"""load vocab"""
vocab = {}
with open(vocab_path, 'r', encoding="utf-8") as fhr:
for line in fhr:
line = line.strip()
line = line.split(' ')
if len(line) != 2:
continue
token, count = line
vocab[token] = int(count)
sorted_tokens = sorted([item for item in vocab.items() if item[1] >= min_freq], key=lambda x: x[1], reverse=True)
sorted_tokens = sorted_tokens[:max_size]
all_tokens = [[PAD, 0], [UNK, 0]] + sorted_tokens
vocab = {item[0]: i for i, item in enumerate(all_tokens)}
return vocab
def load_labels(label_path):
"""load labels"""
labels = []
with open(label_path, 'r', encoding="utf-8") as fhr:
for line in fhr:
line = line.strip()
if line not in labels:
labels.append(line)
return labels
def build_dataset(config, use_word, min_freq=5):
"""build dataset"""
print("use min words freq:%d" % (min_freq))
if use_word:
tokenizer = lambda x: x.split(' ') # word-level
else:
tokenizer = lambda x: [y for y in x] # char-level
_ = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=min_freq)
vocab = load_vocab(config.vocab_path, max_size=MAX_VOCAB_SIZE, min_freq=min_freq)
print("Vocab size:", len(vocab))
labels = load_labels(config.labels_path)
print("label size:", len(labels))
train = TextDataset(
file_path=config.train_path,
vocab=vocab,
labels=labels,
tokenizer=tokenizer,
wordNgrams=config.wordNgrams,
buckets=config.bucket,
device=config.device,
max_length=config.max_length,
nraws=80000,
shuffle=True
)
dev = TextDataset(
file_path=config.dev_path,
vocab=vocab,
labels=labels,
tokenizer=tokenizer,
wordNgrams=config.wordNgrams,
buckets=config.bucket,
device=config.device,
max_length=config.max_length,
nraws=80000,
shuffle=False
)
test = TextDataset(
file_path=config.test_path,
vocab=vocab,
labels=labels,
tokenizer=tokenizer,
wordNgrams=config.wordNgrams,
buckets=config.bucket,
device=config.device,
max_length=config.max_length,
nraws=80000,
shuffle=False
)
config.class_list = labels
config.num_classes = len(labels)
return vocab, train, dev, test
class TextDataset:
"""textdataset struct"""
def __init__(self, file_path, vocab, labels, tokenizer, wordNgrams,
buckets, device, max_length=32, nraws=80000, shuffle=False):
file_raws = 0
with open(file_path, 'r', encoding="utf-8") as f:
for _ in f:
file_raws += 1
self.file_path = file_path
self.file_raws = file_raws
if file_raws < 200000:
self.nraws = file_raws
else:
self.nraws = nraws
self.shuffle = shuffle
self.vocab = vocab
self.labels = labels
self.tokenizer = tokenizer
self.wordNgrams = wordNgrams
self.buckets = buckets
self.max_length = max_length
self.device = device
def process_oneline(self, line):
""" process """
line = line.strip()
content, label = line.split('\t')
if content == 0:
content = "0"
tokens = self.tokenizer(content.strip())
seq_len = len(tokens)
if seq_len > self.max_length:
tokens = tokens[:self.max_length]
token_hash_list = [hash_str(token) for token in tokens]
ngram = addWordNgrams(token_hash_list, self.wordNgrams, self.buckets)
ngram_pad_size = int((self.wordNgrams - 1) * (self.max_length - self.wordNgrams / 2))
if len(ngram) > ngram_pad_size:
ngram = ngram[:ngram_pad_size]
tokens_to_id = [self.vocab.get(token, self.vocab.get(UNK)) for token in tokens]
y = self.labels.index(label.strip())
return tokens_to_id, ngram, y
def initial(self):
"""init"""
self.finput = open(self.file_path, 'r', encoding="utf-8")
self.samples = list()
for _ in range(self.nraws):
line = self.finput.readline()
if line:
preprocess_data = self.process_oneline(line)
self.samples.append(preprocess_data)
else:
break
self.current_sample_num = len(self.samples)
self.index = list(range(self.current_sample_num))
if self.shuffle:
random.shuffle(self.samples)
self.finput.close()
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
one_sample = self.samples[idx]
return one_sample
class DataGenerator:
"""data generator"""
def __init__(self, dataset, max_length):
"""init function"""
self.ids = []
self.ngrad_ids = []
self.label = []
for i in range(len(dataset)):
ids_item = dataset[i][0]
ids_item = self.padding(ids_item, max_length)
ngradids_item = dataset[i][1]
ngradids_item = self.padding(ngradids_item, max_length)
self.ids.append(np.array(ids_item))
self.ngrad_ids.append(np.array(ngradids_item))
self.label.append(np.array([dataset[i][2]]))
def __getitem__(self, item):
return self.ids[item], self.ngrad_ids[item], self.label[item]
def __len__(self):
return len(self.ids)
def padding(self, mylist, maxlen):
"""padding"""
if len(mylist) > maxlen:
return mylist[:maxlen]
return mylist + [0] * (maxlen - len(mylist))
def build_dataloader(dataset, batch_size, max_length, shuffle=False, rank_size=1, rank_id=0, num_parallel_workers=4):
"""build data loader"""
type_cast_op = C.TypeCast(mstype.int32)
dataset.initial()
datagenerator = DataGenerator(dataset, max_length)
d_iter = ds.GeneratorDataset(datagenerator, ["ids", "ngrad_ids", "label"], shuffle=shuffle, num_shards=rank_size,
shard_id=rank_id, num_parallel_workers=num_parallel_workers)
d_iter = d_iter.map(operations=type_cast_op, input_columns="ids")
d_iter = d_iter.map(operations=type_cast_op, input_columns="ngrad_ids")
d_iter = d_iter.map(operations=type_cast_op, input_columns="label")
d_iter = d_iter.batch(batch_size, drop_remainder=True)
return d_iter
def get_time_dif(start_time):
"""get time"""
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hypertext"""
import math
import numpy as np
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits, Dropout
from mindspore.ops import Fill, Squeeze, Concat
from src.mobius_linear import MobiusLinear
from src.poincare import EinsteinMidpoint, Logmap0
class HModel(Cell):
"""hypertext model"""
def __init__(self, config):
super(HModel, self).__init__()
self.cat = Concat(axis=1)
self.loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
self.config = config
self.c_seed = 1.0
self.fill = Fill()
self.squeeze = Squeeze(1)
self.min_norm = 1e-15
num_input_fmaps = config.embed
num_output_fmaps = config.n_vocab
receptive_field_size = 1
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
std = 1.0 * math.sqrt(2.0 / float(fan_in + fan_out))
emb = Tensor(np.random.normal(0, std, (config.n_vocab, config.embed)), mstype.float32)
self.embedding = Parameter(emb, requires_grad=True)
num_input_fmaps = config.embed
num_output_fmaps = config.bucket
receptive_field_size = 1
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
std = 1.0 * math.sqrt(2.0 / float(fan_in + fan_out))
emb_wordngram = Tensor(np.random.normal(0, std, (config.bucket, config.embed)), mstype.float32)
self.embedding_wordngram = Parameter(emb_wordngram)
self.config_dropout = config.dropout
if config.dropout != 0.0:
self.dropout = Dropout(config.dropout)
self.hyperLinear = MobiusLinear(config.embed,
config.num_classes, c=self.c_seed)
self.einstein_midpoint = EinsteinMidpoint(self.min_norm)
self.logmap0 = Logmap0(self.min_norm)
def construct(self, x1, x2):
"""class construction"""
out_word = self.embedding[x1]
out_wordngram = self.embedding_wordngram[x2]
out = self.cat([out_word, out_wordngram])
if self.config_dropout != 0.0:
out = self.dropout(out)
out = self.einstein_midpoint(out, c=self.c_seed)
out = self.hyperLinear(out)
out = self.logmap0(out, self.c_seed)
return out
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hypertext train model"""
from mindspore.ops import Squeeze, Argmax, Cast
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.parameter import ParameterTuple
import mindspore.numpy as mnp
import mindspore.common.dtype as mstype
from mindspore import nn, save_checkpoint
from mindspore.train.callback import Callback
from src.hypertext import HModel
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad = C.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor], clipped gradients.
"""
if clip_type not in (0, 1):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
class HModelWithLoss(nn.Cell):
"""loss model"""
def __init__(self, config):
"""init"""
super(HModelWithLoss, self).__init__()
self.hmodel = HModel(config).to_float(mstype.float16)
self.loss_func = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
self.squeeze = Squeeze(axis=1)
def construct(self, x1, x2, label):
"""class construction"""
out = self.hmodel(x1, x2)
label = self.squeeze(label)
predict_score = self.loss_func(out, label)
return predict_score
class EvalCallBack(Callback):
"""eval"""
def __init__(self, model, eval_dataset, epoch_per_eval, save_ckpk):
"""init function"""
self.model = model
self.eval_dataset = eval_dataset
self.epoch_per_eval = epoch_per_eval
self.save_ckpk = save_ckpk
self.dev_curr = 0
def step_end(self, run_context):
"""per setp to eval"""
cb_param = run_context.original_args()
cur_step = cb_param.cur_step_num
if cur_step % (self.epoch_per_eval) == 0:
print(cur_step)
acc = self.eval_net()
print(acc)
if acc > 0.5:
if self.dev_curr < acc:
self.dev_curr = acc
save_checkpoint(self.model, self.save_ckpk)
def eval_net(self):
"""eval net"""
squ = Squeeze(-1)
argmax = Argmax(output_type=mstype.int32)
cur, total = 0, 0
print('----------start eval model-------------')
net_work = self.model
n = 0
for d in self.eval_dataset.create_dict_iterator():
if n == 200:
break
n += 1
net_work.set_train(False)
out = net_work(d['ids'], d['ngrad_ids'])
predict = argmax(out)
acc = predict == squ(d['label'])
acc = mnp.array(acc, dtype=mnp.float16)
cur += (mnp.sum(acc, -1))
total += len(acc)
return cur / total
class HModelTrainOneStepCell(nn.Cell):
"""train loss"""
def __init__(self, network, optimizer, sens=1.0):
"""init fun"""
super(HModelTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True)
self.sens = sens
self.hyper_map = C.HyperMap()
self.cast = Cast()
self.hyper_map = C.HyperMap()
self.cast = Cast()
def set_sens(self, value):
"""set sense"""
self.sens = value
def construct(self, x1, x2, label):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(x1, x2, label)
gradient_function = self.grad(self.network, weights)
grads = gradient_function(x1, x2, label)
return F.depend(loss, self.optimizer(grads))
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""math utils"""
import mindspore.common.dtype as mstype
from mindspore.nn import Cell
from mindspore.ops import Log, Sub, clip_by_value
eps = 1e-15
class Artanh(Cell):
"""artanh"""
def __init__(self):
"""init"""
super(Artanh, self).__init__()
self.log = Log()
self.sub = Sub()
def construct(self, x):
"""construct fun"""
x = clip_by_value(x, -1 + eps, 1 - eps)
out = self.log(1 + x.astype(mstype.float32))
out = 0.5 * self.sub(out, self.log(1 - x.astype(mstype.float32)))
return out
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""mobius liner"""
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
import mindspore.common.dtype as mstype
from mindspore.nn import Cell
from mindspore.ops import Zeros
from numpy import ones
from numpy.random import randn
from src.poincare import Proj, MobiusMatvec, Expmap0, MobiusAdd
class MobiusLinear(Cell):
"""Mobius linear layer."""
def __init__(self, in_features, out_features, c, use_bias=True):
"""init fun"""
super(MobiusLinear, self).__init__()
self.zeros = Zeros()
self.use_bias = use_bias
self.in_features = in_features
self.out_features = out_features
self.c = c
self.bias = Parameter(Tensor(ones([1, out_features]), mstype.float32))
self.weight = Parameter(
Tensor(randn(out_features, in_features), mstype.float32))
self.min_norm = 1e-15
self.mobius_matvec = MobiusMatvec(self.min_norm)
self.proj = Proj(self.min_norm)
self.expmap0 = Expmap0(min_norm=self.min_norm)
self.mobius_add = MobiusAdd(self.min_norm)
def construct(self, x):
"""class construction"""
mv = self.mobius_matvec(self.weight, x, self.c)
res = self.proj(mv, self.c)
if self.use_bias:
proj_tan0 = self.bias.view(1, -1)
bias = proj_tan0
hyp_bias = self.expmap0(bias, self.c)
hyp_bias = self.proj(hyp_bias, self.c)
res = self.mobius_add(res, hyp_bias, c=self.c)
res = self.proj(res, self.c)
return res
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""poincare file"""
import mindspore.numpy as mnp
from mindspore.nn import Cell, Norm
from mindspore.ops import Shape, ReduceSum, Sqrt, ExpandDims, Tanh, Transpose, matmul, Pow, Reshape, clip_by_value
import mindspore.common.dtype as mstype
from src.math_utils import Artanh
class LorentzFactors(Cell):
"""lorentz_factors class"""
def __init__(self, min_norm):
"""init"""
super(LorentzFactors, self).__init__()
self.min_norm = min_norm
self.norm = Norm(axis=-1)
def construct(self, x):
"""class construction"""
x_norm = self.norm(x)
return 1.0 / (1.0 - x_norm ** 2 + self.min_norm)
class ClampMin(Cell):
"""clamp_min class"""
def __init__(self):
"""init fun"""
super(ClampMin, self).__init__()
self.shape = Shape()
def construct(self, tensor, min1):
"""class construction"""
min_mask = (tensor <= min1)
min_mask1 = (tensor >= min1)
min_add = mnp.ones(self.shape(tensor)) * min1 * min_mask
return tensor * min_mask1 + min_add
class Proj(Cell):
"""proj class"""
def __init__(self, min_norm):
"""init fun"""
super(Proj, self).__init__()
self.clamp_min = ClampMin()
self.min_norm = min_norm
self.norm_k = Norm(axis=-1, keep_dims=True)
self.maxnorm = 1 - 4e-3
def construct(self, x, c):
"""class construction"""
norm = self.clamp_min(self.norm_k(x), self.min_norm)
maxnorm = self.maxnorm / (c ** 0.5)
cond = norm > maxnorm
projected = x / norm * maxnorm
return mnp.where(cond, projected, x)
class Clamp(Cell):
"""clamp class"""
def __init__(self):
super(Clamp, self).__init__()
self.shape = Shape()
def construct(self, tensor, min1, max1):
"""class construction"""
return clip_by_value(tensor, min1, max1)
class Logmap0(Cell):
"""logmap0 class"""
def __init__(self, min_norm):
"""init fun"""
super(Logmap0, self).__init__()
self.min_norm = min_norm
self.norm_k = Norm(axis=-1, keep_dims=True)
self.artanh = Artanh()
self.norm_k = Norm(axis=-1, keep_dims=True)
self.clamp_min = ClampMin()
def construct(self, p, c):
"""class construction"""
sqrt_c = c ** 0.5
p_norm = self.clamp_min(self.norm_k(p), self.min_norm)
scale = 1. / sqrt_c * self.artanh(sqrt_c * p_norm) / p_norm
return scale * p
class KleinToPoincare(Cell):
"""klein to poincare class"""
def __init__(self, min_norm):
"""init"""
super(KleinToPoincare, self).__init__()
self.min_norm = min_norm
self.sqrt = Sqrt()
self.sum = ReduceSum(keep_dims=True)
self.proj = Proj(self.min_norm)
def construct(self, x, c):
"""class construction"""
x_poincare = x / (1.0 + self.sqrt(1.0 - self.sum(x * x, -1)))
x_poincare = self.proj(x_poincare, c)
return x_poincare
class ToKlein(Cell):
"""to klein class"""
def __init__(self, min_norm):
"""init fun"""
super(ToKlein, self).__init__()
self.min_norm = min_norm
self.sum = ReduceSum(keep_dims=True)
self.klein_constraint = KleinConstraint(self.min_norm)
def construct(self, x, c):
"""class construction"""
x_2 = self.sum(x * x, -1)
x_klein = 2 * x / (1.0 + x_2)
x_klein = self.klein_constraint(x_klein)
return x_klein
class KleinConstraint(Cell):
"""klein constraint class"""
def __init__(self, min_norm):
"""init fun"""
super(KleinConstraint, self).__init__()
self.norm = Norm(axis=-1)
self.min_norm = min_norm
self.maxnorm = 1 - 4e-3
self.shape = Shape()
self.reshape = Reshape()
def construct(self, x):
"""class construction"""
last_dim_val = self.shape(x)[-1]
norm = self.reshape(self.norm(x), (-1, 1))
maxnorm = self.maxnorm
cond = norm > maxnorm
x_reshape = self.reshape(x, (-1, last_dim_val))
projected = x_reshape / (norm + self.min_norm) * maxnorm
x_reshape = mnp.where(cond, projected, x_reshape)
x = self.reshape(x_reshape, self.shape(x))
return x
class EinsteinMidpoint(Cell):
"""einstein mindpoint class"""
def __init__(self, min_norm):
"""init fun"""
super(EinsteinMidpoint, self).__init__()
self.to_klein = ToKlein(min_norm)
self.lorentz_factors = LorentzFactors(min_norm)
self.sum = ReduceSum(keep_dims=True)
self.unsqueeze = ExpandDims()
self.sumFalse = ReduceSum(keep_dims=False)
self.klein_constraint = KleinConstraint(min_norm)
self.klein_to_poincare = KleinToPoincare(min_norm)
def construct(self, x, c):
"""class construction"""
x = self.to_klein(x, c)
x_lorentz = self.lorentz_factors(x)
x_norm = mnp.norm(x, axis=-1)
# deal with pad value
x_lorentz = (1.0 - (x_norm == 0.0).astype(mstype.float32)) * x_lorentz
x_lorentz_sum = self.sum(x_lorentz, -1)
x_lorentz_expand = self.unsqueeze(x_lorentz, -1)
x_midpoint = self.sumFalse(x_lorentz_expand * x, 1) / x_lorentz_sum
x_midpoint = self.klein_constraint(x_midpoint)
x_p = self.klein_to_poincare(x_midpoint, c)
return x_p
class ClampTanh(Cell):
"""clamp tanh class"""
def __init__(self):
"""init fun"""
super(ClampTanh, self).__init__()
self.clamp = Clamp()
self.tanh = Tanh()
def construct(self, x, c=15):
"""class construction"""
return self.tanh(self.clamp(x, -c, c))
class MobiusMatvec(Cell):
"""mobius matvec class"""
def __init__(self, min_norm):
"""init fun"""
super(MobiusMatvec, self).__init__()
self.min_norm = min_norm
self.norm_k = Norm(axis=-1, keep_dims=True)
self.artanh = Artanh()
self.norm_k = Norm(axis=-1, keep_dims=True)
self.clamp_min = ClampMin()
self.transpose = Transpose()
self.clamp_tanh = ClampTanh()
def construct(self, m, x, c):
"""class construction"""
sqrt_c = c ** 0.5
x_norm = self.clamp_min(self.norm_k(x), self.min_norm)
mx = matmul(x, self.transpose(m, (1, 0)))
mx_norm = self.clamp_min(self.norm_k(x), self.min_norm)
t1 = self.artanh(sqrt_c * x_norm)
t2 = self.clamp_tanh(mx_norm / x_norm * t1)
res_c = t2 * mx / (mx_norm * sqrt_c)
cond = mnp.array([[0]] * len(mx))
res_0 = mnp.zeros(1)
res = mnp.where(cond, res_0, res_c)
return res
class Expmap0(Cell):
"""expmap0 class"""
def __init__(self, min_norm):
"""init fun"""
super(Expmap0, self).__init__()
self.clamp_min = ClampMin()
self.min_norm = min_norm
self.clamp_tanh = ClampTanh()
self.norm_k = Norm(axis=-1, keep_dims=True)
def construct(self, u, c):
"""constructfun"""
sqrt_c = c ** 0.5
u_norm = self.clamp_min(self.norm_k(u), self.min_norm)
gamma_1 = self.clamp_tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
return gamma_1
class MobiusAdd(Cell):
"""mobius add"""
def __init__(self, min_norm):
"""init fun"""
super(MobiusAdd, self).__init__()
self.pow = Pow()
self.sum = ReduceSum(keep_dims=True)
self.clamp_min = ClampMin()
self.min_norm = min_norm
def construct(self, x, y, c, dim=-1):
"""constructfun"""
x2 = self.sum(self.pow(x, 2), dim)
y2 = self.sum(self.pow(y, 2), dim)
xy = self.sum(x * y, dim)
num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
return num / self.clamp_min(denom, self.min_norm)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""riemannian adam"""
import numpy as np
from mindspore import Parameter
import mindspore.common.dtype as mstype
from mindspore.common import Tensor
from mindspore.nn.optim.optimizer import opt_init_args_register, Optimizer
from mindspore.ops import Sqrt, Add, Assign, Pow, Mul, ReduceSum
class RiemannianAdam(Optimizer):
"""RiemannianAdam optimizer"""
@opt_init_args_register
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, \
weight_decay=0.0):
"""init fun"""
super(RiemannianAdam, self).__init__(learning_rate=learning_rate, parameters=params, weight_decay=weight_decay)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.sum = ReduceSum(keep_dims=True)
self.sumFalse = ReduceSum(keep_dims=False)
self.sqrt = Sqrt()
self.add = Add()
self.exp_avg = self.parameters.clone(prefix='exp_avg', init='zeros')
self.exp_avg_sq = self.parameters.clone(prefix='exp_avg_sq', init='zeros')
self.step = Parameter(Tensor(0, mstype.int32), name='step')
self.assign = Assign()
self.pow = Pow()
self.mul = Mul()
def construct(self, gradients):
"""class construction"""
beta1 = self.beta1
beta2 = self.beta2
eps = self.eps
learning_rate = self.get_lr()
params = self.parameters
success = None
step = self.step
for exp_avg, exp_avg_sq, param, grad in zip(self.exp_avg, self.exp_avg_sq, params, gradients):
point = param
if grad is None:
continue
exp_avg_update = self.add(self.mul(exp_avg, beta1), (1 - beta1) * grad)
exp_avg_sq_update = self.add(self.mul(exp_avg_sq, beta2),
(1 - beta2) * (self.sum(grad * grad, -1))
)
denom = self.add(self.sqrt(exp_avg_sq_update), eps)
step += 1
bias_cor1 = 1 - self.pow(beta1, step)
bias_cor2 = 1 - self.pow(beta2, step)
step_size = learning_rate * bias_cor2 ** 0.5 / bias_cor1
direction = exp_avg_update / denom
new_point = point - step_size * direction
step += 1
self.assign(exp_avg, exp_avg_update)
self.assign(exp_avg_sq, exp_avg_sq_update)
success = self.assign(param, new_point)
self.assign(self.step, step)
return success
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train file"""
import argparse
import os
from mindspore import load_checkpoint, load_param_into_net, context, Model
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.communication import management as MultiDevice
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
from src.config import Config
from src.dataset import build_dataset, build_dataloader
from src.hypertext_train import HModelWithLoss, HModelTrainOneStepCell, EvalCallBack
from src.radam_optimizer import RiemannianAdam
parser = argparse.ArgumentParser(description='HyperText Text Classification')
parser.add_argument('--model', type=str, default='HyperText',
help='HyperText')
parser.add_argument('--modelPath', default='./output/save.ckpt', type=str, help='save model path')
parser.add_argument('--num_epochs', default=2, type=int, help='num_epochs')
parser.add_argument('--datasetdir', default='./data/iflytek_public', type=str,
help='dataset dir iflytek_public tnews_public')
parser.add_argument('--outputdir', default='./output', type=str, help='output dir')
parser.add_argument('--batch_size', default=32, type=int, help='batch_size')
parser.add_argument('--datasetType', default='iflytek', type=str, help='iflytek/tnews')
parser.add_argument('--device', default='GPU', type=str, help='device GPU Ascend')
parser.add_argument("--run_distribute", type=str, default=False, help="run_distribute")
args = parser.parse_args()
config = Config(args.datasetdir, args.outputdir, args.device)
if config.device == 'GPU':
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
elif config.device == 'Ascend':
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
config.num_epochs = int(args.num_epochs)
config.batch_size = int(args.batch_size)
config.outputdir = args.outputdir
if not os.path.exists(config.outputdir):
os.mkdir(config.outputdir)
if args.datasetType == 'tnews':
config.useTnews()
else:
config.useIflyek()
print('start process data ..........')
vocab, train_dataset, dev_dataset, test_dataset = build_dataset(config, use_word=True, min_freq=int(config.min_freq))
config.n_vocab = len(vocab)
def build_train(dataset, eval_data, lr, save_path=None, run_distribute=False):
"""build train"""
net_with_loss = HModelWithLoss(config)
net_with_loss.init_parameters_data()
if save_path is not None:
parameter_dict = load_checkpoint(save_path)
load_param_into_net(net_with_loss, parameter_dict)
if dataset is None:
raise ValueError("pre-process dataset must be provided")
optimizer = RiemannianAdam(learning_rate=lr,
params=filter(lambda x: x.requires_grad, net_with_loss.get_parameters()))
net_with_grads = HModelTrainOneStepCell(net_with_loss, optimizer=optimizer)
net_with_grads.set_train()
model = Model(net_with_grads)
print("Prepare to Training....")
epoch_size = dataset.get_repeat_count()
print("Epoch size ", epoch_size)
eval_cb = EvalCallBack(net_with_loss.hmodel, eval_data, config.eval_step,
config.outputdir + '/' + 'hypertext_' + config.datasetType + '.ckpt')
callbacks = [LossMonitor(10), eval_cb, TimeMonitor(50)]
if run_distribute:
print(f" | Rank {MultiDevice.get_rank()} Call model train.")
model.train(epoch=config.num_epochs, train_dataset=dataset, callbacks=callbacks, dataset_sink_mode=False)
def set_parallel_env():
"""set parallel env"""
context.reset_auto_parallel_context()
MultiDevice.init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=MultiDevice.get_group_size(),
gradients_mean=True)
def train_single(train_data, dev_data, lr):
"""train single"""
print("Starting training on single device.")
data_iter = build_dataloader(train_data, config.batch_size, config.max_length)
dev_iter = build_dataloader(dev_data, config.batch_size, config.max_length)
build_train(data_iter, dev_iter, lr, save_path=None, run_distribute=False)
def train_paralle(train_data, dev_data, lr):
"""train paralle"""
set_parallel_env()
print("Starting training on multiple devices.")
data_iter = build_dataloader(train_data, config.batch_size, config.max_length,
rank_size=MultiDevice.get_group_size(),
rank_id=MultiDevice.get_rank(),
shuffle=False)
dev_iter = build_dataloader(dev_data, config.batch_size, config.max_length,
rank_size=MultiDevice.get_group_size(),
rank_id=MultiDevice.get_rank(),
shuffle=False)
build_train(data_iter, dev_iter, lr, save_path=None, run_distribute=True)
def run_train(train_data, dev_data, lr, run_distribute):
"""run train"""
if config.device == "GPU":
init("nccl")
config.rank_id = get_rank()
if run_distribute:
train_paralle(train_data, dev_data, lr)
else:
train_single(train_data, dev_data, lr)
run_train(train_dataset, dev_dataset, config.learning_rate, args.run_distribute)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment