Skip to content
Snippets Groups Projects
Commit 580f097c authored by YinanF's avatar YinanF
Browse files

add new model tgcn to model_zoo

parent 4aa3dc75
No related branches found
No related tags found
No related merge requests found
Showing
with 1462 additions and 0 deletions
# 目录
- [T-GCN概述](#T-GCN概述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速开始](#快速开始)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练流程](#训练流程)
- [运行](#运行)
- [结果](#结果)
- [评估流程](#评估流程)
- [运行](#运行)
- [结果](#结果)
- [MINDIR模型导出流程](#MINDIR模型导出流程)
- [运行](#运行)
- [结果](#结果)
- [模型说明](#模型说明)
- [训练性能](#训练性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#ModelZoo主页)
# [T-GCN概述](#目录)
时间图卷积网络(Temporal Graph Convolutional Network,T-GCN)模型,简称T-GCN模型,是Zhao L等人提出的一种适用于城市道路交通预测的模型。所谓交通预测,即基于道路历史交通信息,对一定时期内的交通信息进行预测,包括但不限于交通速度、交通流量、交通密度等信息。
[论文](https://arxiv.org/pdf/1811.05320.pdf):Zhao L, Song Y, Zhang C, et al. T-gcn: A temporal graph convolutional network for traffic prediction[J]. IEEE Transactions on Intelligent Transportation Systems, 2019, 21(9): 3848-3858.
# [模型架构](#目录)
T-GCN模型主要由两大模块构成,分别为图卷积网络(Graph Convolutional Network,GCN)与门控循环单元(Gated Recurrent Unit,GRU)。
模型整体处理流程如下:输入n组历史时间序列数据,利用图卷积网络捕获城市路网拓扑结构,以获取数据的空间特征。再将得到的具有空间特征的数据输入门控循环单元,利用单元间的信息传递捕获数据的动态变化,以获取数据的时间特征。最后,经过全连接层,输出最终预测结果。
其中,GCN模块通过在傅里叶域中构造一个作用于图数据的节点及其一阶邻域的滤波器来捕获节点间的空间特征,之后在其上叠加多个卷积层来实现。GCN模块可对城市中心道路与其周围道路间的拓扑结构及道路属性实现编码,捕获数据的空间相关性。而GRU模块则是作为一种经典的递归神经网络变体来捕获交通流量数据中的时间相关性。该模块使用门控机制来记忆尽可能多的长期信息,且结构相对简单,参数较少,训练速度较快,可以在捕获当前时刻交通信息的同时,仍然保持历史交通信息的变化趋势,具有捕获数据的时间相关性的能力。
# [数据集](#目录)
- 数据集:实验基于两大由现实采集的[SZ-taxi数据集](https://github.com/lehaifeng/T-GCN/tree/master/T-GCN/T-GCN-PyTorch/data)[Los-loop数据集](https://github.com/lehaifeng/T-GCN/tree/master/T-GCN/T-GCN-PyTorch/data)
(1)SZ-taxi数据集选取深圳市罗湖区的156条主要城市道路为研究区域,记录了2015年1月1日至1月31日的出租车运行轨迹。该数据集主要包含两个部分,一是记录了城市道路间拓扑关系的一个156*156大小的邻接矩阵,其中每行代表一条道路,矩阵中的值表示道路间的连接。二是记录了每一条道路上速度值随时间变化的特征矩阵,其中每行代表一条道路,每列为不同时间段道路上的交通速度,每15分钟记录一次。
(2)Los-loop数据集由洛杉矶高速公路上共计207个环形探测器于2012年3月1日至2012年3月7日实时采集得到,数据每5分钟记录一次。与SZ-taxi数据集相似,该数据集主要包含邻接矩阵与特征矩阵两个部分,邻接矩阵中的值由探测器之间的距离计算得到。由于该数据集中存在数据缺失,因此论文作者采用线性插值的方法进行了缺失值填充。
- 数据处理:输入数据被归一化至[0,1]区间,并划分其中的80%作训练集,20%作测试集,来分别预测未来15分钟、30分钟、45分钟、60分钟的交通速度。
# [环境要求](#目录)
- 硬件(Ascend / GPU)
- 需要准备具有Ascend或GPU处理能力的硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/)
- 如需获取更多信息,请查看如下链接:
- [MindSpore 教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html)
# [快速开始](#目录)
通过官方指南[安装MindSpore](https://www.mindspore.cn/install)后,下载[数据集](https://github.com/lehaifeng/T-GCN/tree/master/T-GCN/T-GCN-PyTorch/data),将下载好的数据集按如下目录结构进行组织,也可按此结构自行添加数据集:
```python
.
└─tgcn
├─data
├─SZ-taxi
├─adj.csv # 邻接矩阵
└─feature.csv # 特征矩阵
├─Los-loop
├─adj.csv # 邻接矩阵
└─feature.csv # 特征矩阵
...
```
准备好数据集后,即可按顺序依次进行模型训练与评估/导出操作:
- 训练:
```python
# 单卡训练
bash ./scripts/run_standalone_train.sh [DEVICE_ID]
# Ascend多卡训练
bash ./scripts/run_distributed_train_ascend.sh [RANK_TABLE] [RANK_SIZE] [DEVICE_START] [DATA_PATH]
```
示例:
```python
# 单卡训练
bash ./scripts/run_standalone_train.sh 0
# Ascend多卡训练(8卡)
bash ./scripts/run_distributed_train_ascend.sh ./rank_table_8pcs.json 8 0 ./data
```
- 评估:
```python
# 评估
bash ./scripts/run_eval.sh [DEVICE_ID]
```
示例:
```python
# 评估
bash ./scripts/run_eval.sh 0
```
- MINDIR模型导出
```python
# MINDIR模型导出
bash ./scripts/run_export.sh [DEVICE_ID]
```
示例:
```python
# MINDIR模型导出
bash ./scripts/run_export.sh 0
```
# [脚本说明](#目录)
## [脚本及样例代码](#目录)
```python
.
└─tgcn
├─README_CN.md # 中文指南
├─requirements.txt # pip依赖文件
├─scripts
├─run_distributed_train_ascend.sh # Ascend多卡训练运行脚本
├─run_eval.sh # 评估运行脚本
├─run_export.sh # MINDIR模型导出运行脚本
└─run_standalone_train.sh # 单卡训练运行脚本
├─src
├─model
├─__init__.py
├─graph_conv.py # 图卷积计算
├─loss.py # 自定义损失函数
└─tgcn.py # T-GCN模型架构
├─__init__.py
├─callback.py # 自定义回调函数
├─config.py # 模型参数设定
├─dataprocess.py # 数据处理模块
├─metrics.py # 模型评估指标
└─task.py # 监督预测任务
├─eval.py # 评估
├─export.py # MINDIR模型导出
└─train.py # 训练
```
## [脚本参数](#目录)
- 训练、评估、MINDIR模型导出等操作相关参数皆在`config.py`脚本中设定:
```python
class ConfigTGCN:
device = 'Ascend'
seed = 1
dataset = 'SZ-taxi'
hidden_dim = 100
seq_len = 4
pre_len = 1
train_split_rate = 0.8
epochs = 3000
batch_size = 64
learning_rate = 0.001
weight_decay = 1.5e-3
data_sink = True
```
如需查阅相关参数信息说明,请参阅`config.py`脚本内容。
## [训练流程](#目录)
### [运行](#目录)
开始训练前,请确认已在`config.py`脚本中完成相关训练参数设定,在同一任务下,后续评估流程与MINDIR模型导出流程请保持参数一致。
```python
# 单卡训练
# 用法:
bash ./scripts/run_standalone_train.sh [DEVICE_ID]
# 示例:
bash ./scripts/run_standalone_train.sh 0
# Ascend多卡训练
# 用法:
bash ./scripts/run_distributed_train_ascend.sh [RANK_TABLE] [RANK_SIZE] [DEVICE_START] [DATA_PATH]
# 示例(8卡):
bash ./scripts/run_distributed_train_ascend.sh ./rank_table_8pcs.json 8 0 ./data
```
单卡训练中`[DEVICE_ID]`为训练所调用卡的卡号。
Ascend多卡训练中`[RANK_TABLE]`为相应RANK_TABLE_FILE文件路径(如8卡训练使用的`./rank_table_8pcs.json`),RANK_TABLE_FILE可按[此方法](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)生成。`[RANK_SIZE]`为训练所调用卡的数量,`[DEVICE_START]`为起始卡号,`[DATA_PATH]`为数据集存放根目录。
### [结果](#目录)
训练时,当前训练轮次数,模型损失值,每轮次运行时间等有关信息会以如下形式展示:
```python
==========Training Start==========
epoch: 1 step: 37, loss is 47.07869
epoch time: 20385.370 ms, per step time: 550.956 ms
RMSE eval: 8.408103
Best checkpoint saved!
epoch: 2 step: 37, loss is 26.325077
epoch time: 607.063 ms, per step time: 16.407 ms
RMSE eval: 6.355909
Best checkpoint saved!
epoch: 3 step: 37, loss is 24.1607
epoch time: 606.936 ms, per step time: 16.404 ms
RMSE eval: 6.126811
Best checkpoint saved!
epoch: 4 step: 37, loss is 23.835127
epoch time: 606.999 ms, per step time: 16.405 ms
RMSE eval: 6.077283
Best checkpoint saved!
epoch: 5 step: 37, loss is 23.536343
epoch time: 606.879 ms, per step time: 16.402 ms
RMSE eval: 6.035936
Best checkpoint saved!
epoch: 6 step: 37, loss is 23.218105
epoch time: 606.861 ms, per step time: 16.402 ms
RMSE eval: 5.993234
Best checkpoint saved!
...
```
单卡训练将会把上述信息以运行日志的形式保存至`./logs/train.log`,且模型会以覆盖的形式自动保存最优检查点(.ckpt 文件)于`./checkpoints`目录下,供后续评估与模型导出流程加载使用(如`./checkpoints/SZ-taxi_1.ckpt`)。
Ascend多卡训练与单卡训练所展示信息的形式基本一致,运行日志及最优检查点将保存在以对应卡号ID命名的`./device{ID}`目录下(如`./device0/logs/train.log``./device0/checkpoints/SZ-taxi_1.ckpt`)。
## [评估流程](#目录)
### [运行](#目录)
在完成训练流程的基础上,评估流程将自动从`./checkpoints`目录加载对应任务的最优检查点(.ckpt 文件)用于模型评估。
```python
# 评估
# 用法:
bash ./scripts/run_eval.sh [DEVICE_ID]
# 示例:
bash ./scripts/run_eval.sh 0
```
### [结果](#目录)
训练后模型在验证集上的相关指标评估结果将以如下形式展示,且以运行日志的形式保存至`./logs/eval.log`
```python
=====Evaluation Results=====
RMSE: 4.083120
MAE: 2.730229
Accuracy: 0.715577
R2: 0.847140
Var: 0.847583
============================
```
## [MINDIR模型导出流程](#目录)
### [运行](#目录)
在完成训练流程的基础上,MINDIR模型导出流程将自动从`./checkpoints`目录加载对应任务的最优检查点(.ckpt 文件)用于对应MINDIR模型导出。
```python
# MINDIR模型导出
# 用法:
bash ./scripts/run_export.sh [DEVICE_ID]
# 示例:
bash ./scripts/run_export.sh 0
```
### [结果](#目录)
若模型导出成功,程序将以如下形式展示,且以运行日志的形式保存至`./logs/export.log`
```python
==========================================
SZ-taxi_1.mindir exported successfully!
==========================================
```
同时MINDIR模型文件将导出至`./outputs`目录下,供后续进一步使用(如`./outputs/SZ-taxi_1.mindir`)。
# [模型说明](#目录)
## [训练性能](#目录)
- 下表中训练性能由T-GCN模型基于SZ-taxi数据集分别预测未来15分钟、30分钟、45分钟、60分钟(即pre_len分别取1、2、3、4)的交通速度得到,相关指标为4组训练任务平均值:
| 参数 | Ascend |
| -------------------------- | -----------------------------------------------------------|
| 模型名称 | T-GCN |
| 运行环境 | 操作系统 Euler 2.8;Ascend 910;处理器 2.60GHz,192核心;内存,755G |
| 上传日期 | 2021-09-30 |
| MindSpore版本 | 1.3.0 |
| 数据集 | SZ-taxi(hidden_dim=100;seq_len=4) |
| 训练参数 | seed=1;epoch=3000;batch_size = 64;lr=0.001;train_split_rate = 0.8;weight_decay = 1.5e-3 |
| 优化器 | Adam with Weight Decay |
| 损失函数 | 自定义损失函数 |
| 输出 | 交通速度预测值 |
| 平均检查点(.ckpt 文件)大小 | 839 KB |
| 平均性能 | 单卡:23毫秒/步,871毫秒/轮;8卡:25毫秒/步,101毫秒/轮 |
| 平均总耗时 | 单卡:49分19秒;8卡:11分35秒 |
| 脚本 | [训练脚本](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn/train.py) |
- 下表中训练性能由T-GCN模型基于Los-loop数据集分别预测未来15分钟、30分钟、45分钟、60分钟(即pre_len分别取3、6、9、12)的交通速度得到,相关指标为4组训练任务平均值:
| 参数 | Ascend |
| -------------------------- | -----------------------------------------------------------|
| 模型名称 | T-GCN |
| 运行环境 | 操作系统 Euler 2.8;Ascend 910;处理器 2.60GHz,192核心;内存,755G |
| 上传日期 | 2021-09-30 |
| MindSpore版本 | 1.3.0 |
| 数据集 | Los-loop(hidden_dim=64;seq_len=12) |
| 训练参数 | seed=1;epoch=3000;batch_size = 64;lr=0.001;train_split_rate = 0.8;weight_decay = 1.5e-3 |
| 优化器 | Adam with Weight Decay |
| 损失函数 | 自定义损失函数 |
| 输出 | 交通速度预测值 |
| 平均检查点(.ckpt 文件)大小 | 993KB |
| 平均性能 | 单卡:44毫秒/步,1066毫秒/轮;8卡:46毫秒/步,139毫秒/轮 |
| 平均总耗时 | 单卡:1时00分40秒;8卡:15分05秒 |
| 脚本 | [训练脚本](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn/train.py) |
## [评估性能](#目录)
- 下表中评估性能由T-GCN模型基于SZ-taxi数据集分别预测未来15分钟、30分钟、45分钟、60分钟(即pre_len分别取1、2、3、4)的交通速度得到,相关指标为4组评估任务平均值:
| 参数 | Ascend|
| ------------------- | ---------------------------|
| 模型名称 | T-GCN |
| 运行环境 | 操作系统 Euler 2.8;Ascend 910;处理器 2.60GHz,192核心;内存,755G |
| 上传日期 | 2021-09-30 |
| MindSpore版本 | 1.3.0 |
| 数据集 | SZ-taxi(hidden_dim=100;seq_len=4) |
| 输出 | 交通速度预测值 |
| 均方根误差(RMSE)平均值 | 4.1003 |
| 平均绝对误差(MAE)平均值 | 2.7498 |
| 预测准确率(Accuracy)平均值 | 0.7144 |
| R平方($R^2$)平均值 | 0.8458 |
| 可释方差(Explained Variance)平均值 | 0.8461 |
| 脚本 | [评估脚本](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn/eval.py) |
- 下表中评估性能由T-GCN模型基于Los-loop数据集分别预测未来15分钟、30分钟、45分钟、60分钟(即pre_len分别取3、6、9、12)的交通速度得到,相关指标为4组评估任务平均值:
| 参数 | Ascend|
| ------------------- | ---------------------------|
| 模型名称 | T-GCN |
| 运行环境 | 操作系统 Euler 2.8;Ascend 910;处理器 2.60GHz,192核心;内存,755G |
| 上传日期 | 2021-09-30 |
| MindSpore版本 | 1.3.0 |
| 数据集 | Los-loop(hidden_dim=64;seq_len=12) |
| 输出 | 交通速度预测值 |
| 均方根误差(RMSE)平均值 | 6.1869 |
| 平均绝对误差(MAE)平均值 | 3.8552 |
| 预测准确率(Accuracy)平均值 | 0.8946 |
| R平方($R^2$)平均值 | 0.8000 |
| 可释方差(Explained Variance)平均值 | 0.8002 |
| 脚本 | [评估脚本](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn/eval.py) |
# [随机情况说明](#目录)
`train.py`脚本中使用`mindspore.set_seed()`对全局随机种子进行了固定(默认值为1),可在`config.py`脚本中进行修改。
# [ModelZoo主页](#目录)
[T-GCN](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn)
\ No newline at end of file
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Evaluation script
"""
import os
import argparse
from mindspore import context
from mindspore import load_checkpoint, load_param_into_net
from src.config import ConfigTGCN
from src.task import SupervisedForecastTask
from src.dataprocess import load_adj_matrix, load_feat_matrix, generate_dataset_np
from src.metrics import evaluate_network
# Set DEVICE_ID
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', help="DEVICE_ID", type=int, default=0)
args = parser.parse_args()
if __name__ == '__main__':
# Config initialization
config = ConfigTGCN()
# Runtime
context.set_context(mode=context.GRAPH_MODE, device_target=config.device, device_id=args.device_id)
# Create network
net = SupervisedForecastTask(load_adj_matrix(config.dataset), config.hidden_dim, config.pre_len)
# Load parameters from checkpoint into network
ckpt_file_name = config.dataset + "_" + str(config.pre_len) + ".ckpt"
param_dict = load_checkpoint(os.path.join('checkpoints', ckpt_file_name))
load_param_into_net(net, param_dict)
# Evaluation
feat, max_val = load_feat_matrix(config.dataset)
_, _, eval_inputs, eval_targets = generate_dataset_np(feat, config.seq_len, config.pre_len, config.train_split_rate)
evaluate_network(net, max_val, eval_inputs, eval_targets)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Export checkpoints into MINDIR model files
"""
import os
import argparse
import numpy as np
from mindspore import export, load_checkpoint, load_param_into_net, Tensor, context
from src.config import ConfigTGCN
from src.task import SupervisedForecastTask
from src.dataprocess import load_adj_matrix
# Set DEVICE_ID
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', help="DEVICE_ID", type=int, default=0)
args = parser.parse_args()
if __name__ == '__main__':
# Config initialization
config = ConfigTGCN()
# Runtime
context.set_context(mode=context.GRAPH_MODE, device_target=config.device, device_id=args.device_id)
# Create network
adj = (load_adj_matrix(config.dataset))
net = SupervisedForecastTask(adj, config.hidden_dim, config.pre_len)
# Load parameters from checkpoint into network
file_name = config.dataset + "_" + str(config.pre_len) + ".ckpt"
param_dict = load_checkpoint(os.path.join('checkpoints', file_name))
load_param_into_net(net, param_dict)
# Initialize dummy inputs
inputs = np.random.uniform(0.0, 1.0, size=[config.batch_size, config.seq_len, adj.shape[0]]).astype(np.float32)
# Export network into MINDIR model file
if not os.path.exists('outputs'):
os.mkdir('outputs')
file_name = config.dataset + "_" + str(config.pre_len)
path = os.path.join('outputs', file_name)
export(net, Tensor(inputs), file_name=path, file_format='MINDIR')
print("==========================================")
print(file_name + ".mindir exported successfully!")
print("==========================================")
pandas
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -ne 4 ]]; then
echo "Usage: bash ./scripts/run_distributed_train_ascend.sh [RANK_TABLE] [RANK_SIZE] [DEVICE_START] [DATA_PATH]"
exit 1
fi
ulimit -u unlimited
export RANK_SIZE=$2
RANK_TABLE_FILE=$(realpath $1)
DATA_PATH=$(realpath $4)
export RANK_TABLE_FILE
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
echo "DATA_PATH=${DATA_PATH}"
device_start=$3
for((i=0; i<${RANK_SIZE}; i++))
do
export DEVICE_ID=$((device_start + i))
export RANK_ID=$i
rm -rf ./device$i
mkdir ./device$i
cp -r ./src ./device$i
cp ./train.py ./device$i
cd ./device$i
mkdir ./logs
env > ./logs/env.log
nohup python -u train.py --device_id=$DEVICE_ID --data_path=$DATA_PATH --distributed True > ./logs/train.log 2>&1 &
echo "Start training for rank $RANK_ID, device $DEVICE_ID. PID: $!"
echo $! > ./logs/train.pid
cd ..
done
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -ne 1 ]]; then
echo "Usage: bash ./scripts/run_eval.sh [DEVICE_ID]"
exit 1
fi
if [ ! -d "logs" ]; then
mkdir logs
fi
nohup python -u eval.py --device_id=$1 > ./logs/eval.log 2>&1 &
echo "Evaluation started on device $1 ! PID: $!"
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -ne 1 ]]; then
echo "Usage: bash ./scripts/run_export.sh [DEVICE_ID]"
exit 1
fi
if [ ! -d "logs" ]; then
mkdir logs
fi
nohup python -u export.py --device_id=$1 > ./logs/export.log 2>&1 &
echo "Export started on device $1 ! PID: $!"
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [[ $# -ne 1 ]]; then
echo "Usage: bash ./scripts/run_standalone_train.sh [DEVICE_ID]"
exit 1
fi
if [ ! -d "logs" ]; then
mkdir logs
fi
nohup python -u train.py --device_id=$1 > ./logs/train.log 2>&1 &
echo "Training started on device $1 ! PID: $!"
echo $! > ./logs/train.pid
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Module initialization
"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Custom callback and related RMSE metric
"""
import os
import numpy as np
from mindspore.dataset.core.validator_helpers import INT32_MAX
from mindspore.train.callback import Callback
from mindspore import save_checkpoint
from mindspore.nn import Metric
class RMSE(Metric):
"""
RMSE metric for choosing the best checkpoint
"""
def __init__(self, max_val):
super(RMSE, self).__init__()
self.clear()
self.max_val = max_val
def clear(self):
"""Clears the internal evaluation result"""
self._squared_error_sum = 0
self._samples_num = 0
def update(self, *inputs):
"""Calculate and update internal result"""
if len(inputs) != 2:
raise ValueError('RMSE metric need 2 inputs (preds, targets), but got {}'.format(len(inputs)))
preds = self._convert_data(inputs[0])
targets = self._convert_data(inputs[1])
targets = targets.reshape((-1, targets.shape[2]))
squared_error_sum = np.power(targets - preds, 2)
self._squared_error_sum += squared_error_sum.sum()
self._samples_num += np.size(targets)
def eval(self):
"""Calculate evaluation result at the end of each epoch"""
if self._samples_num == 0:
raise RuntimeError('The number of input samples must not be 0.')
return np.sqrt(self._squared_error_sum / self._samples_num) * self.max_val
class SaveCallback(Callback):
"""
Save the best checkpoint (minimum RMSE) during training
"""
def __init__(self, eval_model, ds_eval, config):
super(SaveCallback, self).__init__()
self.model = eval_model
self.ds_eval = ds_eval
self.rmse = INT32_MAX
self.config = config
def epoch_end(self, run_context):
"""Evaluate the network and save the best checkpoint (minimum RMSE)"""
cb_params = run_context.original_args()
result = self.model.eval(self.ds_eval)
print('Eval RMSE:', '{:.6f}'.format(result['RMSE']))
if not os.path.exists('checkpoints'):
os.mkdir('checkpoints')
if result['RMSE'] < self.rmse:
self.rmse = result['RMSE']
file_name = self.config.dataset + '_' + str(self.config.pre_len) + '.ckpt'
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=os.path.join('checkpoints', file_name))
print("Best checkpoint saved!")
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Configuration of parameters
For detailed information, please refer to the paper below:
https://arxiv.org/pdf/1811.05320.pdf
"""
class ConfigTGCN:
"""
Class of parameters configuration
"""
# Choose device: ['Ascend', 'GPU']
device = 'Ascend'
# Global random seed
seed = 1
# Choose datasets: ['SZ-taxi', 'Los-loop', etc]
dataset = 'SZ-taxi'
# hidden_dim: 100 for 'SZ-taxi'; 64 for 'Los-loop'
hidden_dim = 100
# seq_len: 4 for 'SZ-taxi'; 12 for 'Los-loop'
seq_len = 4
# pre_len: [1, 2, 3, 4] separately for 'SZ-taxi'; [3, 6, 9, 12] separately for 'Los-loop'
pre_len = 1
# Training parameters
train_split_rate = 0.8
epochs = 3000
batch_size = 64
learning_rate = 0.001
weight_decay = 1.5e-3
data_sink = True
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Process datasets
Both the 'SZ-taxi' and 'Los-loop' datasets can be downloaded from the link below:
https://github.com/lehaifeng/T-GCN/tree/master/T-GCN/T-GCN-PyTorch/data
"""
import os
import numpy as np
import pandas as pd
import mindspore.dataset as ds
from mindspore.communication import get_rank, get_group_size
class TGCNDataset:
"""
Custom T-GCN datasets
"""
def __init__(self, inputs, targets):
self.inputs = inputs
self.targets = targets
def __getitem__(self, index):
return self.inputs[index], self.targets[index]
def __len__(self):
return len(self.inputs)
def load_adj_matrix(dataset, abs_path=None, dtype=np.float32):
"""
Load adjacency matrix from corresponding csv file
Args:
dataset(str): name of dataset (the same as folder name)
abs_path(str): absolute data directory path
dtype(type): data type (Default: np.float32)
Returns:
adj: adjacency matrix in ndarray
"""
if abs_path is not None:
path = os.path.join(abs_path, dataset, 'adj.csv')
else:
path = os.path.join('data', dataset, 'adj.csv')
adj_df = pd.read_csv(path, header=None)
adj = np.array(adj_df, dtype=dtype)
return adj
def load_feat_matrix(dataset, abs_path=None, dtype=np.float32):
"""
Load feature matrix from corresponding csv file
Args:
dataset(str): name of dataset (the same as folder name)
abs_path(str): absolute data directory path
dtype(type): data type (Default: np.float32)
Returns:
feat: feature matrix in ndarray
max_val: max value in feature matrix
"""
if abs_path is not None:
path = os.path.join(abs_path, dataset, 'feature.csv')
else:
path = os.path.join('data', dataset, 'feature.csv')
feat_df = pd.read_csv(path)
feat = np.array(feat_df, dtype=dtype)
max_val = np.max(feat)
return feat, max_val
def generate_dataset_np(feat, seq_len, pre_len, split_ratio, normalize=True, time_len=None):
"""
Generate ndarrays from matrixes
Args:
feat(ndarray): feature matrix
seq_len(int): length of the train data sequence
pre_len(int): length of the prediction data sequence
split_ratio(float): proportion of the training set
normalize(bool): scale the data to (0, 1], divide by the maximum value in the data
time_len(int): length of the time series in total
Returns:
Train set (inputs, targets) and evaluation set (inputs, targets) in ndarrays
"""
if time_len is None:
time_len = feat.shape[0]
if normalize:
max_val = np.max(feat)
feat = feat / max_val
train_size = int(time_len * split_ratio)
train_data = feat[0:train_size]
eval_data = feat[train_size:time_len]
train_inputs, train_targets, eval_inputs, eval_targets = list(), list(), list(), list()
for i in range(len(train_data) - seq_len - pre_len):
train_inputs.append(np.array(train_data[i: i + seq_len]))
train_targets.append(np.array(train_data[i + seq_len: i + seq_len + pre_len]))
for i in range(len(eval_data) - seq_len - pre_len):
eval_inputs.append(np.array(eval_data[i: i + seq_len]))
eval_targets.append(np.array(eval_data[i + seq_len: i + seq_len + pre_len]))
return np.array(train_inputs), np.array(train_targets), np.array(eval_inputs), np.array(eval_targets)
def generate_dataset_ms(config, training):
"""
Generate MindSpore dataset from ndarrays
Args:
config(ConfigTGCN): configuration of parameters
training(bool): generate training dataset or evaluation dataset
Returns:
dataset: MindSpore dataset for training/evaluation
"""
dataset = config.dataset
seq_len = config.seq_len
pre_len = config.pre_len
split_ratio = config.train_split_rate
batch_size = config.batch_size
feat, _ = load_feat_matrix(dataset)
train_inputs, train_targets, eval_inputs, eval_targets = generate_dataset_np(feat, seq_len, pre_len, split_ratio)
if training:
dataset_generator = TGCNDataset(train_inputs, train_targets)
else:
dataset_generator = TGCNDataset(eval_inputs, eval_targets)
dataset = ds.GeneratorDataset(dataset_generator, ["inputs", "targets"], shuffle=False)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
def generate_dataset_ms_distributed(config, training, abs_path=None):
"""
Generate MindSpore dataset from ndarrays in distributed training
Args:
config(ConfigTGCN): configuration of parameters
training(bool): generate training dataset or evaluation dataset
abs_path(str): absolute data directory path
Returns:
dataset: MindSpore dataset for training/evaluation (distributed)
"""
dataset = config.dataset
seq_len = config.seq_len
pre_len = config.pre_len
split_ratio = config.train_split_rate
if training:
batch_size = config.batch_size
else:
batch_size = 1
# Get rank_id and rank_size
rank_id = get_rank()
rank_size = get_group_size()
feat, _ = load_feat_matrix(dataset, abs_path)
train_inputs, train_targets, eval_inputs, eval_targets = generate_dataset_np(feat, seq_len, pre_len, split_ratio)
if training:
dataset_generator = TGCNDataset(train_inputs, train_targets)
else:
dataset_generator = TGCNDataset(eval_inputs, eval_targets)
dataset = ds.GeneratorDataset(dataset_generator, ["inputs", "targets"], shuffle=False,
num_shards=rank_size, shard_id=rank_id)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Evaluation metrics
"""
import numpy as np
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import dtype as mstype
from mindspore import Tensor
def accuracy(preds, targets):
"""
Calculate the accuracy between predictions and targets
Args:
preds(Tensor): predictions
targets(Tensor): ground truth
Returns:
accuracy: defined as 1 - (norm(targets - preds) / norm(targets))
"""
return 1 - np.linalg.norm(targets.asnumpy() - preds.asnumpy(), 'fro') / np.linalg.norm(targets.asnumpy(), 'fro')
def r2(preds, targets):
"""
Calculate R square between predictions and targets
Args:
preds(Tensor): predictions
targets(Tensor): ground truth
Returns:
R square: coefficient of determination
"""
return (1 - P.ReduceSum()((targets - preds) ** 2) / P.ReduceSum()((targets - P.ReduceMean()(preds)) ** 2)).asnumpy()
def explained_variance(preds, targets):
"""
Calculate the explained variance between predictions and targets
Args:
preds(Tensor): predictions
targets(Tensor): ground truth
Returns:
Var: explained variance
"""
return (1 - (targets - preds).var() / targets.var()).asnumpy()
def evaluate_network(net, max_val, eval_inputs, eval_targets):
"""
Evaluate the performance of network
"""
eval_inputs = Tensor(eval_inputs, mstype.float32)
eval_preds = net(eval_inputs)
eval_targets = Tensor(eval_targets, mstype.float32)
eval_targets = eval_targets.reshape((-1, eval_targets.shape[2]))
rmse = P.Sqrt()(nn.MSELoss()(eval_preds, eval_targets)).asnumpy()
mae = nn.MAELoss()(eval_preds, eval_targets).asnumpy()
acc = accuracy(eval_preds, eval_targets)
r_2 = r2(eval_preds, eval_targets)
var = explained_variance(eval_preds, eval_targets)
print("=====Evaluation Results=====")
print('RMSE:', '{:.6f}'.format(rmse * max_val))
print('MAE:', '{:.6f}'.format(mae * max_val))
print('Accuracy:', '{:.6f}'.format(acc))
print('R2:', '{:.6f}'.format(r_2))
print('Var:', '{:.6f}'.format(var))
print("============================")
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Module initialization
"""
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Graph convolution operation
"""
import mindspore.numpy as np
import mindspore.ops.operations as P
from mindspore import dtype as mstype
def calculate_laplacian_with_self_loop(matrix, matmul):
"""
Calculate laplacian matrix with self loop
Args:
matrix(Tensor): input matrix
matmul(MatMul): the MatMul operator for mixed precision
Returns:
normalized_laplacian: normalized laplacian matrix
"""
matrix = matrix + P.Eye()(matrix.shape[0], matrix.shape[0], mstype.float32)
row_sum = matrix.sum(1)
d_inv_sqrt = P.Pow()(row_sum, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_mat_inv_sqrt = np.diag(d_inv_sqrt)
normalized_laplacian = matmul(matmul(matrix, d_mat_inv_sqrt).transpose(0, 1), d_mat_inv_sqrt)
return normalized_laplacian
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
T-GCN loss cell
"""
import mindspore.nn as nn
import mindspore.numpy as np
class TGCNLoss(nn.Cell):
"""
Custom T-GCN loss cell
"""
def construct(self, predictions, targets):
"""
Calculate loss
Args:
predictions(Tensor): predictions from models
targets(Tensor): ground truth
Returns:
loss: loss value
"""
targets = targets.reshape((-1, targets.shape[2]))
return np.sum((predictions - targets) ** 2) / 2
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
T-GCN architecture
For detailed information, please refer to the paper below:
https://arxiv.org/pdf/1811.05320.pdf
"""
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.common.initializer import initializer, XavierUniform, Constant
from .graph_conv import calculate_laplacian_with_self_loop
class TGCNGraphConvolution(nn.Cell):
"""
T-GCN graph convolution layer
"""
def __init__(self, adj, num_gru_units: int, output_dim: int, bias: float = 0.0):
super(TGCNGraphConvolution, self).__init__()
self._num_gru_units = num_gru_units
self._output_dim = output_dim
self._bias_init_value = bias
self.matmul = nn.MatMul()
self.laplacian = Parameter(calculate_laplacian_with_self_loop(Tensor(adj, mstype.float32), self.matmul),
name='laplacian', requires_grad=False)
self.weights = Parameter(initializer(XavierUniform(), [self._num_gru_units + 1, self._output_dim],
mstype.float32), name='weights')
self.biases = Parameter(initializer(Constant(self._bias_init_value), [self._output_dim],
mstype.float32), name='biases')
def construct(self, inputs, hidden_state):
"""
Calculate graph convolution outputs
Args:
inputs(Tensor): network inputs
hidden_state(Tensor): hidden state
Returns:
outputs: TGCNGraphConvolution outputs
"""
batch_size, num_nodes = inputs.shape
# inputs (batch_size, num_nodes) -> (batch_size, num_nodes, 1)
inputs = inputs.reshape((batch_size, num_nodes, 1))
# hidden_state (batch_size, num_nodes, num_gru_units)
hidden_state = hidden_state.reshape((batch_size, num_nodes, self._num_gru_units))
# [x, h] (batch_size, num_nodes, num_gru_units + 1)
concatenation = P.Concat(axis=2)((inputs, hidden_state))
# [x, h] (num_nodes, num_gru_units + 1, batch_size)
concatenation = concatenation.transpose(1, 2, 0)
# [x, h] (num_nodes, (num_gru_units + 1) * batch_size)
concatenation = concatenation.reshape((num_nodes, (self._num_gru_units + 1) * batch_size))
# A[x, h] (num_nodes, (num_gru_units + 1) * batch_size)
a_times_concat = self.matmul(self.laplacian, concatenation)
# A[x, h] (num_nodes, num_gru_units + 1, batch_size)
a_times_concat = a_times_concat.reshape((num_nodes, self._num_gru_units + 1, batch_size))
# A[x, h] (batch_size, num_nodes, num_gru_units + 1)
a_times_concat = a_times_concat.transpose(2, 0, 1)
# A[x, h] (batch_size * num_nodes, num_gru_units + 1)
a_times_concat = a_times_concat.reshape((batch_size * num_nodes, self._num_gru_units + 1))
# A[x, h]W + b (batch_size * num_nodes, output_dim)
outputs = self.matmul(a_times_concat, self.weights) + self.biases
# A[x, h]W + b (batch_size, num_nodes, output_dim)
outputs = outputs.reshape((batch_size, num_nodes, self._output_dim))
# A[x, h]W + b (batch_size, num_nodes * output_dim)
outputs = outputs.reshape((batch_size, num_nodes * self._output_dim))
return outputs
class TGCNCell(nn.Cell):
"""
T-GCN cell
"""
def __init__(self, adj, input_dim: int, hidden_dim: int):
super(TGCNCell, self).__init__()
self._input_dim = input_dim
self._hidden_dim = hidden_dim
self.adj = Parameter(Tensor(adj, mstype.float32), name='adj', requires_grad=False)
self.graph_conv1 = TGCNGraphConvolution(self.adj, self._hidden_dim, self._hidden_dim * 2, bias=1.0)
self.graph_conv2 = TGCNGraphConvolution(self.adj, self._hidden_dim, self._hidden_dim)
def construct(self, inputs, hidden_state):
"""
Calculate hidden states
Args:
inputs(Tensor): network inputs
hidden_state(Tensor): hidden state
Returns:
new_hidden_state: new hidden state
"""
# [r, u] = sigmoid(A[x, h]W + b)
# [r, u] (batch_size, num_nodes * (2 * num_gru_units))
concatenation = P.Sigmoid()(self.graph_conv1(inputs, hidden_state))
# r (batch_size, num_nodes, num_gru_units), u (batch_size, num_nodes, num_gru_units)
r, u = P.Split(axis=1, output_num=2)(concatenation)
# c = tanh(A[x, (r * h)W + b])
# c (batch_size, num_nodes * num_gru_units)
c = P.Tanh()(self.graph_conv2(inputs, r * hidden_state))
# h := u * h + (1 - u) * c
# h (batch_size, num_nodes * num_gru_units)
new_hidden_state = u * hidden_state + (1.0 - u) * c
return new_hidden_state, new_hidden_state
class TGCN(nn.Cell):
"""
T-GCN network
"""
def __init__(self, adj, hidden_dim: int, **kwargs):
super(TGCN, self).__init__()
self._input_dim = adj.shape[0]
self._hidden_dim = hidden_dim
self.adj = Parameter(Tensor(adj, mstype.float32), name='adj', requires_grad=False)
self.tgcn_cell = TGCNCell(self.adj, self._input_dim, self._hidden_dim)
def construct(self, inputs):
"""
Calculate the final output
Args:
inputs(Tensor): network inputs
Returns:
output: TGCN output
"""
batch_size, seq_len, num_nodes = inputs.shape
hidden_state = P.Zeros()((batch_size, num_nodes * self._hidden_dim), mstype.float32)
output = None
for i in range(seq_len):
output, hidden_state = self.tgcn_cell(inputs[:, i, :], hidden_state)
output = output.reshape((batch_size, num_nodes, self._hidden_dim))
return output
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Supervised forecast task
"""
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from .model.tgcn import TGCN
class SupervisedForecastTask(nn.Cell):
"""
T-GCN applied to supervised forecast task
"""
def __init__(self, adj, hidden_dim: int, pre_len: int):
super(SupervisedForecastTask, self).__init__()
self.adj = Parameter(Tensor(adj, mstype.float32), name='adj', requires_grad=False)
self.tgcn = TGCN(self.adj, hidden_dim)
self.fcn = nn.Dense(hidden_dim, pre_len)
def construct(self, inputs):
"""
Calculate network predictions for supervised forecast task
Args:
inputs(Tensor): network inputs
Returns:
predictions: predictions of supervised forecast task
"""
# (batch_size, seq_len, num_nodes)
batch_size, _, num_nodes = inputs.shape
# (batch_size, num_nodes, hidden_dim)
hidden = self.tgcn(inputs)
# (batch_size * num_nodes, hidden_dim)
hidden = hidden.reshape((-1, hidden.shape[2]))
# (batch_size * num_nodes, pre_len)
predictions = self.fcn(hidden)
predictions = predictions.reshape((batch_size, num_nodes, -1))
# Change data shape for the following calculation of metrics
predictions = predictions.transpose(0, 2, 1).reshape((-1, num_nodes))
return predictions
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Training script
"""
import os
import time
import argparse
from mindspore.communication import init
from mindspore.context import ParallelMode
from mindspore import dtype as mstype
from mindspore import set_seed, nn, context, Model
from mindspore.train.callback import LossMonitor, TimeMonitor
from src.config import ConfigTGCN
from src.dataprocess import load_adj_matrix, load_feat_matrix, generate_dataset_ms, generate_dataset_ms_distributed
from src.task import SupervisedForecastTask
from src.model.loss import TGCNLoss
from src.callback import RMSE, SaveCallback
def run_train(args):
"""
Run training
"""
# Config initialization
config = ConfigTGCN()
# Set global seed for MindSpore and NumPy
set_seed(config.seed)
# ModelArts runtime, datasets and network initialization
if args.run_modelarts:
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id)
mox.file.copy_parallel(src_url=args.data_url, dst_url='./data')
if args.distributed:
device_num = int(os.getenv('RANK_SIZE'))
init()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
training_set = generate_dataset_ms_distributed(config, training=True, abs_path=args.data_path)
eval_set = generate_dataset_ms_distributed(config, training=False, abs_path=args.data_path)
_, max_val = load_feat_matrix(config.dataset, args.data_path)
net = SupervisedForecastTask(load_adj_matrix(config.dataset, args.data_path),
config.hidden_dim, config.pre_len)
else:
training_set = generate_dataset_ms(config, training=True)
eval_set = generate_dataset_ms(config, training=False)
_, max_val = load_feat_matrix(config.dataset)
net = SupervisedForecastTask(load_adj_matrix(config.dataset), config.hidden_dim, config.pre_len)
# Offline runtime, datasets and network initialization
else:
if args.distributed:
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=args.device_id)
context.set_context(device_id=device_id)
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
training_set = generate_dataset_ms_distributed(config, training=True, abs_path=args.data_path)
eval_set = generate_dataset_ms_distributed(config, training=False, abs_path=args.data_path)
_, max_val = load_feat_matrix(config.dataset, args.data_path)
net = SupervisedForecastTask(load_adj_matrix(config.dataset, args.data_path),
config.hidden_dim, config.pre_len)
else:
context.set_context(mode=context.GRAPH_MODE, device_target=config.device, device_id=args.device_id)
training_set = generate_dataset_ms(config, training=True)
eval_set = generate_dataset_ms(config, training=False)
_, max_val = load_feat_matrix(config.dataset)
net = SupervisedForecastTask(load_adj_matrix(config.dataset), config.hidden_dim, config.pre_len)
# Mixed precision
net.tgcn.tgcn_cell.graph_conv1.matmul.to_float(mstype.float16)
net.tgcn.tgcn_cell.graph_conv2.matmul.to_float(mstype.float16)
# Loss function
loss_fn = TGCNLoss()
# Optimizer
optimizer = nn.Adam(net.trainable_params(), config.learning_rate, weight_decay=config.weight_decay)
# Create model
model = Model(net, loss_fn, optimizer, {'RMSE': RMSE(max_val)})
# Training
if args.distributed:
print("==========Distributed Training Start==========")
else:
print("==========Training Start==========")
time_start = time.time()
model.train(config.epochs, training_set,
callbacks=[LossMonitor(), TimeMonitor(), SaveCallback(model, eval_set, config)],
dataset_sink_mode=config.data_sink)
time_end = time.time()
if args.distributed:
print("==========Distributed Training End==========")
else:
print("==========Training End==========")
print("Training time in total:", '{:.6f}'.format(time_end - time_start), "s")
# Save outputs (checkpoints) on ModelArts
if args.run_modelarts:
mox.file.copy_parallel(src_url='./checkpoints', dst_url=args.train_url)
if __name__ == '__main__':
# Set universal arguments
parser = argparse.ArgumentParser()
parser.add_argument('--device_id', help="DEVICE_ID", type=int, default=0)
parser.add_argument('--distributed', help="distributed training", type=bool, default=False)
parser.add_argument('--data_path', help="directory of datasets", type=str, default='./data')
# Set ModelArts arguments
parser.add_argument('--run_modelarts', help="ModelArts runtime", type=bool, default=False)
parser.add_argument('--data_url', help='ModelArts location of data', type=str, default=None)
parser.add_argument('--train_url', help='ModelArts location of training outputs', type=str, default=None)
run_args = parser.parse_args()
# Training
run_train(run_args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment