diff --git a/official/cv/retinanet/README_CN.md b/official/cv/retinanet/README_CN.md
index 7e6d09379074d9e559ae8c5ae74b03cea0e7867f..f1498442452bcda1ecfb75aefda8e002161c4890 100644
--- a/official/cv/retinanet/README_CN.md
+++ b/official/cv/retinanet/README_CN.md
@@ -83,8 +83,11 @@ MSCOCO2017
├─scripts
├─run_single_train.sh # 使用Ascend环境单卡训练
├─run_distribute_train.sh # 使用Ascend环境八卡并行训练
+ ├─run_distribute_train_gpu.sh # 使用GPU环境八卡并行训练
+ ├─run_single_train_gpu.sh # 使用GPU环境单卡训练
├─run_infer_310.sh # Ascend推理shell脚本
├─run_eval.sh # 使用Ascend环境运行推理脚本
+ ├─run_eval_gpu.sh # 使用GPU环境运行推理脚本
├─src
├─dataset.py # 数据预处理
├─retinanet.py # 网络模型定义
@@ -179,11 +182,11 @@ MSCOCO2017
# 八卡并行训练示例:
创建 RANK_TABLE_FILE
-bash scripts/run_distribute_train.sh DEVICE_NUM RANK_TABLE_FILE MINDRECORD_DIR PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
+bash scripts/run_distribute_train.sh DEVICE_NUM RANK_TABLE_FILE CONFIG_PATH MINDRECORD_DIR PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
# 单卡训练示例:
-bash scripts/run_single_train.sh DEVICE_ID MINDRECORD_DIR PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
+bash scripts/run_single_train.sh DEVICE_ID MINDRECORD_DIR CONFIG_PATH PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional)
```
@@ -213,21 +216,30 @@ mindrecord_dr: /home/DataSet/MindRecord_COCO
```MindRecord
# 生成训练数据集
-python create_data.py --create_dataset coco --prefix retinanet.mindrecord --is_training True
+python create_data.py --create_dataset coco --prefix retinanet.mindrecord --is_training True --config_path
+(例如:python create_data.py --create_dataset coco --prefix retinanet.mindrecord --is_training True --config_path /home/retinanet/config/default_config.yaml)
# 生成测试数据集
-python create_data.py --create_dataset coco --prefix retinanet_eval.mindrecord --is_training False
+python create_data.py --create_dataset coco --prefix retinanet_eval.mindrecord --is_training False --config_path
+(例如:python create_data.py --create_dataset coco --prefix retinanet.mindrecord --is_training False --config_path /home/retinanet/config/default_config.yaml)
```
```bash
Ascend:
# 八卡并行训练示例(在retinanet目录下运行):
-bash scripts/run_distribute_train.sh [DEVICE_NUM] [RANK_TABLE_FILE] [MINDRECORD_DIR] [PRE_TRAINED(optional)] [PRE_TRAINED_EPOCH_SIZE(optional)]
-# example: bash scripts/run_distribute_train.sh 8 ~/hccl_8p.json /home/DataSet/MindRecord_COCO/
+bash scripts/run_distribute_train.sh [DEVICE_NUM] [RANK_TABLE_FILE] [MINDRECORD_DIR] [CONFIG_PATH] [PRE_TRAINED(optional)] [PRE_TRAINED_EPOCH_SIZE(optional)]
+# example: bash scripts/run_distribute_train.sh 8 ~/hccl_8p.json /home/DataSet/MindRecord_COCO/ /home/retinanet/config/default_config.yaml
# 单卡训练示例(在retinanet目录下运行):
-bash scripts/run_single_train.sh [DEVICE_ID] [MINDRECORD_DIR]
-# example: bash scripts/run_single_train.sh 0 /home/DataSet/MindRecord_COCO/
+bash scripts/run_single_train.sh [DEVICE_ID] [MINDRECORD_DIR] [CONFIG_PATH]
+# example: bash scripts/run_single_train.sh 0 /home/DataSet/MindRecord_COCO/ /home/retinanet/config/default_config.yaml
+```
+
+```bash
+GPU:
+# 八卡并行训练示例(在retinanet目录下运行):
+bash scripts/run_distribute_train_gpu.sh [DEVICE_NUM] [MINDRECORD_DIR] [CONFIG_PATH] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [PRE_TRAINED(optional)] [PRE_TRAINED_EPOCH_SIZE(optional)]
+# example: bash scripts/run_distribute_train_gpu.sh 8 /home/DataSet/MindRecord_COCO/ /home/retinanet/config/default_config_gpu.yaml 0,1,2,3,4,5,6,7
```
#### 结果
@@ -309,15 +321,16 @@ Epoch time: 164531.610, per step time: 359.239
使用shell脚本进行评估。shell脚本的用法如下:
-```eval
-bash scripts/run_eval.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [CHECKPOINT_PATH] [ANN_FILE PATH]
-# example: bash scripts/run_eval.sh 0 coco /home/DataSet/MindRecord_COCO/ /home/model/retinanet/ckpt/retinanet_500-458.ckpt /home/DataSet/cocodataset/annotations/instances_{}.json
+```bash
+Ascend:
+bash scripts/run_eval.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [CHECKPOINT_PATH] [ANN_FILE PATH] [CONFIG_PATH]
+# example: bash scripts/run_eval.sh 0 coco /home/DataSet/MindRecord_COCO/ /home/model/retinanet/ckpt/retinanet_500-458.ckpt /home/DataSet/cocodataset/annotations/instances_{}.json /home/retinanet/config/default_config.yaml
```
-#### <span id="running">运行</span>
-
-```eval运行
-bash scripts/run_eval.sh 0 coco /home/DataSet/MindRecord_COCO/ /home/model/retinanet/ckpt/retinanet_500-458.ckpt /home/DataSet/cocodataset/annotations/instances_{}.json
+```bash
+GPU:
+bash scripts/run_eval_gpu.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [CHECKPOINT_PATH] [ANN_FILE PATH] [CONFIG_PATH]
+# example: bash scripts/run_eval_gpu.sh 0 coco /home/DataSet/MindRecord_COCO/ /home/model/retinanet/ckpt/retinanet_500-458.ckpt /home/DataSet/cocodataset/annotations/instances_{}.json /home/retinanet/config/default_config_gpu.yaml
```
> checkpoint 可以在训练过程中产生.
@@ -327,6 +340,7 @@ bash scripts/run_eval.sh 0 coco /home/DataSet/MindRecord_COCO/ /home/model/retin
计算结果将存储在示例路径中,您可以在 `eval.log` 查看.
```mAP
+Ascend:
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.347
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.503
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.385
@@ -345,6 +359,26 @@ bash scripts/run_eval.sh 0 coco /home/DataSet/MindRecord_COCO/ /home/model/retin
mAP: 0.34747137754625645
```
+```mAP
+GPU:
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.349
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.504
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.385
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.136
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.366
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.506
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.302
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.414
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.415
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.156
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.434
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.608
+
+========================================
+
+mAP: 0.34852168035724435
+```
+
### [模型导出](#content)
#### <span id="usage">用法</span>
@@ -430,34 +464,34 @@ mAP: 0.3499478734634595
#### 训练性能
-| 参数 | Ascend |
-| -------------------------- | ------------------------------------- |
-| 模型名称 | Retinanet |
-| 运行环境 | Ascend 910;CPU 2.6GHz,192cores;Memory 755G;系统 Euler2.8 |
-| 上传时间 | 10/01/2021 |
-| MindSpore 版本 | 1.2.0 |
-| 数据集 | 123287 张图片 |
-| Batch_size | 32 |
-| 训练参数 | src/config.py |
-| 优化器 | Momentum |
-| 损失函数 | Focal loss |
-| 最终损失 | 0.582 |
-| 精确度 (8p) | mAP[0.3475] |
-| 训练总时间 (8p) | 23h16m54s |
-| 脚本 | [链接](https://gitee.com/mindspore/models/tree/master/official/cv/retinanet) |
+| 参数 | Ascend |GPU|
+| -------------------------- | ------------------------------------- |------------------------------------- |
+| 模型名称 | Retinanet |Retinanet |
+| 运行环境 | Ascend 910;CPU 2.6GHz,192cores;Memory 755G;系统 Euler2.8 | Rtx3090;Memory 512G |
+| 上传时间 | 10/01/2021 |17/02/2022 |
+| MindSpore 版本 | 1.2.0 |1.5.0|
+| 数据集 | 123287 张图片 |123287 张图片 |
+| Batch_size | 32 |32 |
+| 训练参数 | src/config.py |config/default_config_gpu.yaml
+| 优化器 | Momentum |Momentum |
+| 损失函数 | Focal loss |Focal loss |
+| 最终损失 | 0.582 |0.57|
+| 精确度 (8p) | mAP[0.3475] |mAP[0.3499] |
+| 训练总时间 (8p) | 23h16m54s |51h39m6s|
+| 脚本 | [链接](https://gitee.com/mindspore/models/tree/master/official/cv/retinanet) |[链接](https://gitee.com/mindspore/models/tree/master/official/cv/retinanet) |
#### 推理性能
-| 参数 | Ascend |
-| ------------------- | --------------------------- |
-| 模型名称 | Retinanet |
-| 运行环境 | Ascend 910;CPU 2.6GHz,192cores;Memory 755G;系统 Euler2.8|
-| 上传时间 | 10/01/2021 |
-| MindSpore 版本 | 1.2.0 |
-| 数据集 | 5k 张图片 |
-| Batch_size | 32 |
-| 精确度 | mAP[0.3475] |
-| 总时间 | 10 mins and 50 seconds |
+| 参数 | Ascend |GPU|
+| ------------------- | --------------------------- |--|
+| 模型名称 | Retinanet |Retinanet |
+| 运行环境 | Ascend 910;CPU 2.6GHz,192cores;Memory 755G;系统 Euler2.8|Rtx3090;Memory 512G |
+| 上传时间 | 10/01/2021 |17/02/2022 |
+| MindSpore 版本 | 1.2.0 |1.5.0|
+| 数据集 | 5k 张图片 |5k 张图片 |
+| Batch_size | 32 |32 |
+| 精确度 | mAP[0.3475] |mAP[0.3499] |
+| 总时间 | 10 mins and 50 seconds |13 mins and 40 seconds |
## [随机情况的描述](#content)
diff --git a/official/cv/retinanet/default_config.yaml b/official/cv/retinanet/config/default_config.yaml
similarity index 100%
rename from official/cv/retinanet/default_config.yaml
rename to official/cv/retinanet/config/default_config.yaml
diff --git a/official/cv/retinanet/config/default_config_gpu.yaml b/official/cv/retinanet/config/default_config_gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..62adf63947d85b3641230f938f94537905781c64
--- /dev/null
+++ b/official/cv/retinanet/config/default_config_gpu.yaml
@@ -0,0 +1,153 @@
+# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
+enable_modelarts: False
+# url for modelarts
+data_url: ""
+train_url: ""
+checkpoint_url: ""
+# path for local
+data_path: "/cache/data"
+output_path: "/cache/train"
+load_path: "/cache/checkpoint_path"
+device_target: "GPU"
+enable_profiling: False
+need_modelarts_dataset_unzip: True
+modelarts_dataset_unzip_name: "MindRecord_COCO"
+
+# ======================================================================================
+# common options
+distribute: False
+
+# ======================================================================================
+# create dataset
+create_dataset: "coco"
+prefix: "retinanet.mindrecord"
+is_training: True
+
+# ======================================================================================
+# Training options
+img_shape: [600, 600]
+num_retinanet_boxes: 67995
+match_thershold: 0.5
+nms_thershold: 0.6
+min_score: 0.1
+max_boxes: 100
+device_num: 1
+
+# learning rate settings
+lr: 0.1
+global_step: 0
+lr_init: 1e-6
+lr_end_rate: 5e-3
+warmup_epochs1: 2
+warmup_epochs2: 5
+warmup_epochs3: 23
+warmup_epochs4: 60
+warmup_epochs5: 160
+momentum: 0.9
+weight_decay: 1.5e-4
+
+# network
+num_default: [9, 9, 9, 9, 9]
+extras_out_channels: [256, 256, 256, 256, 256]
+feature_size: [75, 38, 19, 10, 5]
+aspect_ratios: [[0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0]]
+steps: [8, 16, 32, 64, 128]
+anchor_size: [32, 64, 128, 256, 512]
+prior_scaling: [0.1, 0.2]
+gamma: 2.0
+alpha: 0.75
+num_classes: 81
+
+# `mindrecord_dir` and `coco_root` are better to use absolute path.
+mindrecord_dir: "./"
+coco_root: "./"
+train_data_type: "train2017"
+val_data_type: "val2017"
+instances_set: "annotations/instances_{}.json"
+coco_classes: ["background", "person", "bicycle", "car", "motorcycle", "airplane", "bus",
+ "train", "truck", "boat", "traffic light", "fire hydrant",
+ "stop sign", "parking meter", "bench", "bird", "cat", "dog",
+ "horse", "sheep", "cow", "elephant", "bear", "zebra",
+ "giraffe", "backpack", "umbrella", "handbag", "tie",
+ "suitcase", "frisbee", "skis", "snowboard", "sports ball",
+ "kite", "baseball bat", "baseball glove", "skateboard",
+ "surfboard", "tennis racket", "bottle", "wine glass", "cup",
+ "fork", "knife", "spoon", "bowl", "banana", "apple",
+ "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
+ "donut", "cake", "chair", "couch", "potted plant", "bed",
+ "dining table", "toilet", "tv", "laptop", "mouse", "remote",
+ "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
+ "refrigerator", "book", "clock", "vase", "scissors",
+ "teddy bear", "hair drier", "toothbrush"]
+
+
+# The annotation.json position of voc validation dataset
+voc_root: ""
+
+# voc original dataset
+voc_dir: ""
+
+# if coco or voc used, `image_dir` and `anno_path` are useless
+image_dir: ""
+anno_path: ""
+save_checkpoint: True
+save_checkpoint_epochs: 1
+keep_checkpoint_max: 2
+save_checkpoint_path: "./ckpt"
+finish_epoch: 0
+
+# optimiter options
+workers: 16
+mode: "sink"
+epoch_size: 500
+batch_size: 32
+pre_trained: ""
+pre_trained_epoch_size: 0
+loss_scale: 200
+filter_weight: False
+
+# ======================================================================================
+# Eval options
+dataset: "coco"
+checkpoint_path: "./"
+
+# ======================================================================================
+# export options
+device_id: 0
+file_format: "MINDIR"
+export_batch_size: 1
+file_name: "retinanet"
+
+# ======================================================================================
+# postprocess options
+result_path: ""
+img_path: ""
+img_id_file: ""
+
+---
+# Help description for each configuration
+enable_modelarts: "Whether training on modelarts default: False"
+data_url: "Url for modelarts"
+train_url: "Url for modelarts"
+data_path: "The location of input data"
+output_pah: "The location of the output file"
+device_target: "device id of GPU or Ascend. (Default: None)"
+enable_profiling: "Whether enable profiling while training default: False"
+workers: "Num parallel workers."
+lr: "Learning rate, default is 0.1."
+mode: "Run sink mode or not, default is sink."
+epoch_size: "Epoch size, default is 500."
+batch_size: "Batch size, default is 32."
+pre_trained: "Pretrained Checkpoint file path."
+pre_trained_epoch_size: "Pretrained epoch size."
+save_checkpoint_epochs: "Save checkpoint epochs, default is 1."
+loss_scale: "Loss scale, default is 1024."
+filter_weight: "Filter weight parameters, default is False."
+dataset: "Dataset, default is coco."
+device_id: "Device id, default is 0."
+file_format: "file format choices [AIR, MINDIR]"
+file_name: "output file name."
+export_batch_size: "batch size"
+result_path: "result file path."
+img_path: "image file path."
+img_id_file: "image id file."
diff --git a/official/cv/retinanet/eval.py b/official/cv/retinanet/eval.py
index 701133bb2e24ec8bbf90d2f799e3a0ed2e1ae379..493d30a7767cfd5c7aacffa81d25de6d4c14b22a 100644
--- a/official/cv/retinanet/eval.py
+++ b/official/cv/retinanet/eval.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -158,7 +158,7 @@ def retinanet_eval():
make_dataset_dir(mindrecord_dir, mindrecord_file, prefix)
batch_size = 1
- ds = create_retinanet_dataset(mindrecord_file, batch_size=batch_size, is_training=False)
+ ds = create_retinanet_dataset(mindrecord_file, batch_size=batch_size, repeat_num=1, is_training=False)
backbone = resnet50(config.num_classes)
net = retinanet50(backbone, config)
net = retinanetInferWithDecoder(net, Tensor(default_boxes), config)
diff --git a/official/cv/retinanet/scripts/run_distribute_train.sh b/official/cv/retinanet/scripts/run_distribute_train.sh
index b39944d675dbc624568e8dd66850d022947a8c40..eeba0526b251ba296fe91c050838ddc8124da15e 100644
--- a/official/cv/retinanet/scripts/run_distribute_train.sh
+++ b/official/cv/retinanet/scripts/run_distribute_train.sh
@@ -1,5 +1,5 @@
#!/bin/bash
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,29 +16,30 @@
echo "=============================================================================================================="
echo "Please run the script as: "
-echo "sh scripts/run_distribute_train.sh DEVICE_NUM RANK_TABLE_FILE MINDRECORD_DIR PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
-echo "for example: sh scripts/run_distribute_train.sh 8 /data/hccl.json /cache/mindrecord_dir/ /opt/retinanet-500_458.ckpt(optional) 200(optional)"
+echo "sh scripts/run_distribute_train.sh DEVICE_NUM RANK_TABLE_FILE MINDRECORD_DIR CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
+echo "for example: sh scripts/run_distribute_train.sh 8 /data/hccl.json /cache/mindrecord_dir/ /config/default_config.yaml /opt/retinanet-500_458.ckpt(optional) 200(optional)"
echo "It is better to use absolute path."
echo "================================================================================================================="
-if [ $# != 3 ] && [ $# != 5 ]
+if [ $# != 4 ] && [ $# != 6 ]
then
- echo "Usage: sh scripts/run_distribute_train.sh [DEVICE_NUM] [RANK_TABLE_FILE] \
- [MINDRECORD_DIR] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
+ echo "Usage: sh scripts/run_distribute_train.sh [DEVICE_NUM] [RANK_TABLE_FILE]\
+ [MINDRECORD_DIR] [CONFIG_PATH] PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
exit 1
fi
-core_num=`cat /proc/cpuinfo |grep "processor"|wc -l`
-process_cores=$(($core_num/8))
-
-echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
-
export RANK_SIZE=$1
MINDRECORD_DIR=$3
-PRE_TRAINED=$4
-PRE_TRAINED_EPOCH_SIZE=$5
+CONFIG_PATH=$4
+PRE_TRAINED=$5
+PRE_TRAINED_EPOCH_SIZE=$6
export RANK_TABLE_FILE=$2
+core_num=`cat /proc/cpuinfo |grep "processor"|wc -l`
+process_cores=$(($core_num/$RANK_SIZE))
+
+echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
+
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
@@ -47,7 +48,7 @@ do
cp ./*.py ./LOG$i
cp -r ./src ./LOG$i
cp -r ./scripts ./LOG$i
- cp ./*yaml ./LOG$i
+ cp ./config/*yaml ./LOG$i
start=`expr $i \* $process_cores`
end=`expr $start \+ $(($process_cores-1))`
cmdopt=$start"-"$end
@@ -55,20 +56,22 @@ do
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
- if [ $# == 3 ]
+ if [ $# == 4 ]
then
taskset -c $cmdopt python train.py \
--workers=$process_cores \
--distribute=True \
+ --config_path=$CONFIG_PATH \
--mindrecord_dir=$MINDRECORD_DIR > log.txt 2>&1 &
fi
- if [ $# == 5 ]
+ if [ $# == 6 ]
then
taskset -c $cmdopt python train.py \
--workers=$process_cores \
--distribute=True \
--mindrecord_dir=$MINDRECORD_DIR \
+ --config_path=$CONFIG_PATH \
--pre_trained=$PRE_TRAINED \
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE > log.txt 2>&1 &
fi
diff --git a/official/cv/retinanet/scripts/run_distribute_train_gpu.sh b/official/cv/retinanet/scripts/run_distribute_train_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..04fe70a41fd565e080f3534bad97b1dba21b8fd3
--- /dev/null
+++ b/official/cv/retinanet/scripts/run_distribute_train_gpu.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "sh scripts/run_distribute_train_GPU.sh DEVICE_NUM MINDRECORD_DIR CONFIG_PATH VISIABLE_DEVICES PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
+echo "for example: sh scripts/run_distribute_train_GPU.sh 8 /cache/mindrecord_dir/ /config/default_config_GPU.yaml 0,1,2,3,4,5,6,7 /opt/retinanet-500_458.ckpt(optional) 200(optional)"
+echo "It is better to use absolute path."
+echo "================================================================================================================="
+
+if [ $# != 4 ] && [ $# != 6 ]
+then
+ echo "Usage: sh scripts/run_distribute_train_GPU.sh [DEVICE_NUM] [MINDRECORD_DIR] \
+ [CONFIG_PATH] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
+ exit 1
+fi
+
+export RANK_SIZE=$1
+MINDRECORD_DIR=$2
+CONFIG_PATH=$3
+export CUDA_VISIBLE_DEVICES="$4"
+PRE_TRAINED=$5
+PRE_TRAINED_EPOCH_SIZE=$6
+
+core_num=`cat /proc/cpuinfo |grep "processor"|wc -l`
+process_cores=$(($core_num/$RANK_SIZE))
+
+echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
+
+rm -rf LOG
+mkdir ./LOG
+cp ./*.py ./LOG
+cp -r ./src ./LOG
+cp -r ./scripts ./LOG
+cp ./config/*yaml ./LOG
+cd ./LOG || exit
+
+if [ $# == 4 ]
+then
+ mpirun -allow-run-as-root -n $1 --output-filename log_output --merge-stderr-to-stdout \
+ python train.py \
+ --distribute=True \
+ --device_num=$RANK_SIZE \
+ --workers=$process_cores \
+ --config_path=$CONFIG_PATH \
+ --mindrecord_dir=$MINDRECORD_DIR > log.txt 2>&1 &
+fi
+
+if [ $# == 6 ]
+then
+ mpirun -allow-run-as-root -n $1 --output-filename log_output --merge-stderr-to-stdout \
+ python train.py \
+ --distribute=True \
+ --device_num=$RANK_SIZE \
+ --workers=$process_cores \
+ --config_path=$CONFIG_PATH \
+ --mindrecord_dir=$MINDRECORD_DIR\
+ --pre_trained=$PRE_TRAINED \
+ --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE > log.txt 2>&1 &
+fi
+
+cd ../
diff --git a/official/cv/retinanet/scripts/run_eval.sh b/official/cv/retinanet/scripts/run_eval.sh
index 9875cbbc1d309a1a8adefb327cced32b0f2a7961..d7b57a7ba39482c908b5361f1f14ad12c9172399 100644
--- a/official/cv/retinanet/scripts/run_eval.sh
+++ b/official/cv/retinanet/scripts/run_eval.sh
@@ -1,5 +1,5 @@
#!/bin/bash
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,9 +14,9 @@
# limitations under the License.
# ============================================================================
-if [ $# != 5 ]
+if [ $# != 6 ]
then
- echo "Usage: sh scripts/run_eval.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [checkpoint_path] [instances_set]"
+ echo "Usage: sh scripts/run_eval.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [checkpoint_path] [instances_set] [CONFIG_PATH]"
exit 1
fi
@@ -24,6 +24,7 @@ DATASET=$2
MINDRECORD_DIR=$3
CHECKPOINT_PATH=$4
INSTANCE_SET=$5
+CONFIG_PATH=$6
echo $DATASET
export DEVICE_NUM=1
@@ -31,6 +32,7 @@ export DEVICE_ID=$1
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
+
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
cd $BASE_PATH/../ || exit
@@ -42,7 +44,7 @@ fi
mkdir ./eval$2
cp ./*.py ./eval$2
cp -r ./src ./eval$2
-cp ./*yaml ./eval$2
+cp ./config/*yaml ./eval$2
cd ./eval$2 || exit
env > env.log
echo "start inferring for device $DEVICE_ID"
@@ -50,5 +52,6 @@ python eval.py \
--dataset=$DATASET \
--checkpoint_path=$CHECKPOINT_PATH \
--instances_set=$INSTANCE_SET \
+ --config_path=$CONFIG_PATH \
--mindrecord_dir=$MINDRECORD_DIR > log.txt 2>&1 &
cd ..
diff --git a/official/cv/retinanet/scripts/run_eval_gpu.sh b/official/cv/retinanet/scripts/run_eval_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0cc22236ad4e4d8386be7bb7e15c4a1d963c27a3
--- /dev/null
+++ b/official/cv/retinanet/scripts/run_eval_gpu.sh
@@ -0,0 +1,57 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+if [ $# != 6 ]
+then
+ echo "Usage: sh scripts/run_eval_GPU.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [checkpoint_path] [instances_set] [CONFIG_PATH]"
+exit 1
+fi
+
+DATASET=$2
+MINDRECORD_DIR=$3
+CHECKPOINT_PATH=$4
+INSTANCE_SET=$5
+CONFIG_PATH=$6
+echo $DATASET
+
+export DEVICE_NUM=1
+export DEVICE_ID=$1
+export RANK_SIZE=$DEVICE_NUM
+export RANK_ID=0
+
+
+BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
+cd $BASE_PATH/../ || exit
+
+if [ -d "eval$2" ];
+then
+ rm -rf ./eval$2
+fi
+
+mkdir ./eval$2
+cp ./*.py ./eval$2
+cp -r ./src ./eval$2
+cp ./config/*yaml ./eval$2
+cd ./eval$2 || exit
+env > env.log
+echo "start inferring for device $DEVICE_ID"
+python eval.py \
+ --dataset=$DATASET \
+ --checkpoint_path=$CHECKPOINT_PATH \
+ --instances_set=$INSTANCE_SET \
+ --config_path=$CONFIG_PATH \
+ --mindrecord_dir=$MINDRECORD_DIR > log.txt 2>&1 &
+cd ..
diff --git a/official/cv/retinanet/scripts/run_single_train.sh b/official/cv/retinanet/scripts/run_single_train.sh
index 3d14d36f18e12ea0a4891b2f44c321f9b0a73098..3feaf753a46d657b3f0df0b02b0a4605d51935e5 100644
--- a/official/cv/retinanet/scripts/run_single_train.sh
+++ b/official/cv/retinanet/scripts/run_single_train.sh
@@ -1,5 +1,5 @@
#!/bin/bash
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,15 +16,15 @@
echo "=============================================================================================================="
echo "Please run the script as: "
-echo "sh scripts/run_single_train.sh DEVICE_ID MINDRECORD_DIR PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
-echo "for example: sh scripts/run_single_train.sh 0 /cache/mindrecord_dir/ /opt/retinanet-500_458.ckpt(optional) 200(optional)"
+echo "sh scripts/run_single_train.sh DEVICE_ID MINDRECORD_DIR CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
+echo "for example: sh scripts/run_single_train.sh 0 /cache/mindrecord_dir/ /config/default_config.yaml /opt/retinanet-500_458.ckpt(optional) 200(optional)"
echo "It is better to use absolute path."
echo "================================================================================================================="
-if [ $# != 2 ] && [ $# != 4 ]
+if [ $# != 3 ] && [ $# != 5 ]
then
echo "Usage: sh scripts/run_single_train.sh [DEVICE_ID] [MINDRECORD_DIR] \
-[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
+[CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
exit 1
fi
@@ -37,29 +37,32 @@ echo "After running the script, the network runs in the background. The log will
export DEVICE_ID=$1
MINDRECORD_DIR=$2
-PRE_TRAINED=$3
-PRE_TRAINED_EPOCH_SIZE=$4
+CONFIG_PATH=$3
+PRE_TRAINED=$4
+PRE_TRAINED_EPOCH_SIZE=$5
rm -rf LOG$1
mkdir ./LOG$1
cp ./*.py ./LOG$1
cp -r ./src ./LOG$1
-cp ./*yaml ./LOG$1
+cp ./config/*yaml ./LOG$1
cd ./LOG$1 || exit
echo "start training for device $1"
env > env.log
-if [ $# == 2 ]
+if [ $# == 3 ]
then
python train.py \
--distribute=False \
+ --config_path=$CONFIG_PATH \
--mindrecord_dir=$MINDRECORD_DIR > log.txt 2>&1 &
fi
-if [ $# == 4 ]
+if [ $# == 5 ]
then
- python train,py \
+ python train.py \
--distribute=False \
--mindrecord_dir=$MINDRECORD_DIR \
+ --config_path=$CONFIG_PATH \
--pre_trained=$PRE_TRAINED \
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE > log.txt 2>&1 &
fi
diff --git a/official/cv/retinanet/scripts/run_single_train_gpu.sh b/official/cv/retinanet/scripts/run_single_train_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a8b52962fe64c56b447fbd0c8b3bff477b5831f9
--- /dev/null
+++ b/official/cv/retinanet/scripts/run_single_train_gpu.sh
@@ -0,0 +1,71 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+echo "=============================================================================================================="
+echo "Please run the script as: "
+echo "sh scripts/run_single_train.sh DEVICE_ID MINDRECORD_DIR CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
+echo "for example: sh scripts/run_single_train.sh 0 /cache/mindrecord_dir/ /config/default_config.yaml /opt/retinanet-500_458.ckpt(optional) 200(optional)"
+echo "It is better to use absolute path."
+echo "================================================================================================================="
+
+if [ $# != 3 ] && [ $# != 5 ]
+then
+ echo "Usage: sh scripts/run_single_train_GPU.sh [DEVICE_ID] [MINDRECORD_DIR] \
+[CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
+ exit 1
+fi
+
+core_num=`cat /proc/cpuinfo |grep "processor"|wc -l`
+process_cores=$core_num
+
+echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
+
+export DEVICE_ID=$1
+MINDRECORD_DIR=$2
+CONFIG_PATH=$3
+PRE_TRAINED=$4
+PRE_TRAINED_EPOCH_SIZE=$5
+
+rm -rf LOG$1
+mkdir ./LOG$1
+cp ./*.py ./LOG$1
+cp -r ./src ./LOG$1
+cp ./config/*yaml ./LOG$1
+cd ./LOG$1 || exit
+echo "start training for device $1"
+env > env.log
+if [ $# == 3 ]
+then
+ python train.py \
+ --distribute=False \
+ --config_path=$CONFIG_PATH \
+ --workers=$process_cores \
+ --mindrecord_dir=$MINDRECORD_DIR > log.txt 2>&1 &
+fi
+
+if [ $# == 5 ]
+then
+ python train.py \
+ --distribute=False \
+ --mindrecord_dir=$MINDRECORD_DIR \
+ --config_path=$CONFIG_PATH \
+ --workers=$process_cores \
+ --pre_trained=$PRE_TRAINED \
+ --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE > log.txt 2>&1 &
+fi
+
+cd ../
+
diff --git a/official/cv/retinanet/src/dataset.py b/official/cv/retinanet/src/dataset.py
index 563c44a25b1e5fc0aa64df6127e9203d4c4d0e55..c5105e9c4ba047ec2780deebd4c6a2ef2180c0b1 100644
--- a/official/cv/retinanet/src/dataset.py
+++ b/official/cv/retinanet/src/dataset.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -387,8 +387,8 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="reti
writer.commit()
-def create_retinanet_dataset(mindrecord_file, batch_size, device_num=1, rank=0,
- is_training=True, num_parallel_workers=24):
+def create_retinanet_dataset(mindrecord_file, batch_size, repeat_num, device_num=1, rank=0,
+ is_training=True, num_parallel_workers=8):
"""Creatr retinanet dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)
diff --git a/official/cv/retinanet/train.py b/official/cv/retinanet/train.py
index fc584c92101495f3c08e8880b52d7fdf1af074db..fedb5d9f9ba81f6601994153e546f20dc45047de 100644
--- a/official/cv/retinanet/train.py
+++ b/official/cv/retinanet/train.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ import ast
import time
import mindspore.nn as nn
from mindspore import context, Tensor
-from mindspore.communication.management import init
+from mindspore.communication.management import init, get_rank
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor, Callback
from mindspore.train import Model
from mindspore.context import ParallelMode
@@ -34,7 +34,6 @@ from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
-
set_seed(1)
@@ -113,6 +112,11 @@ def modelarts_pre_process():
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
+def set_graph_kernel_context(device_target):
+ if device_target == "GPU":
+ # Enable graph kernel for default model ssd300 on GPU back-end.
+ context.set_context(enable_graph_kernel=True,
+ graph_kernel_flags="--enable_parallel_fusion --enable_expand_ops=Conv2D")
@moxing_wrapper(pre_process=modelarts_pre_process)
def main():
@@ -126,7 +130,7 @@ def main():
if os.getenv("DEVICE_ID", "not_set").isdigit():
context.set_context(device_id=get_device_id())
init()
- device_num = get_device_num()
+ device_num = config.device_num
rank = get_rank_id()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
@@ -135,8 +139,22 @@ def main():
device_num = 1
context.set_context(device_id=get_device_id())
- # Set mempool block size in PYNATIVE_MODE for improving memory utilization, which will not take effect in GRAPH_MODE
- context.set_context(mempool_block_size="31GB")
+ elif config.device_target == "GPU":
+ context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
+ set_graph_kernel_context(config.device_target)
+ if config.distribute:
+ if os.getenv("DEVICE_ID", "not_set").isdigit():
+ context.set_context(device_id=get_device_id())
+ init()
+ device_num = config.device_num
+ rank = get_rank()
+ context.reset_auto_parallel_context()
+ context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
+ device_num=device_num)
+ else:
+ rank = 0
+ device_num = 1
+ context.set_context(device_id=get_device_id())
else:
raise ValueError("Unsupported platform.")
@@ -145,10 +163,9 @@ def main():
loss_scale = float(config.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0.
- dataset = create_retinanet_dataset(mindrecord_file, num_parallel_workers=config.workers,
- batch_size=config.batch_size,
- device_num=device_num,
- rank=rank)
+ dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1,
+ num_parallel_workers=8,
+ batch_size=config.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")