Skip to content
Snippets Groups Projects
Commit c977d561 authored by lihaoyang's avatar lihaoyang
Browse files

Remove TB-Net code. TB-Net is moved to MindSpore XAI.

parent 06d88c07
No related branches found
No related tags found
No related merge requests found
Showing
with 10 additions and 2203 deletions
# Contents
# TB-Net Description
- [Contents](#contents)
- [TBNet Description](#tbnet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Inference Process](#inference-process)
- [Export MindIR](#export-mindir)
- [Infer on Ascend310](#infer-on-ascend310)
- [Result](#result)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference and Explanation Performance](#inference-explanation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
TB-Net is a knowledge graph based explainable recommender system. The tutorial and code are released in MindSpore XAI repository.
# [TBNet Description](#contents)
Click [here](https://www.mindspore.cn/xai/docs/en/master/using_tbnet.html) to check the tutorial of using whitebox ecommendation model TB-Net.
TB-Net is a knowledge graph based explainable recommender system.
Click [here](https://gitee.com/mindspore/xai/tree/master/models/whitebox/tbnet) to check TB-Net code.
Paper: Shendi Wang, Haoyang Li, Xiao-Hui Li, Caleb Chen Cao, Lei Chen. Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems
# [Model Architecture](#contents)
TB-Net constructs subgraphs in knowledge graph based on the interaction between users and items as well as the feature of items, and then calculates paths in the graphs using bidirectional conduction algorithm. Finally we can obtain explainable recommendation results.
# [Dataset](#contents)
[Interaction of users and games](https://www.kaggle.com/tamber/steam-video-games), and the [games' feature data](https://www.kaggle.com/nikdavis/steam-store-games?select=steam.csv) on the game platform Steam are public on Kaggle.
Dataset directory: `./data/{DATASET}/`, e.g. `./data/steam/`.
- train: train.csv, evaluation: test.csv
Each line indicates a \<user\>, an \<item\>, the user-item \<rating\> (1 or 0), and PER_ITEM_NUM_PATHS paths between the item and the user's \<hist_item\> (\<hist_item\> is the item whose the user-item \<rating\> in historical data is 1).
```text
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
```
- infer and explain: infer.csv
Each line indicates the \<user\> and \<item\> to be inferred, \<rating\>, and PER_ITEM_NUM_PATHS paths between the item and the user's \<hist_item\> (\<hist_item\> is the item whose the user-item \<rating\> in historical data is 1).
Note that the \<item\> needs to traverse candidate items (all items by default) in the dataset. \<rating\> can be randomly assigned (all values are assigned to 0 by default) and is not used in the inference and explanation phases.
```text
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
```
We have to download the data package and put it underneath the current project path。
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
```
# [Environment Requirements](#contents)
- Hardware(NVIDIA GPU or Ascend NPU)
- Prepare hardware environment with NVIDIA GPU or Ascend NPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- Data preprocessing
Download the data package(e.g. 'steam' dataset) and put it underneath the current project path.
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
```
and then run code as follows.
- Training
```bash
bash scripts/run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
Example:
```bash
bash scripts/run_standalone_train.sh steam 0 Ascend
```
- Evaluation
Evaluation model on test dataset.
```bash
bash scripts/run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
Argument `[CHECKPOINT_ID]` is required.
Example:
```bash
bash scripts/run_eval.sh 19 steam 0 Ascend
```
- Inference and Explanation
Recommende items to user acrodding to `user`, the number of items is determined by `items`.
```bash
python infer.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID] \
--user [USER] \
--items [ITEMS] \
--explanations [EXPLANATIONS] \
--csv [CSV] \
--device_target [DEVICE_TARGET]
```
Arguments `--checkpoint_id` and `--user` are required.
Example:
```bash
python infer.py \
--dataset steam \
--checkpoint_id 19 \
--user 2 \
--items 1 \
--explanations 3 \
--csv test.csv \
--device_target Ascend
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```text
.
└─tbnet
├─README.md
├── scripts
├─run_infer_310.sh # Ascend310 inference script
├─run_standalone_train.sh # NVIDIA GPU or Ascend NPU training script
└─run_eval.sh # NVIDIA GPU or Ascend NPU evaluation script
├─data
├─steam
├─config.json # data and training parameter configuration
├─src_infer.csv # inference and explanation dataset
├─src_test.csv # evaluation dataset
├─src_train.csv # training dataset
└─id_maps.json # explanation configuration
├─src
├─utils
├─__init__.py # init file
├─device_adapter.py # Get cloud ID
├─local_adapter.py # Get local ID
├─moxing_adapter.py # Parameter processing
└─param.py # parse args
├─aggregator.py # inference result aggregation
├─config.py # parsing parameter configuration
├─dataset.py # generate dataset
├─embedding.py # 3-dim embedding matrix initialization
├─metrics.py # model metrics
├─steam.py # 'steam' dataset text explainer
└─tbnet.py # TB-Net model
├─export.py # export mindir script
├─preprocess_dataset.py # dataset preprocess script
├─preprocess.py # inference data preprocess script
├─postprocess.py # inference result calculation script
├─eval.py # evaluation
├─infer.py # inference and explanation
└─train.py # training
```
## [Script Parameters](#contents)
The entire code structure is as following:
```python
data_path: "." # The location of input data
load_path: "./checkpoint" # file path of stored checkpoint file in training
checkpoint_id: 19 # checkpoint id
same_relation: False # only generate paths that relation1 is same as relation2
dataset: "steam" # dataset name
train_csv: "train.csv" # the train csv datafile inside the dataset folder
test_csv: "test.csv" # the test csv datafile inside the dataset folder
infer_csv: "infer.csv" # the infer csv datafile inside the dataset folder
device_id: 0 # Device id
device_target: "GPU" # device id of GPU or Ascend
run_mode: "graph" # run code by GRAPH mode or PYNATIVE mode
epochs: 20 # number of training epochs
```
- preprocess_dataset.py parameters
```text
--dataset 'steam' dataset is supported currently
--device_target run code on GPU or Ascend NPU
--same_relation only generate paths that relation1 is same as relation2
```
- train.py parameters
```text
--dataset 'steam' dataset is supported currently
--train_csv the train csv datafile inside the dataset folder
--test_csv the test csv datafile inside the dataset folder
--device_id device id
--epochs number of training epochs
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
- eval.py parameters
```text
--dataset 'steam' dataset is supported currently
--csv the csv datafile inside the dataset folder (e.g. test.csv)
--checkpoint_id use which checkpoint(.ckpt) file to eval
--device_id device id
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
- infer.py parameters
```text
--dataset 'steam' dataset is supported currently
--csv the csv datafile inside the dataset folder (e.g. infer.csv)
--checkpoint_id use which checkpoint(.ckpt) file to infer
--user id of the user to be recommended to
--items no. of items to be recommended
--reasons no. of recommendation reasons to be shown
--device_id device id
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
## [Inference Process](#contents)
### [Export MindIR](#contents)
```shell
python export.py --config_path [CONFIG_PATH] --checkpoint_path [CKPT_PATH] --device_target [DEVICE] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
```
- `CKPT_PATH` parameter is required.
- `CONFIG_PATH` is `config.json` file, data and training parameter configuration.
- `DEVICE` should be in ['Ascend', 'GPU'].
- `FILE_FORMAT` should be in ['MINDIR', 'AIR'].
Example:
```bash
python export.py \
--config_path ./data/steam/config.json \
--checkpoint_path ./checkpoints/tbnet_epoch19.ckpt \
--device_target Ascend \
--file_name model \
--file_format MINDIR
```
### [Infer on Ascend310](#contents)
Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model.
```shell
# Ascend310 inference
cd scripts
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
- `MINDIR_PATH` specifies path of used "MINDIR" model.
- `DATA_PATH` specifies path of test.csv.
- `DEVICE_ID` is optional, default value is 0.
Example:
```bash
cd scripts
bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
```
### [Result](#contents)
Inference result is saved in current path, you can find result like this in acc.log file.
```bash
auc: 0.8251359368836292
```
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | GPU | Ascend NPU |
| -------------------------- |--------------------------------------------------------------------------------------------| ---------------------------------------------|
| Model Version | TB-Net | TB-Net |
| Resource | NVIDIA RTX 3090 | Ascend 910 |
| Uploaded Date | 2022-07-14 | 2022-06-30 |
| MindSpore Version | 1.6.1 | 1.6.1 |
| Dataset | steam | steam |
| Training Parameter | epoch=20, batch_size=1024, lr=0.001 | epoch=20, batch_size=1024, lr=0.001 |
| Optimizer | Adam | Adam |
| Loss Function | Sigmoid Cross Entropy | Sigmoid Cross Entropy |
| Outputs | AUC=0.8573,Accuracy=0.7733 | AUC=0.8592,准确率=0.7741 |
| Loss | 0.57 | 0.59 |
| Speed | 1pc: 90ms/step | 单卡:80毫秒/步 |
| Total Time | 1pc: 297s | 单卡:336秒 |
| Checkpoint for Fine Tuning | 686.3K (.ckpt file) | 671K (.ckpt 文件) |
| Scripts | [TB-Net scripts](https://gitee.com/mindspore/models/tree/master/official/recommend/tbnet) |
### Evaluation Performance
| Parameters | GPU | Ascend NPU |
| ------------------------- |----------------------------| ----------------------------- |
| Model Version | TB-Net | TB-Net |
| Resource | NVIDIA RTX 3090 | Ascend 910 |
| Uploaded Date | 2022-07-14 | 2022-06-30 |
| MindSpore Version | 1.3.0 | 1.5.1 |
| Dataset | steam | steam |
| Batch Size | 1024 | 1024 |
| Outputs | AUC=0.8487,Accuracy=0.7699 | AUC=0.8486,Accuracy=0.7704 |
| Total Time | 1pc: 5.7s | 1pc: 1.1秒 |
### Inference and Explanation Performance
| Parameters | GPU |
| --------------------------| ------------------------------------- |
| Model Version | TB-Net |
| Resource | Tesla V100-SXM2-32GB |
| Uploaded Date | 2021-08-01 |
| MindSpore Version | 1.3.0 |
| Dataset | steam |
| Outputs | Recommendation Result and Explanation |
| Total Time | 1pc: 3.66s |
# [Description of Random Situation](#contents)
- Initialization of embedding matrix in `tbnet.py` and `embedding.py`.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/models).
\ No newline at end of file
Click [here](https://ieeexplore.ieee.org/document/9835387) to check TB-Net paper (ICDE 2022): *Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems*.
\ No newline at end of file
# 目录
# TB-Net概述
<!-- TOC -->
TB-Net是一个基于知识图谱的可解释推荐系统。使用文档和代码已移至MindSpore XAI仓。
- [目录](#目录)
- [TBNet概述](#tbnet概述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本和样例代码](#脚本和样例代码)
- [脚本参数](#脚本参数)
- [推理过程](#推理过程)
- [导出MindIR](#导出mindir)
- [在Ascend310执行推理](#在ascend310执行推理)
- [结果](#结果)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [评估性能](#评估性能)
- [推理和解释性能](#推理和解释性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
点击[此处](https://www.mindspore.cn/xai/docs/zh-CN/master/using_tbnet.html)查看如何使用MindSpore XAI中的TB-Net白盒推荐模型。
# [TBNet概述](#目录)
点击[此处](https://gitee.com/mindspore/xai/tree/master/models/whitebox/tbnet)查看TB-Net源代码。
TB-Net是一个基于知识图谱的可解释推荐系统。
论文:Shendi Wang, Haoyang Li, Xiao-Hui Li, Caleb Chen Cao, Lei Chen. Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems
# [模型架构](#目录)
TB-Net将用户和物品的交互信息以及物品的属性信息在知识图谱中构建子图,并利用双向传导的计算方法对图谱中的路径进行计算,最后得到可解释的推荐结果。
# [数据集](#目录)
本示例提供Kaggle上的Steam游戏平台公开数据集,包含[用户与游戏的交互记录](https://www.kaggle.com/tamber/steam-video-games)[游戏的属性信息](https://www.kaggle.com/nikdavis/steam-store-games?select=steam.csv)
数据集路径:`./data/{DATASET}/`,如:`./data/steam/`
- 训练:train.csv,评估:test.csv
每一行记录代表某\<user\>对某\<item\>\<rating\>(1或0),以及该\<item\>\<hist_item\>(即该\<user\>历史\<rating\>为1的\<item\>)的PER_ITEM_NUM_PATHS条路径。
```text
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
```
- 推理和解释:infer.csv
每一行记录代表**待推理**\<user\>\<item\>\<rating\>,以及该\<item\>\<hist_item\>(即该\<user\>历史\<rating\>为1的\<item\>)的PER_ITEM_NUM_PATHS条路径。
其中\<item\>需要遍历数据集中**所有**待推荐物品(默认所有物品);\<rating\>可随机赋值(默认全部赋值为0),在推理和解释阶段不会使用。
```text
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
```
# [环境要求](#目录)
- 硬件(NVIDIA GPU or Ascend NPU)
- 使用NVIDIA GPU处理器或者Ascend NPU处理器准备硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/index.html)
# [快速入门](#目录)
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练、评估、推理和解释:
- 数据准备
下载用例数据集包(以'steam'数据集为例),解压到当前项目路径。
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
```
然后按照以下步骤运行代码。
- 训练
```bash
bash scripts/run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
示例:
```bash
bash scripts/run_standalone_train.sh steam 0 Ascend
```
- 评估
评估模型在测试集上的指标。
```bash
bash scripts/run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
参数`[CHECKPOINT_ID]`是必填项。
示例:
```bash
bash scripts/run_eval.sh 19 steam 0 Ascend
```
- 推理和解释
根据`user`推荐一定数量的物品,数量由`items`决定。
```bash
python infer.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID] \
--user [USER] \
--items [ITEMS] \
--explanations [EXPLANATIONS] \
--csv [CSV] \
--device_target [DEVICE_TARGET]
```
参数`--checkpoint_id``--user`是必填项。
示例:
```bash
python infer.py \
--dataset steam \
--checkpoint_id 19 \
--user 2 \
--items 1 \
--explanations 3 \
--csv test.csv \
--device_target Ascend
```
# [脚本说明](#目录)
## [脚本和样例代码](#目录)
```text
.
└─tbnet
├─README.md
├── scripts
├─run_infer_310.sh # 用于Ascend310推理的脚本
├─run_standalone_train.sh # 用于NVIDIA GPU或者Ascend NPU训练的脚本
└─run_eval.sh # 用于NVIDIA GPU或者Ascend NPU评估的脚本
├─data
├─steam
├─config.json # 数据和训练参数配置
├─src_infer.csv # 推理和解释数据集
├─src_test.csv # 测试数据集
├─src_train.csv # 训练数据集
└─id_maps.json # 输出解释相关配置
├─src
├─utils
├─__init__.py # 初始化文件
├─device_adapter.py # 获得云设备id
├─local_adapter.py # 获得本地id
├─moxing_adapter.py # 参数处理
└─param.py # 解析参数
├─aggregator.py # 推理结果聚合
├─config.py # 参数配置解析
├─dataset.py # 创建数据集
├─embedding.py # 三维embedding矩阵初始化
├─metrics.py # 模型度量
├─steam.py # 'steam'数据集文本解析
└─tbnet.py # TB-Net网络
├─export.py # 导出MINDIR脚本
├─preprocess_dataset.py # 数据集预处理脚本
├─preprocess.py # 推理数据预处理脚本
├─postprocess.py # 推理结果计算脚本
├─default_config.yaml # yaml配置文件
├─eval.py # 评估网络
├─infer.py # 推理和解释
└─train.py # 训练网络
```
## [脚本参数](#目录)
train.py与param.py主要参数如下:
```python
data_path: "." # 数据集路径
load_path: "./checkpoint" # 检查点保存路径
checkpoint_id: 19 # 检查点id
same_relation: False # 预处理数据集时,只生成`relation1`与`relation2`相同的路径
dataset: "steam" # 数据集名陈
train_csv: "train.csv" # 数据集中训练集文件名
test_csv: "test.csv" # 数据集中测试集文件名
infer_csv: "infer.csv" # 数据集中推理数据文件名
device_id: 0 # 设备id
device_target: "GPU" # 运行平台
run_mode: "graph" # 运行模式
epochs: 20 # 训练轮数
```
- preprocess_dataset.py参数
```text
--dataset 'steam' dataset is supported currently
--device_target run code on GPU or Ascend NPU
--same_relation only generate paths that relation1 is same as relation2
```
- train.py参数
```text
--dataset 'steam' dataset is supported currently
--train_csv the train csv datafile inside the dataset folder
--test_csv the test csv datafile inside the dataset folder
--device_id device id
--epochs number of training epochs
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
- eval.py参数
```text
--dataset 'steam' dataset is supported currently
--csv the csv datafile inside the dataset folder (e.g. test.csv)
--checkpoint_id use which checkpoint(.ckpt) file to eval
--device_id device id
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
- infer.py参数
```text
--dataset 'steam' dataset is supported currently
--csv the csv datafile inside the dataset folder (e.g. infer.csv)
--checkpoint_id use which checkpoint(.ckpt) file to infer
--user id of the user to be recommended to
--items no. of items to be recommended
--reasons no. of recommendation reasons to be shown
--device_id device id
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
## 推理过程
### 导出MindIR
```shell
python export.py \
--config_path [CONFIG_PATH] \
--checkpoint_path [CKPT_PATH] \
--device_target [DEVICE] \
--file_name [FILE_NAME] \
--file_format [FILE_FORMAT]
```
- `CKPT_PATH` 为必填项。
- `CONFIG_PATH` 即数据集的`config.json`文件, 包含数据和训练参数配置。
- `DEVICE` 可选项为 ['Ascend', 'GPU']。
- `FILE_FORMAT` 可选项为 ['MINDIR', 'AIR']。
示例:
```bash
python export.py \
--config_path ./data/steam/config.json \
--checkpoint_path ./checkpoints/tbnet_epoch19.ckpt \
--device_target Ascend \
--file_name model \
--file_format MINDIR
```
### 在Ascend310执行推理
在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用minir模型执行推理的示例。
```shell
# Ascend310 inference
cd scripts
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
- `MINDIR_PATH` mindir文件路径
- `DATA_PATH` 推理数据集test.csv路径
- `DEVICE_ID` 可选,默认值为0。
示例:
```bash
cd scripts
bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
```
### 结果
推理结果保存在脚本执行的当前路径,你可以在acc.log中看到以下精度计算结果。
```bash
auc: 0.8251359368836292
```
# [模型描述](#目录)
## [性能](#目录)
### [训练性能](#目录)
| 参数 | GPU | Ascend NPU |
| ------------------- |-------------------------------------------------------------------------------------|-------------------------------------|
| 模型版本 | TB-Net | TB-Net |
| 资源 | NVIDIA RTX 3090 | Ascend 910 |
| 上传日期 | 2022-07-14 | 2022-06-30 |
| MindSpore版本 | 1.6.1 | 1.6.1 |
| 数据集 | steam | steam |
| 训练参数 | epoch=20, batch_size=1024, lr=0.001 | epoch=20, batch_size=1024, lr=0.001 |
| 优化器 | Adam | Adam |
| 损失函数 | Sigmoid交叉熵 | Sigmoid交叉熵 |
| 输出 | AUC=0.8573,准确率=0.7733 | AUC=0.8592,准确率=0.7741 |
| 损失 | 0.57 | 0.59 |
| 速度 | 单卡:90毫秒/步 | 单卡:80毫秒/步 |
| 总时长 | 单卡:297秒 | 单卡:336秒 |
| 微调检查点 | 686.3K (.ckpt 文件) | 671K (.ckpt 文件) |
| 脚本 | [TB-Net脚本](https://gitee.com/mindspore/models/tree/master/official/recommend/tbnet) |
### [评估性能](#目录)
| 参数 | GPU | Ascend NPU |
| -------------------------- |-----------------------| ----------------------------- |
| 模型版本 | TB-Net | TB-Net |
| 资源 | NVIDIA RTX 3090 | Ascend 910 |
| 上传日期 | 2022-07-14 | 2022-06-30 |
| MindSpore版本 | 1.6.1 | 1.6.1 |
| 数据集 | steam | steam |
| 批次大小 | 1024 | 1024 |
| 输出 | AUC=0.8487,准确率=0.7699 | AUC=0.8486,准确率=0.7704 |
| 总时长 | 单卡:5.7秒 | 单卡:1.1秒 |
### [推理和解释性能](#目录)
| 参数 | GPU |
| -------------------------- | ----------------------------- |
| 模型版本 | TB-Net |
| 资源 | Tesla V100-SXM2-32GB |
| 上传日期 | 2021-08-01 |
| MindSpore版本 | 1.3.0 |
| 数据集 | steam |
| 输出 | 推荐结果和解释结果 |
| 总时长 | 单卡:3.66秒 |
# [随机情况说明](#目录)
- `tbnet.py``embedding.py`中Embedding矩阵的随机初始化。
# [ModelZoo主页](#目录)
请浏览官网[主页](https://gitee.com/mindspore/models)
\ No newline at end of file
点击[此处](https://ieeexplore.ieee.org/document/9835387)查看TB-Net论文(ICDE 2022):*Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems*
\ No newline at end of file
cmake_minimum_required(VERSION 3.14.1)
project(Ascend310Infer)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined")
set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
option(MINDSPORE_PATH "mindspore install path" "")
include_directories(${MINDSPORE_PATH})
include_directories(${MINDSPORE_PATH}/include)
include_directories(${PROJECT_SRC_ROOT})
find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib)
file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*)
add_executable(main src/main.cc src/utils.cc)
target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags)
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ -d out ]; then
rm -rf out
fi
mkdir out
cd out || exit
if [ -f "Makefile" ]; then
make clean
fi
cmake .. \
-DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`"
make
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_UTILS_H_
#define MINDSPORE_INFERENCE_UTILS_H_
#include <sys/stat.h>
#include <dirent.h>
#include <vector>
#include <string>
#include <memory>
#include "include/api/types.h"
std::vector<std::string> GetAllFiles(std::string_view dirName);
DIR *OpenDir(std::string_view dirName);
std::string RealPath(std::string_view path);
mindspore::MSTensor ReadFileToTensor(const std::string &file);
int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs);
#endif
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/time.h>
#include <gflags/gflags.h>
#include <dirent.h>
#include <iostream>
#include <string>
#include <algorithm>
#include <iosfwd>
#include <vector>
#include <fstream>
#include <sstream>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/api/serialization.h"
#include "include/dataset/execute.h"
#include "include/dataset/vision.h"
#include "inc/utils.h"
using mindspore::Context;
using mindspore::Serialization;
using mindspore::Model;
using mindspore::Status;
using mindspore::MSTensor;
using mindspore::dataset::Execute;
using mindspore::ModelType;
using mindspore::GraphCell;
using mindspore::kSuccess;
DEFINE_string(mindir_path, "", "mindir path");
DEFINE_string(input0_path, ".", "input0 path");
DEFINE_string(input1_path, ".", "input1 path");
DEFINE_string(input2_path, ".", "input2 path");
DEFINE_string(input3_path, ".", "input3 path");
DEFINE_string(input4_path, ".", "input4 path");
DEFINE_string(input5_path, ".", "input5 path");
DEFINE_int32(device_id, 0, "device id");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (RealPath(FLAGS_mindir_path).empty()) {
std::cout << "Invalid mindir" << std::endl;
return 1;
}
auto context = std::make_shared<Context>();
auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>();
ascend310->SetDeviceID(FLAGS_device_id);
ascend310->SetPrecisionMode("allow_fp32_to_fp16");
ascend310->SetOpSelectImplMode("high_precision");
context->MutableDeviceInfo().push_back(ascend310);
mindspore::Graph graph;
Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph);
Model model;
Status ret = model.Build(GraphCell(graph), context);
if (ret != kSuccess) {
std::cout << "ERROR: Build failed." << std::endl;
return 1;
}
std::vector<MSTensor> model_inputs = model.GetInputs();
if (model_inputs.empty()) {
std::cout << "Invalid model, inputs is empty." << std::endl;
return 1;
}
auto input0_files = GetAllFiles(FLAGS_input0_path);
auto input1_files = GetAllFiles(FLAGS_input1_path);
auto input2_files = GetAllFiles(FLAGS_input2_path);
auto input3_files = GetAllFiles(FLAGS_input3_path);
auto input4_files = GetAllFiles(FLAGS_input4_path);
auto input5_files = GetAllFiles(FLAGS_input5_path);
if (input0_files.empty() || input1_files.empty()) {
std::cout << "ERROR: input data empty." << std::endl;
return 1;
}
std::map<double, double> costTime_map;
size_t size = input0_files.size();
for (size_t i = 0; i < size; ++i) {
struct timeval start = {0};
struct timeval end = {0};
double startTimeMs;
double endTimeMs;
std::vector<MSTensor> inputs;
std::vector<MSTensor> outputs;
std::cout << "Start predict input files:" << input0_files[i] << std::endl;
auto input0 = ReadFileToTensor(input0_files[i]);
auto input1 = ReadFileToTensor(input1_files[i]);
auto input2 = ReadFileToTensor(input2_files[i]);
auto input3 = ReadFileToTensor(input3_files[i]);
auto input4 = ReadFileToTensor(input4_files[i]);
auto input5 = ReadFileToTensor(input5_files[i]);
inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(),
input0.Data().get(), input0.DataSize());
inputs.emplace_back(model_inputs[1].Name(), model_inputs[1].DataType(), model_inputs[1].Shape(),
input1.Data().get(), input1.DataSize());
inputs.emplace_back(model_inputs[2].Name(), model_inputs[2].DataType(), model_inputs[2].Shape(),
input2.Data().get(), input2.DataSize());
inputs.emplace_back(model_inputs[3].Name(), model_inputs[3].DataType(), model_inputs[3].Shape(),
input3.Data().get(), input3.DataSize());
inputs.emplace_back(model_inputs[4].Name(), model_inputs[4].DataType(), model_inputs[4].Shape(),
input4.Data().get(), input4.DataSize());
inputs.emplace_back(model_inputs[5].Name(), model_inputs[5].DataType(), model_inputs[5].Shape(),
input5.Data().get(), input5.DataSize());
gettimeofday(&start, nullptr);
ret = model.Predict(inputs, &outputs);
gettimeofday(&end, nullptr);
if (ret != kSuccess) {
std::cout << "Predict " << input0_files[i] << " failed." << std::endl;
return 1;
}
startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000;
endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000;
costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs));
WriteResult(input0_files[i], outputs);
}
double average = 0.0;
int inferCount = 0;
for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) {
double diff = 0.0;
diff = iter->second - iter->first;
average += diff;
inferCount++;
}
average = average / inferCount;
std::stringstream timeCost;
timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl;
std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl;
std::string fileName = "./time_Result" + std::string("/test_perform_static.txt");
std::ofstream fileStream(fileName.c_str(), std::ios::trunc);
fileStream << timeCost.str();
fileStream.close();
costTime_map.clear();
return 0;
}
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <algorithm>
#include <iostream>
#include "inc/utils.h"
using mindspore::MSTensor;
using mindspore::DataType;
std::vector<std::string> GetAllFiles(std::string_view dirName) {
struct dirent *filename;
DIR *dir = OpenDir(dirName);
if (dir == nullptr) {
return {};
}
std::vector<std::string> res;
while ((filename = readdir(dir)) != nullptr) {
std::string dName = std::string(filename->d_name);
if (dName == "." || dName == ".." || filename->d_type != DT_REG) {
continue;
}
res.emplace_back(std::string(dirName) + "/" + filename->d_name);
}
std::sort(res.begin(), res.end());
for (auto &f : res) {
std::cout << "image file: " << f << std::endl;
}
return res;
}
int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) {
std::string homePath = "./result_Files";
for (size_t i = 0; i < outputs.size(); ++i) {
size_t outputSize;
std::shared_ptr<const void> netOutput;
netOutput = outputs[i].Data();
outputSize = outputs[i].DataSize();
int pos = imageFile.rfind('/');
std::string fileName(imageFile, pos + 1);
fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin");
std::string outFileName = homePath + "/" + fileName;
FILE * outputFile = fopen(outFileName.c_str(), "wb");
fwrite(netOutput.get(), outputSize, sizeof(char), outputFile);
fclose(outputFile);
outputFile = nullptr;
}
return 0;
}
mindspore::MSTensor ReadFileToTensor(const std::string &file) {
if (file.empty()) {
std::cout << "Pointer file is nullptr" << std::endl;
return mindspore::MSTensor();
}
std::ifstream ifs(file);
if (!ifs.good()) {
std::cout << "File: " << file << " is not exist" << std::endl;
return mindspore::MSTensor();
}
if (!ifs.is_open()) {
std::cout << "File: " << file << "open failed" << std::endl;
return mindspore::MSTensor();
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size);
ifs.close();
return buffer;
}
DIR *OpenDir(std::string_view dirName) {
if (dirName.empty()) {
std::cout << " dirName is null ! " << std::endl;
return nullptr;
}
std::string realPath = RealPath(dirName);
struct stat s;
lstat(realPath.c_str(), &s);
if (!S_ISDIR(s.st_mode)) {
std::cout << "dirName is not a valid directory !" << std::endl;
return nullptr;
}
DIR *dir;
dir = opendir(realPath.c_str());
if (dir == nullptr) {
std::cout << "Can not open dir " << dirName << std::endl;
return nullptr;
}
std::cout << "Successfully opened the dir " << dirName << std::endl;
return dir;
}
std::string RealPath(std::string_view path) {
char realPathMem[PATH_MAX] = {0};
char *realPathRet = nullptr;
realPathRet = realpath(path.data(), realPathMem);
if (realPathRet == nullptr) {
std::cout << "File: " << path << " is not exist.";
return "";
}
std::string realPath(realPathMem);
std::cout << path << " realpath is: " << realPath << std::endl;
return realPath;
}
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
enable_modelarts: False
# url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# url for openi
ckpt_url: ""
result_url: ""
# path for local
data_path: "."
output_path: "./output"
load_path: "./checkpoint"
# preprocess_data
same_relation: False
#train
dataset: "steam"
train_csv: "train.csv"
test_csv: "test.csv"
infer_csv: "infer.csv"
device_id: 0
epochs: 20
device_target: "GPU"
run_mode: "graph"
#eval
checkpoint_id: 19
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TB-Net evaluation."""
import os
import math
from mindspore import context, Model, load_checkpoint, load_param_into_net
import mindspore.common.dtype as mstype
from src import tbnet, config, metrics, dataset
from src.utils.param import param
from src.utils.moxing_adapter import moxing_wrapper
from preprocess_dataset import preprocess_data
@moxing_wrapper(preprocess_data)
def eval_tbnet():
"""Evaluation process."""
config_path = os.path.join(param.data_path, 'data', param.dataset, 'config.json')
test_csv_path = os.path.join(param.data_path, 'data', param.dataset, param.test_csv)
ckpt_path = param.load_path
context.set_context(device_id=param.device_id)
if param.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=param.device_target)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=param.device_target)
print(f"creating dataset from {test_csv_path}...")
net_config = config.TBNetConfig(config_path)
if param.device_target == 'Ascend':
net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
eval_ds = dataset.create(test_csv_path, net_config.per_item_paths, train=True).batch(net_config.batch_size)
print(f"creating TBNet from checkpoint {param.checkpoint_id} for evaluation...")
network = tbnet.TBNet(net_config)
if param.device_target == 'Ascend':
network.to_float(mstype.float16)
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{param.checkpoint_id}.ckpt'))
load_param_into_net(network, param_dict)
loss_net = tbnet.NetWithLossClass(network, net_config)
train_net = tbnet.TrainStepWrap(loss_net, net_config.lr)
train_net.set_train()
eval_net = tbnet.PredictWithSigmoid(network)
model = Model(network=train_net, eval_network=eval_net, metrics={'auc': metrics.AUC(), 'acc': metrics.ACC()})
print("evaluating...")
e_out = model.eval(eval_ds, dataset_sink_mode=False)
print(f'Test AUC:{e_out["auc"]} ACC:{e_out["acc"]}')
if param.enable_modelarts:
with open(os.path.join(param.output_path, 'result.txt'), 'w') as f:
f.write(f'Test AUC:{e_out["auc"]} ACC:{e_out["acc"]}')
if __name__ == '__main__':
eval_tbnet()
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export."""
import os
import argparse
import math
import numpy as np
from mindspore import context, load_checkpoint, load_param_into_net, Tensor, export
from src import tbnet, config
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Export.')
parser.add_argument(
'--config_path',
type=str,
required=True,
default='',
help="json file for dataset"
)
parser.add_argument(
'--checkpoint_path',
type=str,
required=True,
help="use which checkpoint(.ckpt) file to export"
)
parser.add_argument(
'--device_id',
type=int,
required=False,
default=0,
help="device id"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='Ascend',
choices=['Ascend', 'GPU'],
help="run code on platform"
)
parser.add_argument(
'--run_mode',
type=str,
required=False,
default='graph',
choices=['graph', 'pynative'],
help="run code by GRAPH mode or PYNATIVE mode"
)
parser.add_argument(
'--file_name',
type=str,
default='tbnet',
help="model name."
)
parser.add_argument(
'--file_format',
type=str,
default='MINDIR',
choices=['MINDIR', 'AIR'],
help="model format."
)
return parser.parse_args()
def export_tbnet():
"""Data preprocess for inference."""
args = get_args()
config_path = args.config_path
ckpt_path = args.checkpoint_path
if not os.path.exists(config_path):
raise ValueError("please check the config path.")
if not os.path.exists(ckpt_path):
raise ValueError("please check the checkpoint path.")
context.set_context(device_id=args.device_id)
if args.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
net_config = config.TBNetConfig(config_path)
if args.device_target == 'Ascend':
net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
network = tbnet.TBNet(net_config)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(network, param_dict)
eval_net = tbnet.PredictWithSigmoid(network)
item = Tensor(np.ones((1,)).astype(np.int))
rl1 = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
ety = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
rl2 = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
his = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
rate = Tensor(np.ones((1,)).astype(np.float32))
inputs = [item, rl1, ety, rl2, his, rate]
export(eval_net, *inputs, file_name=args.file_name, file_format=args.file_format)
if __name__ == '__main__':
export_tbnet()
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TBNet inference."""
import os
import argparse
import math
from mindspore import load_checkpoint, load_param_into_net, context
import mindspore.common.dtype as mstype
from src.config import TBNetConfig
from src.tbnet import TBNet
from src.aggregator import InferenceAggregator
from src import dataset
from src import steam
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Infer TBNet.')
parser.add_argument(
'--dataset',
type=str,
required=False,
default='steam',
help="'steam' dataset is supported currently"
)
parser.add_argument(
'--csv',
type=str,
required=False,
default='infer.csv',
help="the csv datafile inside the dataset folder (e.g. infer.csv)"
)
parser.add_argument(
'--checkpoint_id',
type=int,
required=True,
help="use which checkpoint(.ckpt) file to infer"
)
parser.add_argument(
'--user',
type=int,
required=True,
help="id of the user to be recommended to"
)
parser.add_argument(
'--items',
type=int,
required=False,
default=1,
help="no. of items to be recommended"
)
parser.add_argument(
'--explanations',
type=int,
required=False,
default=3,
help="no. of recommendation explanations to be shown"
)
parser.add_argument(
'--device_id',
type=int,
required=False,
default=0,
help="device id"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='GPU',
choices=['GPU', 'Ascend'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
'--run_mode',
type=str,
required=False,
default='graph',
choices=['graph', 'pynative'],
help="run code by GRAPH mode or PYNATIVE mode"
)
return parser.parse_args()
def infer_tbnet():
"""Inference process."""
args = get_args()
context.set_context(device_id=args.device_id)
if args.run_mode == 'graph':
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
else:
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
home = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(home, 'data', args.dataset, 'config.json')
translate_path = os.path.join(home, 'data', args.dataset, 'translate.json')
data_path = os.path.join(home, 'data', args.dataset, args.csv)
ckpt_path = os.path.join(home, 'checkpoints')
print(f"creating TBNet from checkpoint {args.checkpoint_id}...")
config = TBNetConfig(config_path)
if args.device_target == 'Ascend':
config.per_item_paths = math.ceil(config.per_item_paths / 16) * 16
config.embedding_dim = math.ceil(config.embedding_dim / 16) * 16
network = TBNet(config)
if args.device_target == 'Ascend':
network.to_float(mstype.float16)
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
load_param_into_net(network, param_dict)
print(f"creating dataset from {data_path}...")
infer_ds = dataset.create(data_path, config.per_item_paths, train=False, users=args.user)
infer_ds = infer_ds.batch(config.batch_size)
print("inferring...")
# infer and aggregate results
aggregator = InferenceAggregator(top_k=args.items)
for user, item, relation1, entity, relation2, hist_item, rating in infer_ds:
del rating
result = network(item, relation1, entity, relation2, hist_item)
item_score = result[0]
path_importance = result[1]
aggregator.aggregate(user, item, relation1, entity, relation2, hist_item, item_score, path_importance)
# show recommendations with explanations
explainer = steam.TextExplainer(translate_path)
recomms = aggregator.recommend()
for user, recomm in recomms.items():
for item_rec in recomm.item_records:
item_name = explainer.translate_item(item_rec.item)
print(f"Recommend <{item_name}> to user:{user}, because:")
# show explanations
explanation = 0
for path in item_rec.paths:
print(" - " + explainer.explain(path))
explanation += 1
if explanation >= args.explanations:
break
print("")
if __name__ == '__main__':
infer_tbnet()
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
model_path=$1
output_model_name=$2
atc --model=$model_path \
--framework=1 \
--output=$output_model_name \
--input_format=NCHW \
--soc_version=Ascend310 \
--output_type=FP32
\ No newline at end of file
{
"tbnet": {
"stream_config": {
"deviceId": "0"
},
"appsrc0": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0:0"
},
"appsrc1": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0:1"
},
"appsrc2": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0:2"
},
"appsrc3": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0:3"
},
"appsrc4": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0:4"
},
"appsrc5": {
"props": {
"blocksize": "409600"
},
"factory": "appsrc",
"next": "mxpi_tensorinfer0:5"
},
"mxpi_tensorinfer0": {
"props": {
"dataSource": "appsrc0,appsrc1,appsrc2,appsrc3,appsrc4,appsrc5",
"modelPath": "../data/model/tbnet.om"
},
"factory": "mxpi_tensorinfer",
"next": "mxpi_dataserialize0"
},
"mxpi_dataserialize0": {
"props": {
"outputDataKeys": "mxpi_tensorinfer0"
},
"factory": "mxpi_dataserialize",
"next": "appsink0"
},
"appsink0": {
"props": {
"blocksize": "4096000"
},
"factory": "appsink"
}
}
}
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
docker_image=$1
model_dir=$2
function show_help() {
echo "Usage: docker_start.sh docker_image model_dir data_dir"
}
function param_check() {
if [ -z "${docker_image}" ]; then
echo "please input docker_image"
show_help
exit 1
fi
if [ -z "${model_dir}" ]; then
echo "please input model_dir"
show_help
exit 1
fi
}
param_check
docker run -it -u root \
--device=/dev/davinci0 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v ${model_dir}:${model_dir} \
${docker_image} \
/bin/bash
cmake_minimum_required(VERSION 3.5.2)
project(Tbnet)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
set(TARGET_MAIN Tbnet)
set(ACL_LIB_PATH $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories($ENV{MX_SDK_HOME}/include)
include_directories($ENV{MX_SDK_HOME}/opensource/include)
include_directories($ENV{MX_SDK_HOME}/opensource/include/opencv4)
include_directories($ENV{MX_SDK_HOME}/opensource/include/gstreamer-1.0)
include_directories($ENV{MX_SDK_HOME}/opensource/include/glib-2.0)
include_directories($ENV{MX_SDK_HOME}/opensource/lib/glib-2.0/include)
link_directories($ENV{MX_SDK_HOME}/lib)
link_directories($ENV{MX_SDK_HOME}/opensource/lib/)
add_compile_options(-std=c++11 -fPIC -fstack-protector-all -pie -Wno-deprecated-declarations)
add_compile_options("-DPLUGIN_NAME=${PLUGIN_NAME}")
add_compile_options("-Dgoogle=mindxsdk_private")
add_definitions(-DENABLE_DVPP_INTERFACE)
include_directories(${ACL_LIB_PATH}/include)
link_directories(${ACL_LIB_PATH}/lib64/)
add_executable(${TARGET_MAIN} src/main.cpp src/Tbnet.cpp)
target_link_libraries(${TARGET_MAIN} ${TARGET_LIBRARY} glog cpprest mxbase libascendcl.so)
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export ASCEND_VERSION=ascend-toolkit/latest
export ARCH_PATTERN=.
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib/modelpostprocessors:${LD_LIBRARY_PATH}
mkdir -p build
cd build || exit
function make_plugin() {
if ! cmake ..;
then
echo "cmake failed."
return 1
fi
if ! (make);
then
echo "make failed."
return 1
fi
return 0
}
if make_plugin;
then
echo "INFO: Build successfully."
else
echo "ERROR: Build failed."
fi
cd - || exit
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "Tbnet.h"
#include <cstdlib>
#include <memory>
#include <string>
#include <cmath>
#include <vector>
#include <algorithm>
#include <queue>
#include <utility>
#include <fstream>
#include <map>
#include <iostream>
#include "acl/acl.h"
#include "MxBase/DeviceManager/DeviceManager.h"
#include "MxBase/Log/Log.h"
namespace {
const std::vector<std::vector<uint32_t>> SHAPE = {{1}, {1, 39}, {1, 39},
{1, 39}, {1, 39}, {1}};
const int FLOAT_SIZE = 4;
const int INT_SIZE = 8;
const int DATA_SIZE_1 = 1;
const int DATA_SIZE_39 = 39;
}
void WriteResult(const std::string &file_name, const std::vector<MxBase::TensorBase> &outputs) {
std::string homePath = "./result";
for (size_t i = 0; i < outputs.size(); ++i) {
float *boxes = reinterpret_cast<float *>(outputs[i].GetBuffer());
std::string outFileName = homePath + "/tbnet_item_bs1_" + file_name + "_" +
std::to_string(i) + ".txt";
std::ofstream outfile(outFileName, std::ios::app);
size_t outputSize;
outputSize = outputs[i].GetSize();
for (size_t j = 0; j < outputSize; ++j) {
if (j != 0) {
outfile << ",";
}
outfile << boxes[j];
}
outfile.close();
}
}
APP_ERROR Tbnet::Init(const InitParam &initParam) {
deviceId_ = initParam.deviceId;
APP_ERROR ret = MxBase::DeviceManager::GetInstance()->InitDevices();
if (ret != APP_ERR_OK) {
LogError << "Init devices failed, ret=" << ret << ".";
return ret;
}
ret = MxBase::TensorContext::GetInstance()->SetContext(initParam.deviceId);
if (ret != APP_ERR_OK) {
LogError << "Set context failed, ret=" << ret << ".";
return ret;
}
model_Tbnet = std::make_shared<MxBase::ModelInferenceProcessor>();
ret = model_Tbnet->Init(initParam.modelPath, modelDesc_);
if (ret != APP_ERR_OK) {
LogError << "ModelInferenceProcessor init failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Tbnet::DeInit() {
model_Tbnet->DeInit();
MxBase::DeviceManager::GetInstance()->DestroyDevices();
return APP_ERR_OK;
}
APP_ERROR Tbnet::ReadBin_float(const std::string &path, std::vector<std::vector<float>> &dataset,
const int datasize) {
std::ifstream inFile(path, std::ios::binary);
float *data = new float[datasize];
inFile.read(reinterpret_cast<char *>(data), datasize * sizeof(data[0]));
std::vector<float> temp(data, data + datasize);
dataset.push_back(temp);
return APP_ERR_OK;
}
APP_ERROR Tbnet::ReadBin_int(const std::string &path, std::vector<std::vector<int64_t>> &dataset,
const int datasize) {
std::ifstream inFile(path, std::ios::binary);
int64_t *data = new int64_t[datasize];
inFile.read(reinterpret_cast<char *>(data), datasize * sizeof(data[0]));
std::vector<int64_t> temp(data, data + datasize);
dataset.push_back(temp);
return APP_ERR_OK;
}
APP_ERROR Tbnet::VectorToTensorBase_float(const std::vector<std::vector<float>> &input,
MxBase::TensorBase &tensorBase,
const std::vector<uint32_t> &shape) {
uint32_t dataSize = 1;
for (int i = 0; i < shape.size(); i++) {
dataSize = dataSize * shape[i];
} // input shape
float *metaFeatureData = new float[dataSize];
uint32_t idx = 0;
for (size_t bs = 0; bs < input.size(); bs++) {
for (size_t c = 0; c < input[bs].size(); c++) {
metaFeatureData[idx++] = input[bs][c];
}
}
MxBase::MemoryData memoryDataDst(dataSize * FLOAT_SIZE, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void *>(metaFeatureData), dataSize * FLOAT_SIZE,
MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_FLOAT32);
return APP_ERR_OK;
}
APP_ERROR Tbnet::VectorToTensorBase_int(const std::vector<std::vector<int64_t>> &input,
MxBase::TensorBase &tensorBase,
const std::vector<uint32_t> &shape) {
int dataSize = 1;
for (int i = 0; i < shape.size(); i++) {
dataSize = dataSize * shape[i];
} // input shape
int64_t *metaFeatureData = new int64_t[dataSize];
uint32_t idx = 0;
for (size_t bs = 0; bs < input.size(); bs++) {
for (size_t c = 0; c < input[bs].size(); c++) {
metaFeatureData[idx++] = input[bs][c];
}
}
MxBase::MemoryData memoryDataDst(dataSize * INT_SIZE, MxBase::MemoryData::MEMORY_DEVICE, deviceId_);
MxBase::MemoryData memoryDataSrc(reinterpret_cast<void *>(metaFeatureData), dataSize * INT_SIZE,
MxBase::MemoryData::MEMORY_HOST_MALLOC);
APP_ERROR ret = MxBase::MemoryHelper::MxbsMallocAndCopy(memoryDataDst, memoryDataSrc);
if (ret != APP_ERR_OK) {
LogError << GetError(ret) << "Memory malloc failed.";
return ret;
}
tensorBase = MxBase::TensorBase(memoryDataDst, false, shape, MxBase::TENSOR_DTYPE_INT64);
return APP_ERR_OK;
}
APP_ERROR Tbnet::Inference(const std::vector<MxBase::TensorBase> &inputs,
std::vector<MxBase::TensorBase> &outputs) {
auto dtypes = model_Tbnet->GetOutputDataType();
for (size_t i = 0; i < modelDesc_.outputTensors.size(); ++i) {
std::vector<uint32_t> shape = {};
for (size_t j = 0; j < modelDesc_.outputTensors[i].tensorDims.size(); ++j) {
shape.push_back((uint32_t)modelDesc_.outputTensors[i].tensorDims[j]);
}
MxBase::TensorBase tensor(shape, dtypes[i], MxBase::MemoryData::MemoryType::MEMORY_DEVICE, deviceId_);
APP_ERROR ret = MxBase::TensorBase::TensorBaseMalloc(tensor);
if (ret != APP_ERR_OK) {
LogError << "TensorBaseMalloc failed, ret=" << ret << ".";
return ret;
}
outputs.push_back(tensor);
}
MxBase::DynamicInfo dynamicInfo = {};
dynamicInfo.dynamicType = MxBase::DynamicType::STATIC_BATCH;
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret = model_Tbnet->ModelInference(inputs, outputs, dynamicInfo);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
inferCostTimeMilliSec += costMs;
if (ret != APP_ERR_OK) {
LogError << "ModelInference Tbnet failed, ret=" << ret << ".";
return ret;
}
return APP_ERR_OK;
}
APP_ERROR Tbnet::Process(const int &index, const std::string &datapath,
const InitParam &initParam, std::vector<int> &outputs) {
std::vector<MxBase::TensorBase> inputs = {};
std::vector<MxBase::TensorBase> outputs_tb = {};
std::vector<std::vector<int64_t>> item;
APP_ERROR ret = ReadBin_int(datapath + "00_item/tbnet_item_bs1_" +
std::to_string(index) + ".bin", item, DATA_SIZE_1);
std::vector<std::vector<int64_t>> rl1;
ReadBin_int(datapath + "01_rl1/tbnet_rl1_bs1_" +
std::to_string(index) + ".bin", rl1, DATA_SIZE_39);
std::vector<std::vector<int64_t>> ety;
ReadBin_int(datapath + "02_ety/tbnet_ety_bs1_" +
std::to_string(index) + ".bin", ety, DATA_SIZE_39);
std::vector<std::vector<int64_t>> rl2;
ReadBin_int(datapath + "03_rl2/tbnet_rl2_bs1_" +
std::to_string(index) + ".bin", rl2, DATA_SIZE_39);
std::vector<std::vector<int64_t>> his;
ReadBin_int(datapath + "04_his/tbnet_his_bs1_" +
std::to_string(index) + ".bin", his, DATA_SIZE_39);
std::vector<std::vector<float>> rate;
ReadBin_float(datapath + "05_rate/tbnet_rate_bs1_" +
std::to_string(index) + ".bin", rate, DATA_SIZE_1);
if (ret != APP_ERR_OK) {
LogError << "ToTensorBase failed, ret=" << ret << ".";
return ret;
}
MxBase::TensorBase tensorBase0;
APP_ERROR ret1 = VectorToTensorBase_int(item, tensorBase0, SHAPE[0]);
inputs.push_back(tensorBase0);
MxBase::TensorBase tensorBase1;
VectorToTensorBase_int(rl1, tensorBase1, SHAPE[1]);
inputs.push_back(tensorBase1);
MxBase::TensorBase tensorBase2;
VectorToTensorBase_int(ety, tensorBase2, SHAPE[2]);
inputs.push_back(tensorBase2);
MxBase::TensorBase tensorBase3;
VectorToTensorBase_int(rl2, tensorBase3, SHAPE[3]);
inputs.push_back(tensorBase3);
MxBase::TensorBase tensorBase4;
VectorToTensorBase_int(his, tensorBase4, SHAPE[4]);
inputs.push_back(tensorBase4);
MxBase::TensorBase tensorBase5;
VectorToTensorBase_float(rate, tensorBase5, SHAPE[5]);
inputs.push_back(tensorBase5);
if (ret1 != APP_ERR_OK) {
LogError << "ToTensorBase failed, ret=" << ret1 << ".";
return ret1;
}
auto startTime = std::chrono::high_resolution_clock::now();
APP_ERROR ret3 = Inference(inputs, outputs_tb);
auto endTime = std::chrono::high_resolution_clock::now();
double costMs = std::chrono::duration<double, std::milli>(endTime - startTime).count();
inferCostTimeMilliSec += costMs;
if (ret3 != APP_ERR_OK) {
LogError << "Inference failed, ret=" << ret3 << ".";
return ret3;
}
for (size_t i = 0; i < outputs_tb.size(); ++i) {
if (!outputs_tb[i].IsHost()) {
outputs_tb[i].ToHost();
}
}
WriteResult(std::to_string(index), outputs_tb);
}
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MXBASE_Tbnet_H
#define MXBASE_Tbnet_H
#include <memory>
#include <string>
#include <vector>
#include "acl/acl.h"
#include "MxBase/DvppWrapper/DvppWrapper.h"
#include "MxBase/ModelInfer/ModelInferenceProcessor.h"
#include "MxBase/Tensor/TensorContext/TensorContext.h"
#include "MxBase/CV/Core/DataType.h"
struct InitParam {
uint32_t deviceId;
bool checkTensor;
std::string modelPath;
};
class Tbnet {
public:
APP_ERROR Init(const InitParam &initParam);
APP_ERROR DeInit();
APP_ERROR VectorToTensorBase_int(const std::vector<std::vector<int64_t>> &input, MxBase::TensorBase &tensorBase,
const std::vector<uint32_t> &shape);
APP_ERROR VectorToTensorBase_float(const std::vector<std::vector<float>> &input, MxBase::TensorBase &tensorBase,
const std::vector<uint32_t> &shape);
APP_ERROR Inference(const std::vector<MxBase::TensorBase> &inputs, std::vector<MxBase::TensorBase> &outputs);
APP_ERROR Process(const int &index, const std::string &datapath,
const InitParam &initParam, std::vector<int> &outputss);
APP_ERROR ReadBin_int(const std::string &path, std::vector<std::vector<int64_t>> &dataset,
const int shape);
APP_ERROR ReadBin_float(const std::string &path, std::vector<std::vector<float>> &dataset,
const int shape);
// get infer time
double GetInferCostMilliSec() const {return inferCostTimeMilliSec;}
private:
std::shared_ptr<MxBase::ModelInferenceProcessor> model_Tbnet;
MxBase::ModelDesc modelDesc_;
uint32_t deviceId_ = 0;
// infer time
double inferCostTimeMilliSec = 0.0;
};
#endif
/*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <dirent.h>
#include <fstream>
#include "MxBase/Log/Log.h"
#include "Tbnet.h"
namespace {
const uint32_t DATA_SIZE = 18415;
} // namespace
int main(int argc, char* argv[]) {
InitParam initParam = {};
initParam.deviceId = 0;
initParam.checkTensor = true;
initParam.modelPath = "../data/model/tbnet.om";
std::string dataPath = "../../preprocess_Result/";
auto model_Tbnet = std::make_shared<Tbnet>();
APP_ERROR ret = model_Tbnet->Init(initParam);
if (ret != APP_ERR_OK) {
LogError << "Tagging init failed, ret=" << ret << ".";
model_Tbnet->DeInit();
return ret;
}
std::vector<int> outputs;
for (int i=0; i < DATA_SIZE; i++) {
LogInfo << "processing " << i;
ret = model_Tbnet->Process(i, dataPath, initParam, outputs);
if (ret !=APP_ERR_OK) {
LogError << "Tbnet process failed, ret=" << ret << ".";
model_Tbnet->DeInit();
return ret;
}
}
model_Tbnet->DeInit();
double total_time = model_Tbnet->GetInferCostMilliSec() / 1000;
LogInfo<< "inferance total cost time: "<< total_time<< ", FPS: "<< DATA_SIZE/total_time;
return APP_ERR_OK;
}
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" main.py """
import argparse
import os
from StreamManagerApi import StreamManagerApi, StringVector
from StreamManagerApi import MxDataInput, InProtobufVector, MxProtobufIn
import MxpiDataType_pb2 as MxpiDataType
import numpy as np
def parse_args(parsers):
"""
Parse commandline arguments.
"""
parsers.add_argument('--data_path', type=str,
default="../../preprocess_Result",
help='text path')
return parsers
def create_protobuf(path, id1, shape):
# Construct the input of the stream
data_input = MxDataInput()
with open(path, 'rb') as f:
data = f.read()
data_input.data = data
tensorPackageList1 = MxpiDataType.MxpiTensorPackageList()
tensorPackage1 = tensorPackageList1.tensorPackageVec.add()
tensorVec1 = tensorPackage1.tensorVec.add()
tensorVec1.deviceId = 0
tensorVec1.memType = 0
for t in shape:
tensorVec1.tensorShape.append(t)
tensorVec1.dataStr = data_input.data
tensorVec1.tensorDataSize = len(data)
protobuf1 = MxProtobufIn()
protobuf1.key = b'appsrc%d' % id1
protobuf1.type = b'MxTools.MxpiTensorPackageList'
protobuf1.protobuf = tensorPackageList1.SerializeToString()
return protobuf1
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Om tbnet Inference')
parser = parse_args(parser)
args, _ = parser.parse_known_args()
# init stream manager
stream_manager = StreamManagerApi()
ret = stream_manager.InitManager()
if ret != 0:
print("Failed to init Stream manager, ret=%s" % str(ret))
exit()
# create streams by pipeline config file
with open("../data/config/tbnet.pipeline", 'rb') as fl:
pipeline = fl.read()
ret = stream_manager.CreateMultipleStreams(pipeline)
if ret != 0:
print("Failed to create Stream, ret=%s" % str(ret))
exit()
# Construct the input of the stream
res_dir_name = 'result'
if not os.path.exists(res_dir_name):
os.makedirs(res_dir_name)
results = []
input_names = ['00_item', '01_rl1', '02_ety', '03_rl2', '04_his', '05_rate']
shape_list = [[1], [1, 39], [1, 39], [1, 39], [1, 39], [1]]
for idx in range(18415):
print('infer %d' % idx)
for index, name in enumerate(input_names):
protobufVec = InProtobufVector()
path_tmp = os.path.join(args.data_path, name,
'tbnet_' + name.split('_')[1] + '_bs1_' + str(idx) + '.bin')
protobufVec.push_back(create_protobuf(path_tmp, index, shape_list[index]))
unique_id = stream_manager.SendProtobuf(b'tbnet', b'appsrc%d' % index, protobufVec)
keyVec = StringVector()
keyVec.push_back(b'mxpi_tensorinfer0')
infer_result = stream_manager.GetProtobuf(b'tbnet', 0, keyVec)
if infer_result.size() == 0:
print("inferResult is null")
exit()
if infer_result[0].errorCode != 0:
print("GetProtobuf error. errorCode=%d" % (
infer_result[0].errorCode))
exit()
# get infer result
result = MxpiDataType.MxpiTensorPackageList()
result.ParseFromString(infer_result[0].messageBuf)
for i in range(4):
res = np.frombuffer(result.tensorPackageVec[0].tensorVec[i].dataStr, dtype=np.float32)
np.savetxt("./result/tbnet_item_bs1_%d_%d.txt" % (idx, i), res, fmt='%.06f')
# destroy streams
stream_manager.DestroyAllStreams()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment