diff --git a/official/cv/retinanet/README_CN.md b/official/cv/retinanet/README_CN.md index f7fb709bcf45f130001591a47c4937b2c9466d5e..d6f85f827a32747f56cb9daf2c749f6f7d6abdb8 100644 --- a/official/cv/retinanet/README_CN.md +++ b/official/cv/retinanet/README_CN.md @@ -1,6 +1,7 @@ + <!-- TOC --> -- [retinanet 鎻忚堪](#retinanet-鎻忚堪) +- <span id="content">[Retinanet 鎻忚堪](#Retinanet-鎻忚堪)</span> - [妯″瀷鏋舵瀯](#妯″瀷鏋舵瀯) - [鏁版嵁闆哴(#鏁版嵁闆�) - [鐜瑕佹眰](#鐜瑕佹眰) @@ -21,19 +22,18 @@ - [鎺ㄧ悊杩囩▼](#鎺ㄧ悊杩囩▼) - [鐢ㄦ硶](#usage) - [杩愯](#running) - - [鍦╫nnx鎵ц鎺ㄧ悊](#span-idonnxrunning鍦╫nnx鎵ц鎺ㄧ悊) - [缁撴灉](#outcome) - - [Onnx缁撴灉](#span-idonnxoutcomeonnx缁撴灉) - [妯″瀷璇存槑](#妯″瀷璇存槑) - [鎬ц兘](#鎬ц兘) - - [璁粌鎬ц兘](#璁粌鎬ц兘) - - [鎺ㄧ悊鎬ц兘](#鎺ㄧ悊鎬ц兘) + - [璁粌鎬ц兘](#璁粌鎬ц兘) + - [鎺ㄧ悊鎬ц兘](#鎺ㄧ悊鎬ц兘) - [闅忔満鎯呭喌鐨勬弿杩癩(#闅忔満鎯呭喌鐨勬弿杩�) - [ModelZoo 涓婚〉](#modelzoo-涓婚〉) +- [杩佺Щ瀛︿範](#杩佺Щ瀛︿範) <!-- /TOC --> -## [retinanet 鎻忚堪](#content) +## [Retinanet 鎻忚堪](#content) RetinaNet绠楁硶婧愯嚜2018骞碏acebook AI Research鐨勮鏂� Focal Loss for Dense Object Detection銆傝璁烘枃鏈€澶х殑璐$尞鍦ㄤ簬鎻愬嚭浜咶ocal Loss鐢ㄤ簬瑙e喅绫诲埆涓嶅潎琛¢棶棰橈紝浠庤€屽垱閫犱簡RetinaNet锛圤ne Stage鐩爣妫€娴嬬畻娉曪級杩欎釜绮惧害瓒呰秺缁忓吀Two Stage鐨凢aster-RCNN鐨勭洰鏍囨娴嬬綉缁溿€� @@ -55,7 +55,16 @@ MSCOCO2017 - 鏁版嵁闆嗗ぇ灏�: 19.3G, 123287寮�80绫诲僵鑹插浘鍍� - 璁粌:19.3G, 118287寮犲浘鐗� + - 娴嬭瘯:1814.3M, 5000寮犲浘鐗� + +- 鏁版嵁鏍煎紡:RGB鍥惧儚. + + - 娉ㄦ剰锛氭暟鎹皢鍦╯rc/dataset.py 涓澶勭悊 + +face-mask-detection(杩佺Щ瀛︿範浣跨敤) + +- 鏁版嵁闆嗗ぇ灏�: 397.65MB, 853寮�3绫诲僵鑹插浘鍍� - 鏁版嵁鏍煎紡:RGB鍥惧儚. - 娉ㄦ剰锛氭暟鎹皢鍦╯rc/dataset.py 涓澶勭悊 @@ -85,10 +94,11 @@ MSCOCO2017 鈹溾攢run_distribute_train_gpu.sh # 浣跨敤GPU鐜鍏崱骞惰璁粌 鈹溾攢run_single_train_gpu.sh # 浣跨敤GPU鐜鍗曞崱璁粌 鈹溾攢run_infer_310.sh # Ascend鎺ㄧ悊shell鑴氭湰 - 鈹溾攢run_onnx_eval.sh # onnx鎺ㄧ悊鐨剆hell鑴氭湰 - 鈹溾攢run_onnx_eval_gpu.sh # 浣跨敤GPU鐜杩愯onnx鎺ㄧ悊鐨剆hell鑴氭湰 鈹溾攢run_eval.sh # 浣跨敤Ascend鐜杩愯鎺ㄧ悊鑴氭湰 鈹溾攢run_eval_gpu.sh # 浣跨敤GPU鐜杩愯鎺ㄧ悊鑴氭湰 + 鈹溾攢config + 鈹溾攢finetune_config.yaml # 杩佺Щ瀛︿範鍙傛暟閰嶇疆 + 鈹斺攢default_config.yaml # 鍙傛暟閰嶇疆 鈹溾攢src 鈹溾攢dataset.py # 鏁版嵁棰勫鐞� 鈹溾攢retinanet.py # 缃戠粶妯″瀷瀹氫箟 @@ -106,8 +116,9 @@ MSCOCO2017 鈹溾攢export.py # 瀵煎嚭 AIR,MINDIR妯″瀷鐨勮剼鏈� 鈹溾攢postprogress.py # 310鎺ㄧ悊鍚庡鐞嗚剼鏈� 鈹斺攢eval.py # 缃戠粶鎺ㄧ悊鑴氭湰 - 鈹斺攢onnx_eval.py # 鐢ㄤ簬onnx鎺ㄧ悊 鈹斺攢create_data.py # 鏋勫缓Mindrecord鏁版嵁闆嗚剼鏈� + 鈹斺攢data_split.py # 杩佺Щ瀛︿範鏁版嵁闆嗗垝鍒嗚剼鏈� + 鈹斺攢quick_start.py # 杩佺Щ瀛︿範鍙鍖栬剼鏈� 鈹斺攢default_config.yaml # 鍙傛暟閰嶇疆 ``` @@ -318,7 +329,7 @@ Epoch time: 164531.610, per step time: 359.239 ### [璇勪及杩囩▼](#content) -#### 鐢ㄦ硶 +#### <span id="usage">鐢ㄦ硶</span> 浣跨敤shell鑴氭湰杩涜璇勪及銆俿hell鑴氭湰鐨勭敤娉曞涓�: @@ -336,7 +347,7 @@ bash scripts/run_eval_gpu.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [CHECKPOINT_ > checkpoint 鍙互鍦ㄨ缁冭繃绋嬩腑浜х敓. -#### 缁撴灉 +#### <span id="outcome">缁撴灉</span> 璁$畻缁撴灉灏嗗瓨鍌ㄥ湪绀轰緥璺緞涓紝鎮ㄥ彲浠ュ湪 `eval.log` 鏌ョ湅. @@ -382,17 +393,17 @@ mAP: 0.34852168035724435 ### [妯″瀷瀵煎嚭](#content) -#### 鐢ㄦ硶 +#### <span id="usage">鐢ㄦ硶</span> 瀵煎嚭妯″瀷鍓嶈淇敼config.py鏂囦欢涓殑checkpoint_path閰嶇疆椤癸紝鍊间负checkpoint鐨勮矾寰勩€� -```bash +```shell python export.py --file_name [RUN_PLATFORM] --file_format[EXPORT_FORMAT] --checkpoint_path [CHECKPOINT PATH] ``` -`EXPORT_FORMAT` 鍙€� ["AIR", "MINDIR", "ONNX"] +`EXPORT_FORMAT` 鍙€� ["AIR", "MINDIR"] -#### 杩愯 +#### <span id="running">杩愯</span> ```杩愯 python export.py --file_name retinanet --file_format MINDIR --checkpoint_path /cache/checkpoint/retinanet_550-458.ckpt @@ -420,7 +431,7 @@ python export.py --file_name retinanet --file_format MINDIR --checkpoint_path / ### [鎺ㄧ悊杩囩▼](#content) -#### 鐢ㄦ硶 +#### <span id="usage">鐢ㄦ硶</span> 鍦ㄦ帹鐞嗕箣鍓嶉渶瑕佸湪鏄囪吘910鐜涓婂畬鎴愭ā鍨嬬殑瀵煎嚭銆傛帹鐞嗘椂瑕佸皢iscrowd涓簍rue鐨勫浘鐗囨帓闄ゆ帀銆傚湪ascend310_infer鐩綍涓嬩繚瀛樹簡鍘绘帓闄ゅ悗鐨勫浘鐗噄d銆� 杩橀渶瑕佷慨鏀筩onfig.py鏂囦欢涓殑coco_root銆乿al_data_type銆乮nstances_set閰嶇疆椤癸紝鍊煎垎鍒彇coco鏁版嵁闆嗙殑鐩綍锛屾帹鐞嗘墍鐢ㄦ暟鎹泦鐨勭洰褰曞悕绉帮紝鎺ㄧ悊瀹屾垚鍚庤绠楃簿搴︾敤鐨刟nnotation鏂囦欢锛宨nstances_set鏄敤val_data_type鎷兼帴璧锋潵鐨勶紝瑕佷繚璇佹枃浠舵纭苟涓斿瓨鍦ㄣ€� @@ -430,36 +441,17 @@ python export.py --file_name retinanet --file_format MINDIR --checkpoint_path / bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [ANN_FILE] [DEVICE_ID] ``` -#### 杩愯 +#### <span id="running">杩愯</span> ```杩愯 bash run_infer_310.sh ./retinanet.mindir ./dataset/coco2017/val2017 ./image_id.txt 0 ``` -#### 鍦╫nnx鎵ц鎺ㄧ悊 - -鍦ㄦ墽琛屾帹鐞嗗墠锛宱nnx鏂囦欢蹇呴』閫氳繃 `export.py`鑴氭湰瀵煎嚭,閫氳繃config_path閫夋嫨閫傜敤浜庝笉鍚屽钩鍙扮殑config鏂囦欢銆備互涓嬪睍绀轰簡浣跨敤onnx妯″瀷鎵ц鎺ㄧ悊鐨勭ず渚嬨€� - -```shell -# Onnx inference -python export.py --file_name [RUN_PLATFORM] --file_format[EXPORT_FORMAT] --checkpoint_path [CHECKPOINT PATH] --config_path [CONFIG PATH] -``` - -EXPORT_FORMAT 閫夋嫨 ["ONNX"] - -浣跨敤shell鑴氭湰杩涜璇勪及銆俿hell鑴氭湰鐨勭敤娉曞涓�: - -```bash -GPU: -bash scripts/run_onnx_eval_gpu.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [ONNX_PATH] [ANN_FILE PATH] [CONFIG_PATH] -# example: bash scripts/run_onnx_eval_gpu.sh 0 coco ./MindRecord_COCO/ /home/retinanet/retinanet.onnx ./cocodataset/annotations/instances_{}.json ./config/default_config_gpu.yaml -``` - -#### 缁撴灉 +#### <span id="outcome">缁撴灉</span> 鎺ㄧ悊鐨勭粨鏋滀繚瀛樺湪褰撳墠鐩綍涓嬶紝鍦╝cc.log鏃ュ織鏂囦欢涓彲浠ユ壘鍒扮被浼间互涓嬬殑缁撴灉銆� -```log +```mAP Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.350 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.509 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.385 @@ -478,63 +470,40 @@ bash scripts/run_onnx_eval_gpu.sh [DEVICE_ID] [DATASET] [MINDRECORD_DIR] [ONNX_P mAP: 0.3499478734634595 ``` -#### Onnx缁撴灉 - -鎺ㄧ悊鐨勭粨鏋滀繚瀛樺湪褰撳墠鐩綍涓嬶紝鍦╨og.txt鏃ュ織鏂囦欢涓彲浠ユ壘鍒扮被浼间互涓嬬殑缁撴灉銆� - -```text - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.350 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.508 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.387 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.133 - Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.365 - Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.517 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.304 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.415 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.417 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.151 - Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.433 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.614 - -======================================== - -mAP: 0.35043225294034314 -``` - ## [妯″瀷璇存槑](#content) ### [鎬ц兘](#content) #### 璁粌鎬ц兘 -| 鍙傛暟 | Ascend | GPU | -| --------------- | ------------------------------------------------------------------------- | ------------------------------------------------------------------------- | -| 妯″瀷鍚嶇О | Retinanet | Retinanet | -| 杩愯鐜 | Ascend 910锛汣PU 2.6GHz锛�192cores锛汳emory 755G锛涚郴缁� Euler2.8 | Rtx3090;Memory 512G | -| 涓婁紶鏃堕棿 | 10/01/2021 | 17/02/2022 | -| MindSpore 鐗堟湰 | 1.2.0 | 1.5.0 | -| 鏁版嵁闆� | 123287 寮犲浘鐗� | 123287 寮犲浘鐗� | -| Batch_size | 32 | 32 | -| 璁粌鍙傛暟 | src/config.py | config/default_config_gpu.yaml | -| 浼樺寲鍣� | Momentum | Momentum | -| 鎹熷け鍑芥暟 | Focal loss | Focal loss | -| 鏈€缁堟崯澶� | 0.582 | 0.57 | -| 绮剧‘搴� (8p) | mAP[0.3475] | mAP[0.3499] | -| 璁粌鎬绘椂闂� (8p) | 23h16m54s | 51h39m6s | -| 鑴氭湰 | [閾炬帴](https://gitee.com/mindspore/models/tree/master/official/cv/retinanet) | [閾炬帴](https://gitee.com/mindspore/models/tree/master/official/cv/retinanet) | +| 鍙傛暟 | Ascend |GPU| +| -------------------------- | ------------------------------------- |------------------------------------- | +| 妯″瀷鍚嶇О | Retinanet |Retinanet | +| 杩愯鐜 | Ascend 910锛汣PU 2.6GHz锛�192cores锛汳emory 755G锛涚郴缁� Euler2.8 | Rtx3090;Memory 512G | +| 涓婁紶鏃堕棿 | 10/01/2021 |17/02/2022 | +| MindSpore 鐗堟湰 | 1.2.0 |1.5.0| +| 鏁版嵁闆� | 123287 寮犲浘鐗� |123287 寮犲浘鐗� | +| Batch_size | 32 |32 | +| 璁粌鍙傛暟 | src/config.py |config/default_config_gpu.yaml +| 浼樺寲鍣� | Momentum |Momentum | +| 鎹熷け鍑芥暟 | Focal loss |Focal loss | +| 鏈€缁堟崯澶� | 0.582 |0.57| +| 绮剧‘搴� (8p) | mAP[0.3475] |mAP[0.3499] | +| 璁粌鎬绘椂闂� (8p) | 23h16m54s |51h39m6s| +| 鑴氭湰 | [閾炬帴](https://gitee.com/mindspore/models/tree/master/official/cv/retinanet) |[閾炬帴](https://gitee.com/mindspore/models/tree/master/official/cv/retinanet) | #### 鎺ㄧ悊鎬ц兘 -| 鍙傛暟 | Ascend | GPU | -| -------------- | ------------------------------------------------------------ | ---------------------- | -| 妯″瀷鍚嶇О | Retinanet | Retinanet | -| 杩愯鐜 | Ascend 910锛汣PU 2.6GHz锛�192cores锛汳emory 755G锛涚郴缁� Euler2.8 | Rtx3090;Memory 512G | -| 涓婁紶鏃堕棿 | 10/01/2021 | 17/02/2022 | -| MindSpore 鐗堟湰 | 1.2.0 | 1.5.0 | -| 鏁版嵁闆� | 5k 寮犲浘鐗� | 5k 寮犲浘鐗� | -| Batch_size | 32 | 32 | -| 绮剧‘搴� | mAP[0.3475] | mAP[0.3499] | -| 鎬绘椂闂� | 10 mins and 50 seconds | 13 mins and 40 seconds | +| 鍙傛暟 | Ascend |GPU| +| ------------------- | --------------------------- |--| +| 妯″瀷鍚嶇О | Retinanet |Retinanet | +| 杩愯鐜 | Ascend 910锛汣PU 2.6GHz锛�192cores锛汳emory 755G锛涚郴缁� Euler2.8|Rtx3090;Memory 512G | +| 涓婁紶鏃堕棿 | 10/01/2021 |17/02/2022 | +| MindSpore 鐗堟湰 | 1.2.0 |1.5.0| +| 鏁版嵁闆� | 5k 寮犲浘鐗� |5k 寮犲浘鐗� | +| Batch_size | 32 |32 | +| 绮剧‘搴� | mAP[0.3475] |mAP[0.3499] | +| 鎬绘椂闂� | 10 mins and 50 seconds |13 mins and 40 seconds | ## [闅忔満鎯呭喌鐨勬弿杩癩(#content) @@ -543,3 +512,125 @@ mAP: 0.35043225294034314 ## [ModelZoo 涓婚〉](#content) 璇锋牳瀵瑰畼鏂� [涓婚〉](https://gitee.com/mindspore/models). + +## [杩佺Щ瀛︿範](#content) + +### [杩佺Щ瀛︿範璁粌娴佺▼](#content) + +#### 鏁版嵁闆嗗鐞� + +[鏁版嵁闆嗕笅杞藉湴鍧€](https://www.kaggle.com/datasets/andrewmvd/face-mask-detection) + +涓嬭浇鏁版嵁闆嗗悗瑙e帇鑷硆etinanet鏍圭洰褰曚笅锛屼娇鐢╠ata_split鑴氭湰鍒掑垎鍑�80%鐨勮缁冮泦鍜�20%鐨勬祴璇曢泦 + +```bash +杩愯鑴氭湰绀轰緥 +python data_split.py +``` + +```text +鏁版嵁闆嗙粨鏋� +鈹斺攢dataset + 鈹溾攢train + 鈹溾攢val + 鈹溾攢annotation + +``` + +```text +璁粌鍓嶏紝鍏堝垱寤篗indRecord鏂囦欢锛屼互face_mask_detection鏁版嵁闆嗕负渚嬶紝yaml鏂囦欢閰嶇疆濂絝acemask鏁版嵁闆嗚矾寰勫拰mindrecord瀛樺偍璺緞 +# your dataset dir +dataset_root: /home/mindspore/retinanet/dataset/ +# mindrecord dataset dir +mindrecord_dir: /home/mindspore/retinanet/mindrecord +``` + +```bash +# 鐢熸垚璁粌鏁版嵁闆� +python create_data.py --config_path +(渚嬪锛歱ython create_data.py --config_path './config/finetune_config.yaml') + +# 鐢熸垚娴嬭瘯鏁版嵁闆� +娴嬭瘯鏁版嵁闆嗗彲浠ュ湪璁粌瀹屾垚鐢眅val鑴氭湰鑷姩鐢熸垚 +``` + +#### 杩佺Щ瀛︿範璁粌杩囩▼ + +闇€瑕佸厛浠嶽Mindspore Hub](https://www.mindspore.cn/resources/hub/details?MindSpore/1.8/retinanet_coco2017)涓嬭浇棰勮缁冪殑ckpt + +```text +# 鍦╢inetune_config.yaml璁剧疆棰勮缁冩ā鍨嬬殑ckpt +pre_trained: "/home/mindspore/retinanet/retinanet_ascend_v170_coco2017_official_cv_acc35.ckpt" +``` + +```bash +#杩愯杩佺Щ瀛︿範璁粌鑴氭湰 +python train.py --config_path './config/finetune_config.yaml' +濡傛灉闇€瑕佷繚瀛樻棩蹇椾俊鎭紝鍙娇鐢ㄥ涓嬪懡浠わ細 +python train.py --config_path ./config/finetune_config.yaml > log.txt 2>&1 +``` + +**缁撴灉灞曠ず** + +璁粌缁撴灉灏嗗瓨鍌ㄥ湪绀轰緥璺緞涓€俢heckpoint灏嗗瓨鍌ㄥ湪 `./ckpt` 璺緞涓嬶紝璁粌loss杈撳嚭绀轰緥濡備笅锛� + +```text +epoch: 1 step: 42, loss is 4.347288131713867 +lr:[0.000088] +Train epoch time: 992053.072 ms, per step time: 23620.311 ms +Epoch time: 164034.415, per step time: 358.154 +epoch: 3 step: 42, loss is 1.8387094736099243 +lr:[0.000495] +Train epoch time: 738396.280 ms, per step time: 17580.864 ms +epoch: 4 step: 42, loss is 1.3805917501449585 +lr:[0.000695] +Train epoch time: 742051.709 ms, per step time: 17667.898 ms +``` + +#### 杩佺Щ瀛︿範鎺ㄧ悊杩囩▼ + +```bash +#杩愯杩佺Щ瀛︿範璁粌鑴氭湰 +python eval.py --config_path './config/finetune_config.yaml' +``` + +**缁撴灉灞曠ず** + +```text + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.538 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.781 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.634 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.420 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.687 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.856 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.284 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.570 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.574 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.448 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.737 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.872 + +======================================== + +mAP: 0.5376701115352185 + +``` + +#### 杩佺Щ瀛︿範quick_start + +杩愯eval鑴氭湰鍚庯紝浼氱敓鎴恅instances_val.json` 鍜� `predictions.json`鏂囦欢锛岄渶瑕佷慨鏀筦quick_start.py`鑴氭湰涓璥instances_val.json` 鍜� `predictions.json`鏂囦欢鐨勮矾寰勫悗鍐嶈繍琛� + +```bash +# 杩愯quick_start鑴氭湰绀轰緥 +python quick_start.py --config_path './config/finetune_config.yaml' +``` + +**缁撴灉璇存槑** +鍥句腑棰滆壊鐨勫惈涔夊垎鍒槸锛� + +- 娴呰摑: 鐪熷疄鏍囩鐨刴ask_weared_incorrect +- 娴呯豢: 鐪熷疄鏍囩鐨剋ith_mask +- 娴呯孩: 鐪熷疄鏍囩鐨剋ithout_mask +- 钃濊壊: 棰勬祴鏍囩鐨刴ask_weared_incorrect +- 缁胯壊: 棰勬祴鏍囩鐨剋ith_mask +- 绾㈣壊: 棰勬祴鏍囩鐨剋ithout_mask diff --git a/official/cv/retinanet/config/finetune_config.yaml b/official/cv/retinanet/config/finetune_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d1c323a3d2091cda177159456488608fb7ccec7 --- /dev/null +++ b/official/cv/retinanet/config/finetune_config.yaml @@ -0,0 +1,143 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing) +enable_modelarts: False +# url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# path for local +data_path: "./data" +output_path: "./train" +load_path: "" +#device_target: "Ascend" +device_target: "CPU" +enable_profiling: False +need_modelarts_dataset_unzip: False +modelarts_dataset_unzip_name: "MindRecord_COCO" + +# ====================================================================================== +# common options +distribute: False + +# ====================================================================================== +# create dataset +create_dataset: "facemask" +prefix: "retinanet.mindrecord" +is_training: True +python_multiprocessing: False + +# ====================================================================================== +# Training options +img_shape: [600,600] +num_retinanet_boxes: 67995 +match_thershold: 0.5 +nms_thershold: 0.6 +min_score: 0.1 +max_boxes: 100 + +# learning rate settings +lr: 0.009 +global_step: 0 +lr_init: 1e-5 +lr_end_rate: 5e-4 +warmup_epochs1: 0 +warmup_epochs2: 1 +warmup_epochs3: 4 +warmup_epochs4: 12 +warmup_epochs5: 30 +momentum: 0.9 +weight_decay: 1.5e-4 + +# network +num_default: [9, 9, 9, 9, 9] +extras_out_channels: [256, 256, 256, 256, 256] +feature_size: [75, 38, 19, 10, 5] +aspect_ratios: [[0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0], [0.5, 1.0, 2.0]] +steps: [8, 16, 32, 64, 128] +anchor_size: [32, 64, 128, 256, 512] +prior_scaling: [0.1, 0.2] +gamma: 2.0 +alpha: 0.75 +num_classes: 4 + +# `mindrecord_dir` and `coco_root` are better to use absolute path. +mindrecord_dir: "./mindrecord" +dataset_root: "./dataset" +train_data_type: "train" +val_data_type: "val" +instances_set: "annotation/instances_{}.json" +coco_classes: ["background","mask_weared_incorrect", "with_mask", "without_mask"] + + +# The annotation.json position of voc validation dataset +voc_root: "./dataset" +facemask_root: "./dataset" + +# voc original dataset +voc_dir: "./dataset" +facemask_dir: "./dataset" + +# if coco or voc used, `image_dir` and `anno_path` are useless +image_dir: "" +anno_path: "" +save_checkpoint: True +save_checkpoint_epochs: 1 +keep_checkpoint_max: 10 +save_checkpoint_path: "./ckpt" +finish_epoch: 0 + +# optimiter options +workers: 8 +mode: "sink" +epoch_size: 95 +batch_size: 16 +pre_trained: "/home/mindspore/retinanet/retinanet_ascend_v170_coco2017_official_cv_acc35.ckpt" +pre_trained_epoch_size: 90 +loss_scale: 200 +filter_weight: True +finetune: True + +# ====================================================================================== +# Eval options +dataset: "facemask" +checkpoint_path: "./ckpt/retinanet_1-95_42.ckpt" + +# ====================================================================================== +# export options +device_id: 0 +file_format: "MINDIR" +export_batch_size: 1 +file_name: "retinanet" + +# ====================================================================================== +# postprocess options +result_path: "" +img_path: "" +img_id_file: "" + +--- +# Help description for each configuration +enable_modelarts: "Whether training on modelarts default: False" +data_url: "Url for modelarts" +train_url: "Url for modelarts" +data_path: "The location of input data" +output_pah: "The location of the output file" +device_target: "device id of GPU or Ascend. (Default: None)" +enable_profiling: "Whether enable profiling while training default: False" +workers: "Num parallel workers." +lr: "Learning rate, default is 0.1." +mode: "Run sink mode or not, default is sink." +epoch_size: "Epoch size, default is 500." +batch_size: "Batch size, default is 32." +pre_trained: "Pretrained Checkpoint file path." +pre_trained_epoch_size: "Pretrained epoch size." +save_checkpoint_epochs: "Save checkpoint epochs, default is 1." +loss_scale: "Loss scale, default is 1024." +filter_weight: "Filter weight parameters, default is False." +dataset: "Dataset, default is coco." +device_id: "Device id, default is 0." +file_format: "file format choices [AIR, MINDIR]" +file_name: "output file name." +export_batch_size: "batch size" +result_path: "result file path." +img_path: "image file path." +img_id_file: "image id file." diff --git a/official/cv/retinanet/data/facemask/train.txt b/official/cv/retinanet/data/facemask/train.txt new file mode 100644 index 0000000000000000000000000000000000000000..da8c407e749ee36ddebe69ede33d760ea3a4e5e4 --- /dev/null +++ b/official/cv/retinanet/data/facemask/train.txt @@ -0,0 +1,682 @@ +maksssksksss304 +maksssksksss311 +maksssksksss6 +maksssksksss479 +maksssksksss231 +maksssksksss180 +maksssksksss485 +maksssksksss557 +maksssksksss290 +maksssksksss435 +maksssksksss703 +maksssksksss827 +maksssksksss129 +maksssksksss635 +maksssksksss5 +maksssksksss559 +maksssksksss412 +maksssksksss235 +maksssksksss80 +maksssksksss312 +maksssksksss810 +maksssksksss65 +maksssksksss179 +maksssksksss823 +maksssksksss770 +maksssksksss279 +maksssksksss433 +maksssksksss68 +maksssksksss423 +maksssksksss246 +maksssksksss72 +maksssksksss769 +maksssksksss287 +maksssksksss817 +maksssksksss551 +maksssksksss831 +maksssksksss742 +maksssksksss509 +maksssksksss153 +maksssksksss760 +maksssksksss232 +maksssksksss452 +maksssksksss402 +maksssksksss841 +maksssksksss343 +maksssksksss754 +maksssksksss144 +maksssksksss546 +maksssksksss143 +maksssksksss344 +maksssksksss115 +maksssksksss40 +maksssksksss467 +maksssksksss320 +maksssksksss220 +maksssksksss106 +maksssksksss360 +maksssksksss591 +maksssksksss151 +maksssksksss191 +maksssksksss768 +maksssksksss211 +maksssksksss254 +maksssksksss470 +maksssksksss372 +maksssksksss594 +maksssksksss150 +maksssksksss379 +maksssksksss533 +maksssksksss730 +maksssksksss271 +maksssksksss339 +maksssksksss613 +maksssksksss465 +maksssksksss101 +maksssksksss634 +maksssksksss364 +maksssksksss397 +maksssksksss161 +maksssksksss458 +maksssksksss583 +maksssksksss783 +maksssksksss700 +maksssksksss193 +maksssksksss60 +maksssksksss328 +maksssksksss299 +maksssksksss112 +maksssksksss720 +maksssksksss371 +maksssksksss346 +maksssksksss99 +maksssksksss385 +maksssksksss182 +maksssksksss718 +maksssksksss166 +maksssksksss369 +maksssksksss650 +maksssksksss575 +maksssksksss210 +maksssksksss242 +maksssksksss53 +maksssksksss787 +maksssksksss87 +maksssksksss838 +maksssksksss431 +maksssksksss156 +maksssksksss125 +maksssksksss274 +maksssksksss582 +maksssksksss336 +maksssksksss209 +maksssksksss599 +maksssksksss499 +maksssksksss54 +maksssksksss713 +maksssksksss812 +maksssksksss181 +maksssksksss398 +maksssksksss520 +maksssksksss498 +maksssksksss269 +maksssksksss744 +maksssksksss100 +maksssksksss155 +maksssksksss565 +maksssksksss159 +maksssksksss534 +maksssksksss692 +maksssksksss306 +maksssksksss300 +maksssksksss654 +maksssksksss47 +maksssksksss729 +maksssksksss609 +maksssksksss98 +maksssksksss846 +maksssksksss679 +maksssksksss142 +maksssksksss704 +maksssksksss597 +maksssksksss266 +maksssksksss111 +maksssksksss669 +maksssksksss177 +maksssksksss736 +maksssksksss523 +maksssksksss419 +maksssksksss672 +maksssksksss49 +maksssksksss123 +maksssksksss316 +maksssksksss35 +maksssksksss31 +maksssksksss538 +maksssksksss821 +maksssksksss794 +maksssksksss445 +maksssksksss282 +maksssksksss365 +maksssksksss624 +maksssksksss788 +maksssksksss614 +maksssksksss711 +maksssksksss775 +maksssksksss510 +maksssksksss250 +maksssksksss767 +maksssksksss622 +maksssksksss74 +maksssksksss313 +maksssksksss302 +maksssksksss747 +maksssksksss542 +maksssksksss58 +maksssksksss107 +maksssksksss800 +maksssksksss27 +maksssksksss118 +maksssksksss217 +maksssksksss464 +maksssksksss386 +maksssksksss592 +maksssksksss76 +maksssksksss749 +maksssksksss244 +maksssksksss641 +maksssksksss491 +maksssksksss356 +maksssksksss260 +maksssksksss836 +maksssksksss719 +maksssksksss381 +maksssksksss199 +maksssksksss685 +maksssksksss61 +maksssksksss640 +maksssksksss34 +maksssksksss573 +maksssksksss439 +maksssksksss502 +maksssksksss48 +maksssksksss0 +maksssksksss105 +maksssksksss653 +maksssksksss513 +maksssksksss46 +maksssksksss496 +maksssksksss252 +maksssksksss756 +maksssksksss598 +maksssksksss503 +maksssksksss705 +maksssksksss495 +maksssksksss114 +maksssksksss798 +maksssksksss24 +maksssksksss411 +maksssksksss626 +maksssksksss665 +maksssksksss213 +maksssksksss92 +maksssksksss497 +maksssksksss443 +maksssksksss476 +maksssksksss78 +maksssksksss317 +maksssksksss103 +maksssksksss396 +maksssksksss442 +maksssksksss505 +maksssksksss521 +maksssksksss585 +maksssksksss247 +maksssksksss446 +maksssksksss368 +maksssksksss579 +maksssksksss514 +maksssksksss425 +maksssksksss238 +maksssksksss531 +maksssksksss261 +maksssksksss709 +maksssksksss690 +maksssksksss701 +maksssksksss772 +maksssksksss643 +maksssksksss204 +maksssksksss668 +maksssksksss354 +maksssksksss671 +maksssksksss126 +maksssksksss327 +maksssksksss4 +maksssksksss349 +maksssksksss795 +maksssksksss140 +maksssksksss586 +maksssksksss295 +maksssksksss222 +maksssksksss822 +maksssksksss791 +maksssksksss239 +maksssksksss779 +maksssksksss197 +maksssksksss228 +maksssksksss563 +maksssksksss777 +maksssksksss689 +maksssksksss293 +maksssksksss681 +maksssksksss484 +maksssksksss285 +maksssksksss830 +maksssksksss276 +maksssksksss590 +maksssksksss212 +maksssksksss352 +maksssksksss847 +maksssksksss617 +maksssksksss721 +maksssksksss796 +maksssksksss663 +maksssksksss89 +maksssksksss621 +maksssksksss265 +maksssksksss172 +maksssksksss82 +maksssksksss75 +maksssksksss815 +maksssksksss294 +maksssksksss8 +maksssksksss407 +maksssksksss405 +maksssksksss471 +maksssksksss473 +maksssksksss544 +maksssksksss422 +maksssksksss79 +maksssksksss154 +maksssksksss387 +maksssksksss811 +maksssksksss226 +maksssksksss2 +maksssksksss255 +maksssksksss121 +maksssksksss852 +maksssksksss307 +maksssksksss62 +maksssksksss149 +maksssksksss577 +maksssksksss615 +maksssksksss803 +maksssksksss560 +maksssksksss673 +maksssksksss581 +maksssksksss589 +maksssksksss288 +maksssksksss472 +maksssksksss259 +maksssksksss684 +maksssksksss481 +maksssksksss73 +maksssksksss38 +maksssksksss670 +maksssksksss418 +maksssksksss141 +maksssksksss776 +maksssksksss120 +maksssksksss176 +maksssksksss429 +maksssksksss175 +maksssksksss807 +maksssksksss432 +maksssksksss102 +maksssksksss451 +maksssksksss205 +maksssksksss64 +maksssksksss131 +maksssksksss171 +maksssksksss127 +maksssksksss697 +maksssksksss128 +maksssksksss45 +maksssksksss462 +maksssksksss41 +maksssksksss620 +maksssksksss605 +maksssksksss198 +maksssksksss146 +maksssksksss556 +maksssksksss200 +maksssksksss716 +maksssksksss708 +maksssksksss819 +maksssksksss580 +maksssksksss221 +maksssksksss138 +maksssksksss383 +maksssksksss734 +maksssksksss168 +maksssksksss816 +maksssksksss52 +maksssksksss647 +maksssksksss826 +maksssksksss676 +maksssksksss415 +maksssksksss629 +maksssksksss455 +maksssksksss724 +maksssksksss454 +maksssksksss694 +maksssksksss404 +maksssksksss529 +maksssksksss348 +maksssksksss377 +maksssksksss536 +maksssksksss623 +maksssksksss753 +maksssksksss469 +maksssksksss28 +maksssksksss687 +maksssksksss183 +maksssksksss660 +maksssksksss447 +maksssksksss43 +maksssksksss801 +maksssksksss438 +maksssksksss391 +maksssksksss39 +maksssksksss436 +maksssksksss764 +maksssksksss548 +maksssksksss33 +maksssksksss487 +maksssksksss541 +maksssksksss426 +maksssksksss361 +maksssksksss851 +maksssksksss532 +maksssksksss202 +maksssksksss324 +maksssksksss223 +maksssksksss839 +maksssksksss196 +maksssksksss508 +maksssksksss113 +maksssksksss263 +maksssksksss237 +maksssksksss738 +maksssksksss395 +maksssksksss603 +maksssksksss611 +maksssksksss342 +maksssksksss357 +maksssksksss178 +maksssksksss848 +maksssksksss315 +maksssksksss727 +maksssksksss173 +maksssksksss403 +maksssksksss417 +maksssksksss225 +maksssksksss298 +maksssksksss771 +maksssksksss695 +maksssksksss666 +maksssksksss384 +maksssksksss525 +maksssksksss362 +maksssksksss427 +maksssksksss825 +maksssksksss340 +maksssksksss408 +maksssksksss743 +maksssksksss273 +maksssksksss698 +maksssksksss434 +maksssksksss482 +maksssksksss608 +maksssksksss828 +maksssksksss170 +maksssksksss558 +maksssksksss488 +maksssksksss799 +maksssksksss658 +maksssksksss50 +maksssksksss530 +maksssksksss524 +maksssksksss607 +maksssksksss19 +maksssksksss227 +maksssksksss518 +maksssksksss351 +maksssksksss284 +maksssksksss376 +maksssksksss390 +maksssksksss782 +maksssksksss793 +maksssksksss486 +maksssksksss394 +maksssksksss414 +maksssksksss303 +maksssksksss850 +maksssksksss258 +maksssksksss382 +maksssksksss774 +maksssksksss256 +maksssksksss334 +maksssksksss117 +maksssksksss601 +maksssksksss595 +maksssksksss766 +maksssksksss331 +maksssksksss84 +maksssksksss657 +maksssksksss707 +maksssksksss725 +maksssksksss460 +maksssksksss547 +maksssksksss430 +maksssksksss428 +maksssksksss494 +maksssksksss501 +maksssksksss528 +maksssksksss638 +maksssksksss194 +maksssksksss814 +maksssksksss201 +maksssksksss91 +maksssksksss535 +maksssksksss696 +maksssksksss122 +maksssksksss675 +maksssksksss834 +maksssksksss16 +maksssksksss748 +maksssksksss17 +maksssksksss137 +maksssksksss169 +maksssksksss157 +maksssksksss130 +maksssksksss97 +maksssksksss792 +maksssksksss11 +maksssksksss292 +maksssksksss189 +maksssksksss20 +maksssksksss55 +maksssksksss478 +maksssksksss308 +maksssksksss844 +maksssksksss132 +maksssksksss18 +maksssksksss353 +maksssksksss587 +maksssksksss77 +maksssksksss81 +maksssksksss70 +maksssksksss248 +maksssksksss576 +maksssksksss627 +maksssksksss489 +maksssksksss540 +maksssksksss527 +maksssksksss190 +maksssksksss566 +maksssksksss648 +maksssksksss251 +maksssksksss512 +maksssksksss67 +maksssksksss51 +maksssksksss569 +maksssksksss29 +maksssksksss309 +maksssksksss840 +maksssksksss319 +maksssksksss733 +maksssksksss187 +maksssksksss506 +maksssksksss268 +maksssksksss270 +maksssksksss219 +maksssksksss283 +maksssksksss835 +maksssksksss374 +maksssksksss740 +maksssksksss69 +maksssksksss262 +maksssksksss236 +maksssksksss214 +maksssksksss490 +maksssksksss759 +maksssksksss230 +maksssksksss459 +maksssksksss568 +maksssksksss596 +maksssksksss37 +maksssksksss71 +maksssksksss165 +maksssksksss570 +maksssksksss642 +maksssksksss842 +maksssksksss325 +maksssksksss763 +maksssksksss564 +maksssksksss735 +maksssksksss281 +maksssksksss388 +maksssksksss752 +maksssksksss57 +maksssksksss824 +maksssksksss572 +maksssksksss802 +maksssksksss780 +maksssksksss618 +maksssksksss765 +maksssksksss784 +maksssksksss543 +maksssksksss332 +maksssksksss400 +maksssksksss731 +maksssksksss457 +maksssksksss440 +maksssksksss667 +maksssksksss550 +maksssksksss7 +maksssksksss88 +maksssksksss683 +maksssksksss366 +maksssksksss655 +maksssksksss682 +maksssksksss466 +maksssksksss21 +maksssksksss240 +maksssksksss755 +maksssksksss483 +maksssksksss843 +maksssksksss578 +maksssksksss184 +maksssksksss275 +maksssksksss833 +maksssksksss804 +maksssksksss399 +maksssksksss820 +maksssksksss593 +maksssksksss162 +maksssksksss549 +maksssksksss410 +maksssksksss691 +maksssksksss253 +maksssksksss66 +maksssksksss715 +maksssksksss124 +maksssksksss604 +maksssksksss659 +maksssksksss164 +maksssksksss633 +maksssksksss305 +maksssksksss739 +maksssksksss552 +maksssksksss109 +maksssksksss152 +maksssksksss245 +maksssksksss448 +maksssksksss94 +maksssksksss337 +maksssksksss688 +maksssksksss728 +maksssksksss1 +maksssksksss453 +maksssksksss504 +maksssksksss335 +maksssksksss389 +maksssksksss286 +maksssksksss93 +maksssksksss367 +maksssksksss741 +maksssksksss188 +maksssksksss463 +maksssksksss500 +maksssksksss370 +maksssksksss373 +maksssksksss380 +maksssksksss218 +maksssksksss686 +maksssksksss750 +maksssksksss321 +maksssksksss516 +maksssksksss322 +maksssksksss761 +maksssksksss257 +maksssksksss63 +maksssksksss845 +maksssksksss359 +maksssksksss519 +maksssksksss23 +maksssksksss229 +maksssksksss714 +maksssksksss562 +maksssksksss10 +maksssksksss233 +maksssksksss409 +maksssksksss264 +maksssksksss329 +maksssksksss289 +maksssksksss646 +maksssksksss36 +maksssksksss797 +maksssksksss330 +maksssksksss849 +maksssksksss318 +maksssksksss805 +maksssksksss139 +maksssksksss545 +maksssksksss134 +maksssksksss474 +maksssksksss174 +maksssksksss86 +maksssksksss280 +maksssksksss644 +maksssksksss449 diff --git a/official/cv/retinanet/data/facemask/val.txt b/official/cv/retinanet/data/facemask/val.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d4bad29a301cf3d53d7e74e85b647c1cdb628f9 --- /dev/null +++ b/official/cv/retinanet/data/facemask/val.txt @@ -0,0 +1,170 @@ +maksssksksss461 +maksssksksss314 +maksssksksss456 +maksssksksss203 +maksssksksss185 +maksssksksss751 +maksssksksss85 +maksssksksss758 +maksssksksss680 +maksssksksss652 +maksssksksss249 +maksssksksss781 +maksssksksss167 +maksssksksss416 +maksssksksss539 +maksssksksss722 +maksssksksss808 +maksssksksss600 +maksssksksss207 +maksssksksss135 +maksssksksss25 +maksssksksss674 +maksssksksss610 +maksssksksss406 +maksssksksss186 +maksssksksss441 +maksssksksss413 +maksssksksss277 +maksssksksss710 +maksssksksss363 +maksssksksss444 +maksssksksss116 +maksssksksss12 +maksssksksss350 +maksssksksss420 +maksssksksss136 +maksssksksss30 +maksssksksss706 +maksssksksss631 +maksssksksss517 +maksssksksss637 +maksssksksss636 +maksssksksss818 +maksssksksss699 +maksssksksss378 +maksssksksss693 +maksssksksss492 +maksssksksss661 +maksssksksss192 +maksssksksss14 +maksssksksss345 +maksssksksss625 +maksssksksss567 +maksssksksss145 +maksssksksss243 +maksssksksss301 +maksssksksss341 +maksssksksss790 +maksssksksss326 +maksssksksss278 +maksssksksss468 +maksssksksss522 +maksssksksss216 +maksssksksss32 +maksssksksss726 +maksssksksss13 +maksssksksss555 +maksssksksss160 +maksssksksss477 +maksssksksss554 +maksssksksss56 +maksssksksss746 +maksssksksss493 +maksssksksss310 +maksssksksss809 +maksssksksss26 +maksssksksss553 +maksssksksss612 +maksssksksss110 +maksssksksss9 +maksssksksss475 +maksssksksss737 +maksssksksss338 +maksssksksss662 +maksssksksss829 +maksssksksss272 +maksssksksss215 +maksssksksss450 +maksssksksss375 +maksssksksss574 +maksssksksss507 +maksssksksss323 +maksssksksss702 +maksssksksss392 +maksssksksss732 +maksssksksss537 +maksssksksss333 +maksssksksss108 +maksssksksss561 +maksssksksss678 +maksssksksss757 +maksssksksss104 +maksssksksss515 +maksssksksss639 +maksssksksss347 +maksssksksss158 +maksssksksss355 +maksssksksss588 +maksssksksss806 +maksssksksss206 +maksssksksss393 +maksssksksss83 +maksssksksss571 +maksssksksss22 +maksssksksss42 +maksssksksss3 +maksssksksss649 +maksssksksss44 +maksssksksss656 +maksssksksss480 +maksssksksss224 +maksssksksss526 +maksssksksss778 +maksssksksss723 +maksssksksss119 +maksssksksss195 +maksssksksss712 +maksssksksss133 +maksssksksss297 +maksssksksss786 +maksssksksss762 +maksssksksss437 +maksssksksss745 +maksssksksss606 +maksssksksss59 +maksssksksss241 +maksssksksss616 +maksssksksss664 +maksssksksss15 +maksssksksss584 +maksssksksss773 +maksssksksss96 +maksssksksss163 +maksssksksss837 +maksssksksss421 +maksssksksss511 +maksssksksss291 +maksssksksss717 +maksssksksss785 +maksssksksss632 +maksssksksss630 +maksssksksss234 +maksssksksss832 +maksssksksss677 +maksssksksss602 +maksssksksss147 +maksssksksss358 +maksssksksss401 +maksssksksss628 +maksssksksss424 +maksssksksss789 +maksssksksss95 +maksssksksss208 +maksssksksss645 +maksssksksss148 +maksssksksss90 +maksssksksss651 +maksssksksss813 +maksssksksss619 +maksssksksss267 diff --git a/official/cv/retinanet/eval.py b/official/cv/retinanet/eval.py index 493d30a7767cfd5c7aacffa81d25de6d4c14b22a..4df00cf7a5606486b9ec3f41424918bc4c0eccc5 100644 --- a/official/cv/retinanet/eval.py +++ b/official/cv/retinanet/eval.py @@ -24,7 +24,8 @@ from pycocotools.cocoeval import COCOeval from mindspore import context, Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.retinanet import retinanet50, resnet50, retinanetInferWithDecoder -from src.dataset import create_retinanet_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord +from src.dataset import create_retinanet_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord, \ + facemask_data_to_mindrecord from src.box_utils import default_boxes from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper @@ -69,6 +70,8 @@ def apply_nms(all_boxes, all_scores, thres, max_boxes): def make_dataset_dir(mindrecord_dir, mindrecord_file, prefix): if config.dataset == "voc": config.coco_root = config.voc_root + if config.dataset == 'facemask': + config.coco_root = config.facemask_root if not os.path.exists(mindrecord_file): if not os.path.isdir(mindrecord_dir): os.makedirs(mindrecord_dir) @@ -86,6 +89,13 @@ def make_dataset_dir(mindrecord_dir, mindrecord_file, prefix): print("Create Mindrecord Done, at {}".format(mindrecord_dir)) else: print("voc_root or voc_dir not exits.") + elif config.dataset == 'facemask': + if os.path.isdir(config.facemask_dir) and os.path.isdir(config.facemask_root): + print("Create Mindrecord.") + facemask_data_to_mindrecord(mindrecord_dir, False, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("facemask_root or facemask_dir not exits.") else: if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): print("Create Mindrecord.") @@ -98,6 +108,7 @@ def make_dataset_dir(mindrecord_dir, mindrecord_file, prefix): def modelarts_pre_process(): '''modelarts pre process function.''' + def unzip(zip_file, save_dir): import zipfile s_time = time.time() @@ -150,7 +161,6 @@ def modelarts_pre_process(): @moxing_wrapper(pre_process=modelarts_pre_process) def retinanet_eval(): - context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id()) prefix = "retinanet_eval.mindrecord" mindrecord_dir = config.mindrecord_dir @@ -179,7 +189,7 @@ def retinanet_eval(): num_classes = config.num_classes coco_root = config.coco_root data_type = config.val_data_type - #Classes need to train or test. + # Classes need to train or test.鈥� val_cls = config.coco_classes val_cls_dict = {} for i, cls in enumerate(val_cls): diff --git a/official/cv/retinanet/quick_start.py b/official/cv/retinanet/quick_start.py new file mode 100644 index 0000000000000000000000000000000000000000..3324f1f592d15e08547e91ec6ef602112e690529 --- /dev/null +++ b/official/cv/retinanet/quick_start.py @@ -0,0 +1,75 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""visualize for retinanet""" + +import os +import cv2 +import matplotlib.pyplot as plt +from pycocotools.coco import COCO +from src.model_utils.config import config + + +def visualize_model(): + # load best ckpt to generate instances_val.json and predictions.json + + dataset_dir = r'./dataset/val/images/' + coco_root = config.voc_root + data_type = config.val_data_type + annotation_file = os.path.join(coco_root, config.instances_set.format(data_type)) + coco = COCO(annotation_file) + catids = coco.getCatIds() + imgids = coco.getImgIds() + coco_res = coco.loadRes('./predictions.json') + catids_res = coco_res.getCatIds() + for i in range(10): + img = coco.loadImgs(imgids[i])[0] + image = cv2.imread(dataset_dir + img['file_name']) + image_res = image + annids = coco.getAnnIds(imgIds=img['id'], catIds=catids, iscrowd=None) + annos = coco.loadAnns(annids) + annids_res = coco_res.getAnnIds(imgIds=img['id'], catIds=catids_res, iscrowd=None) + annos_res = coco_res.loadAnns(annids_res) + plt.figure(figsize=(7, 7)) + for anno in annos: + bbox = anno['bbox'] + x, y, w, h = bbox + if anno['category_id'] == 1: + anno_image = cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (153, 153, 255), 2) + elif anno['category_id'] == 2: + anno_image = cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (153, 255, 153), 2) + else: + anno_image = cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (255, 153, 153), 2) + plt.subplot(1, 2, 1) + plt.plot([-2, 3], [1, 5]) + plt.title('true-label') + plt.imshow(anno_image) + for anno_res in annos_res: + bbox_res = anno_res['bbox'] + x, y, w, h = bbox_res + if anno_res['category_id'] == 1: + res_image = cv2.rectangle(image_res, (int(x), int(y)), (int(x + w), int(y + h)), (0, 0, 255), 2) + elif anno_res['category_id'] == 2: + res_image = cv2.rectangle(image_res, (int(x), int(y)), (int(x + w), int(y + h)), (0, 153, 0), 2) + else: + res_image = cv2.rectangle(image_res, (int(x), int(y)), (int(x + w), int(y + h)), (255, 0, 0), 2) + plt.subplot(1, 2, 2) + plt.title('pred-label') + plt.imshow(res_image) + plt.show() + + +if __name__ == '__main__': + visualize_model() diff --git a/official/cv/retinanet/src/data_split.py b/official/cv/retinanet/src/data_split.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2264474791611574cd9d6616217543f8c7d8de --- /dev/null +++ b/official/cv/retinanet/src/data_split.py @@ -0,0 +1,61 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""data_split""" + +import os +import shutil + +image_original_path = '../images/' +label_original_path = '../annotations/' + +train_image_path = '../dataset/train/images/' +train_label_path = '../dataset/train/annotations/' + +val_image_path = '../dataset/val/images/' +val_label_path = '../dataset/val/annotations/' + + +def mkdir(): + if not os.path.exists(train_image_path): + os.makedirs(train_image_path) + if not os.path.exists(train_label_path): + os.makedirs(train_label_path) + + if not os.path.exists(val_image_path): + os.makedirs(val_image_path) + if not os.path.exists(val_label_path): + os.makedirs(val_label_path) + + +def main(): + mkdir() + with open("./data/facemask/train.txt", 'r') as f: + for line in f: + dst_train_image = train_image_path + line[:-1] + '.jpg' + dst_train_label = train_label_path + line[:-1] + '.xml' + shutil.copyfile(image_original_path + line[:-1] + '.png', dst_train_image) + shutil.copyfile(label_original_path + line[:-1] + '.xml', dst_train_label) + + with open("./data/facemask/val.txt", 'r') as f: + for line in f: + dst_val_image = val_image_path + line[:-1] + '.jpg' + dst_val_label = val_label_path + line[:-1] + '.xml' + shutil.copyfile(image_original_path + line[:-1] + '.png', dst_val_image) + shutil.copyfile(label_original_path + line[:-1] + ".xml", dst_val_label) + + +if __name__ == '__main__': + main() diff --git a/official/cv/retinanet/src/dataset.py b/official/cv/retinanet/src/dataset.py index 1b076e2473a95da150ad791d5c7488b5f22a9fa7..b0ce8681d0bd15767d112f4f62a65a6c424b2c5d 100644 --- a/official/cv/retinanet/src/dataset.py +++ b/official/cv/retinanet/src/dataset.py @@ -19,6 +19,7 @@ from __future__ import division import os import json +import re import xml.etree.ElementTree as et import numpy as np import cv2 @@ -42,6 +43,17 @@ def get_imageId_from_fileName(filename): return id_iter +def get_imageId_from_fackmask(filename): + """Get imageID from fileName""" + filename = os.path.splitext(filename)[0] + regex = re.compile(r'\d+') + iid = regex.search(filename).group(0) + image_id = int(iid) + if filename.isdigit(): + return int(filename) + return image_id + + def random_sample_crop(image, boxes): """Random Crop the image and boxes""" height, width, _ = image.shape @@ -104,6 +116,7 @@ def random_sample_crop(image, boxes): def preprocess_fn(img_id, image, box, is_training): """Preprocess function for dataset.""" cv2.setNumThreads(2) + def _infer_data(image, input_shape): img_h, img_w, _ = image.shape input_h, input_w = input_shape @@ -246,6 +259,98 @@ def create_voc_label(is_training): return images, image_files_dict, image_anno_dict +def create_facemask_label(is_training): + """Get image path and annotation from VOC.""" + facemask_dir = config.voc_dir + cls_map = {name: i for i, name in enumerate(config.coco_classes)} + sub_dir = 'train' if is_training else 'val' + facemask_dir = os.path.join(facemask_dir, sub_dir) + if not os.path.isdir(facemask_dir): + raise ValueError(f'Cannot find {sub_dir} dataset path.') + + image_dir = anno_dir = facemask_dir + if os.path.isdir(os.path.join(facemask_dir, 'images')): + image_dir = os.path.join(facemask_dir, 'images') + if os.path.isdir(os.path.join(facemask_dir, 'annotations')): + anno_dir = os.path.join(facemask_dir, 'annotations') + + if not is_training: + data_dir = config.facemask_root + json_file = os.path.join(data_dir, config.instances_set.format(sub_dir)) + file_dir = os.path.split(json_file)[0] + if not os.path.isdir(file_dir): + os.makedirs(file_dir) + json_dict = {"images": [], "type": "instances", "annotations": [], + "categories": []} + bnd_id = 1 + + image_files_dict = {} + image_anno_dict = {} + images = [] + for anno_file in os.listdir(anno_dir): + print(anno_file) + if not anno_file.endswith('xml'): + continue + tree = et.parse(os.path.join(anno_dir, anno_file)) + root_node = tree.getroot() + file_name = root_node.find('filename').text + file_name = file_name.split('.')[0] + '.jpg' + img_id = get_imageId_from_fackmask(file_name) + image_path = os.path.join(image_dir, file_name) + print(image_path) + if not os.path.isfile(image_path): + print(f'Cannot find image {file_name} according to annotations.') + continue + + labels = [] + for obj in root_node.iter('object'): + cls_name = obj.find('name').text + if cls_name not in cls_map: + print(f'Label "{cls_name}" not in "{config.coco_classes}"') + continue + bnd_box = obj.find('bndbox') + x_min = int(float(bnd_box.find('xmin').text)) - 1 + y_min = int(float(bnd_box.find('ymin').text)) - 1 + x_max = int(float(bnd_box.find('xmax').text)) - 1 + y_max = int(float(bnd_box.find('ymax').text)) - 1 + labels.append([y_min, x_min, y_max, x_max, cls_map[cls_name]]) + + if not is_training: + o_width = abs(x_max - x_min) + o_height = abs(y_max - y_min) + ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id': \ + img_id, 'bbox': [x_min, y_min, o_width, o_height], \ + 'category_id': cls_map[cls_name], 'id': bnd_id, \ + 'ignore': 0, \ + 'segmentation': []} + json_dict['annotations'].append(ann) + bnd_id = bnd_id + 1 + + if labels: + images.append(img_id) + image_files_dict[img_id] = image_path + image_anno_dict[img_id] = np.array(labels) + + if not is_training: + size = root_node.find("size") + width = int(size.find('width').text) + height = int(size.find('height').text) + image = {'file_name': file_name, 'height': height, 'width': width, + 'id': img_id} + json_dict['images'].append(image) + + if not is_training: + for cls_name, cid in cls_map.items(): + cat = {'supercategory': 'none', 'id': cid, 'name': cls_name} + json_dict['categories'].append(cat) + json_fp = open(json_file, 'w') + json_str = json.dumps(json_dict) + json_fp.write(json_str) + json_fp.close() + + return images, image_files_dict, image_anno_dict + + def create_coco_label(is_training): """Get image path and annotation from COCO.""" from pycocotools.coco import COCO @@ -359,6 +464,29 @@ def voc_data_to_mindrecord(mindrecord_dir, is_training, prefix="retinanet.mindre writer.commit() +def facemask_data_to_mindrecord(mindrecord_dir, is_training, prefix="retinanet.mindrecord0", file_num=1): + mindrecord_path = os.path.join(mindrecord_dir, prefix + "0") + writer = FileWriter(mindrecord_path, file_num) + images, image_path_dict, image_anno_dict = create_facemask_label(is_training) + + retinanet_json = { + "img_id": {"type": "int32", "shape": [1]}, + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 5]}, + } + writer.add_schema(retinanet_json, "retinanet_json") + + for img_id in images: + image_path = image_path_dict[img_id] + with open(image_path, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[img_id], dtype=np.int32) + img_id = np.array([img_id], dtype=np.int32) + row = {"img_id": img_id, "image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + + def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="retinanet.mindrecord", file_num=8): """Create MindRecord file.""" mindrecord_dir = config.mindrecord_dir @@ -395,7 +523,6 @@ def create_retinanet_dataset(mindrecord_file, batch_size, repeat_num, device_num decode = C.Decode() ds = ds.map(operations=decode, input_columns=["image"]) change_swap_op = C.HWC2CHW() - # Computed from random subset of ImageNet training images normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) color_adjust_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) @@ -415,6 +542,7 @@ def create_retinanet_dataset(mindrecord_file, batch_size, repeat_num, device_num ds = ds.batch(batch_size, drop_remainder=True) return ds + def create_mindrecord(dataset="coco", prefix="retinanet.mindrecord", is_training=True): print("Start create dataset!") @@ -440,6 +568,13 @@ def create_mindrecord(dataset="coco", prefix="retinanet.mindrecord", is_training print("Create Mindrecord Done, at {}".format(mindrecord_dir)) else: print("voc_dir not exits.") + elif dataset == "facemask": + if os.path.isdir(config.facemask_dir): + print("Create Mindrecord.") + facemask_data_to_mindrecord(mindrecord_dir, is_training, prefix) + print("Create Mindrecord Done, at {}".format(mindrecord_dir)) + else: + print("voc_dir not exits.") else: if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): print("Create Mindrecord.") diff --git a/official/cv/retinanet/src/init_params.py b/official/cv/retinanet/src/init_params.py index 505846757d852d5da6a54f03b0f58c7a6f0c8722..75ca16125f2e5617386fd154162b89f74dc4ab6f 100644 --- a/official/cv/retinanet/src/init_params.py +++ b/official/cv/retinanet/src/init_params.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,8 @@ # ============================================================================ """Parameters utils""" -from mindspore.common.initializer import initializer, TruncatedNormal +from mindspore.common.initializer import initializer, TruncatedNormal, XavierUniform + def init_net_param(network, initialize_mode='TruncatedNormal'): """Init the parameters in net.""" @@ -23,11 +24,12 @@ def init_net_param(network, initialize_mode='TruncatedNormal'): if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: if initialize_mode == 'TruncatedNormal': p.set_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype)) + elif initialize_mode == 'XavierUniform': + p.set_data(initializer(XavierUniform(), p.data.shape, p.data.dtype)) else: p.set_data(initialize_mode, p.data.shape, p.data.dtype) - def filter_checkpoint_parameter(param_dict): """remove useless parameters""" for key in list(param_dict.keys()): diff --git a/official/cv/retinanet/src/lr_schedule.py b/official/cv/retinanet/src/lr_schedule.py index 5020f23429d4503746aeb1e549625e2f024c6d1f..93ce8c4b653aa7bafc19ba66cb80c27a3283803e 100644 --- a/official/cv/retinanet/src/lr_schedule.py +++ b/official/cv/retinanet/src/lr_schedule.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ import math import numpy as np +from src.model_utils.config import config def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs1, warmup_epochs2, @@ -43,26 +44,29 @@ def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs1, warmup_epochs2, warmup_steps3 = warmup_steps2 + steps_per_epoch * warmup_epochs3 warmup_steps4 = warmup_steps3 + steps_per_epoch * warmup_epochs4 warmup_steps5 = warmup_steps4 + steps_per_epoch * warmup_epochs5 + step_radio = [1e-4, 1e-3, 1e-2, 0.1] + if config.finetune: + step_radio = [1e-4, 1e-2, 0.1, 1] for i in range(total_steps): if i < warmup_steps1: - lr = lr_init*(warmup_steps1-i) / (warmup_steps1) + \ - (lr_max*1e-4) * i / (warmup_steps1*3) + lr = lr_init * (warmup_steps1 - i) / (warmup_steps1) + \ + (lr_max * step_radio[0]) * i / (warmup_steps1 * 3) elif warmup_steps1 <= i < warmup_steps2: - lr = 1e-5*(warmup_steps2-i) / (warmup_steps2 - warmup_steps1) + \ - (lr_max*1e-3) * (i-warmup_steps1) / (warmup_steps2 - warmup_steps1) + lr = 1e-5 * (warmup_steps2 - i) / (warmup_steps2 - warmup_steps1) + \ + (lr_max * step_radio[1]) * (i - warmup_steps1) / (warmup_steps2 - warmup_steps1) elif warmup_steps2 <= i < warmup_steps3: - lr = 1e-4*(warmup_steps3-i) / (warmup_steps3 - warmup_steps2) + \ - (lr_max*1e-2) * (i-warmup_steps2) / (warmup_steps3 - warmup_steps2) + lr = 1e-4 * (warmup_steps3 - i) / (warmup_steps3 - warmup_steps2) + \ + (lr_max * step_radio[2]) * (i - warmup_steps2) / (warmup_steps3 - warmup_steps2) elif warmup_steps3 <= i < warmup_steps4: - lr = 1e-3*(warmup_steps4-i) / (warmup_steps4 - warmup_steps3) + \ - (lr_max*1e-1) * (i-warmup_steps3) / (warmup_steps4 - warmup_steps3) + lr = 1e-3 * (warmup_steps4 - i) / (warmup_steps4 - warmup_steps3) + \ + (lr_max * step_radio[3]) * (i - warmup_steps3) / (warmup_steps4 - warmup_steps3) elif warmup_steps4 <= i < warmup_steps5: - lr = 1e-2*(warmup_steps5-i) / (warmup_steps5 - warmup_steps4) + \ - lr_max * (i-warmup_steps4) / (warmup_steps5 - warmup_steps4) + lr = 1e-2 * (warmup_steps5 - i) / (warmup_steps5 - warmup_steps4) + \ + lr_max * (i - warmup_steps4) / (warmup_steps5 - warmup_steps4) else: lr = lr_end + \ - (lr_max - lr_end) * \ - (1. + math.cos(math.pi * (i-warmup_steps5) / (total_steps - warmup_steps5))) / 2. + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps5) / (total_steps - warmup_steps5))) / 2. if lr < 0.0: lr = 0.0 lr_each_step.append(lr) diff --git a/official/cv/retinanet/src/retinanet.py b/official/cv/retinanet/src/retinanet.py index 58557d8dbd89a5013a408844a5cc04a57f41dfa8..28246b77a8464f60445405b7841350536269108c 100644 --- a/official/cv/retinanet/src/retinanet.py +++ b/official/cv/retinanet/src/retinanet.py @@ -181,7 +181,8 @@ class MultiBox(nn.Cell): cls_layers = [] for k, out_channel in enumerate(out_channels): loc_layers += [RegressionModel(in_channel=out_channel, num_anchors=num_default[k])] - cls_layers += [ClassificationModel(in_channel=out_channel, num_anchors=num_default[k])] + cls_layers += [ClassificationModel(in_channel=out_channel, num_anchors=num_default[k], + num_classes=config.num_classes)] self.multi_loc_layers = nn.layer.CellList(loc_layers) self.multi_cls_layers = nn.layer.CellList(cls_layers) diff --git a/official/cv/retinanet/train.py b/official/cv/retinanet/train.py index b0c3e5cb6c9bef6da55cb614ff25435b0b7ea938..99a74bd4926d2544740f561a881adf92d301874b 100644 --- a/official/cv/retinanet/train.py +++ b/official/cv/retinanet/train.py @@ -32,7 +32,7 @@ from src.lr_schedule import get_lr from src.init_params import init_net_param, filter_checkpoint_parameter from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper -from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id +from src.model_utils.device_adapter import get_device_id, get_device_num set_seed(1) @@ -58,11 +58,12 @@ class Monitor(Callback): def step_end(self, run_context): cb_params = run_context.original_args() - print("lr:[{:8.6f}]".format(self.lr_init[cb_params.cur_step_num-1]), flush=True) + print("lr:[{:8.6f}]".format(self.lr_init[cb_params.cur_step_num - 1]), flush=True) def modelarts_pre_process(): '''modelarts pre process function.''' + def unzip(zip_file, save_dir): import zipfile s_time = time.time() @@ -112,55 +113,39 @@ def modelarts_pre_process(): print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1)) + def set_graph_kernel_context(device_target): if device_target == "GPU": # Enable graph kernel for default model ssd300 on GPU back-end. context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_parallel_fusion --enable_expand_ops=Conv2D") + @moxing_wrapper(pre_process=modelarts_pre_process) def main(): - config.lr_init = ast.literal_eval(config.lr_init) config.lr_end_rate = ast.literal_eval(config.lr_end_rate) - + device_id = get_device_id() if config.device_target == "Ascend": - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - if config.distribute: - if os.getenv("DEVICE_ID", "not_set").isdigit(): - context.set_context(device_id=get_device_id()) - init() - device_num = get_device_num() - rank = get_rank_id() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, - device_num=device_num) - else: - rank = 0 - device_num = 1 - context.set_context(device_id=get_device_id()) - - # Set mempool block size in PYNATIVE_MODE for improving memory utilization, which will not take effect in GRAPH_MODE - if context.get_context("mode") == context.PYNATIVE_MODE: - context.set_context(mempool_block_size="31GB") - + context.set_context(mempool_block_size="31GB") elif config.device_target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") set_graph_kernel_context(config.device_target) - if config.distribute: - if os.getenv("DEVICE_ID", "not_set").isdigit(): - context.set_context(device_id=get_device_id()) - init() - device_num = config.device_num - rank = get_rank() - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, - device_num=device_num) - else: - rank = 0 - device_num = 1 - context.set_context(device_id=get_device_id()) + elif config.device_target == "CPU": + device_id = 0 + config.distribute = False else: - raise ValueError("Unsupported platform.") + raise ValueError(f"device_target support ['Ascend', 'GPU', 'CPU'], but get {config.device_target}") + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) + if config.distribute: + init() + device_num = config.device_num + rank = get_rank() + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) + else: + rank = 0 + device_num = 1 + context.set_context(device_id=device_id) mindrecord_file = os.path.join(config.mindrecord_dir, "retinanet.mindrecord0") @@ -178,6 +163,10 @@ def main(): retinanet = retinanet50(backbone, config) net = retinanetWithLossCell(retinanet, config) init_net_param(net) + if config.finetune: + init_net_param(net, initialize_mode='XavierUniform') + else: + init_net_param(net) if config.pre_trained: if config.pre_trained_epoch_size <= 0: