Skip to content
Snippets Groups Projects
Unverified Commit 6c2151f2 authored by zhaoting's avatar zhaoting Committed by Gitee
Browse files

!3007 [西安交通大学][高校贡献][Mindspore][SlowFast]-GPU版本精度达标

Merge pull request !3007 from 李晨阳/master
parents 2ed37b03 a1108b23
No related branches found
No related tags found
No related merge requests found
......@@ -141,10 +141,11 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
# 环境要求
- 硬件(Ascend)
- 使用Ascend处理器来搭建硬件环境。
- 硬件(Ascend/GPU
- 使用Ascend或GPU处理器来搭建硬件环境。
- 框架
- [MindSpore1.5.2](https://www.mindspore.cn/install/en)
- Ascend:[MindSpore1.5.2](https://www.mindspore.cn/install/en)
- GPU:[MindSpore1.7.0](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/index.html)
......@@ -153,30 +154,43 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
# 快速入门
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
通过官方网站安装MindSpore后,运行启动命令之前请将相关启动脚本中的路径改为对应机器上的路径,您可以按照如下步骤进行训练和评估:
- Ascend处理器环境运行
```text
# 运行训练示例
bash scripts/run_standalone_train.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
bash scripts/run_standalone_train_ascend.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
# 运行分布式训练示例
bash scripts/run_distribute_train.sh RANK_TABLE_FILE configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
bash scripts/run_distribute_train_ascend.sh RANK_TABLE_FILE configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
# 运行推理示例
bash scripts/run_standalone_eval.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava checkpoint_epoch_00020_best248.pyth.ckpt 1
bash scripts/run_standalone_eval_ascend.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava checkpoint_epoch_00020_best248.pyth.ckpt 1
# 310离线推理
bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_TARGET] [DEVICE_ID]
```
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
对于Ascend分布式训练,需要提前创建JSON格式的hccl配置文件。
请遵循以下链接中的说明:
<https://gitee.com/mindspore/models/tree/master/utils/hccl_tools.>
- GPU处理器环境运行
```text
# 运行训练示例
bash scripts/run_standalone_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
# 运行分布式训练示例
bash scripts/run_distribute_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
# 运行推理示例
bash scripts/run_standalone_eval_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
```
# 脚本说明
## 脚本及样例代码
......@@ -189,10 +203,13 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
├── ascend310_infer // 实现310推理源代码
├── scripts
│ ├──run_310_infer.sh // 310离线推理的shell脚本
│ ├──run_distribute_train.sh // Ascend训练的shell脚本
│ ├──run_distribute_train.sh // Ascend分布式训练的shell脚本
│ ├──run_distribute_train_gpu.sh // GPU分布式训练的shell脚本
│ ├──run_export.sh // checkpoint文件导出的shell脚本
│ ├──run_standalone_eval.sh // Ascend推理的shell脚本
│ ├──run_standalone_train.sh // 分布式Ascend训练的shell脚本
│ ├──run_standalone_train.sh // Ascend单卡训练的shell脚本
│ ├──run_standalone_eval_gpu.sh // GPU推理的shell脚本
│ ├──run_standalone_train_gpu.sh // GPU单卡训练的shell脚本
├── src
│ ├── datasets // ava数据集处理
│ ├── models
......@@ -253,16 +270,16 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
- Ascend处理器环境运行
```text
bash scripts/run_standalone_train.sh CFG DATA_DIR CHECKPOINT_FILE_PATH
bash scripts/run_standalone_train_ascend.sh CFG DATA_DIR CHECKPOINT_FILE_PATH
```
```text
bash scripts/run_distribute_train.sh RANK_TABLE_FILE configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
bash scripts/run_standalone_train_ascend.sh RANK_TABLE_FILE configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
```
上述python命令将在后台运行,您可以通过train.log文件查看结果。
上述python命令将在后台运行,您可以通过log_standalone_ascend文件查看结果。
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值:
......@@ -273,6 +290,23 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
...
```
- GPU处理器环境运行
```text
bash scripts/run_standalone_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
```
上述python命令将在后台运行,您可以通过log_standalone_gpu文件查看结果。
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值:
```text
# grep "loss is " train.log
epoch:1 step:390, loss is 0.0990763
epcoh:2 step:390, loss is 0.0603111
...
```
模型检查点保存在当前目录下。
### 分布式训练
......@@ -283,10 +317,10 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
bash scripts/run_distribute_train ~/hccl_8p_01234567_127.0.0.1.json
```
上述shell脚本将在后台运行分布训练。您可以通过train_parallel[X]/log文件查看结果。采用以下方式达到损失值:
上述shell脚本将在后台运行分布训练。您可以通过log_distributed_ascend文件查看结果。采用以下方式达到损失值:
```text
# grep "result:" train_parallel*/log
# grep "result:" log_distributed_ascend
train_parallel0/log:epoch:1 step:48, loss is 1.4302931
train_parallel0/log:epcoh:2 step:48, loss is 1.4023874
...
......@@ -296,6 +330,25 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
...
```
- GPU处理器环境运行
```text
bash scripts/run_distribute_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
```
上述shell脚本将在后台运行分布训练。您可以通过log_distributed_gpu文件查看结果。采用以下方式达到损失值:
```text
# grep "result:" log_distributed_gpu
train_parallel0/log:epoch:1 step:48, loss is 0.2674269
train_parallel0/log:epcoh:2 step:48, loss is 0.0610401
...
train_parallel1/log:epoch:1 step:48, loss is 0.2730093
train_parallel1/log:epcoh:2 step:48, loss is 0.0648247
...
...
```
## 导出过程
### 导出
......@@ -313,13 +366,12 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
### 推理
在运行推理之前我们需要先导出模型。
- 在昇腾910上使用ava数据集进行推理
在运行推理之前我们需要先导出模型。
在执行下面的命令之前,我们需要先修改ava的配置文件。修改的项包括AVA.FRAME_DIR、AVA.FRAME_LIST_DIR、AVA.ANNOTATION_DIR和TRAIN.CHECKPOINT_FILE_PATH。
推理的结果保存在当前目录下,在evalX/log日志文件中可以找到类似以下的结果。
推理的结果保存在当前目录下,在log_eval_ascend日志文件中可以找到类似以下的结果。
```text
'PascalBoxes_PerformanceByCategory/AP@0.5IOU/turn (e.g., a screwdriver)': 0.0031881659969238293,
......@@ -334,13 +386,33 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
```
```text
bash scripts/run_standalone_eval.sh CFG DATA_DIR CHECKPOINT_FILE_PATH DEVICE_ID
bash scripts/run_standalone_eval_ascend.sh CFG DATA_DIR CHECKPOINT_FILE_PATH DEVICE_ID
```
示例
```text
bash scripts/run_standalone_eval.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava checkpoint_epoch_00020_best248.pyth.ckpt 1
bash scripts/run_standalone_eval_ascend.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava checkpoint_epoch_00020_best248.pyth.ckpt 1
```
- 在GPU上使用ava数据集进行推理
推理的结果保存在当前目录下,在log_eval_gpu日志文件中可以找到类似以下的结果。
```text
'PascalBoxes_PerformanceByCategory/AP@0.5IOU/turn (e.g., a screwdriver)': 0.0031881659969238293,
'PascalBoxes_PerformanceByCategory/AP@0.5IOU/walk': 0.7207324941463648,
'PascalBoxes_PerformanceByCategory/AP@0.5IOU/watch (a person)': 0.6626902737325869,
'PascalBoxes_PerformanceByCategory/AP@0.5IOU/watch (e.g., TV)': 0.10220154817817734,
'PascalBoxes_PerformanceByCategory/AP@0.5IOU/work on a computer': 0.028072906328370745,
'PascalBoxes_PerformanceByCategory/AP@0.5IOU/write': 0.0774830044468495,
'PascalBoxes_Precision/mAP@0.5IOU': 0.2173776249697695}
[04/04 14:49:23][INFO] ava_eval_helper.py: 169: AVA eval done in 698.487868 seconds.
[04/04 14:49:23][INFO] logging.py: 84: json_stats: {"map": 0.21738, "mode": "test"}
```
```text
bash scripts/run_standalone_eval_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt
```
- 在昇腾310上使用ava数据集进行推理
......@@ -380,6 +452,8 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
#### 训练slowfast
- 使用Ascend
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | Kunpeng-920
......@@ -392,10 +466,26 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
| 速度 | 8卡:476毫秒/步 |
| 总时长 | 8卡:8.1小时 |
- 使用GPU
| 参数 | GPU |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | Nvidia
| 资源 | Nvidia-GeForce RTX 3090;CPU 2.90GHz,64核;内存 251G; |
| MindSpore版本 | 1.7.0 |
| 数据集 | AVA2.2 |
| 训练参数 | lr=0.15,fp=32,mmt=0.9,nesterov=false,roiend=1 |
| 优化器 | Momentum |
| 损失函数 | BCELoss二分类交叉熵 |
| 速度 | 8卡:1500毫秒/步 |
| 总时长 | 8卡:30.6小时 |
### 评估性能
#### 评估slowfast
- 使用Ascend
| 参数 | Ascend |
| ------------------- | --------------------------- |
| 模型版本 | Kunpeng-920 |
......@@ -406,6 +496,18 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视
| 输出 | 概率 |
| 准确性 | 8卡: 21.73% |
- 使用GPU
| 参数 | GPU |
| ------------------- | --------------------------- |
| 模型版本 | Nvidia |
| 资源 | Nvidia-GeForce RTX 3090; |
| MindSpore版本 | 1.7.0 |
| 数据集 | AVA2.2 |
| batch_size | 16 |
| 输出 | 概率 |
| 准确性 | 8卡: 21.73% |
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/models)
......@@ -35,8 +35,7 @@ def run_eval():
logger.info(cfg)
# setup context
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(device_id=device_id,
mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(save_graphs=True, save_graphs_path='irs_eval')
# build dataset
dataset = build_dataset(cfg, "test")
......
......@@ -53,6 +53,6 @@ do
AVA.FRAME_DIR "${DATA_DIR}/frames" \
AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \
AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \
TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log 2>&1 &
TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_distributed_ascend 2>&1 &
cd ..
done
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 3 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_distribute_train_gpu.sh CFG DATA_DIR CHECKPOINT_FILE_PATH"
echo "for example: bash scripts/run_distribute_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt"
echo "=============================================================================================================="
exit 1;
fi
export HCCL_CONNECT_TIMEOUT=600
export DEVICE_NUM=8
export RANK_SIZE=8
CFG=$(realpath $1)
DATA_DIR=$(realpath $2)
CHECKPOINT_FILE_PATH=$(realpath $3)
mpirun -n 8 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python -u train.py --device_target="GPU" --dataset_sink_mode=0 --cfg "$CFG" \
AVA.FRAME_DIR "${DATA_DIR}/frames" \
AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \
AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \
TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_distributed_gpu 2>&1 &
......@@ -48,5 +48,5 @@ taskset -c $cpu_range python -u eval.py --cfg "${CFG}" \
AVA.FRAME_DIR "${DATA_DIR}/frames" \
AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \
AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \
TEST.CHECKPOINT_FILE_PATH "${CHECKPOINT_FILE_PATH}" > log 2>&1 &
TEST.CHECKPOINT_FILE_PATH "${CHECKPOINT_FILE_PATH}" > log_eval_ascend 2>&1 &
cd ..
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval_gpu.sh CFG DATA_DIR CHECKPOINT_FILE_PATH"
echo "for example: bash scripts/run_standalone_eval_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt"
echo "=============================================================================================================="
exit 1;
fi
CFG=$(realpath $1)
DATA_DIR=$(realpath $2)
CHECKPOINT_FILE_PATH=$(realpath $3)
python -u eval.py --device_target="GPU" --dataset_sink_mode=0 --cfg "$CFG" \
AVA.FRAME_DIR "${DATA_DIR}/frames" \
AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \
AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \
TEST.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_eval_gpu 2>&1 &
......@@ -28,4 +28,4 @@ python -u train.py --cfg "$CFG" \
AVA.FRAME_DIR "${DATA_DIR}/frames" \
AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \
AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \
TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log 2>&1 &
TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_standalone_ascend 2>&1 &
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ] ; then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_train_gpu.sh CFG DATA_DIR CHECKPOINT_FILE_PATH"
echo "for example: bash scripts/run_standalone_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt"
echo "=============================================================================================================="
exit 1;
fi
export DEVICE_NUM=1
CFG=$(realpath $1)
DATA_DIR=$(realpath $2)
CHECKPOINT_FILE_PATH=$(realpath $3)
python -u train.py --device_target="GPU" --dataset_sink_mode=0 --cfg "$CFG" \
AVA.FRAME_DIR "${DATA_DIR}/frames" \
AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \
AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \
TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_standalone_gpu 2>&1 &
......@@ -19,7 +19,7 @@ from src.datasets.ava_dataset import Ava
ds.config.set_prefetch_size(8)
ds.config.set_numa_enable(True)
def build_dataset(cfg, split, num_shards=None, shard_id=None):
def build_dataset(cfg, split, num_shards=None, shard_id=None, device_target='Ascend'):
"""
Args:
cfg (CfgNode): configs. Details can be found in
......@@ -36,7 +36,7 @@ def build_dataset(cfg, split, num_shards=None, shard_id=None):
if split == 'train':
dataset = ds.GeneratorDataset(dataset_generator,
["slowpath", "fastpath", "boxes", "labels", "mask"],
num_parallel_workers=16,
num_parallel_workers=16 if device_target == 'Ascend' else 6,
python_multiprocessing=False,
shuffle=True,
num_shards=num_shards,
......@@ -45,7 +45,7 @@ def build_dataset(cfg, split, num_shards=None, shard_id=None):
else:
dataset = ds.GeneratorDataset(dataset_generator,
["slowpath", "fastpath", "boxes", "labels", "ori_boxes", "metadata", "mask"],
num_parallel_workers=16,
num_parallel_workers=16 if device_target == 'Ascend' else 6,
python_multiprocessing=False,
shuffle=False)
dataset = dataset.batch(cfg.TEST.BATCH_SIZE)
......
......@@ -68,6 +68,9 @@ def parse_args():
default=None,
nargs=argparse.REMAINDER,
)
# Define parameters for device target and dataset_sink_mode
parser.add_argument("--device_target", default="Ascend", type=str, help="Ascend/GPU")
parser.add_argument("--dataset_sink_mode", default=1, type=int, help="dataset_sink_mode")
# define 2 parameters for running on modelArts
parser.add_argument('--data_url',
help='path to training/inference dataset folder',
......
......@@ -15,9 +15,11 @@
"""Train."""
import os
import numpy as np
from mindspore import context, nn, dtype, load_checkpoint, set_seed
from mindspore import DynamicLossScaleManager
from mindspore.communication import init
from mindspore.communication import init, get_rank
from mindspore.common.tensor import Tensor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from src.config.defaults import assert_and_infer_cfg
......@@ -29,6 +31,40 @@ from src.models import optimizer as optim
set_seed(42)
class LossMonitor_Standalone(LossMonitor):
def __init__(self, per_print_times=1):
super(LossMonitor_Standalone, self).__init__(per_print_times=1)
self._per_print_times = per_print_times
self._last_print_time = 0
def step_end(self, run_context):
"""
Print training loss at the end of step.
Args:
run_context (RunContext): Include some information of the model.
"""
cb_params = run_context.original_args()
loss = cb_params.net_outputs
print(loss)
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = float(np.mean(loss.asnumpy()))
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if self._per_print_times != 0 and (cb_params.cur_step_num <= self._last_print_time):
while cb_params.cur_step_num <= self._last_print_time:
self._last_print_time -=\
max(self._per_print_times, cb_params.batch_num if cb_params.dataset_sink_mode else 1)
if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times:
self._last_print_time = cb_params.cur_step_num
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
class NetWithLoss(nn.Cell):
"""Construct Loss Net."""
def __init__(self, net):
......@@ -59,16 +95,16 @@ def train():
rank_id = int(os.getenv('RANK_ID', '0'))
device_id = int(os.getenv('DEVICE_ID', '0'))
device_num = int(os.getenv('DEVICE_NUM', '1'))
context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(save_graphs=True, save_graphs_path='irs')
if device_num > 1:
init()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=context.ParallelMode.DATA_PARALLEL,
gradients_mean=True)
rank_id = get_rank()
# build dataset
dataset = build_dataset(cfg, "train", num_shards=device_num, shard_id=rank_id)
dataset = build_dataset(cfg, "train", num_shards=device_num, shard_id=rank_id, device_target=args.device_target)
steps_per_epoch = dataset.get_dataset_size()
# build net with loss
network = SlowFast(cfg).set_train(True)
......@@ -81,8 +117,11 @@ def train():
optimizer = optim.construct_optimizer(net_with_loss, steps_per_epoch, cfg)
# setup callbacks
callbacks = [TimeMonitor(), LossMonitor()]
if (device_num == 1) or (device_num > 1 and device_id in [0, 7]):
if device_num > 1:
callbacks = [TimeMonitor(), LossMonitor()]
else:
callbacks = [TimeMonitor(), LossMonitor_Standalone()]
if rank_id == 0:
ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=cfg.SOLVER.MAX_EPOCH)
ckpt_cb = ModelCheckpoint(prefix="slowfast", directory='checkpoints', config=ckpt_cfg)
callbacks.append(ckpt_cb)
......@@ -92,7 +131,7 @@ def train():
# start training
logger.info("============== Starting Training ==============")
logger.info("total_epoch=%d, steps_per_epoch=%d", cfg.SOLVER.MAX_EPOCH, steps_per_epoch)
model.train(cfg.SOLVER.MAX_EPOCH, dataset, callbacks=callbacks, dataset_sink_mode=True)
model.train(cfg.SOLVER.MAX_EPOCH, dataset, callbacks=callbacks, dataset_sink_mode=bool(args.dataset_sink_mode))
if __name__ == "__main__":
train()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment