From da16dcad8fcc38e0a41efc20124810c604c3b963 Mon Sep 17 00:00:00 2001
From: mo-hai <mohai3737@gmail.com>
Date: Tue, 2 Aug 2022 14:20:44 +0800
Subject: [PATCH] Implement training and evaluation of STGCN on a GPU

---
 research/cv/stgcn/README.md                   | 313 +++++++++++++++++
 research/cv/stgcn/README_CN.md                | 228 ++++++++-----
 research/cv/stgcn/eval.py                     | 145 ++++++++
 research/cv/stgcn/export.py                   |  43 +--
 research/cv/stgcn/modelarts/start_train.py    |   4 +-
 research/cv/stgcn/postprocess.py              |  19 +-
 research/cv/stgcn/preprocess.py               |  23 +-
 research/cv/stgcn/requirements.txt            |   4 +
 .../cv/stgcn/scripts/run_distribute_train.sh  |  14 +-
 research/cv/stgcn/scripts/run_eval_ascend.sh  |  27 +-
 research/cv/stgcn/scripts/run_eval_gpu.sh     |  37 ++
 .../cv/stgcn/scripts/run_single_train_gpu.sh  |  67 ++++
 research/cv/stgcn/src/argparser.py            |  55 +++
 research/cv/stgcn/src/config.py               |   6 +
 research/cv/stgcn/src/dataloader.py           |  16 +-
 research/cv/stgcn/src/model/layers.py         |  68 ++--
 research/cv/stgcn/src/model/metric.py         |   1 +
 research/cv/stgcn/src/model/models.py         |  11 +-
 research/cv/stgcn/src/utility.py              |  14 +-
 research/cv/stgcn/test.py                     | 190 -----------
 research/cv/stgcn/train.py                    | 320 +++++++++---------
 21 files changed, 1025 insertions(+), 580 deletions(-)
 create mode 100644 research/cv/stgcn/README.md
 create mode 100644 research/cv/stgcn/eval.py
 create mode 100644 research/cv/stgcn/requirements.txt
 create mode 100644 research/cv/stgcn/scripts/run_eval_gpu.sh
 create mode 100644 research/cv/stgcn/scripts/run_single_train_gpu.sh
 create mode 100644 research/cv/stgcn/src/argparser.py
 delete mode 100644 research/cv/stgcn/test.py

diff --git a/research/cv/stgcn/README.md b/research/cv/stgcn/README.md
new file mode 100644
index 000000000..df0581e4f
--- /dev/null
+++ b/research/cv/stgcn/README.md
@@ -0,0 +1,313 @@
+# Contents
+
+<!-- TOC -->
+
+[查看中文](./README_CN.md)
+
+- [STGCN Description](#STGCN-description)
+- [Model Architecture](#model-architecture)
+- [Dataset](#dataset)
+- [Environment Requirements](#environment-requirements)
+- [Quick start](#quick-start)
+- [Script Description](#script-description)
+    - [Script and Sample Code](#script-and-sample-code)
+    - [Script Parameters](#script-parameters)
+    - [Training Process](#training-process)
+        - [Usage](#usage)
+        - [Result](#result)
+    - [Evaluation Process](#evaluation-process)
+        - [Usage](#usage-2)
+        - [Result](#result-2)
+    - [Model Export](#model-export)
+    - [Inference Process](#inference-process)
+        - [Usage](#usage-3)
+        - [Result](#result-3)
+- [Model Description](#model-description)
+    - [Performance](#performance)  
+        - [Training Performance](#training-performance)
+        - [Evaluation Performance](#evaluation-performance)
+- [Description of Random State](#description-of-random-state)
+- [ModelZoo Homepage](#modelzoo-homepage)
+
+# [STGCN Description](#contents)
+
+This novel deep learning framework, Spatio-temporal Graph Convolutional Network (STGCN), is proposed in article to solve the problem
+of time series prediction in the general field. Authors formulate the problem on graphs and build the model with complete convolutional structures,
+which enable much faster training speed with fewer parameters. STGCN effectively captures comprehensive spatio-temporal correlations through modeling
+multi-scale traffic networks and consistently outperforms state-of-the-art baselines on various real-world traffic datasets.
+
+[Paper](https://arxiv.org/abs/1709.04875): Bing yu, Haoteng Yin, and Zhanxing Zhu. "Spatio-Temporal Graph Convolutional Networks:
+A Deep Learning Framework for Traffic Forecasting." Proceedings of the 27th International Joint Conference on Artificial Intelligence. 2017.
+
+# [Model Architecture](#contents)
+
+The STGCN model structure is composed of two spatio-temporal convolution blocks (ST-Conv blocks) and fully-connected output layer.
+Each ST-Conv block contains two temporal gated convolution layers and one spatial graph convolution layer in the middle.
+There are two different convolution methods for spatial convolution blocks: Cheb and GCN.
+
+# [Dataset](#contents)
+
+Dataset used:
+
+- Only [PeMSD7-M](https://github.com/hazdzz/STGCN/tree/main/data/pemsd7-m) dataset is available for download. Adj_mat.csv can be found in older version [adj_mat.csv](https://github.com/hazdzz/STGCN/blob/3ca6c36f0e4b874976891d5b09d9d5b0858680d3/data/train/road_traffic/pemsd7-m/adj_mat.csv)
+- BJER4 is not used, as it is a private dataset which is restricted by a confidentiality agreement.
+
+# [Environment Requirements](#contents)
+
+- Hardware(Ascend/GPU)
+    - Prepare hardware environment with Ascend or GPU.
+- Framework
+    - [MindSpore](https://www.mindspore.cn/install/en)
+- For more information about MindSpore, please check the resources below:
+    - [MindSpore Tutorials](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
+    - [MindSpore Python API](https://www.mindspore.cn/docs/en/master/index.html)
+- Other
+    - pandas
+    - sklearn
+    - easydict
+
+# [Quick start](#contents)
+
+After installing MindSpore through the official website, you can start training and evaluation through the following steps:
+
+- running on Ascend with default parameters
+
+```shell
+# single card
+python train.py --device_target="Ascend" --train_url="" --data_url="" --run_distribute=False --run_modelarts=False --graph_conv_type="chebconv" --n_pred=9
+
+# multi card
+bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type rank_table
+```
+
+- running on GPU with default parameters
+
+```shell
+# single card
+python train.py --device_target="GPU" --train_url="" --data_url="" --run_distribute=False --run_modelarts=False --graph_conv_type="chebconv" --n_pred=9
+
+# single card
+bash scripts/run_single_train_gpu.sh data_path n_pred graph_conv_type device_id
+```
+
+# [Script Description](#contents)
+
+## [Script and Sample Code](#contents)
+
+```text
+├── STGCN
+    ├── scripts
+        ├── run_distribute_train.sh       # training on Ascend with 8P
+        ├── run_single_train_gpu.sh       # training on GPU 1P
+        ├── run_eval_ascend.sh            # testing on Ascend
+        ├── run_eval_gpu.sh               # testing on GPU
+    ├── src
+        ├── model
+            ├──layers.py                  # model layer
+            ├──metric.py                  # network with losscell
+            ├──models.py                  # network model
+        ├──argparser.py                   # command line parameters
+        ├──config.py                      # parameters
+        ├──dataloder.py                   # creating dataset
+        ├──utility.py                     # calculate laplacian matrix and evaluate metric
+        ├──weight_init.py                 # layernorm weight init
+    ├── train.py                          # training network
+    ├── eval.py                           # testing network performance
+    ├── export.py
+    ├── postprocess.py                    # compute accuracy for ascend310
+    ├── preprocess.py                     # process dataset for ascend310
+    ├── README.md
+    ├── README_CN.md
+```
+
+## [Script Parameters](#contents)
+
+Training and evaluation parameters can be set in config.py
+
+- config for STGCN
+
+```text
+    stgcn_chebconv_45min_cfg = edict({
+    'learning_rate': 0.003,
+    'n_his': 12,
+    'n_pred': 9,
+    'epochs': 50,
+    'batch_size': 8,  # config.batch_size * int(8 / device_num)
+    'decay_epoch': 10,
+    'gamma': 0.7,
+    'stblock_num': 2,
+    'Ks': 3,
+    'Kt': 3,
+    'time_intvl': 5,
+    'drop_rate': 0.5,
+    'weight_decay_rate': 0.0005,
+    'gated_act_func':"glu",
+    'graph_conv_type': "chebconv",
+    'mat_type': "wid_sym_normd_lap_mat",
+    })
+```
+
+For more information, please check `config.py`.
+
+## [Training process](#contents)
+
+### Usage
+
+- running on Ascend
+
+```shell
+# train single card
+python train.py --device_target="Ascend" --train_url="" --data_url="" --run_distribute=False --run_modelarts=True --graph_conv_type="chebconv" --n_pred=9
+
+# train 8 card
+bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type rank_table
+```
+
+> Note: To train on 8p Ascend put `RANK_TABLE_FILE` in `scripts` folder. [How to generate RANK_TABLE_FILE](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)
+
+- running on GPU
+
+```shell
+# train single card
+python train.py --device_target="GPU" --train_url="./checkpoint" --data_url="./data" --run_distribute=False --run_modelarts=False --graph_conv_type="chebconv" --n_pred=9
+
+# train single card
+bash scripts/run_single_train_gpu.sh data_path n_pred graph_conv_type device_id
+```
+
+### Result
+
+During training epochs, steps and loss will be displayed in terminal:
+
+```text
+  epoch: 1 step: 139, loss is 0.429
+  epoch time: 203885.163 ms, per step time: 1466.800 ms
+  epoch: 2 step: 139, loss is 0.2097
+  epoch time: 6330.939 ms, per step time: 45.546 ms
+  epoch: 3 step: 139, loss is 0.4192
+  epoch time: 6364.882 ms, per step time: 45.791 ms
+  epoch: 4 step: 139, loss is 0.2917
+  epoch time: 6378.299 ms, per step time: 45.887 ms
+  epoch: 5 step: 139, loss is 0.2365
+  epoch time: 6369.215 ms, per step time: 45.822 ms
+  epoch: 6 step: 139, loss is 0.2269
+  epoch time: 6389.238 ms, per step time: 45.966 ms
+  epoch: 7 step: 139, loss is 0.3071
+  epoch time: 6365.901 ms, per step time: 45.798 ms
+  epoch: 8 step: 139, loss is 0.2336
+  epoch time: 6358.127 ms, per step time: 45.742 ms
+  epoch: 9 step: 139, loss is 0.2812
+  epoch time: 6333.794 ms, per step time: 45.567 ms
+  epoch: 10 step: 139, loss is 0.2622
+  epoch time: 6334.013 ms, per step time: 45.568 ms
+  ...
+```
+
+The checkpoint of this model is stored in the `train_url` path
+
+## [Evaluation process](#contents)
+
+### Usage
+
+Use the PeMSD7-m test set for evaluation
+
+- on Ascend
+
+When using python to run, you need to input device, checkpoint path, spatial convolution method, and prediction period.
+
+```shell
+python eval.py --device_target="Ascend" --run_modelarts=False --run_distribute=False --device_id=0 --ckpt_url="" --graph_conv_type="" --n_pred=9
+
+# using script to run
+bash scripts/run_eval_ascend.sh [device] [data_path] [ckpt_url] [device_id] [graph_conv_type] [n_pred]
+```
+
+- on GPU
+
+When using python to run, you need to input device, checkpoint path, spatial convolution method, and prediction period.
+
+```shell
+python eval.py --device_target="GPU" --run_modelarts=False --run_distribute=False --device_id=0 --ckpt_url="" --graph_conv_type="" --n_pred=9
+
+# Using script to run
+bash scripts/run_eval_gpu.sh [device] [data_path] [ckpt_url] [device_id] [graph_conv_type] [n_pred]
+```
+
+### Result
+
+The above python command will run on the terminal, and you can view the result of this evaluation on the terminal. The accuracy of the test set will be presented as follows:
+
+```text
+MAE 3.23 | MAPE 8.32 | RMSE 6.06
+```
+
+## [Model Export](#contents)
+
+```shell
+python export.py --data_url [DATA_URL] --ckpt_file [CKPT_PATH] --n_pred [N_PRED] --graph_conv_type [GRAPH_CONV_TYPE] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
+```
+
+## [Inference Process](#contents)
+
+### Usage
+
+Before performing inference, the minirir file must be exported by export.py. The input file must be in bin format
+
+```shell
+# Ascend310 inference
+bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_TARGET] [DEVICE_ID]
+```
+
+### Result
+
+The inference result is saved in the current path, and you can find the result in the acc.log file
+
+# [Model Description](#contents)
+
+## [Performance](#contents)
+
+### Training Performance
+
+#### STGCN on PeMSD7-m (Cheb, n_pred=9)
+
+| Parameters                 | Ascend 8p                                        | GPU 1p |
+| -------------------------- | ------------------------------------------------ | ------ |
+| Model                      | STGCN                                            | STGCN  |
+| Environment                | ModelArts; Ascend 910; CPU 2.60GHz, 192cores, Memory, 755G | Ubuntu 18.04.6, 1pcs RTX3090, CPU 2.90GHz, 64cores, RAM 252GB |
+| Uploaded Date (month/day/year) | 05/07/2021                                   | 17/01/2021 |
+| MindSpore Version          | 1.2.0                                            | 1.5.0 |
+| Dataset                    | PeMSD7-M                                         | PeMSD7-M |
+| Training Parameters        | epoch=500, steps=139, batch_size=8, lr=0.003     | epoch=50, steps=139, batch_size=64, lr=0.003 |
+| Optimizer                  | AdamWeightDecay                                  | AdamWeightDecay |
+| Loss Function              | MSE Loss                                         | MSE Loss |
+| Outputs                    | probability                                      | probability |
+| Final loss                 | 0.183                                            | 0.23 |
+| Speed                      | 45.601 ms/step                                   | 44 ms/step |
+| Total time                 | 56min                                            | 5 min |
+| Scripts                    | [STGCN script](https://gitee.com/mindspore/models/tree/master/research/cv/stgcn#https://arxiv.org/abs/1709.04875) | [STGCN script](https://gitee.com/mindspore/models/tree/master/research/cv/stgcn#https://arxiv.org/abs/1709.04875) | [STGCN script](https://gitee.com/mindspore/models/tree/master/research/cv/stgcn#https://arxiv.org/abs/1709.04875) |
+
+### Evaluation Performance
+
+#### STGCN on PeMSD7-m (Cheb, n_pred=9)
+
+| Parameters          | Ascend                      | GPU 1P |
+| ------------------- | --------------------------- | ------ |
+| Model Version       | STGCN                       | STGCN  |
+| Resource            | Ascend 910                  | Ubuntu 18.04.6, NVIDIA GeForce RTX3090, CPU 2.90GHz, 64cores, RAM 252GB |
+| Uploaded Date       | 05/07/2021 (month/day/year) | 17/01/2021 |
+| MindSpore Version   | 1.2.0                       | 1.5.0    |
+| Dataset             | PeMSD7-M                    | PeMSD7-M |
+| batch_size          | 8                           | 64        |
+| outputs             | probability                 | probability |
+| MAE                 | 3.23                        | 3.24 |
+| MAPE                | 8.32                        | 8.25 |
+| RMSE                | 6.06                        | 6.03 |
+| Model for inference | about 6M(.ckpt fil)         | about 6M(.ckpt fil) |
+
+# [Description of Random State](#contents)
+
+Random seed is set in `train.py` script.
+
+# [ModelZoo Homepage](#contents)
+
+ Please check the official [homepage](https://gitee.com/mindspore/models).
diff --git a/research/cv/stgcn/README_CN.md b/research/cv/stgcn/README_CN.md
index 2cae1e7e4..63767d674 100644
--- a/research/cv/stgcn/README_CN.md
+++ b/research/cv/stgcn/README_CN.md
@@ -1,7 +1,10 @@
 # Contents
 
-- [Contents](#contents)
-- [STGCN 介绍](#stgcn-介绍)
+<!-- TOC -->
+
+[View English](./README.md)
+
+- [STGCN 介绍](#STGCN-介绍)
 - [模型架构](#模型架构)
 - [数据集](#数据集)
 - [环境要求](#环境要求)
@@ -11,20 +14,20 @@
     - [脚本参数](#脚本参数)
     - [训练步骤](#训练步骤)
         - [训练](#训练)
+        - [结果](#结果)
     - [评估步骤](#评估步骤)
         - [评估](#评估)
+        - [结果](#结果-2)
     - [导出mindir模型](#导出mindir模型)
     - [推理过程](#推理过程)
         - [用法](#用法)
-        - [结果](#结果)
+        - [结果](#结果-3)
 - [模型介绍](#模型介绍)
     - [性能](#性能)
+        - [训练性能](#训练性能)
         - [评估性能](#评估性能)
-            - [STGCN on PeMSD7-m (Cheb,n_pred=9)](#stgcn-on-pemsd7-m-chebn_pred9)
-        - [Inference Performance](#inference-performance)
-            - [STGCN on PeMSD7-m (Cheb,n_pred=9)](#stgcn-on-pemsd7-m-chebn_pred9-1)
 - [随机事件介绍](#随机事件介绍)
-- [ModelZoo 主页](#modelzoo-主页)
+- [ModelZoo 主页](#ModelZoo-主页)
 
 # [STGCN 介绍](#contents)
 
@@ -41,10 +44,8 @@ STGCN模型结构是由两个时空卷积快和一个输出层构成。时空卷
 
 Dataset used:
 
-PeMED7(PeMSD7-m、PeMSD7-L)
-BJER4
-
-由于数据集下载原因,只找到了[PeMSD7-M](https://github.com/hazdzz/STGCN/tree/main/data/pemsd7-m)数据集。
+- 由于数据集下载原因,只找到了[PeMSD7-M](https://github.com/hazdzz/STGCN/tree/main/data/pemsd7-m) 数据集。Adj_mat.csv can be found in older version [adj_mat.csv](https://github.com/hazdzz/STGCN/blob/3ca6c36f0e4b874976891d5b09d9d5b0858680d3/data/train/road_traffic/pemsd7-m/adj_mat.csv)
+- BJER4 is not used, as it is a private dataset which is restricted by a confidentiality agreement.
 
 # [环境要求](#contents)
 
@@ -54,7 +55,11 @@ BJER4
     - [MindSpore](https://www.mindspore.cn/install/en)
 - 如需获取更多信息,请查看如下链接:
     - [MindSpore Tutorials](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
-    - [MindSpore Python API](https://www.mindspore.cn/docs/en/master/index.html)
+    - [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/index.html)
+- Other
+    - pandas
+    - sklearn
+    - easydict
 
 # [快速开始](#contents)
 
@@ -62,37 +67,52 @@ BJER4
 
 - running on Ascend with default parameters
 
-  ```python
-  # 单卡训练
-  python train.py --train_url="" --data_url="" --run_distribute=False --run_modelarts=False --graph_conv_type="chebgcn" --n_pred=9
+```shell
+# 单卡训练
+python train.py --device_target="Ascend" --train_url="" --data_url="" --run_distribute=False --run_modelarts=False --graph_conv_type="chebgcn" --n_pred=9
+
+# 多卡训练
+bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type rank_table
+```
+
+- running on GPU with default parameters
 
-  # 多卡训练
-  bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type rank_table
-  ```
+```shell
+# 单卡训练
+python train.py --device_target="GPU" --train_url="" --data_url="" --run_distribute=False --run_modelarts=False --graph_conv_type="chebconv" --n_pred=9
+
+# 单卡训练
+bash scripts/run_single_train_gpu.sh data_path n_pred graph_conv_type device_id
+```
 
 # [脚本介绍](#contents)
 
 ## [脚本以及简单代码](#contents)
 
-```python
+```text
 ├── STGCN
     ├── scripts
-        ├── run_distribute_train.sh     //traing on Ascend with 8P
-        ├── run_eval_ascend.sh     //testing on Ascend
+        ├── run_distribute_train.sh       # training on Ascend with 8P
+        ├── run_single_train_gpu.sh       # training on GPU 1P
+        ├── run_eval_ascend.sh            # testing on Ascend
+        ├── run_eval_gpu.sh               # testing on GPU
     ├── src
         ├── model
-            ├──layers.py       // model layer
-            ├──metric.py          // network with losscell
-            ├──models.py          // network model
-        ├──config.py       // parameter
-        ├──dataloder.py          // creating dataset
-        ├──utility.py          // calculate laplacian matrix and evaluate metric
-        ├──weight_init.py       // layernorm weight init
-    ├── train.py                // traing network
-    ├── test.py                 // tesing network performance
-    ├── postprocess.py                 // compute accuracy for ascend310
-    ├── preprocess.py                 // process dataset for ascend310
-    ├── README.md                 // descriptions
+            ├──layers.py                  # model layer
+            ├──metric.py                  # network with losscell
+            ├──models.py                  # network model
+        ├──argparser.py                   # command line parameters
+        ├──config.py                      # parameters
+        ├──dataloder.py                   # creating dataset
+        ├──utility.py                     # calculate laplacian matrix and evaluate metric
+        ├──weight_init.py                 # layernorm weight init
+    ├── train.py                          # traing network
+    ├── eval.py                           # tesing network performance
+    ├── export.py
+    ├── postprocess.py                    # compute accuracy for ascend310
+    ├── preprocess.py                     # process dataset for ascend310
+    ├── README.md
+    ├── README_CN.md
 ```
 
 ## [脚本参数](#contents)
@@ -101,18 +121,17 @@ BJER4
 
 - config for STGCN
 
-  ```python
-     stgcn_chebconv_45min_cfg = edict({
+```text
+    stgcn_chebconv_45min_cfg = edict({
     'learning_rate': 0.003,
     'n_his': 12,
     'n_pred': 9,
-    'n_vertex': 228,
-    'epochs': 500,
-    'batch_size': 8,
+    'epochs': 50,
+    'batch_size': 8,  # config.batch_size * int(8 / device_num)
     'decay_epoch': 10,
     'gamma': 0.7,
     'stblock_num': 2,
-    'Ks': 2,
+    'Ks': 3,
     'Kt': 3,
     'time_intvl': 5,
     'drop_rate': 0.5,
@@ -121,7 +140,7 @@ BJER4
     'graph_conv_type': "chebconv",
     'mat_type': "wid_sym_normd_lap_mat",
     })
-  ```
+```
 
 如需查看更多信息,请查看`config.py`.
 
@@ -131,18 +150,29 @@ BJER4
 
 - running on Ascend
 
-  ```python
-  #1P训练
-  python train.py --train_url="" --data_url="" --run_distribute=False --run_modelarts=True --graph_conv_type="chebgcn" --n_pred=9
-  #8P训练
-  bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type rank_table
-  ```
+```shell
+# 单卡训练
+python train.py --device_target="Ascend" --train_url="" --data_url="" --run_distribute=False --run_modelarts=True --graph_conv_type="chebgcn" --n_pred=9
+
+# 八卡训练
+bash scripts/run_distribute_train.sh train_code_path data_path n_pred graph_conv_type rank_table
+```
+
+> 注意:8P训练时需要将RANK_TABLE_FILE放在scripts文件夹中,RANK_TABLE_FILE[生成方法](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)
+
+- running on GPU
 
-  8P训练时需要将RANK_TABLE_FILE放在scripts文件夹中,RANK_TABLE_FILE[生成方法](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)
+```shell
+# 单卡训练
+python train.py --device_target="GPU" --train_url="" --data_url="" --run_distribute=False --run_modelarts=True --graph_conv_type="chebgcn" --n_pred=9
+
+# 单卡训练
+bash scripts/run_single_train_gpu.sh data_path n_pred graph_conv_type device_id
+```
 
-  训练时,训练过程中的epch和step以及此时的loss和精确度会呈现在终端上:
+训练时,训练过程中的epch和step以及此时的loss和精确度会呈现在终端上:
 
-  ```python
+```text
   epoch: 1 step: 139, loss is 0.429
   epoch time: 203885.163 ms, per step time: 1466.800 ms
   epoch: 2 step: 139, loss is 0.2097
@@ -164,9 +194,9 @@ BJER4
   epoch: 10 step: 139, loss is 0.2622
   epoch time: 6334.013 ms, per step time: 45.568 ms
   ...
-  ```
+```
 
-  此模型的checkpoint存储在train_url路径中
+此模型的checkpoint存储在train_url路径中
 
 ## [评估步骤](#contents)
 
@@ -174,19 +204,32 @@ BJER4
 
 - 在Ascend上使用PeMSD7-m 测试集进行评估
 
-  在使用命令运行时,需要传入模型参数地址、模型参数名称、空域卷积方式、预测时段。
+在使用命令运行时,需要传入模型参数地址、空域卷积方式、预测时段。
+
+```shell
+python eval.py --device_target="Ascend" --run_modelarts=False --run_distribute=False --device_id=0 --ckpt_url="" --graph_conv_type="" --n_pred=9
+# 使用脚本评估
+bash scripts/run_eval_ascend.sh device data_path ckpt_url device_id graph_conv_type n_pred
+```
+
+- 在GPU上使用PeMSD7-m 测试集进行评估
+
+在使用命令运行时,需要传入模型参数地址、空域卷积方式、预测时段。
 
-  ```python
-  python test.py --run_modelarts=False --run_distribute=False --device_id=0 --ckpt_url="" --ckpt_name="" --graph_conv_type="" --n_pred=9
-  #使用脚本评估
-  bash scripts/run_eval_ascend.sh data_path ckpt_url ckpt_name device_id graph_conv_type n_pred
-  ```
+```shell
+python eval.py --device_target="GPU" --run_modelarts=False --run_distribute=False --device_id=0 --ckpt_url="" --graph_conv_type="" --n_pred=9
 
-  以上的python命令会在终端上运行,你可以在终端上查看此次评估的结果。测试集的精确度会以如下方式呈现:
+# 使用脚本评估
+bash scripts/run_eval_gpu.sh device data_path ckpt_url device_id graph_conv_type n_pred
+```
 
-  ```python
-  MAE 3.23 | MAPE 8.32 | RMSE 6.06
-  ```
+### 结果
+
+以上的python命令会在终端上运行,你可以在终端上查看此次评估的结果。测试集的精确度会以如下方式呈现:
+
+```text
+MAE 3.23 | MAPE 8.32 | RMSE 6.06
+```
 
 ## [导出mindir模型](#contents)
 
@@ -213,42 +256,43 @@ bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_TAR
 
 ## [性能](#contents)
 
-### 评估性能
+### 训练性能
 
 #### STGCN on PeMSD7-m (Cheb,n_pred=9)
 
-| Parameters                 | ModelArts
-| -------------------------- | -----------------------------------------------------------
-| Model Version              | STGCN
-| Resource                   | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G
-| uploaded Date              | 05/07/2021 (month/day/year)
-| MindSpore Version          | 1.2.0
-| Dataset                    | PeMSD7-m
-| Training Parameters        | epoch=500, steps=139, batch_size = 8, lr=0.003
-| Optimizer                  | AdamWeightDecay
-| Loss Function              | MES Loss
-| outputs                    | probability
-| Loss                       | 0.183
-| Speed                      | 8pc: 45.601 ms/step;
-| Scripts                    | [STGCN script]
-
-### Inference Performance
+| Parameters                 | ModelArts                                       | GPU 1p |
+| -------------------------- | ----------------------------------------------- | ------ |
+| Model Version              | STGCN                                           | STGCN  |
+| Resource                   | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | Ubuntu 18.04.6, 1pcs RTX3090, CPU 2.90GHz, 64cores, RAM 252GB |
+| Uploaded Date (month/day/year) | 05/07/2021                                  | 17/01/2021 |
+| MindSpore Version          | 1.2.0                                           | 1.5.0 |
+| Dataset                    | PeMSD7-m                                        | PeMSD7-M |
+| Training Parameters        | epoch=500, steps=139, batch_size = 8, lr=0.003  | epoch=50, steps=139, batch_size=64, lr=0.003 |
+| Optimizer                  | AdamWeightDecay                                 | AdamWeightDecay |
+| Loss Function              | MES Loss                                        | MES Loss |
+| Outputs                    | probability                                     | probability |
+| Loss                       | 0.183                                           | 0.23 |
+| Speed                      | 8pc: 45.601 ms/step                             | 44 ms/step |
+| Total time                 | 56min                                           | 5 min |
+| Scripts                    | [STGCN script](https://gitee.com/mindspore/models/tree/master/research/cv/stgcn#https://arxiv.org/abs/1709.04875) | [STGCN script](https://gitee.com/mindspore/models/tree/master/research/cv/stgcn#https://arxiv.org/abs/1709.04875) | [STGCN script](https://gitee.com/mindspore/models/tree/master/research/cv/stgcn#https://arxiv.org/abs/1709.04875)
 
-#### STGCN on PeMSD7-m (Cheb,n_pred=9)
+### 评估性能
 
-| Parameters          | Ascend
-| ------------------- | ---------------------------
-| Model Version       | STGCN
-| Resource            | Ascend 910
-| Uploaded Date       | 05/07/2021 (month/day/year)
-| MindSpore Version   | 1.2.0
-| Dataset             | PeMSD7-m
-| batch_size          | 8
-| outputs             | probability
-| MAE                 | 3.23
-| MAPE                | 8.32
-| RMSE                | 6.06
-| Model for inference | about 6M(.ckpt fil)
+#### STGCN on PeMSD7-m (Cheb, n_pred=9)
+
+| Parameters          | Ascend                      | GPU 1P   |
+| ------------------- | --------------------------- | -------- |
+| Model Version       | STGCN                       | STGCN    |
+| Resource            | Ascend 910                  | Ubuntu 18.04.6, 1pcs RTX3090, CPU 2.90GHz, 64cores, RAM 252GB |
+| Uploaded Date       | 05/07/2021 (month/day/year) | 17/01/2021 |
+| MindSpore Version   | 1.2.0                       | 1.5.0    |
+| Dataset             | PeMSD7-M                    | PeMSD7-M |
+| Batch_size          | 8                           | 64        |
+| Outputs             | probability                 | probability |
+| MAE                 | 3.23                        | 3.24     |
+| MAPE                | 8.32                        | 8.25     |
+| RMSE                | 6.06                        | 6.03     |
+| Model for inference | about 6M(.ckpt fil)         | about 6M(.ckpt fil) |
 
 # [随机事件介绍](#contents)
 
diff --git a/research/cv/stgcn/eval.py b/research/cv/stgcn/eval.py
new file mode 100644
index 000000000..7760b0e68
--- /dev/null
+++ b/research/cv/stgcn/eval.py
@@ -0,0 +1,145 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+testing network performance.
+"""
+
+import os
+import pandas as pd
+from sklearn import preprocessing
+
+from mindspore import context, Tensor
+from mindspore.common import set_seed, dtype as mstype
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+
+from src.argparser import arg_parser
+from src.model import models
+from src import dataloader, utility, config
+
+os.system("export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python")
+
+
+def get_config(args):
+    """return config based on selected n_pred and graph_conv_type"""
+    if args.graph_conv_type == "chebconv":
+        if args.n_pred == 9:
+            cfg = config.stgcn_chebconv_45min_cfg
+        elif args.n_pred == 6:
+            cfg = config.stgcn_chebconv_30min_cfg
+        elif args.n_pred == 3:
+            cfg = config.stgcn_chebconv_15min_cfg
+        else:
+            raise ValueError("Unsupported n_pred.")
+    elif args.graph_conv_type == "gcnconv":
+        if args.n_pred == 9:
+            cfg = config.stgcn_gcnconv_45min_cfg
+        elif args.n_pred == 6:
+            cfg = config.stgcn_gcnconv_30min_cfg
+        elif args.n_pred == 3:
+            cfg = config.stgcn_gcnconv_15min_cfg
+        else:
+            raise ValueError("Unsupported pred.")
+    else:
+        raise ValueError("Unsupported graph_conv_type.")
+
+    return cfg
+
+
+def get_params():
+    """get and preprocess parameters"""
+    args = arg_parser()
+    cfg = get_config(args)
+
+    if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0):
+        raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.')
+    Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
+    if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"):
+        raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.')
+
+    if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
+        cfg.Ks = 2
+
+    target = args.device_target
+    context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
+
+    # blocks: settings of channel size in st_conv_blocks and output layer,
+    # using the bottleneck design in st_conv_blocks
+    blocks = [[1]]
+    for _ in range(cfg.stblock_num):
+        blocks.append([64, 16, 64])
+    if Ko == 0:
+        blocks.append([128])
+    elif Ko > 0:
+        blocks.append([128, 128])
+    blocks.append([1])
+
+    cfg.n_pred = cfg.n_pred
+
+    return args, cfg, blocks
+
+
+def run_eval(args, cfg, blocks):
+    """evaluate stgcn model"""
+    if args.run_modelarts:
+        import moxing as mox
+        device_num = 1
+        cfg.batch_size = cfg.batch_size*int(8/device_num)
+        local_data_url = '/cache/data'
+        local_ckpt_url = '/cache/ckpt'
+        mox.file.copy_parallel(args.data_url, local_data_url)
+        mox.file.copy_parallel(args.ckpt_url, local_ckpt_url)
+        data_dir = local_data_url + '/'
+        local_ckpt_url = local_ckpt_url + '/'
+    else:
+        context.set_context(device_id=args.device_id)
+        data_dir = args.data_url + '/'
+        local_ckpt_url = args.ckpt_url
+
+    adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path)
+
+    n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1]
+    n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1]
+    if n_vertex_vel == n_vertex_adj:
+        n_vertex = n_vertex_vel
+    else:
+        raise ValueError(f'ERROR: number of vertices in dataset is not equal to \
+         number of vertices in weighted adjacency matrix.')
+
+    mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
+    conv_matrix = Tensor(Tensor.from_numpy(mat), mstype.float32)
+    if cfg.graph_conv_type == "chebconv":
+        if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
+            raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
+    elif cfg.graph_conv_type == "gcnconv":
+        if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
+            raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
+
+    net = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, cfg.gated_act_func,
+                            cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
+
+    # start evaluation
+    zscore = preprocessing.StandardScaler()
+    dataset = dataloader.create_dataset(data_dir+args.data_path, cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, mode=2)
+
+    param_dict = load_checkpoint(local_ckpt_url)
+    load_param_into_net(net, param_dict)
+
+    test_MAE, test_RMSE, test_MAPE = utility.evaluate_metric(net, dataset, zscore)
+    print(f'MAE {test_MAE:.2f} | MAPE {test_MAPE*100:.2f} | RMSE {test_RMSE:.2f}')
+
+
+if __name__ == "__main__":
+    set_seed(1)
+    run_eval(*get_params())
diff --git a/research/cv/stgcn/export.py b/research/cv/stgcn/export.py
index 04d614d2b..01b49a5a5 100644
--- a/research/cv/stgcn/export.py
+++ b/research/cv/stgcn/export.py
@@ -16,50 +16,36 @@
 ##############export checkpoint file into air, onnx, mindir models#################
 python export.py
 """
-import argparse
 import numpy as np
 import pandas as pd
 
 import mindspore as ms
 from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
-from src import dataloader, utility
-from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg, stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg
+
+from src.argparser import arg_parser
+from src import dataloader, utility, config
 from src.model import models
 
-parser = argparse.ArgumentParser(description='Tracking')
-parser.add_argument("--device_id", type=int, default=0, help="Device id")
-parser.add_argument("--batch_size", type=int, default=1, help="batch size")
-parser.add_argument('--device_target', type=str, default="Ascend",
-                    choices=['Ascend', 'GPU', 'CPU'],
-                    help='device where the code will be implemented (default: Ascend)')
-parser.add_argument('--data_url', type=str, help='Train dataset directory.')
-parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
-parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
-parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
-parser.add_argument("--n_pred", type=int, default=3, help="The number of time interval for predcition.")
-parser.add_argument("--graph_conv_type", type=str, default="chebconv", help="Grapg convolution type.")
-parser.add_argument("--file_name", type=str, default="stgcn", help="output file name.")
-parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
-args = parser.parse_args()
+args = arg_parser()
 
 context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
 
 if args.graph_conv_type == "chebconv":
     if args.n_pred == 9:
-        cfg = stgcn_chebconv_45min_cfg
+        cfg = config.stgcn_chebconv_45min_cfg
     elif args.n_pred == 6:
-        cfg = stgcn_chebconv_30min_cfg
+        cfg = config.stgcn_chebconv_30min_cfg
     elif args.n_pred == 3:
-        cfg = stgcn_chebconv_15min_cfg
+        cfg = config.stgcn_chebconv_15min_cfg
     else:
         raise ValueError("Unsupported n_pred.")
 elif args.graph_conv_type == "gcnconv":
     if args.n_pred == 9:
-        cfg = stgcn_gcnconv_45min_cfg
+        cfg = config.stgcn_gcnconv_45min_cfg
     elif args.n_pred == 6:
-        cfg = stgcn_gcnconv_30min_cfg
+        cfg = config.stgcn_gcnconv_30min_cfg
     elif args.n_pred == 3:
-        cfg = stgcn_gcnconv_15min_cfg
+        cfg = config.stgcn_gcnconv_15min_cfg
     else:
         raise ValueError("Unsupported pred.")
 else:
@@ -76,9 +62,8 @@ if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
 
 # blocks: settings of channel size in st_conv_blocks and output layer,
 # using the bottleneck design in st_conv_blocks
-blocks = []
-blocks.append([1])
-for l in range(cfg.stblock_num):
+blocks = [[1]]
+for _ in range(cfg.stblock_num):
     blocks.append([64, 16, 64])
 if Ko == 0:
     blocks.append([128])
@@ -118,8 +103,8 @@ elif cfg.graph_conv_type == "gcnconv":
     if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
         raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
 
-stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \
-    cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
+stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, cfg.gated_act_func,
+                               cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
 net = stgcn_conv
 
 if __name__ == '__main__':
diff --git a/research/cv/stgcn/modelarts/start_train.py b/research/cv/stgcn/modelarts/start_train.py
index 6ed4aa7d5..ab03d52f1 100644
--- a/research/cv/stgcn/modelarts/start_train.py
+++ b/research/cv/stgcn/modelarts/start_train.py
@@ -71,7 +71,7 @@ parser = argparse.ArgumentParser('mindspore stgcn training')
 # The way of training
 parser.add_argument('--device_target', type=str, default='Ascend', \
  help='device where the code will be implemented. (Default: Ascend)')
-parser.add_argument('--save_check_point', type=bool, default=True, help='Whether save checkpoint')
+parser.add_argument('--save_checkpoint', type=bool, default=True, help='Whether save checkpoint')
 
 # Parameter
 parser.add_argument('--epochs', type=int, default=2, help='Whether save checkpoint')
@@ -209,7 +209,7 @@ if __name__ == "__main__":
     callbacks = [time_cb, loss_cb]
     prefix = ""
     #save training results
-    if args.save_check_point and (device_num == 1 or device_id == 0):
+    if args.save_checkpoint and (device_num == 1 or device_id == 0):
         config_ck = CheckpointConfig(
             save_checkpoint_steps=data_len*args.epochs, keep_checkpoint_max=args.epochs)
         prefix = 'STGCN' + cfg.graph_conv_type + str(cfg.n_pred) + '-'
diff --git a/research/cv/stgcn/postprocess.py b/research/cv/stgcn/postprocess.py
index 3329c0016..33dc25742 100644
--- a/research/cv/stgcn/postprocess.py
+++ b/research/cv/stgcn/postprocess.py
@@ -14,23 +14,14 @@
 # ============================================================================
 """compute acc for ascend 310"""
 import os
-import argparse
 import numpy as np
+from sklearn import preprocessing
 
+from src.argparser import arg_parser
 from src.config import stgcn_chebconv_45min_cfg
 from src import dataloader
-from sklearn import preprocessing
-
-parser = argparse.ArgumentParser('mindspore stgcn testing')
-# Path for data
-parser.add_argument('--data_url', type=str, default='./data/', help='Test dataset directory.')
-parser.add_argument('--label_dir', type=str, default='', help='label data directory.')
-parser.add_argument('--result_dir', type=str, default="./result_Files", help='infer result dir.')
-parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
-# Super parameters for testing
-parser.add_argument('--n_pred', type=int, default=9, help='The number of time interval for predcition')
 
-args, _ = parser.parse_known_args()
+args = arg_parser()
 
 cfg = stgcn_chebconv_45min_cfg
 cfg.batch_size = 1
@@ -42,8 +33,8 @@ if __name__ == "__main__":
     rst_path = args.result_dir
     labels = np.load(args.label_dir)
 
-    dataset = dataloader.create_dataset(args.data_url+args.data_path, \
-     cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, True, mode=2)
+    dataset = dataloader.create_dataset(args.data_url+args.data_path, cfg.batch_size,
+                                        cfg.n_his, cfg.n_pred, zscore, True, mode=2)
 
     mae, sum_y, mape, mse = [], [], [], []
 
diff --git a/research/cv/stgcn/preprocess.py b/research/cv/stgcn/preprocess.py
index 6d4c5a39c..26b37f41c 100644
--- a/research/cv/stgcn/preprocess.py
+++ b/research/cv/stgcn/preprocess.py
@@ -14,34 +14,23 @@
 # ============================================================================
 """generate dataset for ascend 310"""
 import os
-import argparse
 import numpy as np
-
-from src.config import stgcn_chebconv_45min_cfg
-from src import dataloader
 from sklearn import preprocessing
 
-parser = argparse.ArgumentParser('mindspore stgcn testing')
-parser.add_argument('--device_target', type=str, default='Ascend', \
- help='device where the code will be implemented. (Default: Ascend)')
-# Path for data and checkpoint
-parser.add_argument('--data_url', type=str, default='', help='Test dataset directory.')
-parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
-parser.add_argument('--result_path', type=str, default='./preprocess_Result/', help='result path')
-# Super parameters for testing
-parser.add_argument('--n_pred', type=int, default=9, help='The number of time interval for predcition')
+from src import dataloader, config
+from src.argparser import arg_parser
 
-args, _ = parser.parse_known_args()
+args = arg_parser()
 
-cfg = stgcn_chebconv_45min_cfg
+cfg = config.stgcn_chebconv_45min_cfg
 cfg.batch_size = 1
 
 if __name__ == "__main__":
 
     zscore = preprocessing.StandardScaler()
 
-    dataset = dataloader.create_dataset(args.data_url+args.data_path, \
-     cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, True, mode=2)
+    dataset = dataloader.create_dataset(args.data_url+args.data_path, cfg.batch_size,
+                                        cfg.n_his, cfg.n_pred, zscore, True, mode=2)
 
     img_path = os.path.join(args.result_path, "00_data")
     os.mkdir(img_path)
diff --git a/research/cv/stgcn/requirements.txt b/research/cv/stgcn/requirements.txt
new file mode 100644
index 000000000..5669c06a5
--- /dev/null
+++ b/research/cv/stgcn/requirements.txt
@@ -0,0 +1,4 @@
+numpy
+pandas
+sklearn
+easydict
diff --git a/research/cv/stgcn/scripts/run_distribute_train.sh b/research/cv/stgcn/scripts/run_distribute_train.sh
index d357f9898..1179b223f 100644
--- a/research/cv/stgcn/scripts/run_distribute_train.sh
+++ b/research/cv/stgcn/scripts/run_distribute_train.sh
@@ -15,7 +15,7 @@
 # ============================================================================
 
 if [ $# != 5 ]; then
-  echo "Usage: sh run_distribute_train.sh [train_code_path][data_path][n_pred][graph_conv_type] [rank_table]"
+  echo "Usage: bash scripts/run_distribute_train.sh [train_code_path] [data_path] [n_pred] [graph_conv_type] [rank_table]"
   exit 1
 fi
 
@@ -75,9 +75,11 @@ do
     mkdir ${train_code_path}/device${DEVICE_ID}
     cd ${train_code_path}/device${DEVICE_ID} || exit
     python ${train_code_path}/train.py    --data_url=${data_path}   \
-                                               --train_url=./checkpoint   \
-                                               --run_distribute=True   \
-                                               --run_modelarts=False \
-                                               --n_pred=$n_pred     \
-                                               --graph_conv_type=$graph_conv_type > out.log 2>&1 &
+                                          --device_target="Ascend" \
+                                          --epochs=500 \
+                                          --train_url=./checkpoint \
+                                          --run_distribute=True   \
+                                          --run_modelarts=False \
+                                          --n_pred=$n_pred     \
+                                          --graph_conv_type=$graph_conv_type > out.log 2>&1 &
 done
diff --git a/research/cv/stgcn/scripts/run_eval_ascend.sh b/research/cv/stgcn/scripts/run_eval_ascend.sh
index 3f29e5e1f..d2bc3944e 100644
--- a/research/cv/stgcn/scripts/run_eval_ascend.sh
+++ b/research/cv/stgcn/scripts/run_eval_ascend.sh
@@ -15,25 +15,24 @@
 # limitations under the License.
 # ============================================================================
 
-if [ $# != 6 ]
+if [ $# != 5 ]
 then
-    echo "Usage: sh run_standalone_eval_ascend.sh [data_path][ckpt_url][ckpt_name][device_id][graph_conv_type][n_pred]"
+    echo "Usage: bash scripts/run_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [N_PRED] [GRAPH_CONV_TYPE] [DEVICE_ID]"
 exit 1
 fi
 
-export Data_path=$1
-export Ckpt_path=$2
-export Ckpt_name=$3
-export Device_id=$4
-export Graph_conv_type=$5
-export N_pred=$6
+export DATA_PATH=$1
+export CKPT_PATH=$2
+export N_PRED=$3
+export GRAPH_CONV_TYPE=$4
+export DEVICE_ID=$5
 
-python test.py --data_url=$Data_path   \
+python eval.py --data_url=$DATA_PATH   \
+                --device_target="Ascend"  \
                 --train_url=./checkpoint   \
                 --run_distribute=False   \
                 --run_modelarts=False   \
-                --device_id=$Device_id  \
-                --ckpt_url=$Ckpt_path   \
-                --ckpt_name=$Ckpt_name  \
-                --n_pred=$N_pred    \
-                --graph_conv_type=$Graph_conv_type > test.log 2>&1 &
+                --device_id=$DEVICE_ID  \
+                --ckpt_url=$CKPT_PATH   \
+                --n_pred=$N_PRED    \
+                --graph_conv_type=$GRAPH_CONV_TYPE > eval.log 2>&1 &
diff --git a/research/cv/stgcn/scripts/run_eval_gpu.sh b/research/cv/stgcn/scripts/run_eval_gpu.sh
new file mode 100644
index 000000000..e1469e28d
--- /dev/null
+++ b/research/cv/stgcn/scripts/run_eval_gpu.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 5 ]
+then
+    echo "Usage: bash scripts/run_eval_gpu.sh [DATA_PATH] [CKPT_PATH] [N_PRED] [GRAPH_CONV_TYPE] [DEVICE_ID]"
+exit 1
+fi
+
+export DATA_PATH=$1
+export CKPT_PATH=$2
+export N_PRED=$3
+export GRAPH_CONV_TYPE=$4
+export DEVICE_ID=$5
+
+python eval.py --data_url=$DATA_PATH \
+                --device_target="GPU"  \
+                --train_url=./checkpoint \
+                --run_distribute=False \
+                --run_modelarts=False \
+                --device_id=$DEVICE_ID \
+                --ckpt_url=$CKPT_PATH \
+                --n_pred=$N_PRED \
+                --graph_conv_type=$GRAPH_CONV_TYPE > eval.log 2>&1 &
diff --git a/research/cv/stgcn/scripts/run_single_train_gpu.sh b/research/cv/stgcn/scripts/run_single_train_gpu.sh
new file mode 100644
index 000000000..4724e478f
--- /dev/null
+++ b/research/cv/stgcn/scripts/run_single_train_gpu.sh
@@ -0,0 +1,67 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 4 ]; then
+  echo "Usage: bash scripts/run_single_train_gpu.sh [DATA_PATH] [N_PRED] [GRAPH_CONV_TYPE] [DEVICE_ID]"
+  echo: "Example: bash scripts/run_single_train_gpu.sh ./data/pemsd7-m 9 chebconv 0"
+  exit 1
+fi
+
+
+get_real_path(){
+  if [ "${1:0:1}" == "/" ]; then
+    echo "$1"
+  else
+    echo "$(realpath -m $PWD/$1)"
+  fi
+}
+
+
+DATA_PATH=$(get_real_path $1)
+echo $DATA_PATH
+
+
+if [ ! -d $DATA_PATH ]
+then
+    echo "error: train_code_path=$DATA_PATH is not a dictionary."
+exit 1
+fi
+
+ulimit -c unlimited
+export SLOG_PRINT_TO_STDOUT=0
+export N_PRED=$2
+export GRAPH_CONV_TYPE=$3
+export RANK_SIZE=1
+export DEVICE_ID=$4
+
+rm -rf ./train$DEVICE_ID
+mkdir ./train$DEVICE_ID
+cp ./*.py ./train$DEVICE_ID
+cp -r ./src ./train$DEVICE_ID
+cd ./train$DEVICE_ID || exit
+
+echo "start training on GPU device id $DEVICE_ID"
+python train.py \
+       --device_target="GPU"  \
+       --epochs=50 \
+       --run_distribute=False   \
+       --device_id=$DEVICE_ID  \
+       --data_url=${DATA_PATH}   \
+       --train_url=./checkpoint   \
+       --run_modelarts=False \
+       --n_pred=$N_PRED     \
+       --graph_conv_type=$GRAPH_CONV_TYPE > train.log 2>&1 &
+cd ..
diff --git a/research/cv/stgcn/src/argparser.py b/research/cv/stgcn/src/argparser.py
new file mode 100644
index 000000000..b5c30dc69
--- /dev/null
+++ b/research/cv/stgcn/src/argparser.py
@@ -0,0 +1,55 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Command line arguments parsing"""
+
+import argparse
+import ast
+
+
+def arg_parser():
+    """Parsing of command line arguments"""
+    parser = argparse.ArgumentParser('stgcn parameters')
+
+    # The way of training
+    parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
+                        help='device where the code will be implemented (default: Ascend)')
+    parser.add_argument('--epochs', type=int, default=500, help='Epochs to train model')
+    parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run on modelarts')
+    parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
+    parser.add_argument('--device_id', type=int, default=0, help='Device id, default is 0.')
+    parser.add_argument('--save_checkpoint', type=bool, default=True, help='Whether to save checkpoint')
+
+    # Path to data and checkpoints
+    parser.add_argument('--data_url', type=str, required=True, help='Dataset directory.')
+    parser.add_argument('--train_url', type=str, required=True, help='Save checkpoint directory.')
+    parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
+    parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
+    parser.add_argument('--label_dir', type=str, default='', help='label data directory.')
+    parser.add_argument('--ckpt_url', type=str, default="", help='Path to saved checkpoint.')
+    parser.add_argument("--file_name", type=str, default="stgcn", help="output file name.")
+    parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR",
+                        help="file format to export")
+    parser.add_argument('--result_dir', type=str, default="./result_Files", help='infer result dir.')
+
+    # Parameters for training
+    parser.add_argument("--batch_size", type=int, default=1, help="batch size")
+    parser.add_argument('--n_pred', type=int, default=3, help='The number of time interval for predcition, default: 3')
+    parser.add_argument('--opt', type=str, default='RMSProp', help='optimizer, default as AdamW')
+
+    # network
+    parser.add_argument('--graph_conv_type', type=str, default="gcnconv", choices=["gcnconv", "chebconv"],
+                        help='Graph convolution type, default: gcnconv')
+
+    return parser.parse_args()
diff --git a/research/cv/stgcn/src/config.py b/research/cv/stgcn/src/config.py
index 12bb5846c..d7a7328c2 100644
--- a/research/cv/stgcn/src/config.py
+++ b/research/cv/stgcn/src/config.py
@@ -32,6 +32,7 @@ stgcn_chebconv_45min_cfg = edict({
     'time_intvl': 5,
     'drop_rate': 0.5,
     'weight_decay_rate': 0.0005,
+    'opt': "AdamW",
     'gated_act_func': "glu",
     'graph_conv_type': "chebconv",
     'mat_type': "wid_sym_normd_lap_mat",
@@ -51,6 +52,7 @@ stgcn_chebconv_30min_cfg = edict({
     'time_intvl': 5,
     'drop_rate': 0.5,
     'weight_decay_rate': 0.0005,
+    'opt': "AdamW",
     'gated_act_func': "glu",
     'graph_conv_type': "chebconv",
     'mat_type': "wid_sym_normd_lap_mat",
@@ -70,6 +72,7 @@ stgcn_chebconv_15min_cfg = edict({
     'time_intvl': 5,
     'drop_rate': 0.5,
     'weight_decay_rate': 0.0005,
+    'opt': "AdamW",
     'gated_act_func': "glu",
     'graph_conv_type': "chebconv",
     'mat_type': "wid_rw_normd_lap_mat",
@@ -89,6 +92,7 @@ stgcn_gcnconv_45min_cfg = edict({
     'time_intvl': 5,
     'drop_rate': 0.5,
     'weight_decay_rate': 0.0005,
+    'opt': "AdamW",
     'gated_act_func': "glu",
     'graph_conv_type': "gcnconv",
     'mat_type': "hat_sym_normd_lap_mat",
@@ -108,6 +112,7 @@ stgcn_gcnconv_30min_cfg = edict({
     'time_intvl': 5,
     'drop_rate': 0.5,
     'weight_decay_rate': 0.0005,
+    'opt': "AdamW",
     'gated_act_func': "glu",
     'graph_conv_type': "gcnconv",
     'mat_type': "hat_sym_normd_lap_mat",
@@ -127,6 +132,7 @@ stgcn_gcnconv_15min_cfg = edict({
     'time_intvl': 5,
     'drop_rate': 0.5,
     'weight_decay_rate': 0.0005,
+    'opt': "AdamW",
     'gated_act_func': "glu",
     'graph_conv_type': "gcnconv",
     'mat_type': "hat_rw_normd_lap_mat",
diff --git a/research/cv/stgcn/src/dataloader.py b/research/cv/stgcn/src/dataloader.py
index c94684693..a8a8a88bf 100644
--- a/research/cv/stgcn/src/dataloader.py
+++ b/research/cv/stgcn/src/dataloader.py
@@ -20,6 +20,7 @@ import numpy as np
 import pandas as pd
 import mindspore.dataset as ds
 
+
 class STGCNDataset:
     """ BRDNetDataset.
     Args:
@@ -68,7 +69,6 @@ class STGCNDataset:
             self.x[i, :, :, :] = self.dataset[head: tail].reshape(1, self.n_his, self.n_vertex)
             self.y[i] = self.dataset[tail + self.n_pred - 1]
 
-
     def __getitem__(self, index):
         """
         Args:
@@ -89,18 +89,16 @@ def load_weighted_adjacency_matrix(file_path):
     return df.to_numpy()
 
 
-def create_dataset(data_path, batch_size, n_his, n_pred, zscore, is_sigle, device_num=1, device_id=0, mode=0):
+def create_dataset(data_path, batch_size, n_his, n_pred, zscore,
+                   device_num=1, device_id=0, mode=0):
     """
     generate dataset for train or test.
     """
-    data = STGCNDataset(data_path, n_his, n_pred, zscore, mode=mode)
     shuffle = True
     if mode != 0:
         shuffle = False
-    if not is_sigle:
-        dataset = ds.GeneratorDataset(data, column_names=["inputs", "labels"], num_parallel_workers=32, \
-         shuffle=shuffle, num_shards=device_num, shard_id=device_id)
-    else:
-        dataset = ds.GeneratorDataset(data, column_names=["inputs", "labels"], num_parallel_workers=32, shuffle=shuffle)
-    dataset = dataset.batch(batch_size)
+    data = STGCNDataset(data_path, n_his, n_pred, zscore, mode=mode)
+    dataset = ds.GeneratorDataset(data, column_names=["inputs", "labels"], num_parallel_workers=8,
+                                  shuffle=shuffle, num_shards=device_num, shard_id=device_id)
+    dataset = dataset.batch(batch_size, drop_remainder=False)
     return dataset
diff --git a/research/cv/stgcn/src/model/layers.py b/research/cv/stgcn/src/model/layers.py
index 8a33d270c..82391d0af 100644
--- a/research/cv/stgcn/src/model/layers.py
+++ b/research/cv/stgcn/src/model/layers.py
@@ -20,14 +20,16 @@ import mindspore.common.dtype as mstype
 
 from mindspore.common.initializer import initializer
 
+
 class Align(nn.Cell):
     """align"""
+
     def __init__(self, c_in, c_out):
         super(Align, self).__init__()
         self.c_in = c_in
         self.c_out = c_out
-        self.align_conv = nn.Conv2d(in_channels=self.c_in, out_channels=self.c_out, kernel_size=1, \
-         pad_mode='valid', weight_init='he_uniform')
+        self.align_conv = nn.Conv2d(in_channels=self.c_in, out_channels=self.c_out, kernel_size=1,
+                                    pad_mode='valid', weight_init='he_uniform')
         self.concat = ops.Concat(axis=1)
         self.zeros = ops.Zeros()
 
@@ -41,10 +43,12 @@ class Align(nn.Cell):
             x_align = self.concat((x, y))
         return x_align
 
+
 class CausalConv2d(nn.Cell):
     """causal conv2d"""
-    def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
-     enable_padding=False, dilation=1, groups=1, bias=True):
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+                 enable_padding=False, dilation=1, groups=1, bias=True):
         super(CausalConv2d, self).__init__()
         if isinstance(kernel_size, int):
             kernel_size = (kernel_size, kernel_size)
@@ -59,15 +63,18 @@ class CausalConv2d(nn.Cell):
             self.__padding = 0
         if isinstance(self.__padding, int):
             self.left_padding = (self.__padding, self.__padding)
-        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, \
-         padding=0, pad_mode='valid', dilation=dilation, group=groups, has_bias=bias, weight_init='he_uniform')
+        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
+                                padding=0, pad_mode='valid', dilation=dilation, group=groups,
+                                has_bias=bias, weight_init='he_uniform')
         self.pad = ops.Pad(((0, 0), (0, 0), (self.left_padding[0], 0), (self.left_padding[1], 0)))
+
     def construct(self, x):
         if self.__padding != 0:
             x = self.pad(x)
         result = self.conv2d(x)
         return result
 
+
 class TemporalConvLayer(nn.Cell):
     """
     # Temporal Convolution Layer (GLU)
@@ -81,6 +88,7 @@ class TemporalConvLayer(nn.Cell):
 
     #param x: tensor, [batch_size, c_in, timestep, n_vertex]
     """
+
     def __init__(self, Kt, c_in, c_out, n_vertex, act_func):
         super(TemporalConvLayer, self).__init__()
         self.Kt = Kt
@@ -89,8 +97,8 @@ class TemporalConvLayer(nn.Cell):
         self.n_vertex = n_vertex
         self.act_func = act_func
         self.align = Align(self.c_in, self.c_out)
-        self.causal_conv = CausalConv2d(in_channels=self.c_in, out_channels=2 * self.c_out, \
-         kernel_size=(self.Kt, 1), enable_padding=False, dilation=1)
+        self.causal_conv = CausalConv2d(in_channels=self.c_in, out_channels=2 * self.c_out,
+                                        kernel_size=(self.Kt, 1), enable_padding=False, dilation=1)
         self.linear = nn.Dense(self.n_vertex, self.n_vertex).to_float(mstype.float16)
         self.sigmoid = nn.Sigmoid()
         self.tanh = nn.Tanh()
@@ -107,8 +115,6 @@ class TemporalConvLayer(nn.Cell):
         x_pq = self.split(x_tc_out)
         x_p = x_pq[0]
         x_q = x_pq[1]
-        x_glu = x_causal_conv
-        x_gtu = x_causal_conv
         if self.act_func == 'glu':
             # (x_p + x_in) ⊙ Sigmoid(x_q)
             x_glu = self.mul(self.add(x_p, x_in), self.sigmoid(x_q))
@@ -120,8 +126,10 @@ class TemporalConvLayer(nn.Cell):
             x_tc_out = x_gtu
         return x_tc_out
 
+
 class ChebConv(nn.Cell):
     """cheb conv"""
+
     def __init__(self, c_in, c_out, Ks, chebconv_matrix):
         super(ChebConv, self).__init__()
         self.c_in = c_in
@@ -155,14 +163,16 @@ class ChebConv(nn.Cell):
                 x_list.append(self.matmul(2 * self.chebconv_matrix, x_list[k - 1]) - x_list[k - 2])
         x_tensor = self.stack(x_list)
 
-        x_mul = self.matmul(self.reshape(x_tensor, (-1, self.Ks * c_in)), self.reshape(self.weight, \
-         (self.Ks * c_in, -1)))
+        x_mul = self.matmul(self.reshape(x_tensor, (-1, self.Ks * c_in)), self.reshape(self.weight,
+                                                                                       (self.Ks * c_in, -1)))
         x_mul = self.reshape(x_mul, (-1, self.c_out))
         x_chebconv = self.bias_add(x_mul, self.bias)
         return x_chebconv
 
+
 class GCNConv(nn.Cell):
     """gcn conv"""
+
     def __init__(self, c_in, c_out, gcnconv_matrix):
         super(GCNConv, self).__init__()
         self.c_in = c_in
@@ -190,8 +200,10 @@ class GCNConv(nn.Cell):
 
         return x_gcnconv_out
 
+
 class GraphConvLayer(nn.Cell):
     """grarh conv layer"""
+
     def __init__(self, Ks, c_in, c_out, graph_conv_type, graph_conv_matrix):
         super(GraphConvLayer, self).__init__()
         self.Ks = Ks
@@ -220,6 +232,7 @@ class GraphConvLayer(nn.Cell):
         x_gc_out = x_gc_with_rc
         return x_gc_out
 
+
 class STConvBlock(nn.Cell):
     """
     # STConv Block contains 'TNSATND' structure
@@ -230,8 +243,9 @@ class STConvBlock(nn.Cell):
     # D: Dropout
     #Kt    Ks   n_vertex
     """
-    def __init__(self, Kt, Ks, n_vertex, last_block_channel, channels, gated_act_func, graph_conv_type, \
-     graph_conv_matrix, drop_rate):
+
+    def __init__(self, Kt, Ks, n_vertex, last_block_channel, channels, gated_act_func, graph_conv_type,
+                 graph_conv_matrix, drop_rate):
         super(STConvBlock, self).__init__()
         self.Kt = Kt
         self.Ks = Ks
@@ -243,14 +257,14 @@ class STConvBlock(nn.Cell):
         self.graph_conv_type = graph_conv_type
         self.graph_conv_matrix = graph_conv_matrix
         self.drop_rate = drop_rate
-        self.tmp_conv1 = TemporalConvLayer(self.Kt, self.last_block_channel, self.channels[0], \
-         self.n_vertex, self.gated_act_func)
-        self.graph_conv = GraphConvLayer(self.Ks, self.channels[0], self.channels[1], \
-         self.graph_conv_type, self.graph_conv_matrix)
-        self.tmp_conv2 = TemporalConvLayer(self.Kt, self.channels[1], self.channels[2], \
-         self.n_vertex, self.gated_act_func)
-        self.tc2_ln = nn.LayerNorm([self.n_vertex, self.channels[2]], begin_norm_axis=2, \
-         begin_params_axis=2, epsilon=1e-05)
+        self.tmp_conv1 = TemporalConvLayer(self.Kt, self.last_block_channel, self.channels[0],
+                                           self.n_vertex, self.gated_act_func)
+        self.graph_conv = GraphConvLayer(self.Ks, self.channels[0], self.channels[1],
+                                         self.graph_conv_type, self.graph_conv_matrix)
+        self.tmp_conv2 = TemporalConvLayer(self.Kt, self.channels[1], self.channels[2],
+                                           self.n_vertex, self.gated_act_func)
+        self.tc2_ln = nn.LayerNorm([self.n_vertex, self.channels[2]],
+                                   begin_norm_axis=2, begin_params_axis=2, epsilon=1e-05)
 
         self.relu = nn.ReLU()
         self.do = nn.Dropout(keep_prob=self.drop_rate)
@@ -269,6 +283,7 @@ class STConvBlock(nn.Cell):
         x_st_conv_out = x_do
         return x_st_conv_out
 
+
 class OutputBlock(nn.Cell):
     """
     # Output block contains 'TNFF' structure
@@ -277,6 +292,7 @@ class OutputBlock(nn.Cell):
     # F: Fully-Connected Layer
     # F: Fully-Connected Layer
     """
+
     def __init__(self, Ko, last_block_channel, channels, end_channel, n_vertex, gated_act_func, drop_rate):
         super(OutputBlock, self).__init__()
         self.Ko = Ko
@@ -286,12 +302,12 @@ class OutputBlock(nn.Cell):
         self.n_vertex = n_vertex
         self.gated_act_func = gated_act_func
         self.drop_rate = drop_rate
-        self.tmp_conv1 = TemporalConvLayer(self.Ko, self.last_block_channel, \
-         self.channels[0], self.n_vertex, self.gated_act_func)
+        self.tmp_conv1 = TemporalConvLayer(self.Ko, self.last_block_channel,
+                                           self.channels[0], self.n_vertex, self.gated_act_func)
         self.fc1 = nn.Dense(self.channels[0], self.channels[1]).to_float(mstype.float16)
         self.fc2 = nn.Dense(self.channels[1], self.end_channel).to_float(mstype.float16)
-        self.tc1_ln = nn.LayerNorm([self.n_vertex, self.channels[0]], begin_norm_axis=2, \
-         begin_params_axis=2, epsilon=1e-05)
+        self.tc1_ln = nn.LayerNorm([self.n_vertex, self.channels[0]],
+                                   begin_norm_axis=2, begin_params_axis=2, epsilon=1e-05)
         self.sigmoid = nn.Sigmoid()
         self.transpose = ops.Transpose()
 
diff --git a/research/cv/stgcn/src/model/metric.py b/research/cv/stgcn/src/model/metric.py
index ee3024aee..0f695a703 100644
--- a/research/cv/stgcn/src/model/metric.py
+++ b/research/cv/stgcn/src/model/metric.py
@@ -19,6 +19,7 @@ stgcn network with loss.
 import mindspore.nn as nn
 import mindspore.ops as P
 
+
 class LossCellWithNetwork(nn.Cell):
     """STGCN loss."""
     def __init__(self, network):
diff --git a/research/cv/stgcn/src/model/models.py b/research/cv/stgcn/src/model/models.py
index fcd3d7217..52b719eb6 100644
--- a/research/cv/stgcn/src/model/models.py
+++ b/research/cv/stgcn/src/model/models.py
@@ -20,6 +20,7 @@ import mindspore.nn as nn
 
 from src.model import layers
 
+
 class STGCN_Conv(nn.Cell):
     """
     # STGCN(ChebConv) contains 'TGTND TGTND TNFF' structure
@@ -49,14 +50,14 @@ class STGCN_Conv(nn.Cell):
     def __init__(self, Kt, Ks, blocks, T, n_vertex, gated_act_func, graph_conv_type, chebconv_matrix, drop_rate):
         super(STGCN_Conv, self).__init__()
         modules = []
-        for l in range(len(blocks) - 3):
-            modules.append(layers.STConvBlock(Kt, Ks, n_vertex, blocks[l][-1], blocks[l+1], \
-             gated_act_func, graph_conv_type, chebconv_matrix, drop_rate))
+        for i in range(len(blocks) - 3):
+            modules.append(layers.STConvBlock(Kt, Ks, n_vertex, blocks[i][-1], blocks[i+1],
+                                              gated_act_func, graph_conv_type, chebconv_matrix, drop_rate))
         self.st_blocks = nn.SequentialCell(modules)
         Ko = T - (len(blocks) - 3) * 2 * (Kt - 1)
         self.Ko = Ko
-        self.output = layers.OutputBlock(self.Ko, blocks[-3][-1], blocks[-2], \
-         blocks[-1][0], n_vertex, gated_act_func, drop_rate)
+        self.output = layers.OutputBlock(self.Ko, blocks[-3][-1], blocks[-2], blocks[-1][0],
+                                         n_vertex, gated_act_func, drop_rate)
 
     def construct(self, x):
         x_stbs = self.st_blocks(x)
diff --git a/research/cv/stgcn/src/utility.py b/research/cv/stgcn/src/utility.py
index 605e5ba54..10706872e 100644
--- a/research/cv/stgcn/src/utility.py
+++ b/research/cv/stgcn/src/utility.py
@@ -23,18 +23,18 @@ import mindspore.ops as ops
 from scipy.linalg import fractional_matrix_power
 from scipy.sparse.linalg import eigs
 
+
 def calculate_laplacian_matrix(adj_mat, mat_type):
     """
     calculate laplacian matrix used for graph convolution layer.
     """
-    n_vertex = adj_mat.shape[0]
 
     # row sum
     deg_mat_row = np.asmatrix(np.diag(np.sum(adj_mat, axis=1)))
     # column sum
-    #deg_mat_col = np.asmatrix(np.diag(np.sum(adj_mat, axis=0)))
     deg_mat = deg_mat_row
 
+    n_vertex = adj_mat.shape[0]
     adj_mat = np.asmatrix(adj_mat)
     id_mat = np.asmatrix(np.identity(n_vertex))
 
@@ -43,8 +43,8 @@ def calculate_laplacian_matrix(adj_mat, mat_type):
 
     # For SpectraConv
     # To [0, 1]
-    sym_normd_lap_mat = np.matmul(np.matmul(fractional_matrix_power(deg_mat, -0.5), \
-     com_lap_mat), fractional_matrix_power(deg_mat, -0.5))
+    sym_normd_lap_mat = np.matmul(np.matmul(fractional_matrix_power(deg_mat, -0.5), com_lap_mat),
+                                  fractional_matrix_power(deg_mat, -0.5))
 
     # For ChebConv
     # From [0, 1] to [-1, 1]
@@ -54,8 +54,8 @@ def calculate_laplacian_matrix(adj_mat, mat_type):
     # For GCNConv
     wid_deg_mat = deg_mat + id_mat
     wid_adj_mat = adj_mat + id_mat
-    hat_sym_normd_lap_mat = np.matmul(np.matmul(fractional_matrix_power(wid_deg_mat, -0.5), \
-     wid_adj_mat), fractional_matrix_power(wid_deg_mat, -0.5))
+    hat_sym_normd_lap_mat = np.matmul(np.matmul(fractional_matrix_power(wid_deg_mat, -0.5), wid_adj_mat),
+                                      fractional_matrix_power(wid_deg_mat, -0.5))
 
     # Random Walk
     rw_lap_mat = np.matmul(np.linalg.matrix_power(deg_mat, -1), adj_mat)
@@ -84,6 +84,7 @@ def calculate_laplacian_matrix(adj_mat, mat_type):
         return hat_rw_normd_lap_mat
     raise ValueError(f'ERROR: "{mat_type}" is unknown.')
 
+
 def evaluate_metric(model, dataset, scaler):
     """
     evaluate the performance of network.
@@ -104,6 +105,5 @@ def evaluate_metric(model, dataset, scaler):
     MAE = np.array(mae).mean()
     MAPE = np.array(mape).mean()
     RMSE = np.sqrt(np.array(mse).mean())
-    #WMAPE = np.sum(np.array(mae)) / np.sum(np.array(sum_y))
 
     return MAE, RMSE, MAPE
diff --git a/research/cv/stgcn/test.py b/research/cv/stgcn/test.py
deleted file mode 100644
index f2b3aae57..000000000
--- a/research/cv/stgcn/test.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-"""
-testing network performance.
-"""
-
-import os
-import ast
-import argparse
-
-import pandas as pd
-from sklearn import preprocessing
-
-from mindspore.common import dtype as mstype
-
-from mindspore import context
-from mindspore import Tensor
-from mindspore.train.serialization import load_checkpoint, load_param_into_net
-from mindspore.communication.management import init
-from mindspore.context import ParallelMode
-
-from src.model import models
-from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg, stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg
-from src import dataloader, utility
-
-os.system("export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python")
-parser = argparse.ArgumentParser('mindspore stgcn testing')
-parser.add_argument('--device_target', type=str, default='Ascend', \
- help='device where the code will be implemented. (Default: Ascend)')
-
-# The way of testing
-parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run on modelarts.')
-parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
-parser.add_argument('--device_id', type=int, default=0, help='Device id.')
-
-# Path for data and checkpoint
-parser.add_argument('--data_url', type=str, default='', help='Test dataset directory.')
-parser.add_argument('--train_url', type=str, default='', help='Output directory.')
-parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
-parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
-parser.add_argument('--ckpt_url', type=str, default='', help='The path of checkpoint.')
-parser.add_argument('--ckpt_name', type=str, default="", help='the name of checkpoint.')
-
-# Super parameters for testing
-parser.add_argument('--n_pred', type=int, default=3, help='The number of time interval for predcition')
-
-#network
-parser.add_argument('--graph_conv_type', type=str, default="gcnconv", help='Grapg convolution type')
-#dataset
-
-
-args, _ = parser.parse_known_args()
-context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
-
-if args.graph_conv_type == "chebconv":
-    if args.n_pred == 9:
-        cfg = stgcn_chebconv_45min_cfg
-    elif args.n_pred == 6:
-        cfg = stgcn_chebconv_30min_cfg
-    elif args.n_pred == 3:
-        cfg = stgcn_chebconv_15min_cfg
-    else:
-        raise ValueError("Unsupported n_pred.")
-elif args.graph_conv_type == "gcnconv":
-    if args.n_pred == 9:
-        cfg = stgcn_gcnconv_45min_cfg
-    elif args.n_pred == 6:
-        cfg = stgcn_gcnconv_30min_cfg
-    elif args.n_pred == 3:
-        cfg = stgcn_gcnconv_15min_cfg
-    else:
-        raise ValueError("Unsupported pred.")
-else:
-    raise ValueError("Unsupported graph_conv_type.")
-
-
-if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0):
-    raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.')
-Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
-if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"):
-    raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.')
-
-
-
-if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
-    cfg.Ks = 2
-
-# blocks: settings of channel size in st_conv_blocks and output layer,
-# using the bottleneck design in st_conv_blocks
-blocks = []
-blocks.append([1])
-for l in range(cfg.stblock_num):
-    blocks.append([64, 16, 64])
-if Ko == 0:
-    blocks.append([128])
-elif Ko > 0:
-    blocks.append([128, 128])
-blocks.append([1])
-
-
-day_slot = int(24 * 60 / cfg.time_intvl)
-cfg.n_pred = cfg.n_pred
-
-time_pred = cfg.n_pred * cfg.time_intvl
-time_pred_str = str(time_pred) + '_mins'
-
-if args.run_modelarts:
-    import moxing as mox
-    device_id = int(os.getenv('DEVICE_ID'))
-    device_num = int(os.getenv('RANK_SIZE'))
-    cfg.batch_size = cfg.batch_size*int(8/device_num)
-    local_data_url = '/cache/data'
-    local_ckpt_url = '/cache/ckpt'
-    mox.file.copy_parallel(args.data_url, local_data_url)
-    mox.file.copy_parallel(args.ckpt_url, local_ckpt_url)
-    if device_num > 1:
-        init()
-        context.set_auto_parallel_context(device_num=device_num, \
-         parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
-    data_dir = local_data_url + '/'
-    local_ckpt_url = local_ckpt_url + '/'
-else:
-    if args.run_distribute:
-        device_id = int(os.getenv('DEVICE_ID'))
-        device_num = int(os.getenv('RANK_SIZE'))
-        cfg.batch_size = cfg.batch_size*int(8/device_num)
-        context.set_context(device_id=device_id)
-        init()
-        context.reset_auto_parallel_context()
-        context.set_auto_parallel_context(device_num=device_num, \
-         parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
-    else:
-        device_num = 1
-        device_id = args.device_id
-        context.set_context(device_id=args.device_id)
-    data_dir = args.data_url + '/'
-    local_ckpt_url = args.ckpt_url + '/'
-
-adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path)
-
-n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1]
-n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1]
-if n_vertex_vel == n_vertex_adj:
-    n_vertex = n_vertex_vel
-else:
-    raise ValueError(f'ERROR: number of vertices in dataset is not equal to \
-     number of vertices in weighted adjacency matrix.')
-
-mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
-conv_matrix = Tensor(Tensor.from_numpy(mat), mstype.float32)
-if cfg.graph_conv_type == "chebconv":
-    if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
-        raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
-elif cfg.graph_conv_type == "gcnconv":
-    if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
-        raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
-
-stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \
-    cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
-net = stgcn_conv
-
-
-if __name__ == "__main__":
-
-    zscore = preprocessing.StandardScaler()
-    if args.run_modelarts or args.run_distribute:
-        dataset = dataloader.create_dataset(data_dir+args.data_path, \
-         cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, False, device_num, device_id, mode=2)
-    else:
-        dataset = dataloader.create_dataset(data_dir+args.data_path, \
-         cfg.batch_size, cfg.n_his, cfg.n_pred, zscore, True, device_num, device_id, mode=2)
-    data_len = dataset.get_dataset_size()
-
-    param_dict = load_checkpoint(local_ckpt_url+args.ckpt_name)
-    load_param_into_net(net, param_dict)
-
-    test_MAE, test_RMSE, test_MAPE = utility.evaluate_metric(net, dataset, zscore)
-    print(f'MAE {test_MAE:.2f} | MAPE {test_MAPE*100:.2f} | RMSE {test_RMSE:.2f}')
diff --git a/research/cv/stgcn/train.py b/research/cv/stgcn/train.py
index 1ac61cca2..546d6ca16 100644
--- a/research/cv/stgcn/train.py
+++ b/research/cv/stgcn/train.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2021-2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,207 +17,189 @@ train network.
 """
 
 import os
-import argparse
-import ast
 import pandas as pd
 from sklearn import preprocessing
 
-from mindspore.common import dtype as mstype
 import mindspore.nn as nn
-
-from mindspore import context
-from mindspore import Tensor
-from mindspore.communication.management import init
+from mindspore import context, Tensor
+from mindspore.common import set_seed, dtype as mstype
 from mindspore.train.model import Model
-from mindspore.context import ParallelMode
 from mindspore.train.callback import CheckpointConfig, LossMonitor, ModelCheckpoint, TimeMonitor
-from mindspore.common import set_seed
+from mindspore.context import ParallelMode
+from mindspore.communication.management import init, get_rank, get_group_size
 
-from src.config import stgcn_chebconv_45min_cfg, stgcn_chebconv_30min_cfg, stgcn_chebconv_15min_cfg, stgcn_gcnconv_45min_cfg, stgcn_gcnconv_30min_cfg, stgcn_gcnconv_15min_cfg
-from src import dataloader, utility
+from src.argparser import arg_parser
+from src import dataloader, utility, config
 from src.model import models, metric
 
-set_seed(1)
 
-parser = argparse.ArgumentParser('mindspore stgcn training')
+def get_config(args):
+    """return config based on selected n_pred and graph_conv_type"""
+    if args.graph_conv_type == "chebconv":
+        if args.n_pred == 9:
+            cfg = config.stgcn_chebconv_45min_cfg
+        elif args.n_pred == 6:
+            cfg = config.stgcn_chebconv_30min_cfg
+        elif args.n_pred == 3:
+            cfg = config.stgcn_chebconv_15min_cfg
+        else:
+            raise ValueError("Unsupported n_pred.")
+    elif args.graph_conv_type == "gcnconv":
+        if args.n_pred == 9:
+            cfg = config.stgcn_gcnconv_45min_cfg
+        elif args.n_pred == 6:
+            cfg = config.stgcn_gcnconv_30min_cfg
+        elif args.n_pred == 3:
+            cfg = config.stgcn_gcnconv_15min_cfg
+        else:
+            raise ValueError("Unsupported n_pred.")
+    else:
+        raise ValueError("Unsupported graph_conv_type.")
 
-# The way of training
-parser.add_argument('--device_target', type=str, default='Ascend', \
- help='device where the code will be implemented. (Default: Ascend)')
-parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
-parser.add_argument('--device_id', type=int, default=0, help='Device id')
-parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run on modelarts')
-parser.add_argument('--save_check_point', type=bool, default=True, help='Whether save checkpoint')
+    return cfg
 
-# Path for data and checkpoint
-parser.add_argument('--data_url', type=str, required=True, help='Train dataset directory.')
-parser.add_argument('--train_url', type=str, required=True, help='Save checkpoint directory.')
-parser.add_argument('--data_path', type=str, default="vel.csv", help='Dataset file of vel.')
-parser.add_argument('--wam_path', type=str, default="adj_mat.csv", help='Dataset file of warm.')
 
-# Super parameters for training
-parser.add_argument('--n_pred', type=int, default=3, help='The number of time interval for predcition, default as 3')
-parser.add_argument('--opt', type=str, default='AdamW', help='optimizer, default as AdamW')
+def get_params(args):
+    """get and preprocess parameters"""
+    cfg = get_config(args)
 
-#network
-parser.add_argument('--graph_conv_type', type=str, default="gcnconv", help='Grapg convolution type')
+    if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
+        cfg.Ks = 2
+    Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
 
-args, _ = parser.parse_known_args()
+    # blocks: settings of channel size in st_conv_blocks and output layer,
+    # using the bottleneck design in st_conv_blocks
+    blocks = [[1]]
+    for _ in range(cfg.stblock_num):
+        blocks.append([64, 16, 64])
+    if Ko == 0:
+        blocks.append([128])
+    elif Ko > 0:
+        blocks.append([128, 128])
+    blocks.append([1])
 
-if args.graph_conv_type == "chebconv":
-    if args.n_pred == 9:
-        cfg = stgcn_chebconv_45min_cfg
-    elif args.n_pred == 6:
-        cfg = stgcn_chebconv_30min_cfg
-    elif args.n_pred == 3:
-        cfg = stgcn_chebconv_15min_cfg
-    else:
-        raise ValueError("Unsupported n_pred.")
-elif args.graph_conv_type == "gcnconv":
-    if args.n_pred == 9:
-        cfg = stgcn_gcnconv_45min_cfg
-    elif args.n_pred == 6:
-        cfg = stgcn_gcnconv_30min_cfg
-    elif args.n_pred == 3:
-        cfg = stgcn_gcnconv_15min_cfg
-    else:
-        raise ValueError("Unsupported pred.")
-else:
-    raise ValueError("Unsupported graph_conv_type.")
-
-context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=False)
-
-if ((cfg.Kt - 1) * 2 * cfg.stblock_num > cfg.n_his) or ((cfg.Kt - 1) * 2 * cfg.stblock_num <= 0):
-    raise ValueError(f'ERROR: {cfg.Kt} and {cfg.stblock_num} are unacceptable.')
-
-Ko = cfg.n_his - (cfg.Kt - 1) * 2 * cfg.stblock_num
-
-if (cfg.graph_conv_type != "chebconv") and (cfg.graph_conv_type != "gcnconv"):
-    raise NotImplementedError(f'ERROR: {cfg.graph_conv_type} is not implemented.')
-
-if (cfg.graph_conv_type == 'gcnconv') and (cfg.Ks != 2):
-    cfg.Ks = 2
-
-# blocks: settings of channel size in st_conv_blocks and output layer,
-# using the bottleneck design in st_conv_blocks
-blocks = []
-blocks.append([1])
-for l in range(cfg.stblock_num):
-    blocks.append([64, 16, 64])
-if Ko == 0:
-    blocks.append([128])
-elif Ko > 0:
-    blocks.append([128, 128])
-blocks.append([1])
-
-
-day_slot = int(24 * 60 / cfg.time_intvl)
-cfg.n_pred = cfg.n_pred
-
-time_pred = cfg.n_pred * cfg.time_intvl
-time_pred_str = str(time_pred) + '_mins'
-
-if args.run_modelarts:
-    import moxing as mox
-    device_id = int(os.getenv('DEVICE_ID'))
-    device_num = int(os.getenv('RANK_SIZE'))
-    cfg.batch_size = cfg.batch_size*int(8/device_num)
-    context.set_context(device_id=device_id)
-    local_data_url = '/cache/data'
-    local_train_url = '/cache/train'
-    #mox.file.make_dirs(local_train_url)
-    mox.file.copy_parallel(args.data_url, local_data_url)
-    if device_num > 1:
-        init()
-        #context.set_auto_parallel_context(parameter_broadcast=True)
-        context.set_auto_parallel_context(device_num=device_num, \
-         parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
-    data_dir = local_data_url + '/'
-else:
-    if args.run_distribute:
+    time_pred = cfg.n_pred * cfg.time_intvl
+    time_pred_str = str(time_pred) + '_mins'
+
+    if cfg.graph_conv_type == "chebconv":
+        if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
+            raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
+    elif cfg.graph_conv_type == "gcnconv":
+        if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
+            raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
+
+    return args, cfg, blocks, time_pred_str
+
+
+def run_train(args, cfg, blocks, time_pred_str):
+    """train stgcn model"""
+    target = args.device_target
+    context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
+    if args.run_modelarts:
+        import moxing as mox
         device_id = int(os.getenv('DEVICE_ID'))
         device_num = int(os.getenv('RANK_SIZE'))
-        cfg.batch_size = cfg.batch_size*int(8/device_num)
+        cfg.batch_size = cfg.batch_size * int(8/device_num)
         context.set_context(device_id=device_id)
-        init()
-        context.reset_auto_parallel_context()
-        #context.set_auto_parallel_context(parameter_broadcast=True)
-        context.set_auto_parallel_context(device_num=device_num, \
-         parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
+        local_data_url = '/cache/data'
+        local_train_url = '/cache/train'
+        mox.file.copy_parallel(args.data_url, local_data_url)
+        if device_num > 1:
+            init()
+            context.set_auto_parallel_context(device_num=device_num,
+                                              parallel_mode=ParallelMode.DATA_PARALLEL,
+                                              gradients_mean=True)
+        data_dir = local_data_url + '/'
     else:
-        context.set_context(device_id=args.device_id)
-        device_num = 1
-        cfg.batch_size = cfg.batch_size*int(8/device_num)
-        device_id = args.device_id
-    data_dir = args.data_url + '/'
-    model_save_path = args.train_url + cfg.graph_conv_type + '_' + time_pred_str
-
-adj_mat = dataloader.load_weighted_adjacency_matrix(data_dir+args.wam_path)
-
-n_vertex_vel = pd.read_csv(data_dir+args.data_path, header=None).shape[1]
-n_vertex_adj = pd.read_csv(data_dir+args.wam_path, header=None).shape[1]
-if n_vertex_vel == n_vertex_adj:
-    n_vertex = n_vertex_vel
-else:
-    raise ValueError(f"ERROR: number of vertices in dataset is not equal to number \
-     of vertices in weighted adjacency matrix.")
-
-mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
-conv_matrix = Tensor(Tensor.from_numpy(mat), mstype.float32)
-if cfg.graph_conv_type == "chebconv":
-    if (cfg.mat_type != "wid_sym_normd_lap_mat") and (cfg.mat_type != "wid_rw_normd_lap_mat"):
-        raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
-elif cfg.graph_conv_type == "gcnconv":
-    if (cfg.mat_type != "hat_sym_normd_lap_mat") and (cfg.mat_type != "hat_rw_normd_lap_mat"):
-        raise ValueError(f'ERROR: {cfg.mat_type} is wrong.')
-
-stgcn_conv = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, \
-    cfg.gated_act_func, cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
-net = stgcn_conv
+        if target == "Ascend":
+            device_id = 0
+            device_num = 1
+            context.set_context(device_id=args.device_id)
+            if args.run_distribute:
+                device_id = int(os.getenv('DEVICE_ID'))
+                device_num = int(os.getenv('RANK_SIZE'))
+                context.set_context(device_id=device_id)
+                init()
+                context.reset_auto_parallel_context()
+                # context.set_auto_parallel_context(parameter_broadcast=True)
+                context.set_auto_parallel_context(device_num=device_num,
+                                                  parallel_mode=ParallelMode.DATA_PARALLEL,
+                                                  gradients_mean=True)
+        elif target == "GPU":
+            device_id = args.device_id
+            device_num = 1
+            if args.run_distribute:
+                init()
+                device_id = get_rank()
+                device_num = get_group_size()
+                context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
+                                                  device_num=device_num)
+        else:
+            raise ValueError("Unsupported platform, only GPU or Ascend is supported.")
+
+        cfg.batch_size = cfg.batch_size * int(8 / device_num)
+        data_dir = args.data_url + '/'
+        model_save_path = args.train_url + cfg.graph_conv_type + '_' + time_pred_str
+        if args.device_target == "GPU" and args.run_distribute:
+            model_save_path = os.path.join(model_save_path, "ckpt_" + str(device_id) + "/")
+
+    adj_mat = dataloader.load_weighted_adjacency_matrix(os.path.join(data_dir, args.wam_path))
+    n_vertex_vel = pd.read_csv(os.path.join(data_dir, args.data_path), header=None).shape[1]
+    n_vertex_adj = pd.read_csv(os.path.join(data_dir, args.wam_path), header=None).shape[1]
+    if n_vertex_vel == n_vertex_adj:
+        n_vertex = n_vertex_vel
+    else:
+        raise ValueError(f"ERROR: number of vertices in dataset is not equal to number \
+            of vertices in weighted adjacency matrix.")
+    mat = utility.calculate_laplacian_matrix(adj_mat, cfg.mat_type)
+    conv_matrix = Tensor(Tensor.from_numpy(mat), mstype.float32)
 
-if __name__ == "__main__":
-    #start training
+    net = models.STGCN_Conv(cfg.Kt, cfg.Ks, blocks, cfg.n_his, n_vertex, cfg.gated_act_func,
+                            cfg.graph_conv_type, conv_matrix, cfg.drop_rate)
 
+    # start training
     zscore = preprocessing.StandardScaler()
-    if args.run_modelarts or args.run_distribute:
-        dataset = dataloader.create_dataset(data_dir+args.data_path, cfg.batch_size, cfg.n_his, \
-         cfg.n_pred, zscore, False, device_num, device_id, mode=0)
-    else:
-        dataset = dataloader.create_dataset(data_dir+args.data_path, cfg.batch_size, cfg.n_his, \
-         cfg.n_pred, zscore, True, device_num, device_id, mode=0)
-    data_len = dataset.get_dataset_size()
+    dataset = dataloader.create_dataset(os.path.join(data_dir, args.data_path), cfg.batch_size, cfg.n_his, cfg.n_pred,
+                                        zscore, device_num, device_id, mode=0)
+    dataset_size = dataset.get_dataset_size()
+
+    learning_rate = nn.exponential_decay_lr(learning_rate=cfg.learning_rate, decay_rate=cfg.gamma,
+                                            total_step=dataset_size*args.epochs, step_per_epoch=dataset_size,
+                                            decay_epoch=cfg.decay_epoch)
 
-    learning_rate = nn.exponential_decay_lr(learning_rate=cfg.learning_rate, decay_rate=cfg.gamma, \
-     total_step=data_len*cfg.epochs, step_per_epoch=data_len, decay_epoch=cfg.decay_epoch)
-    if args.opt == "RMSProp":
+    if cfg.opt == "RMSProp":
         optimizer = nn.RMSProp(net.trainable_params(), learning_rate=learning_rate)
-    elif args.opt == "Adam":
-        optimizer = nn.Adam(net.trainable_params(), learning_rate=learning_rate, \
-         weight_decay=cfg.weight_decay_rate)
-    elif args.opt == "AdamW":
-        optimizer = nn.AdamWeightDecay(net.trainable_params(), learning_rate=learning_rate, \
-         weight_decay=cfg.weight_decay_rate)
+    elif cfg.opt == "Adam":
+        optimizer = nn.Adam(net.trainable_params(), learning_rate=learning_rate, weight_decay=cfg.weight_decay_rate)
+    elif cfg.opt == "AdamW":
+        optimizer = nn.AdamWeightDecay(net.trainable_params(), learning_rate=learning_rate,
+                                       weight_decay=cfg.weight_decay_rate)
     else:
-        raise ValueError(f'ERROR: optimizer {args.opt} is undefined.')
+        raise ValueError(f'ERROR: optimizer {cfg.opt} is undefined.')
 
-    loss_cb = LossMonitor()
-    time_cb = TimeMonitor(data_size=data_len)
+    loss_cb = LossMonitor(per_print_times=dataset_size)
+    time_cb = TimeMonitor()
     callbacks = [time_cb, loss_cb]
 
-    #save training results
-    if args.save_check_point and (device_num == 1 or device_id == 0):
-        config_ck = CheckpointConfig(
-            save_checkpoint_steps=data_len*cfg.epochs, keep_checkpoint_max=cfg.epochs)
+    # save training results
+    if args.save_checkpoint and (device_num == 1 or device_id == 0):
+        ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args.epochs,
+                                       keep_checkpoint_max=args.epochs)
         if args.run_modelarts:
-            ckpoint_cb = ModelCheckpoint(prefix='STGCN'+cfg.graph_conv_type+str(cfg.n_pred)+'-', \
-             directory=local_train_url, config=config_ck)
+            ckpt_cb = ModelCheckpoint(prefix='STGCN' + cfg.graph_conv_type + str(cfg.n_pred) + '-',
+                                      directory=local_train_url, config=ckpt_config)
         else:
-            ckpoint_cb = ModelCheckpoint(prefix='STGCN', directory=model_save_path, config=config_ck)
-        callbacks += [ckpoint_cb]
+            ckpt_cb = ModelCheckpoint(prefix='STGCN', directory=model_save_path, config=ckpt_config)
+        callbacks += [ckpt_cb]
 
     net = metric.LossCellWithNetwork(net)
-    model = Model(net, optimizer=optimizer, amp_level='O3')
+    model = Model(net, optimizer=optimizer)
 
-    model.train(cfg.epochs, dataset, callbacks=callbacks)
+    model.train(args.epochs, dataset, callbacks=callbacks, dataset_sink_mode=False)
     if args.run_modelarts:
         mox.file.copy_parallel(src_url=local_train_url, dst_url=args.train_url)
+
+
+if __name__ == "__main__":
+    set_seed(1)
+    run_train(*get_params(arg_parser()))
-- 
GitLab