Skip to content
Snippets Groups Projects
Commit e0bd3b3c authored by d294270681's avatar d294270681
Browse files

add train scripts on Ascend NPU

update official/recommend/tbnet/scripts/run_eval.sh.

update official/recommend/tbnet/scripts/run_standalone_train.sh.
parent 03294cfb
No related branches found
No related tags found
No related merge requests found
Showing
with 969 additions and 179 deletions
......@@ -54,10 +54,17 @@ Note that the \<item\> needs to traverse candidate items (all items by default)
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
```
We have to download the data package and put it underneath the current project path。
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
```
# [Environment Requirements](#contents)
- Hardware(GPU)
- Prepare hardware environment with GPU processor.
- Hardware(NVIDIA GPU or Ascend NPU)
- Prepare hardware environment with NVIDIA GPU or Ascend NPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
......@@ -70,51 +77,57 @@ After installing MindSpore via the official website, you can start training and
- Data preprocessing
Process the data to the format in chapter [Dataset](#Dataset) (e.g. 'steam' dataset), and then run code as follows.
Download the data package(e.g. 'steam' dataset) and put it underneath the current project path.
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
cd scripts
```
and then run code as follows.
- Training
```bash
python train.py \
--dataset [DATASET] \
--epochs [EPOCHS]
bash run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
Example:
```bash
python train.py \
--dataset steam \
--epochs 20
bash run_standalone_train.sh steam 0 Ascend
```
- Evaluation
Evaluation model on test dataset.
```bash
python eval.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID]
bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
Argument `--checkpoint_id` is required.
Argument `[CHECKPOINT_ID]` is required.
Example:
```bash
python eval.py \
--dataset steam \
--checkpoint_id 8
bash run_eval.sh 19 steam 0 Ascend
```
- Inference and Explanation
Recommende items to user acrodding to `user`, the number of items is determined by `items`.
```bash
python infer.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID] \
--user [USER] \
--items [ITEMS] \
--explanations [EXPLANATIONS]
--explanations [EXPLANATIONS] \
--csv [CSV] \
--device_target [DEVICE_TARGET]
```
Arguments `--checkpoint_id` and `--user` are required.
......@@ -124,10 +137,12 @@ Example:
```bash
python infer.py \
--dataset steam \
--checkpoint_id 8 \
--user 1 \
--checkpoint_id 19 \
--user 2 \
--items 1 \
--explanations 3
--explanations 3 \
--csv test.csv \
--device_target Ascend
```
# [Script Description](#contents)
......@@ -139,14 +154,16 @@ python infer.py \
└─tbnet
├─README.md
├── scripts
│ └─run_infer_310.sh # Ascend310 inference script
├─run_infer_310.sh # Ascend310 inference script
├─run_standalone_train.sh # NVIDIA GPU or Ascend NPU training script
└─run_eval.sh # NVIDIA GPU or Ascend NPU evaluation script
├─data
├─steam
├─config.json # data and training parameter configuration
├─infer.csv # inference and explanation dataset
├─test.csv # evaluation dataset
├─train.csv # training dataset
└─trainslate.json # explanation configuration
├─src_infer.csv # inference and explanation dataset
├─src_test.csv # evaluation dataset
├─src_train.csv # training dataset
└─id_maps.json # explanation configuration
├─src
├─aggregator.py # inference result aggregation
├─config.py # parsing parameter configuration
......@@ -156,6 +173,7 @@ python infer.py \
├─steam.py # 'steam' dataset text explainer
└─tbnet.py # TB-Net model
├─export.py # export mindir script
├─preprocess_dataset.py # dataset preprocess script
├─preprocess.py # inference data preprocess script
├─postprocess.py # inference result calculation script
├─eval.py # evaluation
......@@ -165,6 +183,14 @@ python infer.py \
## [Script Parameters](#contents)
- preprocess_dataset.py parameters
```text
--dataset 'steam' dataset is supported currently
--device_target run code on GPU or Ascend NPU
--same_relation only generate paths that relation1 is same as relation2
```
- train.py parameters
```text
......@@ -173,7 +199,7 @@ python infer.py \
--test_csv the test csv datafile inside the dataset folder
--device_id device id
--epochs number of training epochs
--device_target run code on GPU
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
......@@ -184,7 +210,7 @@ python infer.py \
--csv the csv datafile inside the dataset folder (e.g. test.csv)
--checkpoint_id use which checkpoint(.ckpt) file to eval
--device_id device id
--device_target run code on GPU
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
......@@ -198,7 +224,7 @@ python infer.py \
--items no. of items to be recommended
--reasons no. of recommendation reasons to be shown
--device_id device id
--device_target run code on GPU
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
......@@ -215,6 +241,17 @@ python export.py --config_path [CONFIG_PATH] --checkpoint_path [CKPT_PATH] --dev
- `DEVICE` should be in ['Ascend', 'GPU'].
- `FILE_FORMAT` should be in ['MINDIR', 'AIR'].
Example:
```bash
python export.py \
--config_path ./data/steam/config.json \
--checkpoint_path ./checkpoints/tbnet_epoch19.ckpt \
--device_target Ascend \
--file_name model \
--file_format MINDIR
```
### [Infer on Ascend310](#contents)
Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model.
......@@ -228,6 +265,12 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
- `DATA_PATH` specifies path of test.csv.
- `DEVICE_ID` is optional, default value is 0.
Example:
```bash
bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
```
### [Result](#contents)
Inference result is saved in current path, you can find result like this in acc.log file.
......@@ -242,35 +285,35 @@ auc: 0.8251359368836292
### Training Performance
| Parameters | GPU |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | TB-Net |
| Resource |Tesla V100-SXM2-32GB |
| Uploaded Date | 2021-08-01 |
| MindSpore Version | 1.3.0 |
| Dataset | steam |
| Training Parameter | epoch=20, batch_size=1024, lr=0.001 |
| Optimizer | Adam |
| Loss Function | Sigmoid Cross Entropy |
| Outputs | AUC=0.8596,Accuracy=0.7761 |
| Loss | 0.57 |
| Speed | 1pc: 90ms/step |
| Total Time | 1pc: 297s |
| Checkpoint for Fine Tuning | 104.66M (.ckpt file) |
| Scripts | [TB-Net scripts](https://gitee.com/mindspore/models/tree/master/official/recommend/tbnet) |
| Parameters | GPU | Ascend NPU |
| -------------------------- |--------------------------------------------------------------------------------------------| ---------------------------------------------|
| Model Version | TB-Net | TB-Net |
| Resource | NVIDIA RTX 3090 | Ascend 910 |
| Uploaded Date | 2022-07-14 | 2022-06-30 |
| MindSpore Version | 1.6.1 | 1.6.1 |
| Dataset | steam | steam |
| Training Parameter | epoch=20, batch_size=1024, lr=0.001 | epoch=20, batch_size=1024, lr=0.001 |
| Optimizer | Adam | Adam |
| Loss Function | Sigmoid Cross Entropy | Sigmoid Cross Entropy |
| Outputs | AUC=0.8573,Accuracy=0.7733 | AUC=0.8592,准确率=0.7741 |
| Loss | 0.57 | 0.59 |
| Speed | 1pc: 90ms/step | 单卡:80毫秒/步 |
| Total Time | 1pc: 297s | 单卡:336秒 |
| Checkpoint for Fine Tuning | 686.3K (.ckpt file) | 671K (.ckpt 文件) |
| Scripts | [TB-Net scripts](https://gitee.com/mindspore/models/tree/master/official/recommend/tbnet) |
### Evaluation Performance
| Parameters | GPU |
| ------------------------- | ----------------------------- |
| Model Version | TB-Net |
| Resource | Tesla V100-SXM2-32GB |
| Uploaded Date | 2021-08-01 |
| MindSpore Version | 1.3.0 |
| Dataset | steam |
| Batch Size | 1024 |
| Outputs | AUC=0.8252,Accuracy=0.7503 |
| Total Time | 1pc: 5.7s |
| Parameters | GPU | Ascend NPU |
| ------------------------- |----------------------------| ----------------------------- |
| Model Version | TB-Net | TB-Net |
| Resource | NVIDIA RTX 3090 | Ascend 910 |
| Uploaded Date | 2022-07-14 | 2022-06-30 |
| MindSpore Version | 1.3.0 | 1.5.1 |
| Dataset | steam | steam |
| Batch Size | 1024 | 1024 |
| Outputs | AUC=0.8487,Accuracy=0.7699 | AUC=0.8486,Accuracy=0.7704 |
| Total Time | 1pc: 5.7s | 1pc: 1.1秒 |
### Inference and Explanation Performance
......
......@@ -58,8 +58,8 @@ TB-Net将用户和物品的交互信息以及物品的属性信息在知识图
# [环境要求](#目录)
- 硬件(GPU)
- 使用GPU处理器准备硬件环境。
- 硬件(NVIDIA GPU or Ascend NPU)
- 使用NVIDIA GPU处理器或者Ascend NPU处理器准备硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
......@@ -72,51 +72,57 @@ TB-Net将用户和物品的交互信息以及物品的属性信息在知识图
- 数据准备
将数据处理成上一节[数据集](#数据集)中的格式(以'steam'数据集为例),然后按照以下步骤运行代码。
下载用例数据集包(以'steam'数据集为例),解压到当前项目路径。
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
cd scripts
```
然后按照以下步骤运行代码。
- 训练
```bash
python train.py \
--dataset [DATASET] \
--epochs [EPOCHS]
bash run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
示例:
```bash
python train.py \
--dataset steam \
--epochs 20
bash run_standalone_train.sh steam 0 Ascend
```
- 评估
评估模型在测试集上的指标。
```bash
python eval.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID]
bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
参数`--checkpoint_id`是必填项。
参数`[CHECKPOINT_ID]`是必填项。
示例:
```bash
python eval.py \
--dataset steam \
--checkpoint_id 8
bash run_eval.sh 19 steam 0 Ascend
```
- 推理和解释
根据`user`推荐一定数量的物品,数量由`items`决定。
```bash
python infer.py \
--dataset [DATASET] \
--checkpoint_id [CHECKPOINT_ID] \
--user [USER] \
--items [ITEMS] \
--explanations [EXPLANATIONS]
--explanations [EXPLANATIONS] \
--csv [CSV] \
--device_target [DEVICE_TARGET]
```
参数`--checkpoint_id``--user`是必填项。
......@@ -126,10 +132,12 @@ python infer.py \
```bash
python infer.py \
--dataset steam \
--checkpoint_id 8 \
--user 1 \
--checkpoint_id 19 \
--user 2 \
--items 1 \
--explanations 3
--explanations 3 \
--csv test.csv \
--device_target Ascend
```
# [脚本说明](#目录)
......@@ -141,14 +149,16 @@ python infer.py \
└─tbnet
├─README.md
├── scripts
│ └─run_infer_310.sh # 用于Ascend310推理的脚本
├─run_infer_310.sh # 用于Ascend310推理的脚本
├─run_standalone_train.sh # 用于NVIDIA GPU或者Ascend NPU训练的脚本
└─run_eval.sh # 用于NVIDIA GPU或者Ascend NPU评估的脚本
├─data
├─steam
├─config.json # 数据和训练参数配置
├─infer.csv # 推理和解释数据集
├─test.csv # 测试数据集
├─train.csv # 训练数据集
└─trainslate.json # 输出解释相关配置
├─src_infer.csv # 推理和解释数据集
├─src_test.csv # 测试数据集
├─src_train.csv # 训练数据集
└─id_maps.json # 输出解释相关配置
├─src
├─aggregator.py # 推理结果聚合
├─config.py # 参数配置解析
......@@ -157,9 +167,10 @@ python infer.py \
├─metrics.py # 模型度量
├─steam.py # 'steam'数据集文本解析
└─tbnet.py # TB-Net网络
├─export.py # 导出MINDIR脚本
├─preprocess.py # 推理数据预处理脚本
├─postprocess.py # 推理结果计算脚本
├─export.py # 导出MINDIR脚本
├─preprocess_dataset.py # 数据集预处理脚本
├─preprocess.py # 推理数据预处理脚本
├─postprocess.py # 推理结果计算脚本
├─eval.py # 评估网络
├─infer.py # 推理和解释
└─train.py # 训练网络
......@@ -167,6 +178,14 @@ python infer.py \
## [脚本参数](#目录)
- preprocess_dataset.py参数
```text
--dataset 'steam' dataset is supported currently
--device_target run code on GPU or Ascend NPU
--same_relation only generate paths that relation1 is same as relation2
```
- train.py参数
```text
......@@ -175,7 +194,7 @@ python infer.py \
--test_csv the test csv datafile inside the dataset folder
--device_id device id
--epochs number of training epochs
--device_target run code on GPU
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
......@@ -186,7 +205,7 @@ python infer.py \
--csv the csv datafile inside the dataset folder (e.g. test.csv)
--checkpoint_id use which checkpoint(.ckpt) file to eval
--device_id device id
--device_target run code on GPU
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
......@@ -200,7 +219,7 @@ python infer.py \
--items no. of items to be recommended
--reasons no. of recommendation reasons to be shown
--device_id device id
--device_target run code on GPU
--device_target run code on GPU or Ascend NPU
--run_mode run code by GRAPH mode or PYNATIVE mode
```
......@@ -209,7 +228,12 @@ python infer.py \
### 导出MindIR
```shell
python export.py --config_path [CONFIG_PATH] --checkpoint_path [CKPT_PATH] --device_target [DEVICE] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
python export.py \
--config_path [CONFIG_PATH] \
--checkpoint_path [CKPT_PATH] \
--device_target [DEVICE] \
--file_name [FILE_NAME] \
--file_format [FILE_FORMAT]
```
- `CKPT_PATH` 为必填项。
......@@ -217,6 +241,17 @@ python export.py --config_path [CONFIG_PATH] --checkpoint_path [CKPT_PATH] --dev
- `DEVICE` 可选项为 ['Ascend', 'GPU']。
- `FILE_FORMAT` 可选项为 ['MINDIR', 'AIR']。
示例:
```bash
python export.py \
--config_path ./data/steam/config.json \
--checkpoint_path ./checkpoints/tbnet_epoch19.ckpt \
--device_target Ascend \
--file_name model \
--file_format MINDIR
```
### 在Ascend310执行推理
在执行推理前,mindir文件必须通过`export.py`脚本导出。以下展示了使用minir模型执行推理的示例。
......@@ -230,6 +265,12 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
- `DATA_PATH` 推理数据集test.csv路径
- `DEVICE_ID` 可选,默认值为0。
示例:
```bash
bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
```
### 结果
推理结果保存在脚本执行的当前路径,你可以在acc.log中看到以下精度计算结果。
......@@ -244,35 +285,35 @@ auc: 0.8251359368836292
### [训练性能](#目录)
| 参数 | GPU |
| ------------------- | --------------------------------------------------- |
| 模型版本 | TB-Net |
| 资源 |Tesla V100-SXM2-32GB |
| 上传日期 | 2021-08-01 |
| MindSpore版本 | 1.3.0 |
| 数据集 | steam |
| 训练参数 | epoch=20, batch_size=1024, lr=0.001 |
| 优化器 | Adam |
| 损失函数 | Sigmoid交叉熵 |
| 输出 | AUC=0.8596,准确率=0.7761 |
| 损失 | 0.57 |
| 速度 | 单卡:90毫秒/步 |
| 总时长 | 单卡:297秒 |
| 微调检查点 | 104.66M (.ckpt 文件) |
| 参数 | GPU | Ascend NPU |
| ------------------- |-------------------------------------------------------------------------------------|-------------------------------------|
| 模型版本 | TB-Net | TB-Net |
| 资源 | NVIDIA RTX 3090 | Ascend 910 |
| 上传日期 | 2022-07-14 | 2022-06-30 |
| MindSpore版本 | 1.6.1 | 1.6.1 |
| 数据集 | steam | steam |
| 训练参数 | epoch=20, batch_size=1024, lr=0.001 | epoch=20, batch_size=1024, lr=0.001 |
| 优化器 | Adam | Adam |
| 损失函数 | Sigmoid交叉熵 | Sigmoid交叉熵 |
| 输出 | AUC=0.8573,准确率=0.7733 | AUC=0.8592,准确率=0.7741 |
| 损失 | 0.57 | 0.59 |
| 速度 | 单卡:90毫秒/步 | 单卡:80毫秒/步 |
| 总时长 | 单卡:297秒 | 单卡:336秒 |
| 微调检查点 | 686.3K (.ckpt 文件) | 671K (.ckpt 文件) |
| 脚本 | [TB-Net脚本](https://gitee.com/mindspore/models/tree/master/official/recommend/tbnet) |
### [评估性能](#目录)
| 参数 | GPU |
| -------------------------- | ----------------------------- |
| 模型版本 | TB-Net |
| 资源 | Tesla V100-SXM2-32GB |
| 上传日期 | 2021-08-01 |
| MindSpore版本 | 1.3.0 |
| 数据集 | steam |
| 批次大小 | 1024 |
| 输出 | AUC=0.8252,准确率=0.7503 |
| 总时长 | 单卡:5.7秒 |
| 参数 | GPU | Ascend NPU |
| -------------------------- |-----------------------| ----------------------------- |
| 模型版本 | TB-Net | TB-Net |
| 资源 | NVIDIA RTX 3090 | Ascend 910 |
| 上传日期 | 2022-07-14 | 2022-06-30 |
| MindSpore版本 | 1.6.1 | 1.6.1 |
| 数据集 | steam | steam |
| 批次大小 | 1024 | 1024 |
| 输出 | AUC=0.8487,准确率=0.7699 | AUC=0.8486,准确率=0.7704 |
| 总时长 | 单卡:5.7秒 | 单卡:1.1秒 |
### [推理和解释性能](#目录)
......
{
"num_item": 3005,
"num_relation": 5,
"num_entity": 5138,
"per_item_num_paths": 39,
"embedding_dim": 26,
"batch_size": 1024,
"lr": 0.001,
"kge_weight": 0.05,
"node_weight": 0.002,
"l2_weight": 1e-6
}
\ No newline at end of file
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
\ No newline at end of file
#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
\ No newline at end of file
{
"item": {
"0": "Star Wars",
"1": "Battlefield 1"
},
"relation": {
"0": "Developer",
"1": "Genre"
},
"entity": {
"425": "EA Games",
"426": "Shooting"
}
}
\ No newline at end of file
......@@ -16,8 +16,10 @@
import os
import argparse
import math
from mindspore import context, Model, load_checkpoint, load_param_into_net
import mindspore.common.dtype as mstype
from src import tbnet, config, metrics, dataset
......@@ -62,8 +64,8 @@ def get_args():
type=str,
required=False,
default='GPU',
choices=['GPU'],
help="run code on GPU"
choices=['GPU', 'Ascend'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
......@@ -95,10 +97,15 @@ def eval_tbnet():
print(f"creating dataset from {test_csv_path}...")
net_config = config.TBNetConfig(config_path)
eval_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
if args.device_target == 'Ascend':
net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
eval_ds = dataset.create(test_csv_path, net_config.per_item_paths, train=True).batch(net_config.batch_size)
print(f"creating TBNet from checkpoint {args.checkpoint_id} for evaluation...")
network = tbnet.TBNet(net_config)
if args.device_target == 'Ascend':
network.to_float(mstype.float16)
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
load_param_into_net(network, param_dict)
......
......@@ -16,6 +16,7 @@
import os
import argparse
import math
import numpy as np
from mindspore import context, load_checkpoint, load_param_into_net, Tensor, export
......@@ -103,20 +104,23 @@ def export_tbnet():
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
net_config = config.TBNetConfig(config_path)
if args.device_target == 'Ascend':
net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
network = tbnet.TBNet(net_config)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(network, param_dict)
eval_net = tbnet.PredictWithSigmoid(network)
item = Tensor(np.ones((1,)).astype(np.int))
rl1 = Tensor(np.ones((1, 39)).astype(np.int))
ety = Tensor(np.ones((1, 39)).astype(np.int))
rl2 = Tensor(np.ones((1, 39)).astype(np.int))
his = Tensor(np.ones((1, 39)).astype(np.int))
rl1 = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
ety = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
rl2 = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
his = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
rate = Tensor(np.ones((1,)).astype(np.float32))
inputs = [item, rl1, ety, rl2, his, rate]
export(eval_net, *inputs, file_name=args.file_name, file_format=args.file_format)
if __name__ == '__main__':
export_tbnet()
......@@ -16,8 +16,10 @@
import os
import argparse
import math
from mindspore import load_checkpoint, load_param_into_net, context
import mindspore.common.dtype as mstype
from src.config import TBNetConfig
from src.tbnet import TBNet
from src.aggregator import InferenceAggregator
......@@ -88,8 +90,8 @@ def get_args():
type=str,
required=False,
default='GPU',
choices=['GPU'],
help="run code on GPU"
choices=['GPU', 'Ascend'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
......@@ -121,12 +123,17 @@ def infer_tbnet():
print(f"creating TBNet from checkpoint {args.checkpoint_id}...")
config = TBNetConfig(config_path)
if args.device_target == 'Ascend':
config.per_item_paths = math.ceil(config.per_item_paths / 16) * 16
config.embedding_dim = math.ceil(config.embedding_dim / 16) * 16
network = TBNet(config)
if args.device_target == 'Ascend':
network.to_float(mstype.float16)
param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
load_param_into_net(network, param_dict)
print(f"creating dataset from {data_path}...")
infer_ds = dataset.create(data_path, config.per_item_num_paths, train=False, users=args.user)
infer_ds = dataset.create(data_path, config.per_item_paths, train=False, users=args.user)
infer_ds = infer_ds.batch(config.batch_size)
print("inferring...")
......
......@@ -17,6 +17,7 @@
import os
import argparse
import shutil
import math
import numpy as np
from mindspore import context
......@@ -44,7 +45,6 @@ def get_args():
help="the csv datafile inside the dataset folder (e.g. test.csv)"
)
parser.add_argument(
'--device_id',
type=int,
......@@ -58,8 +58,8 @@ def get_args():
type=str,
required=False,
default='Ascend',
choices=['Ascend'],
help="run code on GPU"
choices=['Ascend', 'GPU'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
......@@ -90,7 +90,9 @@ def preprocess_tbnet():
print(f"creating dataset from {test_csv_path}...")
net_config = config.TBNetConfig(config_path)
eval_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(1)
if args.device_target == 'Ascend':
net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
eval_ds = dataset.create(test_csv_path, net_config.per_item_paths, train=True).batch(1)
item_path = os.path.join('./preprocess_Result/', '00_item')
rl1_path = os.path.join('./preprocess_Result/', '01_rl1')
ety_path = os.path.join('./preprocess_Result/', '02_ety')
......@@ -134,5 +136,7 @@ def preprocess_tbnet():
rate_rst.tofile(rate_real_path)
idx += 1
if __name__ == '__main__':
preprocess_tbnet()
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data Preprocessing app."""
# This script should be run directly with 'python <script> <args>'.
import os
import io
import argparse
import json
import math
from src.path_gen import PathGen
from src.config import TBNetConfig
def get_args():
"""Parse commandline arguments."""
parser = argparse.ArgumentParser(description='Preprocess TB-Net data.')
parser.add_argument(
'--dataset',
type=str,
required=False,
default='steam',
help="'steam' dataset is supported currently"
)
parser.add_argument(
'--device_target',
type=str,
required=False,
default='GPU',
choices=['GPU', 'Ascend'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
'--same_relation',
required=False,
action='store_true',
default=False,
help="only generate paths that relation1 is same as relation2"
)
return parser.parse_args()
def preprocess_csv(path_gen, data_home, src_name, out_name):
"""Pre-process a csv file."""
src_path = os.path.join(data_home, src_name)
out_path = os.path.join(data_home, out_name)
print(f'converting {src_path} to {out_path} ...')
rows = path_gen.generate(src_path, out_path)
print(f'{rows} rows of path data generated.')
def preprocess_data():
"""Pre-process the dataset."""
args = get_args()
home = os.path.dirname(os.path.realpath(__file__))
data_home = os.path.join(home, 'data', args.dataset)
config_path = os.path.join(data_home, 'config.json')
id_maps_path = os.path.join(data_home, 'id_maps.json')
cfg = TBNetConfig(config_path)
if args.device_target == 'Ascend':
cfg.per_item_paths = math.ceil(cfg.per_item_paths / 16) * 16
path_gen = PathGen(per_item_paths=cfg.per_item_paths, same_relation=args.same_relation)
preprocess_csv(path_gen, data_home, 'src_train.csv', 'train.csv')
# save id maps for the later use by Recommender in infer.py
with io.open(id_maps_path, mode="w", encoding="utf-8") as f:
json.dump(path_gen.id_maps(), f, indent=4)
# count distinct objects from the training set
cfg.num_items = path_gen.num_items
cfg.num_references = path_gen.num_references
cfg.num_relations = path_gen.num_relations
cfg.save(config_path)
print(f'{config_path} updated.')
print(f'num_items: {cfg.num_items}')
print(f'num_references: {cfg.num_references}')
print(f'num_relations: {cfg.num_relations}')
# treat new items and references in test and infer set as unseen entities
# dummy internal id 0 will be assigned to them
path_gen.grow_id_maps = False
preprocess_csv(path_gen, data_home, 'src_test.csv', 'test.csv')
# for inference, only take interacted('c') and other('x') items as candidate items,
# the purchased('p') items won't be recommended.
# assume there is only one user in src_infer.csv
path_gen.subject_ratings = "cx"
preprocess_csv(path_gen, data_home, 'src_infer.csv', 'infer.csv')
print(f'Dataset {data_home} processed.')
if __name__ == '__main__':
preprocess_data()
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 3 || $# -gt 4 ]]; then
echo "Usage: bash run_train.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
CHECKPOINT_ID means model checkpoint id.
DATA_NAME means dataset name, it's value is 'steam'.
DEVICE_ID means device id, it can be set by environment variable DEVICE_ID.
DEVICE_TARGET is optional, it's value is ['GPU', 'Ascend'], default 'GPU'."
exit 1
fi
CHECKPOINT_ID=$1
DATA_NAME=$2
DEVICE_ID=$3
DEVICE_TARGET='GPU'
if [ $# == 4 ]; then
DEVICE_TARGET=$4
fi
python ../eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
--device_id $DEVICE_ID &> eval_standalone_gpu_log &
\ No newline at end of file
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: bash run_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
DATA_NAME means dataset name, it's value is 'steam'.
DEVICE_ID means device id, it can be set by environment variable DEVICE_ID.
DEVICE_TARGET is optional, it's value is ['GPU', 'Ascend'], default 'GPU'."
exit 1
fi
DATA_NAME=$1
DEVICE_ID=$2
DEVICE_TARGET='GPU'
if [ $# == 3 ]; then
DEVICE_TARGET=$3
fi
python ../preprocess_dataset.py --dataset $DATA_NAME --device_target $DEVICE_TARGET &> train_standalone_log &&
python ../train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id $DEVICE_ID &>> train_standalone_log &
\ No newline at end of file
......@@ -27,13 +27,17 @@ class TBNetConfig:
def __init__(self, config_path):
with open(config_path) as f:
json_dict = json.load(f)
self.num_item = int(json_dict['num_item'])
self.num_relation = int(json_dict['num_relation'])
self.num_entity = int(json_dict['num_entity'])
self.per_item_num_paths = int(json_dict['per_item_num_paths'])
self.num_items = int(json_dict['num_items'])
self.num_relations = int(json_dict['num_relations'])
self.num_references = int(json_dict['num_references'])
self.per_item_paths = int(json_dict['per_item_paths'])
self.embedding_dim = int(json_dict['embedding_dim'])
self.batch_size = int(json_dict['batch_size'])
self.lr = float(json_dict['lr'])
self.kge_weight = float(json_dict['kge_weight'])
self.node_weight = float(json_dict['node_weight'])
self.l2_weight = float(json_dict['l2_weight'])
def save(self, config_path):
with open(config_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)
......@@ -14,10 +14,12 @@
# ============================================================================
"""Dataset loader."""
import os
from functools import partial
import numpy as np
from mindspore.dataset import GeneratorDataset
import mindspore.dataset as ds
import mindspore.mindrecord as record
def create(data_path, per_item_num_paths, train, users=None, **kwargs):
......@@ -39,13 +41,71 @@ def create(data_path, per_item_num_paths, train, users=None, **kwargs):
"""
if isinstance(users, int):
users = (users,)
kwargs['source'] = partial(csv_generator, data_path, per_item_num_paths, users, train)
if train:
kwargs['column_names'] = ['item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
kwargs['columns_list'] = ['item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
else:
kwargs['column_names'] = ['user', 'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
return GeneratorDataset(**kwargs)
kwargs['columns_list'] = ['user', 'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
mindrecord_file_path = csv_dataset(partial(csv_generator, data_path, per_item_num_paths, users, train), data_path,
train)
return ds.MindDataset(mindrecord_file_path, **kwargs)
def csv_dataset(generator, csv_path, train):
"""Dataset for csv datafile."""
file_name = os.path.basename(csv_path)
mindrecord_file_path = os.path.join(os.path.dirname(csv_path), file_name[0:file_name.rfind('.')] + '.mindrecord')
if os.path.exists(mindrecord_file_path):
os.remove(mindrecord_file_path)
if os.path.exists(mindrecord_file_path + ".db"):
os.remove(mindrecord_file_path + ".db")
data_schema = {
"item": {"type": "int32", "shape": []},
"relation1": {"type": "int32", "shape": [-1]},
"entity": {"type": "int32", "shape": [-1]},
"relation2": {"type": "int32", "shape": [-1]},
"hist_item": {"type": "int32", "shape": [-1]},
"rating": {"type": "float32", "shape": []},
}
if not train:
data_schema["user"] = {"type": "int32", "shape": []}
writer = record.FileWriter(file_name=mindrecord_file_path, shard_num=1)
writer.add_schema(data_schema, "Preprocessed dataset.")
data = []
for i, row in enumerate(generator()):
if train:
sample = {
"item": row[0],
"relation1": row[1],
"entity": row[2],
"relation2": row[3],
"hist_item": row[4],
"rating": row[5],
}
else:
sample = {
"user": row[0],
"item": row[1],
"relation1": row[2],
"entity": row[3],
"relation2": row[4],
"hist_item": row[5],
"rating": row[6],
}
data.append(sample)
if i % 10 == 0:
writer.write_raw_data(data)
data = []
if data:
writer.write_raw_data(data)
writer.commit()
return mindrecord_file_path
def csv_generator(csv_path, per_item_num_paths, users, train):
......@@ -81,8 +141,8 @@ def csv_generator(csv_path, per_item_num_paths, users, train):
if train:
# item, relation1, entity, relation2, hist_item, rating
yield np.array(item, dtype=np.int), relation1, entity, relation2, hist_item, \
np.array(rating, dtype=np.float32)
np.array(rating, dtype=np.float32)
else:
# user, item, relation1, entity, relation2, hist_item, rating
yield np.array(user, dtype=np.int), np.array(item, dtype=np.int),\
relation1, entity, relation2, hist_item, np.array(rating, dtype=np.float32)
yield np.array(user, dtype=np.int), np.array(item, dtype=np.int), \
relation1, entity, relation2, hist_item, np.array(rating, dtype=np.float32)
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Relation path data generator."""
import io
import random
import csv
import warnings
class _UserRec:
"""User record, helper class for path generation."""
def __init__(self, src_id, intern_id):
self.src_id = src_id
self.intern_id = intern_id
self.positive_items = dict()
self.interact_items = dict()
self.other_items = dict()
self.has_unseen_ref = False
def add_item(self, item_rec, rating):
"""Add an item."""
if rating == 'p':
item_dict = self.positive_items
elif rating == 'c':
item_dict = self.interact_items
else:
item_dict = self.other_items
item_dict[item_rec.intern_id] = item_rec
class _ItemRec:
"""Item record, helper class for path generation."""
def __init__(self, src_id, intern_id, ref_src_ids, ref_ids):
self.src_id = src_id
self.intern_id = intern_id
self.ref_src_ids = ref_src_ids
self.ref_ids = ref_ids
class PathGen:
"""
Generate relation path csv from the source csv table.
Args:
per_item_paths (int): Number of relation paths per subject item, must be positive.
same_relation (bool): True to only generate paths that relation1 is same as relation2, usually faster.
id_maps (dict[str, Union[dict[str, int], int]], Optional): Object id maps, the internal id baseline, new user,
item and entity IDs will be based on that. If Which is None or empty, grow_id_maps will be True by
default.
"""
def __init__(self, per_item_paths, same_relation=False, id_maps=None):
self._per_item_paths = per_item_paths
self._same_relation = same_relation
self._user_id_counter = 1
self._entity_id_counter = 1
self._num_relations = 0
self._rows_generated = 0
self._user_rec = None
if id_maps:
self._item_id_map = id_maps.get('item', dict())
self._ref_id_map = id_maps.get('reference', dict())
self._rl_id_map = id_maps.get('relation', None)
self._user_id_counter = id_maps.get('_user_id_counter', self._user_id_counter)
max_item_id = max(self._item_id_map.values()) if self._item_id_map else 0
max_ref_id = max(self._ref_id_map.values()) if self._ref_id_map else 0
self._entity_id_counter = max(max_item_id, max_ref_id) + 1
else:
self._item_id_map = dict()
self._ref_id_map = dict()
self._rl_id_map = None
self.grow_id_maps = not (bool(self._item_id_map) and bool(self._ref_id_map))
self.subject_ratings = ""
self._unseen_items = 0
self._unseen_refs = 0
@property
def num_users(self):
"""int, the number of distinct users."""
return self._user_id_counter - 1
@property
def num_references(self):
"""int, the number of distinct references."""
return len(self._ref_id_map)
@property
def num_items(self):
"""int, the number of distinct items."""
return len(self._item_id_map)
@property
def num_relations(self):
"""int, the number of distinct relations."""
return self._num_relations
@property
def rows_generated(self):
"""int, total number of rows generated to the output CSVs."""
return self._rows_generated
@property
def per_item_paths(self):
"""int, the number of path per subject item."""
return self._per_item_paths
@property
def same_relation(self):
"""bool, only generate paths with the same relation on both sides."""
return self._same_relation
@property
def unseen_items(self):
"""int, total number of unseen items has encountered."""
return self._unseen_items
@property
def unseen_refs(self):
"""int, total number of unseen references has encountered."""
return self._unseen_refs
def id_maps(self):
"""dict, object ID maps."""
maps = {
"item": dict(self._item_id_map),
"reference": dict(self._ref_id_map),
"_user_id_counter": self._user_id_counter
}
if self._rl_id_map is not None:
maps["relation"] = dict(self._rl_id_map)
return maps
def generate(self, in_csv, out_csv, in_sep=',', in_mv_sep=';', in_encoding='utf-8'):
"""
Generate paths csv from the source CSV files.
args:
in_csv (Union[str, TextIOBase]): The input source csv path or stream.
out_csv (Union[str, TextIOBase]): The output source csv path or stream.
in_sep (str): Separator of the input csv.
in_mv_sep (str): Multi-value separator of the input csv in a single column.
in_encoding (str): Encoding of the input source csv, ignored if in_csv is a text stream already.
Returns:
int, the number of rows that generated to the output csv in this call.
"""
if not isinstance(in_csv, (str, io.TextIOBase)):
raise TypeError(f"Unexpected in_csv type:{type(in_csv)}")
if not isinstance(out_csv, (str, io.TextIOBase)):
raise TypeError(f"Unexpected out_csv type:{type(out_csv)}")
opened_files = []
try:
if isinstance(in_csv, str):
in_csv = io.open(in_csv, mode="r", encoding=in_encoding)
opened_files.append(in_csv)
in_csv = csv.reader(in_csv, delimiter=in_sep)
col_indices = self._pre_generate(in_csv, None)
if isinstance(out_csv, str):
out_csv = io.open(out_csv, mode="w", encoding="ascii")
opened_files.append(out_csv)
rows_generated = self._do_generate(in_csv, out_csv, in_mv_sep, col_indices)
except (IOError, ValueError, RuntimeError, PermissionError, KeyError) as e:
raise e
finally:
for f in opened_files:
f.close()
return rows_generated
def _pre_generate(self, in_csv, in_col_map):
"""Prepare for the path generation."""
if in_col_map is not None:
expected_cols = self._default_abstract_header(len(in_col_map) - 3)
map_values = list(in_col_map.values())
for col in expected_cols:
if col not in map_values:
raise ValueError("col_map has no '{col}' value.")
header = self._read_header(in_csv)
if len(header) < 4:
raise IOError(f"No. of in_csv columns:{len(header)} is less than 4.")
num_relations = len(header) - 3
if self._num_relations > 0:
if num_relations != self._num_relations:
raise IOError(f"Inconsistent no. of in_csv relations.")
else:
self._num_relations = num_relations
col_indices = self._get_col_indices(header, in_col_map)
rl_id_map = self._to_relation_id_map(header, col_indices)
if not self._rl_id_map:
self._rl_id_map = rl_id_map
elif rl_id_map != self._rl_id_map:
raise IOError(f"Inconsistent in_csv relations.")
return col_indices
def _do_generate(self, in_csv, out_csv, in_mv_sep, col_indices):
"""Do generate the paths."""
old_rows_generated = self._rows_generated
old_unseen_items = self._unseen_items
old_unseen_refs = self._unseen_refs
col_count = len(col_indices)
self._user_rec = None
for line in in_csv:
values = list(map(lambda x: x.strip(), line))
if len(values) != col_count:
raise IOError(f"No. of in_csv columns:{len(values)} is not {col_count}.")
self._process_line(values, in_mv_sep, col_indices, out_csv)
if self._user_rec is not None:
self._process_user_rec(self._user_rec, out_csv)
self._user_rec = None
delta_unseen_items = self._unseen_items - old_unseen_items
delta_unseen_refs = self._unseen_refs - old_unseen_refs
if delta_unseen_items > 0:
warnings.warn(f"{delta_unseen_items} unseen items' internal IDs were set to 0, "
f"set grow_id_maps to True for adding new internal IDs.", RuntimeWarning)
if delta_unseen_refs > 0:
warnings.warn(f"{delta_unseen_refs} unseen references' internal IDs were set to 0, "
f"set grow_id_maps to True for adding new internal IDs.", RuntimeWarning)
return self._rows_generated - old_rows_generated
def _process_line(self, values, in_mv_sep, col_indices, out_csv):
"""Process a line from the input CSV."""
user_src = values[col_indices[0]]
item_src = values[col_indices[1]]
rating = values[col_indices[2]].lower()
if rating not in ('p', 'c', 'x'):
raise IOError(f"Unrecognized rating:'{rating}', must be one of 'p', 'c' or 'x'.")
ref_srcs = [values[col_indices[i]] for i in range(3, len(col_indices))]
if in_mv_sep:
ref_srcs = list(map(lambda x: list(map(lambda y: y.strip(), x.split(in_mv_sep))), ref_srcs))
else:
ref_srcs = list(map(lambda x: [x], ref_srcs))
if self._user_rec is not None and user_src != self._user_rec.src_id:
# user changed
self._process_user_rec(self._user_rec, out_csv)
self._user_rec = None
if self._user_rec is None:
self._user_rec = _UserRec(user_src, self._user_id_counter)
self._user_id_counter += 1
item_rec, has_unseen_ref = self._to_item_rec(item_src, ref_srcs)
self._user_rec.add_item(item_rec, rating)
self._user_rec.has_unseen_ref |= has_unseen_ref
def _process_user_rec(self, user_rec, out_csv):
"""Generate paths for an user."""
positive_count = 0
subject_items = []
if self.subject_ratings == "":
subject_items.extend(user_rec.positive_items.values())
subject_items.extend(user_rec.other_items.values())
positive_count = len(user_rec.positive_items)
else:
if 'p' in self.subject_ratings:
subject_items.extend(user_rec.positive_items.values())
positive_count = len(user_rec.positive_items)
if 'c' in self.subject_ratings:
subject_items.extend(user_rec.interact_items.values())
if 'x' in self.subject_ratings:
subject_items.extend(user_rec.other_items.values())
hist_items = []
hist_items.extend(user_rec.positive_items.values())
hist_items.extend(user_rec.interact_items.values())
for i, subject in enumerate(subject_items):
paths = []
for hist in hist_items:
if hist.src_id == subject.src_id:
continue
self._find_paths(not user_rec.has_unseen_ref, subject, hist, paths)
if not paths:
continue
paths = random.sample(paths, min(len(paths), self._per_item_paths))
row = [0] * (3 + self._per_item_paths * 4)
row[0] = user_rec.src_id
row[1] = subject.intern_id # subject item
row[2] = 1 if i < positive_count else 0 # label
for p, path in enumerate(paths):
offset = 3 + p * 4
for j in range(4):
row[offset + j] = path[j]
out_csv.write(','.join(map(str, row)))
out_csv.write('\n')
self._rows_generated += 1
def _find_paths(self, by_intern_id, subject_item, hist_item, paths):
"""Find paths between the subject and historical item."""
if by_intern_id:
for i, ref_list in enumerate(subject_item.ref_ids):
for ref in ref_list:
self._find_paths_by_intern_id(i, ref, hist_item, paths)
else:
for i, (ref_src_list, ref_list) in enumerate(zip(subject_item.ref_src_ids,
subject_item.ref_ids)):
for src_ref, ref in zip(ref_src_list, ref_list):
self._find_paths_by_src(i, src_ref, ref, hist_item, paths)
def _find_paths_by_intern_id(self, subject_ridx, ref_id, hist_item, paths):
"""Find paths by internal reference ID, a bit faster."""
if self._same_relation:
if ref_id in hist_item.ref_ids[subject_ridx]:
relation_id = self._ridx_to_relation_id(subject_ridx)
paths.append((relation_id,
ref_id,
relation_id,
hist_item.intern_id))
else:
for hist_ridx, hist_ref_list in enumerate(hist_item.ref_ids):
if ref_id in hist_ref_list:
paths.append((self._ridx_to_relation_id(subject_ridx),
ref_id,
self._ridx_to_relation_id(hist_ridx),
hist_item.intern_id))
def _find_paths_by_src(self, subject_ridx, ref_src_id, ref_id, hist_item, paths):
"""Find paths by source reference ID."""
if self._same_relation:
if ref_src_id in hist_item.ref_src_ids[subject_ridx]:
relation_id = self._ridx_to_relation_id(subject_ridx)
paths.append((relation_id,
ref_id,
relation_id,
hist_item.intern_id))
else:
for hist_ridx, hist_ref_src_list in enumerate(hist_item.ref_src_ids):
if ref_src_id in hist_ref_src_list:
paths.append((self._ridx_to_relation_id(subject_ridx),
ref_id,
self._ridx_to_relation_id(hist_ridx),
hist_item.intern_id))
def _ridx_to_relation_id(self, idx):
"""Relation index to id."""
return idx
def _to_relation_id_map(self, header, col_indices):
"""Convert input csv header to a relation id map."""
id_map = {}
id_counter = 0
for i in range(3, len(col_indices)):
id_map[header[col_indices[i]]] = id_counter
id_counter += 1
if len(id_map) < len(header) - 3:
raise IOError("Duplicated column!")
return id_map
def _to_item_rec(self, item_src, ref_srcs):
"""Convert the item src id and the source reference to an item record."""
item_id = self._item_id_map.get(item_src, -1)
if item_id == -1:
if not self.grow_id_maps:
item_id = 0
self._unseen_items += 1
else:
item_id = self._entity_id_counter
self._item_id_map[item_src] = item_id
self._entity_id_counter += 1
has_unseen_ref = False
ref_ids = [[] for _ in range(len(ref_srcs))]
for i, ref_src_list in enumerate(ref_srcs):
for ref_src in ref_src_list:
ref_id = self._ref_id_map.get(ref_src, -1)
if ref_id == -1:
if not self.grow_id_maps:
ref_id = 0
self._unseen_refs += 1
has_unseen_ref = True
else:
ref_id = self._entity_id_counter
self._ref_id_map[ref_src] = ref_id
self._entity_id_counter += 1
ref_ids[i].append(ref_id)
return _ItemRec(item_src, item_id, ref_srcs, ref_ids), has_unseen_ref
def _get_col_indices(self, header, col_map):
"""Find the column indices base on the mapping."""
if col_map:
mapped = [col_map[col] for col in header]
default_header = self._default_abstract_header(len(header) - 3)
return [mapped.index(col) for col in default_header]
return range(len(header))
@staticmethod
def _read_header(in_csv):
"""Read the CSV header."""
line = next(in_csv)
splited = list(map(lambda x: x.strip(), line))
return splited
@staticmethod
def _default_abstract_header(num_relation):
"""Get the default abstract header."""
abstract_header = ["user", "item", "rating"]
abstract_header.extend([f"r{i + 1}" for i in num_relation])
return abstract_header
......@@ -14,16 +14,25 @@
# ============================================================================
"""TB-Net Model."""
from mindspore import nn
from mindspore import nn, Tensor
from mindspore import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
import mindspore.common.dtype as mstype
from src.embedding import EmbeddingMatrix
grad_scale = C.MultitypeFuncGraph("grad_scale")
@grad_scale.register("Tensor", "Tensor")
def gradient_scale(scale, grad):
return grad * F.cast(scale, F.dtype(grad))
class TBNet(nn.Cell):
"""
......@@ -68,8 +77,8 @@ class TBNet(nn.Cell):
def _parse_config(self, config):
"""Argument parsing."""
self.num_entity = config.num_entity
self.num_relation = config.num_relation
self.num_entity = config.num_items + config.num_references + 1
self.num_relation = config.num_relations
self.dim = config.embedding_dim
self.kge_weight = config.kge_weight
self.node_weight = config.node_weight
......@@ -279,7 +288,7 @@ class NetWithLossClass(nn.Cell):
class TrainStepWrap(nn.Cell):
"""TrainStepWrap definition."""
def __init__(self, network, lr, sens=1):
def __init__(self, network, lr, sens=1, loss_scale=False):
super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network
self.network.set_train()
......@@ -294,11 +303,13 @@ class TrainStepWrap(nn.Cell):
loss_scale=sens)
self.hyper_map = C.HyperMap()
self.reciprocal_sense = Tensor(1 / sens, mstype.float32)
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.loss_scale = loss_scale
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
......@@ -307,6 +318,10 @@ class TrainStepWrap(nn.Cell):
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
def scale_grad(self, gradients):
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_sense), gradients)
return gradients
def construct(self, items, relation1, mid_entity, relation2, hist_item, labels):
"""
Args:
......@@ -325,11 +340,14 @@ class TrainStepWrap(nn.Cell):
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(items, relation1, mid_entity, relation2, hist_item, labels, sens)
if self.loss_scale:
grads = self.scale_grad(grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
self.optimizer(grads)
return loss
return F.depend(loss, self.optimizer(grads))
class PredictWithSigmoid(nn.Cell):
......
......@@ -16,11 +16,13 @@
import os
import argparse
import math
import numpy as np
from mindspore import context, Model, Tensor
from mindspore.train.serialization import save_checkpoint
from mindspore.train.callback import Callback, TimeMonitor
import mindspore.common.dtype as mstype
from src import tbnet, config, metrics, dataset
......@@ -104,8 +106,8 @@ def get_args():
type=str,
required=False,
default='GPU',
choices=['GPU'],
help="run code on GPU"
choices=['GPU', 'Ascend'],
help="run code on GPU or Ascend NPU"
)
parser.add_argument(
......@@ -141,13 +143,21 @@ def train_tbnet():
print(f"creating dataset from {train_csv_path}...")
net_config = config.TBNetConfig(config_path)
train_ds = dataset.create(train_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
test_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
if args.device_target == 'Ascend':
net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
train_ds = dataset.create(train_csv_path, net_config.per_item_paths, train=True).batch(net_config.batch_size)
test_ds = dataset.create(test_csv_path, net_config.per_item_paths, train=True).batch(net_config.batch_size)
print("creating TBNet for training...")
network = tbnet.TBNet(net_config)
loss_net = tbnet.NetWithLossClass(network, net_config)
train_net = tbnet.TrainStepWrap(loss_net, net_config.lr)
if args.device_target == 'Ascend':
loss_net.to_float(mstype.float16)
train_net = tbnet.TrainStepWrap(loss_net, net_config.lr, loss_scale=True)
else:
train_net = tbnet.TrainStepWrap(loss_net, net_config.lr)
train_net.set_train()
eval_net = tbnet.PredictWithSigmoid(network)
time_callback = TimeMonitor(data_size=train_ds.get_dataset_size())
......@@ -161,7 +171,8 @@ def train_tbnet():
test_out = model.eval(test_ds, dataset_sink_mode=False)
print(f'Train AUC:{train_out["auc"]} ACC:{train_out["acc"]} Test AUC:{test_out["auc"]} ACC:{test_out["acc"]}')
save_checkpoint(network, os.path.join(ckpt_path, f'tbnet_epoch{i}.ckpt'))
if i >= args.epochs-5:
save_checkpoint(network, os.path.join(ckpt_path, f'tbnet_epoch{i}.ckpt'))
if __name__ == '__main__':
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment