Skip to content
Snippets Groups Projects
Unverified Commit 34160cbd authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1995 predrnn++ script update

Merge pull request !1995 from wangzeyangyi/predrnn++
parents 003072e4 91798006
No related branches found
No related tags found
No related merge requests found
......@@ -69,7 +69,8 @@ The moving mnist dataset from the official website contains 3 files: train/valid
```bash
# standalone training example in Ascend
$ bash scripts/run_standalone_train.sh [DATASET_PATH] [DEVICE_ID]
# TRAIN_MINDRECORD_PATH is the path of the file mnist_train.mindrecord
$ bash scripts/run_standalone_train.sh [TRAIN_MINDRECORD_PATH] [DEVICE_ID]
```
## [Script Description](#contents)
......@@ -95,11 +96,10 @@ Predrnn++
│ ├── mnist_to_mindrecord.py # Dataset conversion
│ ├── mnist.py # Moving mnist dataset preprocess
│ ├── datasets_factory.py
├── utils
│ ├── config.py # Default configurations
│ ├── preprocess.py # Dataset patch functions
│ ├── metrics.py # Evaluation metric functions
└── train.py # Training script
│ ├── preprocess.py # Default configurations
├── metrics.py # Evaluation metric functions
├── config.py # Dataset patch functions
├── train.py # Training script
├── eval.py # Evaluation script
├── default_config.yaml # Config file
```
......@@ -109,14 +109,12 @@ Predrnn++
#### Training Script Parameters
```shell
# distributed training in Ascend or GPU
Usage: bash scripts/run_standalone_train.sh [DATASET_PATH] [DEVICE_ID]
#### Parameters Configuration
Parameters for both training and evaluation can be set in default_config.yaml.
```yaml
```
train_mindrecord: "" # path to train dataset
test_mindrecord: "" # path to test dataset
......@@ -141,7 +139,7 @@ sink_size: 10 # number of data to sink per epo
device_num: 1 # number of NPU used
device_id: 0 # id of NPU used
```
```lang-yaml
## [Training Process](#contents)
......@@ -153,7 +151,7 @@ device_id: 0 # id of NPU used
``` bash
bash scripts/run_standalone_train.sh [DATASET_PATH] [DEVICE_ID]
bash scripts/run_standalone_train.sh [TRAIN_MINDRECORD_PATH] [DEVICE_ID]
```
......@@ -161,11 +159,11 @@ bash scripts/run_standalone_train.sh [DATASET_PATH] [DEVICE_ID]
### [Evaluation](#contents)
- Run `run_eval.sh` for evaluation.
- Run `run_eval.sh` for evaluation. TEST_MINDRECORD_PATH is the path of the file mnist_test.mindrecord
``` bash
bash scripts/run_eval.sh [DATASET_PATH] [DEVICE_ID] [CHECKPOINT_PATH]
bash scripts/run_eval.sh [TEST_MINDRECORD_PATH] [DEVICE_ID] [CHECKPOINT_PATH]
```
......
......@@ -18,7 +18,7 @@ import argparse
from pprint import pprint, pformat
import yaml
_config_path = '../default_config.yaml'
_config_path = 'default_config.yaml'
class Config:
"""
......
......@@ -17,7 +17,7 @@ import numpy as np
import mindspore.dataset as ds
from mindspore.mindrecord import FileWriter
from data_provider import mnist
from utils import preprocess
from data_provider import preprocess
class MnistToRecord:
def __init__(self, input_param):
......
......@@ -23,10 +23,10 @@ from mindspore import Tensor, context
from mindspore import load_checkpoint, load_param_into_net
from data_provider.mnist_to_mindrecord import create_mnist_dataset
from data_provider import preprocess
from nets.predrnn_pp import PreRNN
from utils import metrics
from utils import preprocess
from utils.config import config
from config import config
import metrics
if __name__ == '__main__':
......
......@@ -19,7 +19,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
from data_provider.mnist_to_mindrecord import create_mnist_dataset
from nets.predrnn_pp import PreRNN, NetWithLossCell
from utils.config import config
from config import config
if __name__ == '__main__':
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment