diff --git a/research/cv/slowfast/README.md b/research/cv/slowfast/README.md index 8cf28285f2317adcdcd2b80836cb36674b47dba0..10ecff01434dfb4a0d0d6aede2e4215f6fc897bb 100644 --- a/research/cv/slowfast/README.md +++ b/research/cv/slowfast/README.md @@ -141,10 +141,11 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 # 环境要求 -- 硬件(Ascend) - - 使用Ascend处理器来搭建硬件环境。 +- 硬件(Ascend/GPU) + - 使用Ascend或GPU处理器来搭建硬件环境。 - 框架 - - [MindSpore1.5.2](https://www.mindspore.cn/install/en) + - Ascend:[MindSpore1.5.2](https://www.mindspore.cn/install/en) + - GPU:[MindSpore1.7.0](https://www.mindspore.cn/install/en) - 如需查看详情,请参见如下资源: - [MindSpore教程](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) - [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/index.html) @@ -153,30 +154,43 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 # 快速入门 -通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估: +通过官方网站安装MindSpore后,运行启动命令之前请将相关启动脚本中的路径改为对应机器上的路径,您可以按照如下步骤进行训练和评估: - Ascend处理器环境运行 ```text # 运行训练示例 - bash scripts/run_standalone_train.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt + bash scripts/run_standalone_train_ascend.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt # 运行分布式训练示例 - bash scripts/run_distribute_train.sh RANK_TABLE_FILE configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt + bash scripts/run_distribute_train_ascend.sh RANK_TABLE_FILE configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt # 运行推理示例 - bash scripts/run_standalone_eval.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava checkpoint_epoch_00020_best248.pyth.ckpt 1 + bash scripts/run_standalone_eval_ascend.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava checkpoint_epoch_00020_best248.pyth.ckpt 1 # 310离线推理 bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [NEED_PREPROCESS] [DEVICE_TARGET] [DEVICE_ID] ``` - 对于分布式训练,需要提前创建JSON格式的hccl配置文件。 + 对于Ascend分布式训练,需要提前创建JSON格式的hccl配置文件。 请遵循以下链接中的说明: <https://gitee.com/mindspore/models/tree/master/utils/hccl_tools.> +- GPU处理器环境运行 + + ```text + # 运行训练示例 + bash scripts/run_standalone_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt + + # 运行分布式训练示例 + bash scripts/run_distribute_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt + + # 运行推理示例 + bash scripts/run_standalone_eval_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt + ``` + # 脚本说明 ## 脚本及样例代码 @@ -189,10 +203,13 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 ├── ascend310_infer // 实现310推理源代码 ├── scripts │ ├──run_310_infer.sh // 310离线推理的shell脚本 - │ ├──run_distribute_train.sh // Ascend训练的shell脚本 + │ ├──run_distribute_train.sh // Ascend分布式训练的shell脚本 + │ ├──run_distribute_train_gpu.sh // GPU分布式训练的shell脚本 │ ├──run_export.sh // checkpoint文件导出的shell脚本 │ ├──run_standalone_eval.sh // Ascend推理的shell脚本 - │ ├──run_standalone_train.sh // 分布式Ascend训练的shell脚本 + │ ├──run_standalone_train.sh // Ascend单卡训练的shell脚本 + │ ├──run_standalone_eval_gpu.sh // GPU推理的shell脚本 + │ ├──run_standalone_train_gpu.sh // GPU单卡训练的shell脚本 ├── src │ ├── datasets // ava数据集处理 │ ├── models @@ -253,16 +270,16 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 - Ascend处理器环境运行 ```text - bash scripts/run_standalone_train.sh CFG DATA_DIR CHECKPOINT_FILE_PATH + bash scripts/run_standalone_train_ascend.sh CFG DATA_DIR CHECKPOINT_FILE_PATH ``` 如 ```text - bash scripts/run_distribute_train.sh RANK_TABLE_FILE configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt + bash scripts/run_standalone_train_ascend.sh RANK_TABLE_FILE configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt ``` - 上述python命令将在后台运行,您可以通过train.log文件查看结果。 + 上述python命令将在后台运行,您可以通过log_standalone_ascend文件查看结果。 训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值: @@ -273,6 +290,23 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 ... ``` +- GPU处理器环境运行 + + ```text + bash scripts/run_standalone_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt + ``` + + 上述python命令将在后台运行,您可以通过log_standalone_gpu文件查看结果。 + + 训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值: + + ```text + # grep "loss is " train.log + epoch:1 step:390, loss is 0.0990763 + epcoh:2 step:390, loss is 0.0603111 + ... + ``` + 模型检查点保存在当前目录下。 ### 分布式训练 @@ -283,10 +317,10 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 bash scripts/run_distribute_train ~/hccl_8p_01234567_127.0.0.1.json ``` - 上述shell脚本将在后台运行分布训练。您可以通过train_parallel[X]/log文件查看结果。采用以下方式达到损失值: + 上述shell脚本将在后台运行分布训练。您可以通过log_distributed_ascend文件查看结果。采用以下方式达到损失值: ```text - # grep "result:" train_parallel*/log + # grep "result:" log_distributed_ascend train_parallel0/log:epoch:1 step:48, loss is 1.4302931 train_parallel0/log:epcoh:2 step:48, loss is 1.4023874 ... @@ -296,6 +330,25 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 ... ``` +- GPU处理器环境运行 + + ```text + bash scripts/run_distribute_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt + ``` + + 上述shell脚本将在后台运行分布训练。您可以通过log_distributed_gpu文件查看结果。采用以下方式达到损失值: + + ```text + # grep "result:" log_distributed_gpu + train_parallel0/log:epoch:1 step:48, loss is 0.2674269 + train_parallel0/log:epcoh:2 step:48, loss is 0.0610401 + ... + train_parallel1/log:epoch:1 step:48, loss is 0.2730093 + train_parallel1/log:epcoh:2 step:48, loss is 0.0648247 + ... + ... + ``` + ## 导出过程 ### 导出 @@ -313,13 +366,12 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 ### 推理 -在运行推理之前我们需要先导出模型。 - - 在昇腾910上使用ava数据集进行推理 + 在运行推理之前我们需要先导出模型。 在执行下面的命令之前,我们需要先修改ava的配置文件。修改的项包括AVA.FRAME_DIR、AVA.FRAME_LIST_DIR、AVA.ANNOTATION_DIR和TRAIN.CHECKPOINT_FILE_PATH。 - 推理的结果保存在当前目录下,在evalX/log日志文件中可以找到类似以下的结果。 + 推理的结果保存在当前目录下,在log_eval_ascend日志文件中可以找到类似以下的结果。 ```text 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/turn (e.g., a screwdriver)': 0.0031881659969238293, @@ -334,13 +386,33 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 ``` ```text - bash scripts/run_standalone_eval.sh CFG DATA_DIR CHECKPOINT_FILE_PATH DEVICE_ID + bash scripts/run_standalone_eval_ascend.sh CFG DATA_DIR CHECKPOINT_FILE_PATH DEVICE_ID ``` 示例 ```text - bash scripts/run_standalone_eval.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava checkpoint_epoch_00020_best248.pyth.ckpt 1 + bash scripts/run_standalone_eval_ascend.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava checkpoint_epoch_00020_best248.pyth.ckpt 1 + ``` + +- 在GPU上使用ava数据集进行推理 + + 推理的结果保存在当前目录下,在log_eval_gpu日志文件中可以找到类似以下的结果。 + + ```text + 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/turn (e.g., a screwdriver)': 0.0031881659969238293, + 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/walk': 0.7207324941463648, + 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/watch (a person)': 0.6626902737325869, + 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/watch (e.g., TV)': 0.10220154817817734, + 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/work on a computer': 0.028072906328370745, + 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/write': 0.0774830044468495, + 'PascalBoxes_Precision/mAP@0.5IOU': 0.2173776249697695} + [04/04 14:49:23][INFO] ava_eval_helper.py: 169: AVA eval done in 698.487868 seconds. + [04/04 14:49:23][INFO] logging.py: 84: json_stats: {"map": 0.21738, "mode": "test"} + ``` + + ```text + bash scripts/run_standalone_eval_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt ``` - 在昇腾310上使用ava数据集进行推理 @@ -380,6 +452,8 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 #### 训练slowfast +- 使用Ascend + | 参数 | Ascend | | -------------------------- | ----------------------------------------------------------- | | 模型版本 | Kunpeng-920 @@ -392,10 +466,26 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 | 速度 | 8卡:476毫秒/步 | | 总时长 | 8卡:8.1小时 | +- 使用GPU + +| 参数 | GPU | +| -------------------------- | ----------------------------------------------------------- | +| 模型版本 | Nvidia +| 资源 | Nvidia-GeForce RTX 3090;CPU 2.90GHz,64核;内存 251G; | +| MindSpore版本 | 1.7.0 | +| 数据集 | AVA2.2 | +| 训练参数 | lr=0.15,fp=32,mmt=0.9,nesterov=false,roiend=1 | +| 优化器 | Momentum | +| 损失函数 | BCELoss二分类交叉熵 | +| 速度 | 8卡:1500毫秒/步 | +| 总时长 | 8卡:30.6小时 | + ### 评估性能 #### 评估slowfast +- 使用Ascend + | 参数 | Ascend | | ------------------- | --------------------------- | | 模型版本 | Kunpeng-920 | @@ -406,6 +496,18 @@ slowfast是由Facebook AI研究团队提出的一种新颖的方法来分析视 | 输出 | 概率 | | 准确性 | 8卡: 21.73% | +- 使用GPU + +| 参数 | GPU | +| ------------------- | --------------------------- | +| 模型版本 | Nvidia | +| 资源 | Nvidia-GeForce RTX 3090; | +| MindSpore版本 | 1.7.0 | +| 数据集 | AVA2.2 | +| batch_size | 16 | +| 输出 | 概率 | +| 准确性 | 8卡: 21.73% | + # ModelZoo主页 请浏览官网[主页](https://gitee.com/mindspore/models)。 diff --git a/research/cv/slowfast/eval.py b/research/cv/slowfast/eval.py index bef3825964cfcdf2ad82ec6de253bee2cfbaabee..409f97b1548cf6a64ec2b43781df976d1ad81b22 100644 --- a/research/cv/slowfast/eval.py +++ b/research/cv/slowfast/eval.py @@ -35,8 +35,7 @@ def run_eval(): logger.info(cfg) # setup context device_id = int(os.getenv('DEVICE_ID', '0')) - context.set_context(device_id=device_id, - mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(save_graphs=True, save_graphs_path='irs_eval') # build dataset dataset = build_dataset(cfg, "test") diff --git a/research/cv/slowfast/scripts/run_distribute_train.sh b/research/cv/slowfast/scripts/run_distribute_train_ascend.sh similarity index 98% rename from research/cv/slowfast/scripts/run_distribute_train.sh rename to research/cv/slowfast/scripts/run_distribute_train_ascend.sh index 01d98500d92ad0c7023a080975a697fa95ca2cdf..d409f4f12a1ba478bc3b7eac7a74d970490114ab 100644 --- a/research/cv/slowfast/scripts/run_distribute_train.sh +++ b/research/cv/slowfast/scripts/run_distribute_train_ascend.sh @@ -53,6 +53,6 @@ do AVA.FRAME_DIR "${DATA_DIR}/frames" \ AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \ AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \ - TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log 2>&1 & + TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_distributed_ascend 2>&1 & cd .. done diff --git a/research/cv/slowfast/scripts/run_distribute_train_gpu.sh b/research/cv/slowfast/scripts/run_distribute_train_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..4c7a72417bccf0e635a98d255bc290ed1eb36782 --- /dev/null +++ b/research/cv/slowfast/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# an simple tutorial as follows, more parameters can be setting + +if [ $# != 3 ] ; then +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash scripts/run_distribute_train_gpu.sh CFG DATA_DIR CHECKPOINT_FILE_PATH" +echo "for example: bash scripts/run_distribute_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt" +echo "==============================================================================================================" +exit 1; +fi +export HCCL_CONNECT_TIMEOUT=600 +export DEVICE_NUM=8 +export RANK_SIZE=8 +CFG=$(realpath $1) +DATA_DIR=$(realpath $2) +CHECKPOINT_FILE_PATH=$(realpath $3) + +mpirun -n 8 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \ +python -u train.py --device_target="GPU" --dataset_sink_mode=0 --cfg "$CFG" \ + AVA.FRAME_DIR "${DATA_DIR}/frames" \ + AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \ + AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \ + TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_distributed_gpu 2>&1 & diff --git a/research/cv/slowfast/scripts/run_standalone_eval.sh b/research/cv/slowfast/scripts/run_standalone_eval_ascend.sh similarity index 96% rename from research/cv/slowfast/scripts/run_standalone_eval.sh rename to research/cv/slowfast/scripts/run_standalone_eval_ascend.sh index 15de822ed9607c702d8282fccb3494ddf7fbda08..d2e1534436750875a39e018a43e7c6997ee4ae6b 100644 --- a/research/cv/slowfast/scripts/run_standalone_eval.sh +++ b/research/cv/slowfast/scripts/run_standalone_eval_ascend.sh @@ -48,5 +48,5 @@ taskset -c $cpu_range python -u eval.py --cfg "${CFG}" \ AVA.FRAME_DIR "${DATA_DIR}/frames" \ AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \ AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \ - TEST.CHECKPOINT_FILE_PATH "${CHECKPOINT_FILE_PATH}" > log 2>&1 & + TEST.CHECKPOINT_FILE_PATH "${CHECKPOINT_FILE_PATH}" > log_eval_ascend 2>&1 & cd .. diff --git a/research/cv/slowfast/scripts/run_standalone_eval_gpu.sh b/research/cv/slowfast/scripts/run_standalone_eval_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..73923b42eb1c68745ecd502b8756ca93624aff50 --- /dev/null +++ b/research/cv/slowfast/scripts/run_standalone_eval_gpu.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] ; then +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash scripts/run_standalone_eval_gpu.sh CFG DATA_DIR CHECKPOINT_FILE_PATH" +echo "for example: bash scripts/run_standalone_eval_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt" +echo "==============================================================================================================" +exit 1; +fi + +CFG=$(realpath $1) +DATA_DIR=$(realpath $2) +CHECKPOINT_FILE_PATH=$(realpath $3) + +python -u eval.py --device_target="GPU" --dataset_sink_mode=0 --cfg "$CFG" \ + AVA.FRAME_DIR "${DATA_DIR}/frames" \ + AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \ + AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \ + TEST.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_eval_gpu 2>&1 & + diff --git a/research/cv/slowfast/scripts/run_standalone_train.sh b/research/cv/slowfast/scripts/run_standalone_train_ascend.sh similarity index 94% rename from research/cv/slowfast/scripts/run_standalone_train.sh rename to research/cv/slowfast/scripts/run_standalone_train_ascend.sh index fc5442a30ae8853d3008a71ea17a618d9072cca5..633319a854649d2c7d8f27979a3f72ec49008767 100644 --- a/research/cv/slowfast/scripts/run_standalone_train.sh +++ b/research/cv/slowfast/scripts/run_standalone_train_ascend.sh @@ -28,4 +28,4 @@ python -u train.py --cfg "$CFG" \ AVA.FRAME_DIR "${DATA_DIR}/frames" \ AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \ AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \ - TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log 2>&1 & + TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_standalone_ascend 2>&1 & diff --git a/research/cv/slowfast/scripts/run_standalone_train_gpu.sh b/research/cv/slowfast/scripts/run_standalone_train_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..eae77a82cd500d37122940d195be1731653f69af --- /dev/null +++ b/research/cv/slowfast/scripts/run_standalone_train_gpu.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# != 3 ] ; then +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash scripts/run_standalone_train_gpu.sh CFG DATA_DIR CHECKPOINT_FILE_PATH" +echo "for example: bash scripts/run_standalone_train_gpu.sh configs/AVA/SLOWFAST_32x2_R50_SHORT.yaml data/ava SLOWFAST_8x8_R50.pkl.ckpt" +echo "==============================================================================================================" +exit 1; +fi + +export DEVICE_NUM=1 +CFG=$(realpath $1) +DATA_DIR=$(realpath $2) +CHECKPOINT_FILE_PATH=$(realpath $3) + +python -u train.py --device_target="GPU" --dataset_sink_mode=0 --cfg "$CFG" \ + AVA.FRAME_DIR "${DATA_DIR}/frames" \ + AVA.FRAME_LIST_DIR "${DATA_DIR}/ava_annotations" \ + AVA.ANNOTATION_DIR "${DATA_DIR}/ava_annotations" \ + TRAIN.CHECKPOINT_FILE_PATH "$CHECKPOINT_FILE_PATH" > log_standalone_gpu 2>&1 & diff --git a/research/cv/slowfast/src/datasets/build.py b/research/cv/slowfast/src/datasets/build.py index d2103dc667670aa6dbd2b94737c8f3038cdfd3b6..0ecb1e8e6d65ad47ac775927d055c2b030715230 100644 --- a/research/cv/slowfast/src/datasets/build.py +++ b/research/cv/slowfast/src/datasets/build.py @@ -19,7 +19,7 @@ from src.datasets.ava_dataset import Ava ds.config.set_prefetch_size(8) ds.config.set_numa_enable(True) -def build_dataset(cfg, split, num_shards=None, shard_id=None): +def build_dataset(cfg, split, num_shards=None, shard_id=None, device_target='Ascend'): """ Args: cfg (CfgNode): configs. Details can be found in @@ -36,7 +36,7 @@ def build_dataset(cfg, split, num_shards=None, shard_id=None): if split == 'train': dataset = ds.GeneratorDataset(dataset_generator, ["slowpath", "fastpath", "boxes", "labels", "mask"], - num_parallel_workers=16, + num_parallel_workers=16 if device_target == 'Ascend' else 6, python_multiprocessing=False, shuffle=True, num_shards=num_shards, @@ -45,7 +45,7 @@ def build_dataset(cfg, split, num_shards=None, shard_id=None): else: dataset = ds.GeneratorDataset(dataset_generator, ["slowpath", "fastpath", "boxes", "labels", "ori_boxes", "metadata", "mask"], - num_parallel_workers=16, + num_parallel_workers=16 if device_target == 'Ascend' else 6, python_multiprocessing=False, shuffle=False) dataset = dataset.batch(cfg.TEST.BATCH_SIZE) diff --git a/research/cv/slowfast/src/utils/parser.py b/research/cv/slowfast/src/utils/parser.py index efce9eb92564be5aef5ad4c28ef187fc2b4565aa..ad10b1fb38b63c4dd10670d688e473061f060341 100644 --- a/research/cv/slowfast/src/utils/parser.py +++ b/research/cv/slowfast/src/utils/parser.py @@ -68,6 +68,9 @@ def parse_args(): default=None, nargs=argparse.REMAINDER, ) + # Define parameters for device target and dataset_sink_mode + parser.add_argument("--device_target", default="Ascend", type=str, help="Ascend/GPU") + parser.add_argument("--dataset_sink_mode", default=1, type=int, help="dataset_sink_mode") # define 2 parameters for running on modelArts parser.add_argument('--data_url', help='path to training/inference dataset folder', diff --git a/research/cv/slowfast/train.py b/research/cv/slowfast/train.py index a4be829ec4f35d9f33924988accc5afcb0c711d2..88279b0fd60a8b7cbd1ae9456493d16765f63a0d 100644 --- a/research/cv/slowfast/train.py +++ b/research/cv/slowfast/train.py @@ -15,9 +15,11 @@ """Train.""" import os +import numpy as np from mindspore import context, nn, dtype, load_checkpoint, set_seed from mindspore import DynamicLossScaleManager -from mindspore.communication import init +from mindspore.communication import init, get_rank +from mindspore.common.tensor import Tensor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train import Model from src.config.defaults import assert_and_infer_cfg @@ -29,6 +31,40 @@ from src.models import optimizer as optim set_seed(42) +class LossMonitor_Standalone(LossMonitor): + def __init__(self, per_print_times=1): + super(LossMonitor_Standalone, self).__init__(per_print_times=1) + self._per_print_times = per_print_times + self._last_print_time = 0 + + def step_end(self, run_context): + """ + Print training loss at the end of step. + + Args: + run_context (RunContext): Include some information of the model. + """ + cb_params = run_context.original_args() + loss = cb_params.net_outputs + + print(loss) + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] + + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = float(np.mean(loss.asnumpy())) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if self._per_print_times != 0 and (cb_params.cur_step_num <= self._last_print_time): + while cb_params.cur_step_num <= self._last_print_time: + self._last_print_time -=\ + max(self._per_print_times, cb_params.batch_num if cb_params.dataset_sink_mode else 1) + if self._per_print_times != 0 and (cb_params.cur_step_num - self._last_print_time) >= self._per_print_times: + self._last_print_time = cb_params.cur_step_num + print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) + class NetWithLoss(nn.Cell): """Construct Loss Net.""" def __init__(self, net): @@ -59,16 +95,16 @@ def train(): rank_id = int(os.getenv('RANK_ID', '0')) device_id = int(os.getenv('DEVICE_ID', '0')) device_num = int(os.getenv('DEVICE_NUM', '1')) - context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(save_graphs=True, save_graphs_path='irs') if device_num > 1: init() context.set_auto_parallel_context(device_num=device_num, parallel_mode=context.ParallelMode.DATA_PARALLEL, gradients_mean=True) - + rank_id = get_rank() # build dataset - dataset = build_dataset(cfg, "train", num_shards=device_num, shard_id=rank_id) + dataset = build_dataset(cfg, "train", num_shards=device_num, shard_id=rank_id, device_target=args.device_target) steps_per_epoch = dataset.get_dataset_size() # build net with loss network = SlowFast(cfg).set_train(True) @@ -81,8 +117,11 @@ def train(): optimizer = optim.construct_optimizer(net_with_loss, steps_per_epoch, cfg) # setup callbacks - callbacks = [TimeMonitor(), LossMonitor()] - if (device_num == 1) or (device_num > 1 and device_id in [0, 7]): + if device_num > 1: + callbacks = [TimeMonitor(), LossMonitor()] + else: + callbacks = [TimeMonitor(), LossMonitor_Standalone()] + if rank_id == 0: ckpt_cfg = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=cfg.SOLVER.MAX_EPOCH) ckpt_cb = ModelCheckpoint(prefix="slowfast", directory='checkpoints', config=ckpt_cfg) callbacks.append(ckpt_cb) @@ -92,7 +131,7 @@ def train(): # start training logger.info("============== Starting Training ==============") logger.info("total_epoch=%d, steps_per_epoch=%d", cfg.SOLVER.MAX_EPOCH, steps_per_epoch) - model.train(cfg.SOLVER.MAX_EPOCH, dataset, callbacks=callbacks, dataset_sink_mode=True) + model.train(cfg.SOLVER.MAX_EPOCH, dataset, callbacks=callbacks, dataset_sink_mode=bool(args.dataset_sink_mode)) if __name__ == "__main__": train()