Skip to content
Snippets Groups Projects
Commit 0b9c5798 authored by nmzfrank's avatar nmzfrank
Browse files

add filip model

parent 3147cf83
No related branches found
No related tags found
No related merge requests found
Showing
with 1511 additions and 0 deletions
## Contents
[查看中文](./README_CN.md)
- [Contents](#contents)
- [Wukong Dataset](#wukong-dataset)
- [Environment requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Prepare Dataset](#prepare-dataset)
- [Prepare files required for tokenizer](#prepare-files-required-for-tokenizer)
- [Propare prompt files](#propare-prompt-files)
- [Prepare pretrained model checkpoint](#prepare-pretrained-model-checkpoint)
- [Zero-shot Classification](#zero-shot-classification)
## Wukong Dataset
This project provides the zero-shot classification task on ILSVRC dataset using multi-modality large-scale model pretrained on Noah-Wukong dataset. Model structure as follows:
|Model|Wukong_Vit|
|:----|:----|
|Embedding dimension|256|
|Input image resolution|224x224|
|Image encoder| |
|patch_size|14|
|width|1024|
|#layers|24|
|#heads|16|
|Input text token length|32|
|Text encoder| |
|#layers|12|
|#width|768|
|#heads|12|
More benchmark of the multi-modality modal please refer to [Noah-Wukong Benchmark](https://wukong-dataset.github.io/wukong-dataset/benchmark.html)
## Environment requirements
- Hardware
- Ascend processor
- Framework
- [Mindspore](https://www.mindspore.cn/ "Mindspore")
- Tutorial
- [Mindspore Tutorial](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [Mindspore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
## Quick Start
### Prepare Dataset
- Download ILSVRC dataset and organize the file as follows:
```text
.
└── data_root
├── class1
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── class2
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── class3
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── classN
├── ...
```
- Download corresponding Chinese class name file [imagenet_class_name_zh.json](https://drive.google.com/file/d/1LL0GygtD-ob19EwRuSTfm43ZuFqqy4Q_/view?usp=sharing) and place it the same folder with main.py .
### Prepare files required for tokenizer
Download following files and place them under src/tools/
- English: [bpe_simple_vocab_16e6.txt.gz](https://drive.google.com/file/d/1SCrD7wewUhxljCggEQxQr1khCfT6mGnj/view?usp=sharing)
- Chinese: [vocab_zh.txt](https://drive.google.com/file/d/1jmbTqpnef3czYWMK2QXYm_i79FpV1bxl/view?usp=sharing)
### Propare prompt files
Download prompt file[zh_templates.txt](https://drive.google.com/file/d/1Zky3V9LYRGBaAZzGEuTNAINYHLVPn8bd/view?usp=sharing)to src/tools/.This file defines the prompts used in zero-shot classification task. The number of prompts can be modified according to time/performance balance. Custom prompts are also allowed.
### Prepare pretrained model checkpoint
Download pretrained checkpoint file [
wk100m_yfcc_vit_l_14_filip_lit.pth](https://drive.google.com/file/d/19Xx9UbDeitSoy5MB-vs9LSHa5nDNu4FX/view?usp=sharing).
Use src/tools/convert.py to convert to proper format.
```shell
python convert.py [pth_path] [pkl_path]
```
### Zero-shot Classification
Run eval.py to do zero-shot classification.
```shell
python eval.py --ckpt_path [ckpt_path] --dataset_path [/path/to/data_root] --batch_size [batch size]
```
evaluation result is
```text
INFO:main:correct @1: 51.51; correct @5: 78.33
```
Detailed zero-shot classification performance is as below:
|dataset|ResNet50|ResNet101|ViT-B/32|ViT-B/16|Wukong_ViT (global similarity)|Wukong_ViT|Wukong_ViT-500M|Wukong_Swin (global similarity)|Wukong_Swin|
|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|
|CIFAR10|49|60.3|89|89.5|93.6|90.6|90.3|95.3|95.5|
|CIFAR100|23.5|31.1|57.3|49.4|64.6|66.3|65.3|69.1|77.2|
|Caltech101|72.3|76.3|83.6|84.4|86|89.9|89.2|87.6|91.6|
|Caltech256|58.4|64.3|70.5|75.4|76.8|86.2|86|78.2|88.4|
|Sports|78|83.3|90.6|88.1|86.8|97.8|96.9|93.4|99.1|
|Flowers|29|30.9|38|42.9|55.1|69.4|71.6|54.5|75.1|
|Food101|37|43.6|42.7|49.4|53.5|70|65.2|46.6|66.1|
|Pets|41.7|43.8|44.9|51.2|46.4|61.3|67|47.7|64.5|
|SUN397|33.1|38.4|39.8|42.7|44.9|60.2|58.9|42.4|56.5|
|ImageNet|28.3|32.8|33.2|38.3|44.8|54|54.3|43.6|58.5|
|ImageNet-r|38.9|47.1|52.3|61.9|67.4|72.2|77.5|49.3|55.3|
|ImageNet-a|14.8|20.8|22.2|35.4|55.1|52.2|53.2|36.8|41.9|
|ImageNet-s|16|20.6|24|29|33.3|36.5|36.8|24.7|31.4|
|DTD|22.4|25.2|27.1|31.7|35.6|46.4|44.6|31.1|39.8|
|Dogs|12.1|12.6|17.3|23.3|21.1|29.4|35.4|22.9|40.3|
|EuroSAT|17.6|12.5|35.3|43.9|39.7|25.5|32.3|28.8|21|
|Aircraft|10.1|10|10.3|14.8|20.8|22.3|21.5|8.9|10.1|
\ No newline at end of file
## 目录
[View English](./README.md)
- [目录](#目录)
- [Wukong数据集](#wukong数据集)
- [环境要求](#环境要求)
- [快速开始](#快速开始)
- [准备ILSVRC数据集](#准备ilsvrc数据集)
- [准备分词器需要的文件](#准备分词器需要的文件)
- [准备prompt文件](#准备prompt文件)
- [准备预训练模型文件](#准备预训练模型文件)
- [Zero-shot分类推理](#zero-shot分类推理)
## Wukong数据集
该项目提供了基于Noah-Wukong数据集进行预训练得到的多模态大模型,在ILSVRC数据集上进行zero-shot分类的方法。模型结构如下:
|Model|Wukong_Vit|
|:----|:----|
|Embedding dimension|256|
|Input image resolution|224x224|
|Image encoder| |
|patch_size|14|
|width|1024|
|#layers|24|
|#heads|16|
|Input text token length|32|
|Text encoder| |
|#layers|12|
|#width|768|
|#heads|12|
更多benchmark可以参考[Noah-Wukong Benchmark](https://wukong-dataset.github.io/wukong-dataset/benchmark.html)
## 环境要求
- 硬件
- 准备Ascend处理器搭建硬件环境
- 框架
- [Mindspore](https://www.mindspore.cn/ "Mindspore")
- 如需查看详情,请参考如下资源
- [Mindspore 教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [Mindspore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
## 快速开始
### 准备ILSVRC数据集
- 下载ILSVRC数据集,需要满足如下文件结构:
```text
.
└── data_root
├── class1
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── class2
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── class3
│ ├── 000000000001.jpg
│ ├── 000000000002.jpg
│ ├── ...
├── classN
├── ...
```
- 下载对应的中文类名文件 [imagenet_class_name_zh.json](https://drive.google.com/file/d/1LL0GygtD-ob19EwRuSTfm43ZuFqqy4Q_/view?usp=sharing),放在main.py同级目录下。
### 准备分词器需要的文件
下载下列文件并放在src/tools/目录下
- 英文: [bpe_simple_vocab_16e6.txt.gz](https://drive.google.com/file/d/1SCrD7wewUhxljCggEQxQr1khCfT6mGnj/view?usp=sharing)
- 中文: [vocab_zh.txt](https://drive.google.com/file/d/1jmbTqpnef3czYWMK2QXYm_i79FpV1bxl/view?usp=sharing)
### 准备prompt文件
下载prompt文件[zh_templates.txt](https://drive.google.com/file/d/1Zky3V9LYRGBaAZzGEuTNAINYHLVPn8bd/view?usp=sharing)至src/tools/目录下。文件指定了zero-shot分类时所使用的prompt形式。可以根据实际情况(运行时间、性能)调整prompt的数量,也可以根据文件中prompt格式新增自定义的prompt。
### 准备预训练模型文件
下载对应模型的预训练参数 [
wk100m_yfcc_vit_l_14_filip_lit.pth](https://drive.google.com/file/d/19Xx9UbDeitSoy5MB-vs9LSHa5nDNu4FX/view?usp=sharing)
为了加载到Mindspore模型中,运行src/tools/convert.py对模型进行格式转换
```shell
python convert.py [pth_path] [pkl_path]
```
### Zero-shot分类推理
运行下列命令进行zero-shot分类推理。
```shell
python eval.py --ckpt_path [ckpt_path] --dataset_path [/path/to/data_root] --batch_size [batch size]
```
推理结果为
```text
INFO:main:correct @1: 51.51; correct @5: 78.33
```
模型在其他数据集上推理性能如下:
|dataset|ResNet50|ResNet101|ViT-B/32|ViT-B/16|Wukong_ViT (global similarity)|Wukong_ViT|Wukong_ViT-500M|Wukong_Swin (global similarity)|Wukong_Swin|
|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|
|CIFAR10|49|60.3|89|89.5|93.6|90.6|90.3|95.3|95.5|
|CIFAR100|23.5|31.1|57.3|49.4|64.6|66.3|65.3|69.1|77.2|
|Caltech101|72.3|76.3|83.6|84.4|86|89.9|89.2|87.6|91.6|
|Caltech256|58.4|64.3|70.5|75.4|76.8|86.2|86|78.2|88.4|
|Sports|78|83.3|90.6|88.1|86.8|97.8|96.9|93.4|99.1|
|Flowers|29|30.9|38|42.9|55.1|69.4|71.6|54.5|75.1|
|Food101|37|43.6|42.7|49.4|53.5|70|65.2|46.6|66.1|
|Pets|41.7|43.8|44.9|51.2|46.4|61.3|67|47.7|64.5|
|SUN397|33.1|38.4|39.8|42.7|44.9|60.2|58.9|42.4|56.5|
|ImageNet|28.3|32.8|33.2|38.3|44.8|54|54.3|43.6|58.5|
|ImageNet-r|38.9|47.1|52.3|61.9|67.4|72.2|77.5|49.3|55.3|
|ImageNet-a|14.8|20.8|22.2|35.4|55.1|52.2|53.2|36.8|41.9|
|ImageNet-s|16|20.6|24|29|33.3|36.5|36.8|24.7|31.4|
|DTD|22.4|25.2|27.1|31.7|35.6|46.4|44.6|31.1|39.8|
|Dogs|12.1|12.6|17.3|23.3|21.1|29.4|35.4|22.9|40.3|
|EuroSAT|17.6|12.5|35.3|43.9|39.7|25.5|32.3|28.8|21|
|Aircraft|10.1|10|10.3|14.8|20.8|22.3|21.5|8.9|10.1|
\ No newline at end of file
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import json
import logging
import argparse
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from src.model import VisualTransformer, BERT_Wukong, TemplateEncoder, FilipEval
from src.tools import generate_zh_template, load_model
from src.dataset import get_dataset
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('main')
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def main():
parser = argparse.ArgumentParser(description='evaluation for wukong dataset')
parser.add_argument('--ckpt_path', help="checkpoint file path for torch model", required=True)
parser.add_argument('--dataset_path', help="ILSVRC dataset path root", required=True)
parser.add_argument('--batch_size', help="evaluate dataset batch size", type=int, default=4)
args = parser.parse_args()
ckpt_path = args.ckpt_path
dataset_path = args.dataset_path
text_encoder = BERT_Wukong(
context_length=32,
vocab_size=21128,
width=768,
heads=12,
layers=12,
output_dim=256
)
visual_encoder = VisualTransformer(
input_resolution=224,
layers=24,
width=1024,
patch_size=14,
output_dim=256
)
load_model(ckpt_path, visual_encoder, text_encoder)
val_dataset = get_dataset(dataset_path, args.batch_size)
dataset_size = val_dataset.get_dataset_size()
logger.info("start generating template feature")
class_name_file = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'imagenet_class_name_zh.json'
)
mapping = json.load(open(class_name_file, 'r'))
sort_keys = sorted(list(mapping.keys()))
dataset_labels = [mapping[key] for key in sort_keys]
template_tokens = generate_zh_template(dataset_labels)
template_tokens = Tensor(template_tokens)
template_encoder = TemplateEncoder(text_encoder)
template_encoder.set_train(False)
template_feature, n_template = template_encoder(template_tokens)
logger.info("template feature generated successfully")
logger.info("==========================")
filip_eval = FilipEval(template_feature, n_template, visual_encoder, text_encoder)
filip_eval.set_train(False)
filip_eval = filip_eval.to_float(mstype.float16)
correct_1 = []
correct_5 = []
logger.info('total iter: %d', dataset_size)
for i, data in enumerate(val_dataset):
logger.info('processing %d/%d', i, dataset_size)
output = filip_eval(*data)
acc1, acc5 = output[0].asnumpy(), output[1].asnumpy()
correct_1.append(acc1)
correct_5.append(acc5)
correct_1 = np.hstack(correct_1)
correct_1 = correct_1.mean()
correct_5 = np.hstack(correct_5)
correct_5 = correct_5.mean()
logger.info("correct @1: {:.2f}; correct @5: {:.2f}".format(correct_1 * 100, correct_5 * 100))
if __name__ == '__main__':
main()
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from .dataset import get_dataset
__all__ = ['get_dataset']
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import dtype as mstype
import mindspore.dataset as ds
from mindspore.dataset.vision import Inter
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
def get_dataset(dataset_path, batch_size):
norm_mean = (0.48145466, 0.4578275, 0.40821073)
norm_std = (0.26862954, 0.26130258, 0.27577711)
norm_mean_2 = tuple(map(lambda x: x * 255, norm_mean))
norm_std_2 = tuple(map(lambda x: x * 255, norm_std))
val_dataset = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4)
val_dataset = val_dataset.map(
[C.Decode(),
C.Normalize(mean=norm_mean_2, std=norm_std_2),
C.Resize(224, Inter.BICUBIC),
C.CenterCrop(224),
C.HWC2CHW(),
C2.TypeCast(mstype.float32)],
input_columns=["image"], output_columns=None, column_order=["image", "label"])
val_dataset = val_dataset.batch(batch_size)
return val_dataset
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from .visual_encoder import VisualTransformer
from .text_encoder import BERT_Wukong
from .matrics import TemplateEncoder, FilipEval
__all__ = ['VisualTransformer', 'BERT_Wukong', 'TemplateEncoder', 'FilipEval']
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import dtype as mstype
class TemplateEncoder(nn.Cell):
def __init__(self, text_encoder):
super(TemplateEncoder, self).__init__()
self.text_encoder = text_encoder
self.text_norm = nn.Norm(axis=-1, keep_dims=True)
self.concat = ops.Concat()
self.expand_dims = ops.ExpandDims()
def construct(self, text_tokens):
n_class, n_templates, token_len = text_tokens.shape
text_tokens = text_tokens.reshape((n_class * n_templates, token_len))
res = []
batch_num = n_class * n_templates // 10
for i in range(10):
text_tokens_part = text_tokens[batch_num * i: batch_num * (i + 1), :]
text_tokens_features_part = self.text_encoder(text_tokens_part)
text_tokens_features_part = text_tokens_features_part / self.text_norm(text_tokens_features_part)
text_pad_mask = text_tokens_part > 0
text_pad_mask = self.expand_dims(text_pad_mask, -1)
text_tokens_features_part = text_tokens_features_part * text_pad_mask
res.append(text_tokens_features_part)
text_features = self.concat(res)
if n_templates > 1:
text_features = text_features.reshape(n_class, n_templates, token_len, -1)
return text_features, n_templates
class LateSimilarity(nn.Cell):
def __init__(self, chunk_size=200):
super(LateSimilarity, self).__init__()
self.chunk_size = chunk_size
self.matmul = ops.MatMul(transpose_b=True)
self.concat = ops.Concat(1)
def construct(self, rep1, rep2):
batch_size1, n_token1, feat_dim = rep1.shape
_, n_token2, _ = rep2.shape
out = self.matmul(rep1.reshape(-1, feat_dim), rep2.reshape(-1, feat_dim))
out = out.reshape(batch_size1, n_token1, -1, n_token2).max(3)
out = out.mean(1)
return out
class FilipEval(nn.Cell):
def __init__(self, text_features, n_template, image_encoder, text_encoder):
super(FilipEval, self).__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.text_features = text_features
self.n_template = n_template
self.image_norm = nn.Norm(axis=-1, keep_dims=True)
self.text_norm = nn.Norm(axis=-1, keep_dims=True)
self.topk = ops.TopK(sorted=True)
self.equal = ops.Equal()
self.cast = ops.Cast()
self.concat = ops.Concat()
self.softmax = ops.Softmax()
self.expand_dims = ops.ExpandDims()
self.sim_func = LateSimilarity()
def construct(self, images, targets):
# text_tokens: #class x #templates x token_length
image_features = self.image_encoder(images)
total = image_features.shape[0]
image_features = image_features[:, 1:, :]
image_features = image_features / self.image_norm(image_features)
if self.n_template > 1:
all_sim = []
for i in range(self.text_features.shape[1]):
text_feat = self.text_features[:, i, :, :]
sim_one = self.softmax(self.sim_func(image_features, text_feat))
sim_one = self.expand_dims(sim_one, 0)
all_sim.append(sim_one)
similarity = self.concat(all_sim)
similarity = similarity.mean(0)
else:
similarity = self.sim_func(image_features, self.text_features)
pred = self.topk(similarity, 5)[1].transpose()
correct = self.equal(pred, targets.view(1, -1).expand_as(pred))
correct = self.cast(correct, mstype.float32)
return correct[:1].sum(0), correct[:5].sum(0), self.cast(total, mstype.float32)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.ops import operations as P
from mindspore import Parameter, Tensor
from mindspore.common.initializer import TruncatedNormal, initializer
class BERTMultiheadAttention(nn.Cell):
def __init__(self, d_model, n_head):
"""
:param d_model: width of tensor/embedding dim
:param n_head: output of mutlithead attention/num_heads
"""
super(BERTMultiheadAttention, self).__init__()
self.embed_dim = d_model
self.num_heads = n_head
self.head_dim = self.embed_dim // self.num_heads
self.in_proj = nn.Dense(self.embed_dim, 3 * self.embed_dim)
self.out_proj = nn.Dense(self.embed_dim, self.embed_dim)
self.split = ops.Split(-1, 3)
self.expand_dims = P.ExpandDims()
self.softmax = nn.Softmax(-1)
self.transpose = ops.Transpose()
self.scaling = self.head_dim ** -0.5
def construct(self, query, key, value, attn_mask):
tgt_len, bsz, embed_dim = query.shape
qkv = self.in_proj(query).view(tgt_len, bsz, 3, embed_dim).transpose((2, 0, 1, 3))
q = qkv[0:1]
k = qkv[1:2]
v = qkv[2:3]
q = ops.Squeeze(0)(q)
k = ops.Squeeze(0)(k)
v = ops.Squeeze(0)(v)
q = q * self.scaling
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose((1, 0, 2)) # (bs) x (HW + 1) x h
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose((1, 0, 2)) # (bs) x (HW + 1) x h
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose((1, 0, 2)) # (bs) x (HW + 1) x h
attn_output_weights = ops.matmul(q, k.transpose((0, 2, 1))) # bs x (HW + 1) x (HW + 1)
attn_output_weights += self.expand_dims(attn_mask, 0)
attn_output_weights = self.softmax(attn_output_weights) # bs x (HW + 1) x (HW + 1)
attn_output = ops.matmul(attn_output_weights, v) # bs x (HW + 1) x h
attn_output = self.transpose(attn_output, (1, 0, 2))
attn_output = attn_output.view(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class QuickGELU(nn.Cell):
def __init__(self):
super(QuickGELU, self).__init__()
self.ratio = 1.702
self.sigmoid = nn.Sigmoid()
def construct(self, x):
return x * self.sigmoid(self.ratio * x)
class BERTAttentionWithMask(nn.Cell):
def __init__(self, d_model, n_head, attn_mask):
super(BERTAttentionWithMask, self).__init__()
self.attn = BERTMultiheadAttention(d_model, n_head)
self.attn_mask = attn_mask
def construct(self, x):
return self.attn(x, x, x, self.attn_mask)
class BERTResidualAttentionBlock(nn.Cell):
def __init__(self, d_model, n_head, attn_mask):
super(BERTResidualAttentionBlock, self).__init__()
self.attn = BERTAttentionWithMask(d_model, n_head, attn_mask)
self.ln_1 = nn.LayerNorm([d_model])
self.c_fc = nn.Dense(d_model, d_model * 4)
self.gelu = QuickGELU()
self.c_proj = nn.Dense(d_model * 4, d_model)
self.mlp = nn.SequentialCell([
self.c_fc,
self.gelu,
self.c_proj
])
self.ln_2 = nn.LayerNorm([d_model])
def construct(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class BERTTransformer(nn.Cell):
def __init__(self, width, layers, heads, attn_mask):
super(BERTTransformer, self).__init__()
self.width = width
self.layers = layers
self.resblocks = nn.SequentialCell(
*[BERTResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
)
def construct(self, x):
return self.resblocks(x)
class BERT_Wukong(nn.Cell):
def __init__(self, context_length, vocab_size, output_dim, width, layers, heads):
super(BERT_Wukong, self).__init__()
self.width = width
self.layers = layers
self.vocab_size = vocab_size
self.embedding_table = Parameter(initializer(TruncatedNormal(0.02), [vocab_size, width]))
self.gather = ops.Gather()
self.reshape = ops.Reshape()
self.positional_embedding = Parameter(initializer(TruncatedNormal(0.01), [context_length, width]))
self.ln_final = nn.LayerNorm([self.width])
self.text_projection = Parameter(
Tensor(np.random.normal(0, self.width ** -0.5, size=(self.width, output_dim)).astype(np.float32)))
self.transformer_layer = BERTTransformer(width, layers, heads, self.build_attntion_mask(context_length))
@staticmethod
def build_attntion_mask(context_length):
mask = np.triu(np.full((context_length, context_length), -np.inf).astype(np.float32), 1)
mask = Tensor(mask)
return mask
def construct(self, text):
bsz, ctx_len = text.shape
flatten_id = text.flatten()
gather_result = self.gather(self.embedding_table, flatten_id, 0)
x = self.reshape(gather_result, (bsz, ctx_len, -1))
x = x + self.positional_embedding
x = x.transpose(1, 0, 2)
x = self.transformer_layer(x)
x = x.transpose(1, 0, 2)
x = self.ln_final(x)
x = ops.matmul(x, self.text_projection)
return x
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.ops import operations as P
from mindspore import Parameter, Tensor
class VITMultiheadAttention(nn.Cell):
def __init__(self, d_model, n_head):
"""
:param d_model: width of tensor/embedding dim
:param n_head: output of mutlithead attention/num_heads
"""
super(VITMultiheadAttention, self).__init__()
self.embed_dim = d_model
self.num_heads = n_head
self.head_dim = self.embed_dim // self.num_heads
self.in_proj = nn.Dense(self.embed_dim, 3 * self.embed_dim)
self.out_proj = nn.Dense(self.embed_dim, self.embed_dim)
self.split = ops.Split(-1, 3)
self.expand_dims = P.ExpandDims()
self.softmax = nn.Softmax(-1)
self.transpose = ops.Transpose()
self.scaling = self.head_dim ** -0.5
def construct(self, query, key, value):
tgt_len, bsz, embed_dim = query.shape
qkv = self.in_proj(query).view(tgt_len, bsz, 3, embed_dim).transpose((2, 0, 1, 3))
q = qkv[0:1]
k = qkv[1:2]
v = qkv[2:3]
q = ops.Squeeze(0)(q)
k = ops.Squeeze(0)(k)
v = ops.Squeeze(0)(v)
q = q * self.scaling
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose((1, 0, 2)) # (bs) x (HW + 1) x h
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose((1, 0, 2)) # (bs) x (HW + 1) x h
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose((1, 0, 2)) # (bs) x (HW + 1) x h
attn_output_weights = ops.matmul(q, k.transpose((0, 2, 1))) # bs x (HW + 1) x (HW + 1)
attn_output_weights = self.softmax(attn_output_weights) # bs x (HW + 1) x (HW + 1)
attn_output = ops.matmul(attn_output_weights, v) # bs x (HW + 1) x h
attn_output = self.transpose(attn_output, (1, 0, 2))
attn_output = attn_output.view(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class QuickGELU(nn.Cell):
def __init__(self):
super(QuickGELU, self).__init__()
self.ratio = 1.702
self.sigmoid = nn.Sigmoid()
def construct(self, x):
return x * self.sigmoid(self.ratio * x)
class VITAttentionWithMask(nn.Cell):
def __init__(self, d_model, n_head):
super(VITAttentionWithMask, self).__init__()
self.attn = VITMultiheadAttention(d_model, n_head)
def construct(self, x):
return self.attn(x, x, x)
class VITResidualAttentionBlock(nn.Cell):
def __init__(self, d_model, n_head):
super(VITResidualAttentionBlock, self).__init__()
self.attn = VITAttentionWithMask(d_model, n_head)
self.ln_1 = nn.LayerNorm([d_model])
self.c_fc = nn.Dense(d_model, d_model * 4)
self.gelu = QuickGELU()
self.c_proj = nn.Dense(d_model * 4, d_model)
self.mlp = nn.SequentialCell([
self.c_fc,
self.gelu,
self.c_proj
])
self.ln_2 = nn.LayerNorm([d_model])
def construct(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class VITTransformer(nn.Cell):
def __init__(self, width, layers, heads):
super(VITTransformer, self).__init__()
self.width = width
self.layers = layers
self.resblocks = nn.SequentialCell(
*[VITResidualAttentionBlock(width, heads) for _ in range(layers)]
)
def construct(self, x):
return self.resblocks(x)
class VisualTransformer(nn.Cell):
def __init__(self, input_resolution, patch_size, width, layers, output_dim, heads=None):
super(VisualTransformer, self).__init__()
if heads is None:
heads = width // 64
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(3, width, patch_size, patch_size)
scale = width ** -0.5
self.class_embedding = Parameter(scale * Tensor(np.random.normal(0, 1, size=(width)).astype(np.float32)))
self.positional_embedding = Parameter(
scale * Tensor(np.random.normal(
size=((input_resolution // patch_size) ** 2 + 1, width)).astype(np.float32)))
self.ln_pre = nn.LayerNorm([width])
self.transformer = VITTransformer(width, layers, heads)
self.ln_post = nn.LayerNorm([width])
self.proj = Parameter(scale * Tensor(np.random.normal(0, 1, size=(width, output_dim)).astype(np.float32)))
self.cat = ops.Concat(1)
self.tile = ops.Tile()
def construct(self, x):
x = self.conv1(x)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.transpose(0, 2, 1)
class_embedding = self.tile(self.class_embedding, (x.shape[0], 1, 1))
x = self.cat((class_embedding, x))
x = x + self.positional_embedding
x = self.ln_pre(x)
x = x.transpose(1, 0, 2)
x = self.transformer(x)
x = x.transpose(1, 0, 2)
x = self.ln_post(x)
x = ops.matmul(x, self.proj)
return x
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from .template_generate import generate_zh_template
from .model_utils import load_model
__all__ = ['generate_zh_template', 'load_model']
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pickle
import torch
def convert(pth_path, pkl_path):
param_dict = torch.load(pth_path)
data = dict()
for k, v in param_dict.items():
data[k] = v.cpu().detach().numpy()
with open(pkl_path, 'wb') as pkl_writer:
pickle.dump(data, pkl_writer)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pickle
import logging
from mindspore import Tensor
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('model utils')
def update_param(net_param_dict, params, ms_full_name, torch_full_name):
old_param = net_param_dict[ms_full_name]
new_param = Tensor(params[torch_full_name], old_param.data.dtype)
old_param.set_data(new_param)
def load_visual_encoder(net, param_dict):
for mindspore_full_name, torch_full_name in [
('class_embedding', 'visual.class_embedding'),
('positional_embedding', 'visual.positional_embedding'),
('proj', 'visual.proj'),
('conv1.weight', 'visual.conv1.weight'),
('ln_pre.gamma', 'visual.ln_pre.weight'),
('ln_pre.beta', 'visual.ln_pre.bias'),
('ln_post.gamma', 'visual.ln_post.weight'),
('ln_post.beta', 'visual.ln_post.bias')
]:
update_param(net, param_dict, mindspore_full_name, torch_full_name)
for i in range(24):
mindspore_prefix = 'transformer.resblocks.'
torch_prefix = 'visual.transformer.resblocks.'
for mindspore_name, torch_name in [
('attn.attn.in_proj.weight', 'attn.in_proj_weight'),
('attn.attn.in_proj.bias', 'attn.in_proj_bias'),
('attn.attn.out_proj.weight', 'attn.out_proj.weight'),
('attn.attn.out_proj.bias', 'attn.out_proj.bias'),
('ln_1.gamma', 'ln_1.weight'),
('ln_1.beta', 'ln_1.bias'),
('ln_2.gamma', 'ln_2.weight'),
('ln_2.beta', 'ln_2.bias'),
('c_fc.weight', 'mlp.c_fc.weight'),
('c_fc.bias', 'mlp.c_fc.bias'),
('c_proj.weight', 'mlp.c_proj.weight'),
('c_proj.bias', 'mlp.c_proj.bias')
]:
mindspore_full_name = '{}{}.{}'.format(mindspore_prefix, i, mindspore_name)
torch_full_name = '{}{}.{}'.format(torch_prefix, i, torch_name)
update_param(net, param_dict, mindspore_full_name, torch_full_name)
def load_text_encoder(net, param_dict):
for mindspore_full_name, torch_full_name in [
('embedding_table', 'transformer.token_embedding.weight'),
('positional_embedding', 'transformer.positional_embedding'),
('text_projection', 'transformer.text_projection'),
('ln_final.gamma', 'transformer.ln_final.weight'),
('ln_final.beta', 'transformer.ln_final.bias')
]:
update_param(net, param_dict, mindspore_full_name, torch_full_name)
mindspore_prefix = 'transformer_layer.resblocks.'
torch_prefix = 'transformer.resblocks.'
for i in range(12):
for mindspore_name, torch_name in [
('attn.attn.in_proj.weight', 'attn.in_proj_weight'),
('attn.attn.in_proj.bias', 'attn.in_proj_bias'),
('attn.attn.out_proj.weight', 'attn.out_proj.weight'),
('attn.attn.out_proj.bias', 'attn.out_proj.bias'),
('ln_1.gamma', 'ln_1.weight'),
('ln_1.beta', 'ln_1.bias'),
('c_fc.weight', 'mlp.c_fc.weight'),
('c_fc.bias', 'mlp.c_fc.bias'),
('c_proj.weight', 'mlp.c_proj.weight'),
('c_proj.bias', 'mlp.c_proj.bias'),
('ln_2.gamma', 'ln_2.weight'),
('ln_2.beta', 'ln_2.bias')
]:
mindspore_full_name = '{}{}.{}'.format(mindspore_prefix, i, mindspore_name)
torch_full_name = '{}{}.{}'.format(torch_prefix, i, torch_name)
update_param(net, param_dict, mindspore_full_name, torch_full_name)
def load_model(ckpt_path, visual_encoder, text_encoder):
with open(ckpt_path, 'rb') as ckpt_fp:
param_dict = pickle.load(ckpt_fp)
visual_encoder_param = visual_encoder.parameters_dict()
text_encoder_param = text_encoder.parameters_dict()
load_visual_encoder(visual_encoder_param, param_dict)
load_text_encoder(text_encoder_param, param_dict)
logger.info("model loaded")
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import gzip
import html
from functools import lru_cache
from pathlib import Path
from typing import Union, List
import ftfy
import numpy as np
import regex as re
from .utils import is_control, is_whitespace, is_chinese_char, \
is_punctuation, strip_accents
SOT_TEXT = "<|startoftext|>"
EOT_TEXT = "<|endoftext|>"
CONTEXT_LEN = 77
vocab_path_en = "bpe_simple_vocab_16e6.txt.gz"
vocab_path_zh = "vocab_zh.txt"
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
class BpeTokenizer:
def __init__(self, bpe_path):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges = merges[1: 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
for merge in merges:
vocab.append("".join(merge))
vocab.extend([SOT_TEXT, EOT_TEXT])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
SOT_TEXT: SOT_TEXT,
EOT_TEXT: EOT_TEXT,
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except IndexError:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = (
bytearray([self.byte_decoder[c] for c in text])
.decode("utf-8", errors="replace")
.replace("</w>", " ")
)
return text
class WordpieceTokenizer:
def __init__(self, vocab_path):
with open(vocab_path) as vocab_file:
vocab = [line.strip() for line in vocab_file]
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.max_input_chars_per_word = 100
self.tokenize_chinese_chars = True
self.unk_token = "[UNK]"
self.never_split = [self.unk_token, SOT_TEXT, EOT_TEXT]
@staticmethod
def __whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
def __split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if self.never_split and text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
@staticmethod
def __clean_text(text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or is_control(char):
continue
if is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
@staticmethod
def __tokenize_chinese_chars(text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def __wordpiece_tokenize(self, text):
output_tokens = []
for token in self.__whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.encoder:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def __basic_tokenize(self, text):
# union() returns a new set by concatenating the two sets.
text = self.__clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
if self.tokenize_chinese_chars:
text = self.__tokenize_chinese_chars(text)
orig_tokens = self.__whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if token not in self.never_split:
token = token.lower()
token = strip_accents(token)
split_tokens.extend(self.__split_on_punc(token))
output_tokens = self.__whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def text_tokenize(self, text):
split_tokens = []
for token in self.__basic_tokenize(text):
if token in self.never_split:
split_tokens.append(token)
else:
split_tokens += self.__wordpiece_tokenize(token)
return split_tokens
def encode(self, text):
tokens = self.text_tokenize(text)
return [self.encoder.get(token, self.unk_token) for token in tokens]
def decode(self, tokens):
segments = [self.decoder.get(token, self.unk_token) for token in tokens]
text = ""
for segment in segments:
if segment in self.never_split:
text += segment
else:
text += segment.lstrip("##")
return text
# default tokenizer for 'en'
_tokenizer = BpeTokenizer(Path(__file__).with_name(vocab_path_en).as_posix())
def set_tokenizer_lang(lang="en", context_length=77):
global _tokenizer, SOT_TEXT, EOT_TEXT, CONTEXT_LEN
CONTEXT_LEN = context_length
if lang == "en":
vocab_en = Path(__file__).with_name(vocab_path_en).as_posix()
_tokenizer = BpeTokenizer(vocab_en)
elif lang == "zh":
vocab_zh = Path(__file__).with_name(vocab_path_zh).as_posix()
SOT_TEXT = "[CLS]"
EOT_TEXT = "[SEP]"
_tokenizer = WordpieceTokenizer(vocab_zh)
else:
raise RuntimeError("Tokenizer for language \"{}\" is not supported."
.format(lang))
@lru_cache()
def get_sot_token():
return _tokenizer.encoder[SOT_TEXT]
@lru_cache()
def get_eot_token():
return _tokenizer.encoder[EOT_TEXT]
def tokenize(texts: Union[str, List[str]]) -> np.ndarray:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, CONTEXT_LEN]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder[SOT_TEXT]
eot_token = _tokenizer.encoder[EOT_TEXT]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = np.zeros((len(all_tokens), CONTEXT_LEN), dtype=np.int64)
for i, tokens in enumerate(all_tokens):
if len(tokens) > CONTEXT_LEN:
tokens = tokens[:CONTEXT_LEN - 1] + [eot_token]
result[i, : len(tokens)] = np.array(tokens)
return result
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
from .simple_tokenizer import set_tokenizer_lang, tokenize
def generate_zh_template(label_list):
set_tokenizer_lang('zh', 32)
template_list = []
template_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'zh_templates.txt'
)
templates = []
for line in open(template_path, 'r'):
templates.append(line.strip())
num_prompts = len(templates)
num_labels = len(label_list)
for label in label_list:
for template in templates:
template_list.append(template.replace('{}', label))
token = tokenize(template_list).reshape((num_labels, num_prompts, -1))
return token
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import unicodedata
def abs_root_dir(cfg, data_root=None):
def get_abs_path(data_dir, data_root):
if os.path.isabs(data_dir):
return data_dir
return os.path.join(data_root, data_dir)
if isinstance(cfg, dict):
for key, value in cfg.items():
if key == 'root_dir':
cfg[key] = get_abs_path(value, data_root)
break
abs_root_dir(value, data_root=data_root)
elif isinstance(cfg, list):
for item in cfg:
abs_root_dir(item, data_root=data_root)
else:
return
def is_control(char):
"""Checks whether `char` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char in ['\t', '\n', '\r']:
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def is_whitespace(char):
"""Checks whether `char` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if char in [' ', '\t', '\n', '\r']:
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(0x4E00 <= cp <= 0x9FFF)
or (0x3400 <= cp <= 0x4DBF) #
or (0x20000 <= cp <= 0x2A6DF) #
or (0x2A700 <= cp <= 0x2B73F) #
or (0x2B740 <= cp <= 0x2B81F) #
or (0x2B820 <= cp <= 0x2CEAF) #
or (0xF900 <= cp <= 0xFAFF)
or (0x2F800 <= cp <= 0x2FA1F) #
):
return True
return False
def is_punctuation(char):
"""Checks whether `char` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (33 <= cp <= 47) or (58 <= cp <= 64) \
or (91 <= cp <= 96) or (123 <= cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def strip_accents(text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment