diff --git a/research/cv/SPPNet/README_CN.md b/research/cv/SPPNet/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..2d97dd6289aefc266ff27cf099215d147a51b9cf --- /dev/null +++ b/research/cv/SPPNet/README_CN.md @@ -0,0 +1,428 @@ +# 鐩綍 + +<!-- TOC --> + +- [SPPNet鎻忚堪](#spatial_pyramid_pooling鎻忚堪) +- [妯″瀷鏋舵瀯](#妯″瀷鏋舵瀯) +- [鏁版嵁闆哴(#鏁版嵁闆�) +- [鐜瑕佹眰](#鐜瑕佹眰) +- [蹇€熷叆闂╙(#蹇€熷叆闂�) +- [鑴氭湰璇存槑](#鑴氭湰璇存槑) + - [鑴氭湰鍙婃牱渚嬩唬鐮乚(#鑴氭湰鍙婃牱渚嬩唬鐮�) + - [鑴氭湰鍙傛暟](#鑴氭湰鍙傛暟) + - [璁粌杩囩▼](#璁粌杩囩▼) + - [璁粌](#璁粌) + - [璇勪及杩囩▼](#璇勪及杩囩▼) + - [璇勪及](#璇勪及) +- [鎺ㄧ悊杩囩▼](#鎺ㄧ悊杩囩▼) + - [瀵煎嚭MindIR](#瀵煎嚭MindIR) + - [鍦ˋscend310鎵ц鎺ㄧ悊](#鍦ˋscend310鎵ц鎺ㄧ悊) + - [缁撴灉](#缁撴灉) +- [妯″瀷鎻忚堪](#妯″瀷鎻忚堪) + - [鎬ц兘](#鎬ц兘) + - [璇勪及鎬ц兘](#璇勪及鎬ц兘) +- [闅忔満鎯呭喌璇存槑](#闅忔満鎯呭喌璇存槑) +- [ModelZoo涓婚〉](#modelzoo涓婚〉) + +<!-- /TOC --> + +# SPPNET鎻忚堪 + +SPPNET鏄綍鍑槑绛変汉2015骞存彁鍑恒€傝缃戠粶鍦ㄦ渶鍚庝竴灞傚嵎绉悗鍔犲叆浜嗙┖闂撮噾瀛楀姹犲寲灞�(Spatial Pyramid Pooling layer)鏇挎崲鍘熸潵鐨勬睜鍖栧眰(Pooling layer),浣跨綉缁滄帴鍙椾笉鍚岀殑灏哄鐨刦eature maps骞惰緭鍑虹浉鍚屽ぇ灏忕殑feature maps锛屼粠鑰岃В鍐充簡Resize瀵艰嚧鍥剧墖鍨嬪彉鐨勯棶棰樸€� + +[璁烘枃](https://arxiv.org/pdf/1406.4729.pdf)锛� K. He, X. Zhang, S. Ren and J. Sun, "Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition," in IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 37, no. 9, pp. 1904-1916, 1 Sept. 2015, doi: 10.1109/TPAMI.2015.2389824. + +# 妯″瀷鏋舵瀯 + +SPPNET鍩轰簬ZFNET锛孼FNET鐢�5涓嵎绉眰鍜�3涓叏杩炴帴灞傜粍鎴愶紝SPPNET鍦ㄥ師鏉ョ殑ZFNET鐨刢onv5涔嬪悗鍔犲叆浜哠patial Pyramid Pooling layer銆� + +# 鏁版嵁闆� + +浣跨敤鐨勬暟鎹泦锛歔ImageNet2012](http://www.image-net.org/) + +- 鏁版嵁闆嗗ぇ灏忥細鍏�1000涓被銆�224*224褰╄壊鍥惧儚 + - 璁粌闆嗭細鍏�1,281,167寮犲浘鍍� + - 娴嬭瘯闆嗭細鍏�50,000寮犲浘鍍� + +- 鏁版嵁鏍煎紡锛欽PEG + - 娉細鏁版嵁鍦╠ataset.py涓鐞嗐€� + +- 涓嬭浇鏁版嵁闆嗐€傜洰褰曠粨鏋勫涓嬶細 + +```text +鈹斺攢dataset + 鈹溾攢ilsvrc # 璁粌鏁版嵁闆� + 鈹斺攢validation_preprocess # 璇勪及鏁版嵁闆� +``` + +# 鐜瑕佹眰 + +- 纭欢锛圓scend锛� + - 鍑嗗Ascend澶勭悊鍣ㄦ惌寤虹‖浠剁幆澧冦€� + +- 妗嗘灦 + - [MindSpore](https://www.mindspore.cn/install) + +- 濡傞渶鏌ョ湅璇︽儏锛岃鍙傝濡備笅璧勬簮锛� + - [MindSpore鏁欑▼](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html) + +# 蹇€熷叆闂� + +閫氳繃瀹樻柟缃戠珯瀹夎MindSpore鍚庯紝鎮ㄥ彲浠ユ寜鐓у涓嬫楠よ繘琛岃缁冨拰璇勪及锛� + + ```bash + # 杩涘叆鑴氭湰鐩綍锛岃缁僑PPNET瀹炰緥 + bash run_standalone_train_ascend.sh [TRAIN_DATA_PATH] [EVAL_DATA_PATH] [DEVICE_ID] [TRAIN_MODEL] + + # 杩涘叆鑴氭湰鐩綍锛岃瘎浼癝PPNET瀹炰緥 + bash run_standalone_eval_ascend.sh [TEST_DATA_PATH] [CKPT_PATH] [DEVICE_ID] [TEST_MODEL] + + # 杩愯鍒嗗竷寮忚缁冨疄渚� + bash run_distribution_ascend.sh [RANK_TABLE_FILE] [TRAIN_DATA_PATH] [EVAL_DATA_PATH] [TRAIN_MODEL] + ``` + +# 鑴氭湰璇存槑 + +## 鑴氭湰鍙婃牱渚嬩唬鐮� + +```bash +鈹溾攢鈹€ cv + 鈹溾攢鈹€ sppnet + 鈹溾攢鈹€ README.md // SPPNet鐩稿叧璇存槑 + 鈹溾攢鈹€ scripts + 鈹� 鈹溾攢鈹€run_distribution_ascend.sh // Ascend澶氬崱璁粌+鎺ㄧ悊鐨剆hell鑴氭湰 + 鈹� 鈹溾攢鈹€run_standalone_eval_ascend.sh // Ascend鍗曞崱鎺ㄧ悊鐨剆hell鑴氭湰 + 鈹� 鈹溾攢鈹€run_standalone_train_ascend.sh // Ascend鍗曞崱璁粌+鎺ㄧ悊鐨剆hell鑴氭湰 + 鈹溾攢鈹€ src + 鈹� 鈹溾攢鈹€dataset.py // 鍒涘缓鏁版嵁闆� + 鈹� 鈹溾攢鈹€sppnet.py // sppnet/zfnet鏋舵瀯 + 鈹� 鈹溾攢鈹€spatial_pyramid_pooling.py // 閲戝瓧濉旀睜鍖栧眰鏋舵瀯 + 鈹� 鈹溾攢鈹€generator_lr.py // 鐢熸垚姣忎釜姝ラ鐨勫涔犵巼 + 鈹� 鈹溾攢鈹€eval_callback.py // 璁粌鏃惰繘琛屾帹鐞嗙殑鑴氭湰 + 鈹� 鈹溾攢鈹€config.py // 鍙傛暟閰嶇疆 + 鈹溾攢鈹€ train.py // 璁粌鑴氭湰 + 鈹溾攢鈹€ eval.py // 璇勪及鑴氭湰 + 鈹溾攢鈹€ export.py // 妯″瀷杞崲灏哻heckpoint鏂囦欢瀵煎嚭鍒癮ir/mindir +``` + +## 鑴氭湰鍙傛暟 + +鍦╟onfig.py涓彲浠ュ悓鏃堕厤缃缁冨弬鏁板拰璇勪及鍙傛暟锛� + + ```bash + # zfnet閰嶇疆鍙傛暟 + 'num_classes': 1000, # 鏁版嵁闆嗙被鍒暟閲� + 'momentum': 0.9, # 鍔ㄩ噺 + 'epoch_size': 150, # epoch澶у皬 + 'batch_size': 256, # 杈撳叆寮犻噺鐨勬壒娆″ぇ灏� + 'image_height': 224, # 鍥剧墖闀垮害 + 'image_width': 224, # 鍥剧墖瀹藉害 + 'warmup_epochs' : 5, # 鐑韩鍛ㄦ湡鏁� + 'iteration_max': 150, # 浣欏鸡閫€鐏渶澶ц凯浠f鏁� + 'lr_init': 0.035, # 鍒濆瀛︿範鐜� + 'lr_min': 0.0, # 鏈€灏忓涔犵巼 + 'weight_decay': 0.0001, # 鏉冮噸琛板噺 + 'loss_scale': 1024, # 鎹熷け绛夌骇 + 'is_dynamic_loss_scale': 0, # 鏄惁鍔ㄦ€佽皟鑺傛崯澶� + + # sppnet(single train)閰嶇疆鍙傛暟 + 'num_classes': 1000, # 鏁版嵁闆嗙被鍒暟閲� + 'momentum': 0.9, # 鍔ㄩ噺 + 'epoch_size': 160, # epoch澶у皬 + 'batch_size': 256, # 杈撳叆寮犻噺鐨勬壒娆″ぇ灏� + 'image_height': 224, # 鍥剧墖闀垮害 + 'image_width': 224, # 鍥剧墖瀹藉害 + 'warmup_epochs' : 0, # 鐑韩鍛ㄦ湡鏁� + 'iteration_max': 150, # 浣欏鸡閫€鐏渶澶ц凯浠f鏁� + 'lr_init': 0.01, # 鍒濆瀛︿範鐜� + 'lr_min': 0.0, # 鏈€灏忓涔犵巼 + 'weight_decay': 0.0001, # 鏉冮噸琛板噺 + 'loss_scale': 1024, # 鎹熷け绛夌骇 + 'is_dynamic_loss_scale': 0, # 鏄惁鍔ㄦ€佽皟鑺傛崯澶� + + # sppnet(mult train)閰嶇疆鍙傛暟 + 'num_classes': 1000, # 鏁版嵁闆嗙被鍒暟閲� + 'momentum': 0.9, # 鍔ㄩ噺 + 'epoch_size': 160, # epoch澶у皬 + 'batch_size': 128, # 杈撳叆寮犻噺鐨勬壒娆″ぇ灏� + 'image_height': 224, # 鍥剧墖闀垮害 + 'image_width': 224, # 鍥剧墖瀹藉害 + 'warmup_epochs' : 0, # 鐑韩鍛ㄦ湡鏁� + 'iteration_max': 150, # 浣欏鸡閫€鐏渶澶ц凯浠f鏁� + 'lr_init': 0.01, # 鍒濆瀛︿範鐜� + 'lr_min': 0.0, # 鏈€灏忓涔犵巼 + 'weight_decay': 0.0001, # 鏉冮噸琛板噺 + 'loss_scale': 1024, # 鎹熷け绛夌骇 + 'is_dynamic_loss_scale': 0, # 鏄惁鍔ㄦ€佽皟鑺傛崯澶� + ``` + +train.py涓富瑕佸弬鏁板涓嬶細 + + ```bash + --train_model: 璁粌鐨勬ā鍨嬶紝鍙€夊€间负"zfnet"銆�"sppnet_single"銆�"sppnet_mult"锛岄粯璁ゅ€间负"sppnet_single" + --train_path: 鍒拌缁冩暟鎹泦鐨勭粷瀵瑰畬鏁磋矾寰勶紝榛樿鍊间负"./imagenet_original/train" + --eval_path: 鍒拌瘎浼版暟鎹泦鐨勭粷瀵瑰畬鏁磋矾寰勶紝榛樿鍊间负"./imagenet_original/val" + --device_target: 瀹炵幇浠g爜鐨勮澶囷紝榛樿鍊间负"Ascend" + --ckpt_path: 璁粌鍚庝繚瀛樼殑妫€鏌ョ偣鏂囦欢鐨勭粷瀵瑰畬鏁磋矾寰勶紝榛樿鍊间负"./ckpt" + --dataset_sink_mode: 鏄惁杩涜鏁版嵁涓嬫矇锛岄粯璁ゅ€间负True + --device_id: 浣跨敤璁惧鐨勫崱鍙凤紝榛樿鍊间负0 + --device_num: 浣跨敤璁惧鐨勬暟閲忥紝榛樿鍊间负1 + ``` + +## 璁粌杩囩▼ + +### 璁粌 + +- Ascend澶勭悊鍣ㄧ幆澧冭繍琛� + + ```bash + # 鍗曞崱璁粌zfnet + python train.py --train_path ./imagenet/train --eval_path ./imagenet/val --device_id 0 --train_model zfnet > log 2>&1 & + + # 鎴栬繘鍏ヨ剼鏈洰褰曪紝鎵ц鑴氭湰 + bash run_standalone_train_ascend.sh ./imagenet_original/train ./imagenet_original/val 0 zfnet + + # 鍒嗗竷寮忚缁儂fnet锛岃繘鍏ヨ剼鏈洰褰曪紝鎵ц鑴氭湰 + bash run_distribution_ascend.sh ./hccl.json ./imagenet_original/train ./imagenet_original/val zfnet + + # 鍗曞崱璁粌sppnet(single train) + python train.py --train_path ./imagenet/train --eval_path ./imagenet/val --device_id 0 > log 2>&1 & + + # 鎴栬繘鍏ヨ剼鏈洰褰曪紝鎵ц鑴氭湰 + bash run_standalone_train_ascend.sh ./imagenet_original/train ./imagenet_original/val 0 sppnet_single + + # 鍒嗗竷寮忚缁僺ppnet(single train)锛岃繘鍏ヨ剼鏈洰褰曪紝鎵ц鑴氭湰 + bash run_distribution_ascend.sh ./hccl.json ./imagenet_original/train ./imagenet_original/val sppnet_single + + # 鍗曞崱璁粌sppnet(mult train) + python train.py --train_path ./imagenet/train --eval_path ./imagenet/val --device_id 0 --train_model sppnet_mult > log 2>&1 & + + # 鎴栬繘鍏ヨ剼鏈洰褰曪紝鎵ц鑴氭湰 + bash run_standalone_train_ascend.sh ./imagenet_original/train ./imagenet_original/val 0 sppnet_mult + + # 鍒嗗竷寮忚缁僺ppnet(mult train)锛岃繘鍏ヨ剼鏈洰褰曪紝鎵ц鑴氭湰 + bash run_distribution_ascend.sh ./hccl.json ./imagenet_original/train ./imagenet_original/val sppnet_mult + ``` + +- 浣跨敤ImageNet2012鏁版嵁闆嗗崟鍗¤繘琛岃缁儂fnet + + 缁忚繃璁粌鍚庯紝鎹熷け鍊煎涓嬶細 + + ```bash + ============== Starting Training ============== + epoch: 1 step: 5004, loss is 6.906126 + epoch time: 571750.162 ms, per step time: 114.259 ms + epoch: 1, {'top_5_accuracy', 'top_1_accuracy'}: {'top_5_accuracy': 0.005809294871794872, 'top_1_accuracy': 0.0010216346153846154}, eval_cost:19.47 + epoch: 2 step: 5004, loss is 5.69701 + epoch time: 531087.048 ms, per step time: 106.133 ms + epoch: 2, {'top_5_accuracy', 'top_1_accuracy'}: {'top_5_accuracy': 0.1386017628205128, 'top_1_accuracy': 0.04453125}, eval_cost:14.53 + epoch: 3 step: 5004, loss is 4.6244116 + epoch time: 530828.240 ms, per step time: 106.081 ms + epoch: 3, {'top_5_accuracy', 'top_1_accuracy'}: {'top_5_accuracy': 0.36738782051282054, 'top_1_accuracy': 0.1619591346153846}, eval_cost:13.73 + + ... + + epoch: 149 step: 5004, loss is 1.448152 + epoch time: 531029.101 ms, per step time: 106.121 ms + epoch: 149, {'top_5_accuracy', 'top_1_accuracy'}: {'top_5_accuracy': 0.8547876602564103, 'top_1_accuracy': 0.6478966346153846}, eval_cost:14.25 + update best result: {'top_5_accuracy': 0.8547876602564103, 'top_1_accuracy': 0.6478966346153846} + update best checkpoint at: ./ckpt/best.ckpt + epoch: 150 step: 5004, loss is 1.5808313 + epoch time: 530946.874 ms, per step time: 106.104 ms + epoch: 150, {'top_5_accuracy', 'top_1_accuracy'}: {'top_5_accuracy': 0.8547876602564103, 'top_1_accuracy': 0.6483173076923077}, eval_cost:15.02 + update best result: {'top_5_accuracy': 0.8547876602564103, 'top_1_accuracy': 0.6483173076923077} + update best checkpoint at: ./ckpt/best.ckpt + End training, the best {'top_5_accuracy', 'top_1_accuracy'} is: {'top_1_accuracy': 0.6483173076923077, 'top_5_accuracy': 0.8547876602564103}, the best {'top_5_accuracy', 'top_1_accuracy'} epoch is 150 + ``` + + 妯″瀷妫€鏌ョ偣淇濆瓨鍦ㄥ綋鍓嶇洰褰昪kpt涓€� + +- 浣跨敤ImageNet2012鏁版嵁闆嗗崟鍗¤繘琛屽崟灏哄害璁粌sppnet(single train) + + 缁忚繃璁粌鍚庯紝鎹熷け鍊煎涓嬶細 + + ```bash + ============== Starting Training ============== + epoch: 1 step: 5004, loss is 6.754609 + epoch time: 1065948.526 ms, per step time: 213.019 ms + epoch: 1, {'top_1_accuracy', 'top_5_accuracy'}: {'top_1_accuracy': 0.002864583333333333, 'top_5_accuracy': 0.014082532051282052}, eval_cost:20.34 + epoch: 2 step: 5004, loss is 5.5111685 + epoch time: 1021084.963 ms, per step time: 204.054 ms + epoch: 2, {'top_1_accuracy', 'top_5_accuracy'}: {'top_1_accuracy': 0.0616386217948718, 'top_5_accuracy': 0.1776642628205128}, eval_cost:13.30 + epoch: 3 step: 5004, loss is 4.6289835 + epoch time: 1020991.373 ms, per step time: 204.035 ms + epoch: 3, {'top_1_accuracy', 'top_5_accuracy'}: {'top_1_accuracy': 0.15853365384615384, 'top_5_accuracy': 0.35985576923076923}, eval_cost:13.60 + + ... + + epoch: 159, {'top_1_accuracy', 'top_5_accuracy'}: {'top_1_accuracy': 0.6475560897435897, 'top_5_accuracy': 0.8568309294871795}, eval_cost:13.35 + epoch: 160 step: 5004, loss is 1.7843108 + epoch time: 1020822.415 ms, per step time: 204.001 ms + epoch: 160, {'top_1_accuracy', 'top_5_accuracy'}: {'top_1_accuracy': 0.64765625, 'top_5_accuracy': 0.8556891025641026}, eval_cost:13.28 + End training, the best {'top_1_accuracy', 'top_5_accuracy'} is: {'top_1_accuracy': 0.6489783653846154, 'top_5_accuracy': 0.8572516025641026}, the best {'top_1_accuracy', 'top_5_accuracy'} epoch is 146 + ``` + + 妯″瀷妫€鏌ョ偣淇濆瓨鍦ㄥ綋鍓嶇洰褰昪kpt涓€� + +- 浣跨敤ImageNet2012鏁版嵁闆嗗崟鍗″灏哄害璁粌SPPNET(mult train) + + 缁忚繃璁粌鍚庯紝鎹熷け鍊煎涓嬶細 + + ```bash + ============== Starting Training ============== + epoch: 1 step: 10009, loss is 6.8730383 + epoch time: 1529142.058 ms, per step time: 152.777 ms + epoch: 1, {'top_1_accuracy', 'top_5_accuracy'}: {'top_1_accuracy': 0.0015825320512820513, 'top_5_accuracy': 0.009094551282051283}, cost:21.83 + update best result: {'top_1_accuracy': 0.0015825320512820513, 'top_5_accuracy': 0.009094551282051283} + update best checkpoint at: ./ckpt/best.ckpt + ================================================= + ================ Epoch:2 ================== + epoch: 1 step: 10009, loss is 5.8023987 + epoch time: 2104207.357 ms, per step time: 210.232 ms + ================ Epoch:3 ================== + epoch: 1 step: 10009, loss is 4.779583 + epoch time: 1506529.824 ms, per step time: 150.518 ms + epoch: 1, {'top_1_accuracy', 'top_5_accuracy'}: {'top_1_accuracy': 0.1718349358974359, 'top_5_accuracy': 0.3845753205128205}, cost:21.55 + update best result: {'top_1_accuracy': 0.1718349358974359, 'top_5_accuracy': 0.3845753205128205} + update best checkpoint at: ./ckpt/best.ckpt + ================================================= + + ... + + ================ Epoch:148 ================== + epoch: 1 step: 10009, loss is 1.8939599 + epoch time: 2134993.076 ms, per step time: 213.307 ms + ================ Epoch:149 ================== + epoch: 1 step: 10009, loss is 1.6252799 + epoch time: 1552970.284 ms, per step time: 155.157 ms + epoch: 1, {'top_1_accuracy', 'top_5_accuracy'}: {'top_1_accuracy': 0.6448918269230769, 'top_5_accuracy': 0.8544270833333333}, cost:21.66 + ================================================= + + ... + + ``` + + 妯″瀷妫€鏌ョ偣淇濆瓨鍦ㄥ綋鍓嶇洰褰昪kpt涓€� + +## 璇勪及杩囩▼ + +### 璇勪及 + +鍦ㄨ繍琛屼互涓嬪懡浠や箣鍓嶏紝璇锋鏌ョ敤浜庤瘎浼扮殑妫€鏌ョ偣璺緞銆� + +- Ascend澶勭悊鍣ㄧ幆澧冭繍琛� + + ```bash + + python eval.py --data_path ./imagenet_original/val --ckpt_path ./ckpt/best.ckpt --device_id 0 --train_model sppnet_single > eval_log.txt 2>&1 & + + # 鎴栬繘鍏ヨ剼鏈洰褰曪紝鎵ц鑴氭湰 + + bash run_standalone_eval_ascend.sh ./imagenet_original/val ./ckpt/best.ckpt 0 sppnet_single + + ``` + + 鍙€氳繃"eval_log鈥濇枃浠舵煡鐪嬬粨鏋溿€傛祴璇曟暟鎹泦鐨勫噯纭巼濡備笅锛� + + ```bash + ============== Starting Testing ============== + load checkpoint from [./ckpt/best.ckpt]. + result : {'top_5_accuracy': 0.8577724358974359, 'top_1_accuracy': 0.6503605769230769} + ``` + +# 鎺ㄧ悊杩囩▼ + +## 瀵煎嚭MindIR + + ```bash + python export.py --ckpt_file [CKPT_PATH] --export_model [EXPORT_MODEL] --device_id [DEVICE_ID] + ``` + +鍙傛暟ckpt_file涓哄繀濉」锛� export_model蹇呴』鍦╗"zfnet", "sppnet_single", "sppnet_mult"]涓€夋嫨銆� + +## 鍦ˋscend310鎵ц鎺ㄧ悊 + +鍦ㄦ墽琛屾帹鐞嗗墠锛宮indir鏂囦欢蹇呴』閫氳繃export.py鑴氭湰瀵煎嚭銆備互涓嬪睍绀轰簡浣跨敤mindir妯″瀷鎵ц鎺ㄧ悊鐨勭ず渚嬨€� 鐩墠imagenet2012鏁版嵁闆嗕粎鏀寔batch_Size涓�1鐨勬帹鐞嗐€� + + ```bash + # Ascend310 inference + bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] + ``` + +## 缁撴灉 + +鎺ㄧ悊缁撴灉淇濆瓨鍦ㄨ剼鏈墽琛岀殑褰撳墠璺緞锛屼綘鍙互鍦╝cc.log涓湅鍒颁互涓嬬簿搴﹁绠楃粨鏋溿€� + + ```bash + # zfnet.mindir 310鎺ㄧ悊鐨刟cc.log璁$畻缁撴灉濡備笅 + Total data: 50000, top1 accuracy: 0.6546, top5 accuracy: 0.85934. + ``` + +# 妯″瀷鎻忚堪 + +## 鎬ц兘 + +### 璇勪及鎬ц兘 + +#### Imagenet2012涓婄殑zfnet + +| 鍙傛暟 | Ascend | +| -------------------------- | ------------------------------------------------------------| +| 璧勬簮 | Ascend 910锛汣PU 2.60GHz, 192鏍革紱鍐呭瓨锛�755G | +| 涓婁紶鏃ユ湡 | 2021-09-21 | +| MindSpore鐗堟湰 | 1.2.0-beta | +| 鏁版嵁闆� | ImageNet2012 | +| 璁粌鍙傛暟 | epoch=150, step_per_epoch=5004, batch_size=256, lr=0.0035 | +| 浼樺寲鍣� | 鍔ㄩ噺 | +| 鎹熷け鍑芥暟 | Softmax浜ゅ弶鐔� | +| 杈撳嚭 | 姒傜巼 | 姒傜巼 | +| 鎹熷け | 1.58 | +| 閫熷害 | 106姣/姝� | +| 鎬绘椂闂� | 22灏忔椂 | +| 寰皟妫€鏌ョ偣 | 594M 锛�.ckpt鏂囦欢锛� | +| 鑴氭湰 | <https://gitee.com/mindspore/models/tree/r1.2/research/cv/SPPNet> | + +#### Imagenet2012涓婄殑sppnet(single train) + +| 鍙傛暟 | Ascend | +| -------------------------- | ------------------------------------------------------------| +| 璧勬簮 | Ascend 910锛汣PU 2.60GHz, 192鏍革紱鍐呭瓨锛�755G | +| 涓婁紶鏃ユ湡 | 2021-09-21 | +| MindSpore鐗堟湰 | 1.2.0-beta | +| 鏁版嵁闆� | ImageNet2012 | +| 璁粌鍙傛暟 | epoch=160, step_per_epoch=5004, batch_size=256, lr=0.001 | +| 浼樺寲鍣� | 鍔ㄩ噺 | +| 鎹熷け鍑芥暟 | Softmax浜ゅ弶鐔� | +| 杈撳嚭 | 姒傜巼 | 姒傜巼 | +| 鎹熷け | 1.55 | +| 閫熷害 | 203姣/姝� | +| 鎬绘椂闂� | 200灏忔椂 | +| 寰皟妫€鏌ョ偣 | 594M 锛�.ckpt鏂囦欢锛� | +| 鑴氭湰 | <https://gitee.com/mindspore/models/tree/r1.2/research/cv/SPPNet> | + +#### Imagenet2012涓婄殑sppnet(single mult) + +| 鍙傛暟 | Ascend | +| -------------------------- | ------------------------------------------------------------| +| 璧勬簮 | Ascend 910锛汣PU 2.60GHz, 192鏍革紱鍐呭瓨锛�755G | +| 涓婁紶鏃ユ湡 | 2021-09-21 | +| MindSpore鐗堟湰 | 1.2.0-beta | +| 鏁版嵁闆� | ImageNet2012 | +| 璁粌鍙傛暟 | epoch=160, step_per_epoch=10009, batch_size=128, lr=0.001 | +| 浼樺寲鍣� | 鍔ㄩ噺 | +| 鎹熷け鍑芥暟 | Softmax浜ゅ弶鐔� | +| 杈撳嚭 | 姒傜巼 | 姒傜巼 | +| 鎹熷け | 1.78 | +| 閫熷害 | 180姣/姝� | +| 鎬绘椂闂� | 200灏忔椂 | +| 寰皟妫€鏌ョ偣 | 601M 锛�.ckpt鏂囦欢锛� | +| 鑴氭湰 | <https://gitee.com/mindspore/models/tree/r1.2/research/cv/SPPNet> | + +# 闅忔満鎯呭喌璇存槑 + +dataset.py涓缃簡train.py涓殑闅忔満绉嶅瓙銆� + +# ModelZoo涓婚〉 + +璇锋祻瑙堝畼缃慬涓婚〉](https://gitee.com/mindspore/models)銆� diff --git a/research/cv/SPPNet/ascend310_infer/inc/utils.h b/research/cv/SPPNet/ascend310_infer/inc/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f8ae1e5b473d869b77af8d725a280d7c7665527c --- /dev/null +++ b/research/cv/SPPNet/ascend310_infer/inc/utils.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INFERENCE_UTILS_H_ +#define MINDSPORE_INFERENCE_UTILS_H_ + +#include <sys/stat.h> +#include <dirent.h> +#include <vector> +#include <string> +#include <memory> +#include "include/api/types.h" + +std::vector<std::string> GetAllFiles(std::string_view dirName); +DIR *OpenDir(std::string_view dirName); +std::string RealPath(std::string_view path); +mindspore::MSTensor ReadFileToTensor(const std::string &file); +int WriteResult(const std::string& imageFile, const std::vector<mindspore::MSTensor> &outputs); +std::vector<std::string> GetAllFiles(std::string dir_name); +std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name); + +#endif diff --git a/research/cv/SPPNet/ascend310_infer/src/CMakeLists.txt b/research/cv/SPPNet/ascend310_infer/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..14e676821a4936c03e98b9299b3b5f5e4496a8ea --- /dev/null +++ b/research/cv/SPPNet/ascend310_infer/src/CMakeLists.txt @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 3.14.1) +project(MindSporeCxxTestcase[CXX]) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -std=c++17 -Werror -Wall -fPIE -Wl,--allow-shlib-undefined") +set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/) +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${MINDSPORE_PATH}) +include_directories(${MINDSPORE_PATH}/include) +include_directories(${PROJECT_SRC_ROOT}/../) +find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) +file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) +find_package(gflags REQUIRED) +add_executable(main main.cc utils.cc) +target_link_libraries(main ${MS_LIB} ${MD_LIB} gflags) diff --git a/research/cv/SPPNet/ascend310_infer/src/build.sh b/research/cv/SPPNet/ascend310_infer/src/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..7fac9cff3a98c83bce7e8f66053fab2ecebab86d --- /dev/null +++ b/research/cv/SPPNet/ascend310_infer/src/build.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +cmake . -DMINDSPORE_PATH="`pip3.7 show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +make \ No newline at end of file diff --git a/research/cv/SPPNet/ascend310_infer/src/main.cc b/research/cv/SPPNet/ascend310_infer/src/main.cc new file mode 100644 index 0000000000000000000000000000000000000000..41e8ff7b966718c72d1aab95694f0796344e1de8 --- /dev/null +++ b/research/cv/SPPNet/ascend310_infer/src/main.cc @@ -0,0 +1,146 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <sys/time.h> +#include <gflags/gflags.h> +#include <dirent.h> +#include <iostream> +#include <string> +#include <algorithm> +#include <iosfwd> +#include <vector> +#include <fstream> +#include <sstream> + +#include "include/api/model.h" +#include "include/api/context.h" +#include "include/api/types.h" +#include "include/api/serialization.h" +#include "include/dataset/vision_ascend.h" +#include "include/dataset/execute.h" +#include "include/dataset/transforms.h" +#include "include/dataset/vision.h" +#include "inc/utils.h" + +using mindspore::dataset::vision::Decode; +using mindspore::dataset::vision::Resize; +using mindspore::dataset::vision::CenterCrop; +using mindspore::dataset::vision::Normalize; +using mindspore::dataset::vision::HWC2CHW; +using mindspore::dataset::TensorTransform; +using mindspore::Context; +using mindspore::Serialization; +using mindspore::Model; +using mindspore::Status; +using mindspore::ModelType; +using mindspore::GraphCell; +using mindspore::kSuccess; +using mindspore::MSTensor; +using mindspore::dataset::Execute; + + +DEFINE_string(mindir_path, "", "mindir path"); +DEFINE_string(dataset_path, ".", "dataset path"); +DEFINE_int32(device_id, 0, "device id"); + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (RealPath(FLAGS_mindir_path).empty()) { + std::cout << "Invalid mindir" << std::endl; + return 1; + } + + auto context = std::make_shared<Context>(); + auto ascend310 = std::make_shared<mindspore::Ascend310DeviceInfo>(); + ascend310->SetDeviceID(FLAGS_device_id); + context->MutableDeviceInfo().push_back(ascend310); + mindspore::Graph graph; + Serialization::Load(FLAGS_mindir_path, ModelType::kMindIR, &graph); + Model model; + Status ret = model.Build(GraphCell(graph), context); + if (ret != kSuccess) { + std::cout << "ERROR: Build failed." << std::endl; + return 1; + } + + auto all_files = GetAllInputData(FLAGS_dataset_path); + if (all_files.empty()) { + std::cout << "ERROR: no input data." << std::endl; + return 1; + } + + std::map<double, double> costTime_map; + size_t size = all_files.size(); + // Define transform + std::vector<int32_t> crop_paras = {224}; + std::vector<int32_t> resize_paras = {256}; + std::vector<float> mean = {0.485 * 255, 0.456 * 255, 0.406 * 255}; + std::vector<float> std = {0.229 * 255, 0.224 * 255, 0.225 * 255}; + + auto decode = Decode(); + auto resize = Resize(resize_paras); + auto centercrop = CenterCrop(crop_paras); + auto normalize = Normalize(mean, std); + auto hwc2chw = HWC2CHW(); + + mindspore::dataset::Execute SingleOp({decode, resize, centercrop, normalize, hwc2chw}); + + for (size_t i = 0; i < size; ++i) { + for (size_t j = 0; j < all_files[i].size(); ++j) { + struct timeval start = {0}; + struct timeval end = {0}; + double startTimeMs; + double endTimeMs; + std::vector<MSTensor> inputs; + std::vector<MSTensor> outputs; + std::cout << "Start predict input files:" << all_files[i][j] <<std::endl; + auto imgDvpp = std::make_shared<MSTensor>(); + SingleOp(ReadFileToTensor(all_files[i][j]), imgDvpp.get()); + + inputs.emplace_back(imgDvpp->Name(), imgDvpp->DataType(), imgDvpp->Shape(), + imgDvpp->Data().get(), imgDvpp->DataSize()); + gettimeofday(&start, nullptr); + ret = model.Predict(inputs, &outputs); + gettimeofday(&end, nullptr); + if (ret != kSuccess) { + std::cout << "Predict " << all_files[i][j] << " failed." << std::endl; + return 1; + } + startTimeMs = (1.0 * start.tv_sec * 1000000 + start.tv_usec) / 1000; + endTimeMs = (1.0 * end.tv_sec * 1000000 + end.tv_usec) / 1000; + costTime_map.insert(std::pair<double, double>(startTimeMs, endTimeMs)); + WriteResult(all_files[i][j], outputs); + } + } + double average = 0.0; + int inferCount = 0; + + for (auto iter = costTime_map.begin(); iter != costTime_map.end(); iter++) { + double diff = 0.0; + diff = iter->second - iter->first; + average += diff; + inferCount++; + } + average = average / inferCount; + std::stringstream timeCost; + timeCost << "NN inference cost average time: "<< average << " ms of infer_count " << inferCount << std::endl; + std::cout << "NN inference cost average time: "<< average << "ms of infer_count " << inferCount << std::endl; + std::string fileName = "./time_Result" + std::string("/test_perform_static.txt"); + std::ofstream fileStream(fileName.c_str(), std::ios::trunc); + fileStream << timeCost.str(); + fileStream.close(); + costTime_map.clear(); + return 0; +} diff --git a/research/cv/SPPNet/ascend310_infer/src/utils.cc b/research/cv/SPPNet/ascend310_infer/src/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..d71f388b83d23c2813d8bfc883dbcf2e7e0e4ef0 --- /dev/null +++ b/research/cv/SPPNet/ascend310_infer/src/utils.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <fstream> +#include <algorithm> +#include <iostream> +#include "inc/utils.h" + +using mindspore::MSTensor; +using mindspore::DataType; + + +std::vector<std::vector<std::string>> GetAllInputData(std::string dir_name) { + std::vector<std::vector<std::string>> ret; + + DIR *dir = OpenDir(dir_name); + if (dir == nullptr) { + return {}; + } + struct dirent *filename; + /* read all the files in the dir ~ */ + std::vector<std::string> sub_dirs; + while ((filename = readdir(dir)) != nullptr) { + std::string d_name = std::string(filename->d_name); + // get rid of "." and ".." + if (d_name == "." || d_name == ".." || d_name.empty()) { + continue; + } + std::string dir_path = RealPath(std::string(dir_name) + "/" + filename->d_name); + struct stat s; + lstat(dir_path.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + continue; + } + + sub_dirs.emplace_back(dir_path); + } + std::sort(sub_dirs.begin(), sub_dirs.end()); + + (void)std::transform(sub_dirs.begin(), sub_dirs.end(), std::back_inserter(ret), + [](const std::string &d) { return GetAllFiles(d); }); + + return ret; +} + + +std::vector<std::string> GetAllFiles(std::string dir_name) { + struct dirent *filename; + DIR *dir = OpenDir(dir_name); + if (dir == nullptr) { + return {}; + } + + std::vector<std::string> res; + while ((filename = readdir(dir)) != nullptr) { + std::string d_name = std::string(filename->d_name); + if (d_name == "." || d_name == ".." || d_name.size() <= 3) { + continue; + } + res.emplace_back(std::string(dir_name) + "/" + filename->d_name); + } + std::sort(res.begin(), res.end()); + + return res; +} + + +std::vector<std::string> GetAllFiles(std::string_view dirName) { + struct dirent *filename; + DIR *dir = OpenDir(dirName); + if (dir == nullptr) { + return {}; + } + std::vector<std::string> res; + while ((filename = readdir(dir)) != nullptr) { + std::string dName = std::string(filename->d_name); + if (dName == "." || dName == ".." || filename->d_type != DT_REG) { + continue; + } + res.emplace_back(std::string(dirName) + "/" + filename->d_name); + } + std::sort(res.begin(), res.end()); + for (auto &f : res) { + std::cout << "image file: " << f << std::endl; + } + return res; +} + + +int WriteResult(const std::string& imageFile, const std::vector<MSTensor> &outputs) { + std::string homePath = "./result_Files"; + for (size_t i = 0; i < outputs.size(); ++i) { + size_t outputSize; + std::shared_ptr<const void> netOutput; + netOutput = outputs[i].Data(); + outputSize = outputs[i].DataSize(); + int pos = imageFile.rfind('/'); + std::string fileName(imageFile, pos + 1); + fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin"); + std::string outFileName = homePath + "/" + fileName; + FILE *outputFile = fopen(outFileName.c_str(), "wb"); + fwrite(netOutput.get(), outputSize, sizeof(char), outputFile); + fclose(outputFile); + outputFile = nullptr; + } + return 0; +} + +mindspore::MSTensor ReadFileToTensor(const std::string &file) { + if (file.empty()) { + std::cout << "Pointer file is nullptr" << std::endl; + return mindspore::MSTensor(); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist" << std::endl; + return mindspore::MSTensor(); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << "open failed" << std::endl; + return mindspore::MSTensor(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + mindspore::MSTensor buffer(file, mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size); + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast<char *>(buffer.MutableData()), size); + ifs.close(); + + return buffer; +} + + +DIR *OpenDir(std::string_view dirName) { + if (dirName.empty()) { + std::cout << " dirName is null ! " << std::endl; + return nullptr; + } + std::string realPath = RealPath(dirName); + struct stat s; + lstat(realPath.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dirName is not a valid directory !" << std::endl; + return nullptr; + } + DIR *dir; + dir = opendir(realPath.c_str()); + if (dir == nullptr) { + std::cout << "Can not open dir " << dirName << std::endl; + return nullptr; + } + std::cout << "Successfully opened the dir " << dirName << std::endl; + return dir; +} + +std::string RealPath(std::string_view path) { + char realPathMem[PATH_MAX] = {0}; + char *realPathRet = nullptr; + realPathRet = realpath(path.data(), realPathMem); + if (realPathRet == nullptr) { + std::cout << "File: " << path << " is not exist."; + return ""; + } + + std::string realPath(realPathMem); + std::cout << path << " realpath is: " << realPath << std::endl; + return realPath; +} diff --git a/research/cv/SPPNet/create_imagenet2012_label.py b/research/cv/SPPNet/create_imagenet2012_label.py new file mode 100644 index 0000000000000000000000000000000000000000..4ae8e763aec753c6283a74d928d6f46739bec5e1 --- /dev/null +++ b/research/cv/SPPNet/create_imagenet2012_label.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""create_imagenet2012_label""" +import os +import json +import argparse + +parser = argparse.ArgumentParser(description="resnet imagenet2012 label") +parser.add_argument("--img_path", type=str, required=True, help="imagenet2012 file path.") +args = parser.parse_args() + + +def create_label(file_path): + """Create imagenet2012 label""" + print("[WARNING] Create imagenet label. Currently only use for Imagenet2012!") + dirs = os.listdir(file_path) + file_list = [] + for file in dirs: + file_list.append(file) + file_list = sorted(file_list) + + total = 0 + img_label = {} + for i, file_dir in enumerate(file_list): + files = os.listdir(os.path.join(file_path, file_dir)) + for f in files: + img_label[f] = i + total += len(files) + + with open("imagenet_label.json", "w+") as label: + json.dump(img_label, label) + + print("[INFO] Completed! Total {} data.".format(total)) + + +if __name__ == '__main__': + create_label(args.img_path) diff --git a/research/cv/SPPNet/eval.py b/research/cv/SPPNet/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea3ca84dc1f699a87b5e2a499ef19ffed561575 --- /dev/null +++ b/research/cv/SPPNet/eval.py @@ -0,0 +1,77 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## eval sppnet example ######################## +eval sppnet according to model file: +python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt --device_id YourAscendId --train_model model +""" + +import ast +import argparse +import mindspore.nn as nn +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train import Model +from src.config import sppnet_mult_cfg, sppnet_single_cfg, zfnet_cfg +from src.dataset import create_dataset_imagenet +from src.sppnet import SppNet + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='MindSpore SPPNet Example') + parser.add_argument('--device_target', type=str, default="Ascend", + help='device where the code will be implemented (default: Ascend)') + parser.add_argument('--test_model', type=str, default='sppnet_single', help='chose the training model', + choices=['sppnet_single', 'sppnet_mult', 'zfnet']) + parser.add_argument('--data_path', type=str, default="", help='path where the dataset is saved') + parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ + path where the trained ckpt file') + parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, + default=True, help='dataset_sink_mode is False or True') + parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)') + args = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + context.set_context(device_id=args.device_id) + print("============== Starting Testing ==============") + + if args.test_model == "zfnet": + cfg = zfnet_cfg + ds_eval = create_dataset_imagenet(args.data_path, 'zfnet', cfg.batch_size, training=False) + network = SppNet(cfg.num_classes, phase='test', train_model=args.test_model) + + elif args.test_model == "sppnet_single": + cfg = sppnet_single_cfg + ds_eval = create_dataset_imagenet(args.data_path, cfg.batch_size, training=False) + network = SppNet(cfg.num_classes, phase='test', train_model=args.test_model) + + else: + cfg = sppnet_mult_cfg + ds_eval = create_dataset_imagenet(args.data_path, cfg.batch_size, training=False) + network = SppNet(cfg.num_classes, phase='test', train_model=args.test_model) + + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + + param_dict = load_checkpoint(args.ckpt_path) + print("load checkpoint from [{}].".format(args.ckpt_path)) + load_param_into_net(network, param_dict) + network.set_train(False) + + model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + + if ds_eval.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") + + result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) + print("result : {}".format(result)) diff --git a/research/cv/SPPNet/export.py b/research/cv/SPPNet/export.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d856898e960d7f4bd5da5c288173333ae07ad0 --- /dev/null +++ b/research/cv/SPPNet/export.py @@ -0,0 +1,56 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############export checkpoint file into air, onnx, mindir models################# +python export.py +""" +import argparse +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export +from src.config import sppnet_mult_cfg, sppnet_single_cfg, zfnet_cfg +from src.sppnet import SppNet + +parser = argparse.ArgumentParser(description='Classification') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument('--device_target', type=str, default="Ascend", + help='device where the code will be implemented (default: Ascend)') +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format") +parser.add_argument('--export_model', type=str, default='sppnet_single', help='chose the training model', + choices=['sppnet_single', 'sppnet_mult', 'zfnet']) +args = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) + +if __name__ == '__main__': + + if args.train_model == "zfnet": + cfg = zfnet_cfg + network = SppNet(cfg.num_classes, train_model=args.train_model) + + elif args.train_model == "sppnet_single": + cfg = sppnet_single_cfg + network = SppNet(cfg.num_classes, train_model=args.train_model) + + else: + cfg = sppnet_mult_cfg + network = SppNet(cfg.num_classes, train_model=args.train_model) + + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(network, param_dict) + input_arr = Tensor(np.zeros([args.batch_size, 3, cfg.image_height, cfg.image_width]), ms.float32) + export(network, input_arr, file_name=args.train_model, file_format=args.file_format) diff --git a/research/cv/SPPNet/postprocess.py b/research/cv/SPPNet/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..3412e5ff7762c832bc3599f12353b3f577049200 --- /dev/null +++ b/research/cv/SPPNet/postprocess.py @@ -0,0 +1,52 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""post process for 310 inference""" +import os +import json +import argparse +import numpy as np + + +batch_size = 1 +parser = argparse.ArgumentParser(description="resnet inference") +parser.add_argument("--result_path", type=str, required=True, help="result files path.") +parser.add_argument("--label_path", type=str, required=True, help="image file path.") +args = parser.parse_args() + + +def get_result(result_path, label_path): + """get top1 acc / top5 acc result""" + files = os.listdir(result_path) + with open(label_path, "r") as label: + labels = json.load(label) + + top1 = 0 + top5 = 0 + total_data = len(files) + for file in files: + img_ids_name = file.split('_0.')[0] + data_path = os.path.join(result_path, img_ids_name + "_0.bin") + result = np.fromfile(data_path, dtype=np.float32).reshape(batch_size, 1000) + for batch in range(batch_size): + predict = np.argsort(-result[batch], axis=-1) + if labels[img_ids_name+".JPEG"] == predict[0]: + top1 += 1 + if labels[img_ids_name+".JPEG"] in predict[:5]: + top5 += 1 + print(f"Total data: {total_data}, top1 accuracy: {top1/total_data}, top5 accuracy: {top5/total_data}.") + + +if __name__ == '__main__': + get_result(args.result_path, args.label_path) diff --git a/research/cv/SPPNet/scripts/run_distribution_ascend.sh b/research/cv/SPPNet/scripts/run_distribution_ascend.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e2224bb58b12b8fc56d291ae345d95385eb6031 --- /dev/null +++ b/research/cv/SPPNet/scripts/run_distribution_ascend.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# an simple tutorial as follows, more parameters can be setting +if [ $# != 4 ] +then + echo "Usage: sh run_distribution_ascend.sh [RANK_TABLE_FILE] [TRAIN_DATA_PATH] [EVAL_DATA_PATH] [TRAIN_MODEL]" +exit 1 +fi + +if [ ! -f $1 ] +then + echo "error: RANK_TABLE_FILE=$1 is not a file" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +PATH1=$(get_real_path $2) +PATH2=$(get_real_path $3) + + +if [ ! -d $PATH1 ] +then + echo "error: TRAIN_DATA_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -d $PATH2 ] +then + echo "error: EVAL_DATA_PATH=$PATH2 is not a directory" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +RANK_TABLE_FILE=$(realpath $1) +export RANK_TABLE_FILE +export TRAIN_PATH=$2 +export EVAL_PATH=$3 +export TRAIN_MODEL=$4 +export BASE_PATH=${TRAIN_MODEL}"_train_parallel" +echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" + +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$((rank_start + i)) + rm -rf ./$BASE_PATH$i + mkdir ./$BASE_PATH$i + cp -r ../src ./$BASE_PATH$i + cp ../train.py ./$BASE_PATH$i + echo "start training for rank $RANK_ID, device $DEVICE_ID" + cd ./$BASE_PATH$i ||exit + env > env.log + python train.py --device_id=$i --train_path=$TRAIN_PATH --eval_path=$EVAL_PATH --device_num=$DEVICE_NUM --train_model=$TRAIN_MODEL > log 2>&1 & + cd .. +done \ No newline at end of file diff --git a/research/cv/SPPNet/scripts/run_infer_310.sh b/research/cv/SPPNet/scripts/run_infer_310.sh new file mode 100644 index 0000000000000000000000000000000000000000..c5586cd7fd4bb719e9fc7bafb83a54a0d860feef --- /dev/null +++ b/research/cv/SPPNet/scripts/run_infer_310.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [[ $# -lt 2 || $# -gt 3 ]]; then + echo "Usage: sh run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] + DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +model=$(get_real_path $1) +data_path=$(get_real_path $2) + +device_id=0 +if [ $# == 3 ]; then + device_id=$3 +fi + +echo "mindir name: "$model +echo "dataset path: "$data_path +echo "device id: "$device_id + +export ASCEND_HOME=/usr/local/Ascend/ +if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then + export PATH=$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH + export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe + export PYTHONPATH=${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp +else + export PATH=$ASCEND_HOME/atc/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH + export LD_LIBRARY_PATH=/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/opp +fi + +function compile_app() +{ + cd ../ascend310_infer/src/ || exit + if [ -f "Makefile" ]; then + make clean + fi + sh build.sh &> build.log +} + +function infer() +{ + cd - || exit + if [ -d result_Files ]; then + rm -rf ./result_Files + fi + if [ -d time_Result ]; then + rm -rf ./time_Result + fi + mkdir result_Files + mkdir time_Result + ../ascend310_infer/src/main --mindir_path=$model --dataset_path=$data_path --device_id=$device_id &> infer.log +} + +function cal_acc() +{ + python3.7 ../create_imagenet2012_label.py --img_path=$data_path + python3.7 ../postprocess.py --result_path=./result_Files --label_path=./imagenet_label.json &> acc.log & +} + +compile_app +if [ $? -ne 0 ]; then + echo "compile app code failed" + exit 1 +fi +infer +if [ $? -ne 0 ]; then + echo " execute inference failed" + exit 1 +fi +cal_acc +if [ $? -ne 0 ]; then + echo "calculate accuracy failed" + exit 1 +fi \ No newline at end of file diff --git a/research/cv/SPPNet/scripts/run_standalone_eval_ascend.sh b/research/cv/SPPNet/scripts/run_standalone_eval_ascend.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e7edb90bbae4c201131bd921533cfc845a6931c --- /dev/null +++ b/research/cv/SPPNet/scripts/run_standalone_eval_ascend.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# an simple tutorial as follows, more parameters can be setting +if [ $# != 4 ] +then + echo "Usage: sh run_standalone_eval_ascend.sh [TEST_DATA_PATH] [CKPT_PATH] [DEVICE_ID] [TEST_MODEL]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +PATH1=$(get_real_path $1) + +if [ ! -d $PATH1 ] +then + echo "error: TEST_DATA_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -f $2 ] +then + echo "error: CKPT_PATH=$2 is not a file" +exit 1 +fi + + +export DATA_PATH=$1 +export CKPT_PATH=$2 +export DEVICE_ID=$3 +export TEST_MODEL=$4 +echo "start evaluating for $TEST_MODEL" +python ../eval.py --data_path=$DATA_PATH --ckpt_path=$CKPT_PATH --device_id=$DEVICE_ID --test_model=$TEST_MODEL > eval_log 2>&1 & \ No newline at end of file diff --git a/research/cv/SPPNet/scripts/run_standalone_train_ascend.sh b/research/cv/SPPNet/scripts/run_standalone_train_ascend.sh new file mode 100644 index 0000000000000000000000000000000000000000..fe44105fa3b9ba2718c35b0cc49b3f538a663928 --- /dev/null +++ b/research/cv/SPPNet/scripts/run_standalone_train_ascend.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# an simple tutorial as follows, more parameters can be setting +if [ $# != 4 ] +then + echo "Usage: sh run_standalone_train_ascend.sh [TRAIN_DATA_PATH] [EVAL_DATA_PATH] [DEVICE_ID] [TRAIN_MODEL]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + +if [ ! -d $PATH1 ] +then + echo "error: TRAIN_DATA_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -d $PATH2 ] +then + echo "error: EVAL_DATA_PATH=$PATH2 is not a directory" +exit 1 +fi + + +export TRAIN_PATH=$1 +export EVAL_PATH=$2 +export DEVICE_ID=$3 +export TRAIN_MODEL=$4 +rm -rf ./$TRAIN_MODEL +mkdir ./$TRAIN_MODEL +cp -r ../src ./$TRAIN_MODEL +cp ../train.py ./$TRAIN_MODEL +echo "start training for $TRAIN_MODEL" +cd ./$TRAIN_MODEL ||exit +python train.py --device_id=$DEVICE_ID --train_path=$TRAIN_PATH --eval_path=$EVAL_PATH --train_model=$TRAIN_MODEL > log 2>&1 & \ No newline at end of file diff --git a/research/cv/SPPNet/src/__init__.py b/research/cv/SPPNet/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00d293a9a25ddbe585526ef9a76c8af5312f0b95 --- /dev/null +++ b/research/cv/SPPNet/src/__init__.py @@ -0,0 +1 @@ +"""init""" diff --git a/research/cv/SPPNet/src/config.py b/research/cv/SPPNet/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..4e93f44afa3376901f0dd5833269f94c52468109 --- /dev/null +++ b/research/cv/SPPNet/src/config.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py +""" + +from easydict import EasyDict as Edict + +zfnet_cfg = Edict({ + 'num_classes': 1000, + 'momentum': 0.9, + 'epoch_size': 150, + 'batch_size': 256, + 'image_height': 224, + 'image_width': 224, + 'warmup_epochs': 5, + 'iteration_max': 150, + "lr_init": 0.035, + "lr_min": 0.0, + # opt + 'weight_decay': 0.0001, + 'loss_scale': 1024, + # lr + 'is_dynamic_loss_scale': 0, +}) + +sppnet_single_cfg = Edict({ + 'num_classes': 1000, + 'momentum': 0.9, + 'epoch_size': 160, + 'batch_size': 256, + 'image_height': 224, + 'image_width': 224, + 'warmup_epochs': 0, + 'iteration_max': 150, + "lr_init": 0.01, + "lr_min": 0.0, + # opt + 'weight_decay': 0.0001, + 'loss_scale': 1024, + # lr + 'is_dynamic_loss_scale': 0, +}) + +sppnet_mult_cfg = Edict({ + 'num_classes': 1000, + 'momentum': 0.9, + 'epoch_size': 160, + 'batch_size': 128, + 'image_height': 224, + 'image_width': 224, + 'warmup_epochs': 2, + 'iteration_max': 150, + "lr_init": 0.01, + "lr_min": 0.0, + # opt + 'weight_decay': 0.0001, + 'loss_scale': 1024, + # lr + 'is_dynamic_loss_scale': 0, +}) diff --git a/research/cv/SPPNet/src/dataset.py b/research/cv/SPPNet/src/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..108114c962ccd51851f5e0b29d9b5eedec04a744 --- /dev/null +++ b/research/cv/SPPNet/src/dataset.py @@ -0,0 +1,105 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Produce the dataset +""" + +import os +import mindspore.dataset as ds +import mindspore.dataset.vision.c_transforms as CV +from mindspore.communication.management import get_rank, get_group_size + + +def create_dataset_imagenet(dataset_path, train_model_name='sppnet_single', + batch_size=256, training=True, + num_samples=None, workers=12, + shuffle=None, class_indexing=None, + sampler=None, image_size=224): + """ + create a train or eval imagenet2012 dataset for Sppnet + + Args: + dataset_path(string): the path of dataset. + train_model_name(string): model name for training + training(bool): whether dataset is used for train or eval. + batch_size(int): the batch size of dataset. Default: 128 + target(str): the device target. Default: Ascend + Returns: + dataset + """ + + rank_size = int(os.environ.get("RANK_SIZE", 1)) + num_parallel_workers = workers + + if rank_size > 1: + device_num = get_group_size() + rank_id = get_rank() + else: + device_num = 1 + rank_id = 0 + + if device_num == 1: + num_parallel_workers = 16 + ds.config.set_prefetch_size(8) + else: + ds.config.set_numa_enable(True) + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=4, + num_samples=num_samples, shuffle=shuffle, + sampler=sampler, class_indexing=class_indexing, + num_shards=device_num, shard_id=rank_id) + + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if training and image_size == 224: + if train_model_name == 'zfnet': + transform_img = [ + CV.RandomCropDecodeResize((224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.333)), + CV.RandomHorizontalFlip(prob=0.5), + CV.Normalize(mean=mean, std=std), + CV.HWC2CHW() + ] + else: + transform_img = [ + CV.RandomCropDecodeResize((224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.333)), + CV.RandomHorizontalFlip(prob=0.5), + CV.RandomColorAdjust(0.4, 0.4, 0.4, 0.1), + CV.Normalize(mean=mean, std=std), + CV.HWC2CHW() + ] + elif training and image_size == 180: + transform_img = [ + CV.RandomCropDecodeResize((224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.333)), + CV.Resize(180), + CV.RandomHorizontalFlip(prob=0.5), + CV.RandomColorAdjust(0.4, 0.4, 0.4, 0.1), + CV.Normalize(mean=mean, std=std), + CV.HWC2CHW() + ] + else: + transform_img = [ + CV.Decode(), + CV.Resize((256, 256)), + CV.CenterCrop(224), + CV.Normalize(mean=mean, std=std), + CV.HWC2CHW() + ] + + data_set = data_set.map(input_columns="image", operations=transform_img, + num_parallel_workers=num_parallel_workers) + data_set = data_set.batch(batch_size, drop_remainder=True) + + return data_set diff --git a/research/cv/SPPNet/src/eval_callback.py b/research/cv/SPPNet/src/eval_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9a773987fbfdd05e3c44e5bf5f06dd5ade8256 --- /dev/null +++ b/research/cv/SPPNet/src/eval_callback.py @@ -0,0 +1,176 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation callback when training""" + +import os +import stat +import time +from mindspore import save_checkpoint +from mindspore import log as logger +from mindspore.train.callback import Callback + + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `top1_acc , top5 acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, + save_best_ckpt=True, train_model_name="sppnet_single", + ckpt_directory="./ckpt", besk_ckpt_name="best.ckpt"): + super(EvalCallBack, self).__init__() + self.eval_param_dict = eval_param_dict + self.eval_function = eval_function + self.eval_start_epoch = eval_start_epoch + if interval < 1: + raise ValueError("interval should >= 1.") + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_epoch = 0 + if train_model_name == "sppnet_single": + self.best_res = {'top_1_accuracy': 0.6482, 'top_5_accuracy': 0.8566} + else: + self.best_res = {'top_1_accuracy': 0.6381, 'top_5_accuracy': 0.8504} + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) + self.metrics_name = {'top_1_accuracy', 'top_5_accuracy'} + + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + eval_start = time.time() + res = self.eval_function(self.eval_param_dict) + eval_cost = time.time() - eval_start + print("epoch: {}, {}: {}, eval_cost:{:.2f}" + .format(cur_epoch, self.metrics_name, res, eval_cost), flush=True) + if res['top_1_accuracy'] >= self.best_res['top_1_accuracy'] \ + and res['top_5_accuracy'] >= self.best_res['top_5_accuracy']: + self.best_res['top_1_accuracy'] = res['top_1_accuracy'] + self.best_res['top_5_accuracy'] = res['top_5_accuracy'] + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + if self.save_best_ckpt: + if os.path.exists(self.bast_ckpt_path): + self.remove_ckpoint_file(self.bast_ckpt_path) + save_checkpoint(cb_params.train_network, self.bast_ckpt_path) + print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) + + def end(self, run_context): + print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name, + self.best_res, + self.best_epoch), flush=True) + + +class EvalCallBackMult(Callback): + """ + sppnet mult rain Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `top1_acc , top5_acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, eval_function, eval_param_dict, interval=1, + eval_start_epoch=1, save_best_ckpt=True, + ckpt_directory="./ckpt", besk_ckpt_name="best.ckpt"): + super(EvalCallBackMult, self).__init__() + self.eval_param_dict = eval_param_dict + self.eval_function = eval_function + self.eval_start_epoch = eval_start_epoch + if interval < 1: + raise ValueError("interval should >= 1.") + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_epoch = 0 + self.best_res = {'top_1_accuracy': 0, 'top_5_accuracy': 0} + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name) + self.metrics_name = {'top_1_accuracy', 'top_5_accuracy'} + + def remove_ckpoint_file(self, file_name): + """ + Remove the specified checkpoint file from this checkpoint manager and also from the directory. + """ + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + eval_start = time.time() + res = self.eval_function(self.eval_param_dict) + eval_cost = time.time() - eval_start + print("epoch: {}, {}: {}, cost:{:.2f}".format(cur_epoch, self.metrics_name, res, eval_cost), flush=True) + if res['top_1_accuracy'] >= self.best_res['top_1_accuracy'] and res['top_5_accuracy'] >= \ + self.best_res['top_5_accuracy']: + self.best_res['top_1_accuracy'] = res['top_1_accuracy'] + self.best_res['top_5_accuracy'] = res['top_5_accuracy'] + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + if self.save_best_ckpt: + if os.path.exists(self.bast_ckpt_path): + self.remove_ckpoint_file(self.bast_ckpt_path) + save_checkpoint(cb_params.train_network, self.bast_ckpt_path) + print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True) + + def end(self, run_context): + print("=================================================") diff --git a/research/cv/SPPNet/src/generator_lr.py b/research/cv/SPPNet/src/generator_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..cadf09de75f8c02cb2d45bbf0c92e69d2b52c20a --- /dev/null +++ b/research/cv/SPPNet/src/generator_lr.py @@ -0,0 +1,43 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + """Linear learning rate""" + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, iteration_max, lr_min=0): + """ warmup cosine annealing lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr_min + (base_lr - lr_min) * (1. + math.cos(math.pi * last_epoch / iteration_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) diff --git a/research/cv/SPPNet/src/get_param_groups.py b/research/cv/SPPNet/src/get_param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..5260d241f3ca2d3a520f9ff3594461e4758e14db --- /dev/null +++ b/research/cv/SPPNet/src/get_param_groups.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""get parameters for Momentum optimizer""" + + +def get_param_groups(network, lr): + """get parameters""" + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0, "lr": lr}, {'params': decay_params, "lr": lr}] diff --git a/research/cv/SPPNet/src/spatial_pyramid_pooling.py b/research/cv/SPPNet/src/spatial_pyramid_pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..514ccc35bad43ba6c0de6311a21433fe5d4442ff --- /dev/null +++ b/research/cv/SPPNet/src/spatial_pyramid_pooling.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pool""" +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.ops import operations as P + + +class SpatialPyramidPool(nn.Cell): + """ + SpatialPyramidPool + """ + def __init__(self, previous_conv_size=13, out_pool_size=None): + ''' + args: + previous_conv_size(int): input feature map size + out_pool_size(tuple): output pooling size + e.g: input_size: (6, 3, 2, 1) out_pool_size=(6*6+3*3+2*2+1*1) * batch_size + ''' + super(SpatialPyramidPool, self).__init__() + + self.previous_conv_size = previous_conv_size + self.out_pool_size = out_pool_size + self.cat = ops.Concat(axis=1) + self.reshape = ops.Reshape() + self.maxpool_1x1 = ops.ReduceMax(keep_dims=True) + self.padding = P.Pad(((0, 0), (0, 0), (1, 1), (1, 1))) + + def construct(self, x): + """ + input: x (last_out_channels ,previous_conv_size,previous_conv_size) + + output: x (Vector) + """ + B, _, _, _ = ops.Shape()(x) + spp = None + for pool_count in range(len(self.out_pool_size)): + size = self.previous_conv_size / self.out_pool_size[pool_count] + + if size > size // 1: + size = size // 1 + 1 + else: + size = size // 1 + + stride = self.previous_conv_size / self.out_pool_size[pool_count] + stride = stride // 1 + + if self.out_pool_size[pool_count] == 1: + spp_temp = self.maxpool_1x1(x, (2, 3)) + elif self.out_pool_size[pool_count] == 6 and self.previous_conv_size == 10: + x_pad = self.padding(x) + spp_temp = nn.MaxPool2d(2, 2, "valid")(x_pad) + else: + spp_temp = nn.MaxPool2d(size, stride, "valid")(x) + + if pool_count == 0: + spp = self.reshape(spp_temp, (B, -1)) + else: + spp = self.cat((spp, self.reshape(spp_temp, (B, -1)))) + + return spp diff --git a/research/cv/SPPNet/src/sppnet.py b/research/cv/SPPNet/src/sppnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a59fb39d2c3c23a6e588d53cb00a1b9a69033c25 --- /dev/null +++ b/research/cv/SPPNet/src/sppnet.py @@ -0,0 +1,114 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +ZFNet +Original paper: 'Visualizing and Understanding Convolutional Networks,' https://arxiv.org/abs/1311.2901. +input size should be : (3 x 224 x 224) +""" +import mindspore.nn as nn +from mindspore.ops import operations as P +import mindspore.ops as ops +from src.spatial_pyramid_pooling import SpatialPyramidPool + + +class SppNet(nn.Cell): + """ + SppNet + base on zfnet + """ + def __init__(self, num_classes=10, channel=3, phase='train', include_top=True, train_model="sppnet_single"): + ''' + :param num_classes: picture classes + :param channel: obvious is 3 + :param phase: train or test + :param include_top: True + ''' + super(SppNet, self).__init__() + self.conv1 = nn.Conv2d(in_channels=channel, out_channels=96, kernel_size=7, stride=2, + pad_mode="same", has_bias=True) + self.conv2 = nn.Conv2d(96, 256, 5, stride=2, pad_mode="same", has_bias=True) + self.conv3 = nn.Conv2d(256, 384, 3, padding=1, pad_mode="pad", has_bias=True) + self.conv4 = nn.Conv2d(384, 384, 3, padding=1, pad_mode="pad", has_bias=True) + self.conv5 = nn.Conv2d(384, 256, 3, padding=1, pad_mode="pad", has_bias=True) + + self.LRN = P.LRN() + self.relu = P.ReLU() + + self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid') + + self.train_model = train_model + self.include_top = include_top + + self.spp_pool_224 = SpatialPyramidPool(13, (6, 3, 2, 1)) + self.spp_pool_180 = SpatialPyramidPool(10, (6, 3, 2, 1)) + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(in_channels=(6 * 6 + 3 * 3 + 2 * 2 + 1 * 1) * 256, out_channels=4096) + self.fc2 = nn.Dense(in_channels=4096, out_channels=4096) + self.fc3 = nn.Dense(in_channels=4096, out_channels=num_classes) + + if self.train_model == "zfnet": + dropout_ratio = 0.65 + self.fc1 = nn.Dense(in_channels=6 * 6 * 256, out_channels=4096) + elif self.train_model == "sppnet_single": + dropout_ratio = 0.59 + else: + dropout_ratio = 0.58 + if phase == 'test': + dropout_ratio = 1.0 + + self.dropout = nn.Dropout(dropout_ratio) + + def construct(self, x): + """ + input: x (3 * 224 * 224 or 3 * 180 * 180) + + output: x (1000) + """ + _, _, H, _ = ops.Shape()(x) + + x = self.conv1(x) + x = self.relu(x) + x = self.LRN(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.LRN(x) + x = self.max_pool2d(x) + x = self.conv3(x) + x = self.relu(x) + x = self.conv4(x) + x = self.relu(x) + x = self.conv5(x) + x = self.relu(x) + + if self.train_model == "zfnet": + x = self.max_pool2d(x) + x = self.flatten(x) + else: + if H == 224: + x = self.spp_pool_224(x) + elif H == 180: + x = self.spp_pool_180(x) + + if not self.include_top: + return x + x = self.fc1(x) + x = self.relu(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.relu(x) + x = self.dropout(x) + x = self.fc3(x) + return x diff --git a/research/cv/SPPNet/train.py b/research/cv/SPPNet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f00564b3876603bdd7d9e630aa384debe45c8554 --- /dev/null +++ b/research/cv/SPPNet/train.py @@ -0,0 +1,185 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## train SPPNet example ######################## +train SPPnet and get network model files(.ckpt) : +python train.py --train_path /YourDataPath --eval_path /YourValPath --device_id YourAscendId --train_model model +""" +import ast +import argparse +import os +import mindspore.nn as nn +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore import dataset as de +from mindspore import context +from mindspore import Tensor +from mindspore.train import Model +from mindspore.context import ParallelMode +from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager +from mindspore.train.callback import LossMonitor, TimeMonitor +from mindspore.common import set_seed +from src.config import sppnet_mult_cfg, sppnet_single_cfg, zfnet_cfg +from src.dataset import create_dataset_imagenet +from src.generator_lr import warmup_cosine_annealing_lr +from src.sppnet import SppNet +from src.eval_callback import EvalCallBack, EvalCallBackMult + + +set_seed(44) +de.config.set_seed(44) +parser = argparse.ArgumentParser(description='MindSpore SPPNet') +parser.add_argument('--sink_size', type=int, default=-1, help='control the amount of data in each sink') +parser.add_argument('--train_model', type=str, default='sppnet_single', help='chose the training model', + choices=['sppnet_single', 'sppnet_mult', 'zfnet']) +parser.add_argument('--device_target', type=str, default="Ascend", + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--train_path', type=str, + default="./imagenet_original/train", + help='path where the train dataset is saved') +parser.add_argument('--eval_path', type=str, + default="./imagenet_original/val", + help='path where the validate dataset is saved') +parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\ + path where the trained ckpt file') +parser.add_argument('--dataset_sink_mode', type=ast.literal_eval, + default=True, help='dataset_sink_mode is False or True') +parser.add_argument('--device_id', type=int, default=0, help='device id of Ascend. (Default: 0)') +parser.add_argument('--device_num', type=int, default=1) +args = parser.parse_args() + + +def apply_eval(eval_param): + """construct eval function""" + eval_model = eval_param["model"] + eval_ds = eval_param["dataset"] + res = eval_model.eval(eval_ds) + return res + + +if __name__ == "__main__": + + device_num = args.device_num + device_target = args.device_target + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + context.set_context(save_graphs=False) + + if device_target == "Ascend": + context.set_context(device_id=args.device_id) + + if device_num > 1: + init() + device_num = get_group_size() + print("device_num:", device_num) + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + global_rank=args.device_id, + gradients_mean=True) + else: + raise ValueError("Unsupported platform.") + + if args.train_model == "zfnet": + cfg = zfnet_cfg + ds_train = create_dataset_imagenet(args.train_path, 'zfnet', cfg.batch_size) + network = SppNet(cfg.num_classes, phase='train', train_model=args.train_model) + prefix = "checkpoint_zfnet" + elif args.train_model == "sppnet_single": + cfg = sppnet_single_cfg + ds_train = create_dataset_imagenet(args.train_path, cfg.batch_size) + network = SppNet(cfg.num_classes, phase='train', train_model=args.train_model) + prefix = "checkpoint_sppnet" + else: + cfg = sppnet_mult_cfg + ds_train = create_dataset_imagenet(args.train_path, 'sppnet_mult', cfg.batch_size) + network = SppNet(cfg.num_classes, phase='train', train_model=args.train_model) + prefix = "checkpoint_sppnet" + + if ds_train.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") + + loss_scale_manager = None + metrics = {'top_1_accuracy', 'top_5_accuracy'} + step_per_epoch = ds_train.get_dataset_size() if args.sink_size == -1 else args.sink_size + + # loss function + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + + # learning rate generator + lr = Tensor(warmup_cosine_annealing_lr(lr=cfg.lr_init, steps_per_epoch=step_per_epoch, + warmup_epochs=cfg.warmup_epochs, max_epoch=cfg.epoch_size, + iteration_max=cfg.iteration_max, lr_min=cfg.lr_min)) + + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + no_decay_params.append(x) + else: + decay_params.append(x) + + params = [{'params': no_decay_params, 'weight_decay': 0.0, "lr": lr}, {'params': decay_params, "lr": lr}] + + # Optimizer + opt = nn.Momentum(params=params, + learning_rate=lr, + momentum=cfg.momentum, + weight_decay=cfg.weight_decay, + loss_scale=cfg.loss_scale) + + if cfg.is_dynamic_loss_scale == 1: + loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) + else: + loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) + + if device_target == "Ascend": + model = Model(network, loss_fn=loss, optimizer=opt, metrics=metrics, amp_level="O2", keep_batchnorm_fp32=False, + loss_scale_manager=loss_scale_manager) + else: + raise ValueError("Unsupported platform.") + + if device_num > 1: + ckpt_save_dir = os.path.join(args.ckpt_path + "_" + str(get_rank())) + else: + ckpt_save_dir = os.path.join(args.ckpt_path) + + # callback + eval_dataset = create_dataset_imagenet(args.eval_path, cfg.batch_size, training=False) + evalParamDict = {"model": model, "dataset": eval_dataset} + if args.train_model == "sppnet_mult": + eval_cb = EvalCallBackMult(apply_eval, evalParamDict, eval_start_epoch=1) + else: + eval_cb = EvalCallBack(apply_eval, evalParamDict, eval_start_epoch=1, train_model_name=args.train_model) + loss_cb = LossMonitor(per_print_times=step_per_epoch) + time_cb = TimeMonitor(data_size=step_per_epoch) + + print("============== Starting Training ==============") + + if args.train_model == "sppnet_mult": + ds_train_mult_size = create_dataset_imagenet(args.train_path, 'sppnet_mult', cfg.batch_size, + training=True, image_size=180) + for per_epoch in range(cfg.epoch_size): + print("================ Epoch:{} ==================".format(per_epoch+1)) + if per_epoch % 2 == 0: + cb = [time_cb, loss_cb, eval_cb] + model.train(1, ds_train, callbacks=cb, dataset_sink_mode=False, sink_size=args.sink_size) + else: + cb = [time_cb, loss_cb] + model.train(1, ds_train_mult_size, callbacks=cb, dataset_sink_mode=False, sink_size=args.sink_size) + else: + cb = [time_cb, loss_cb, eval_cb] + model.train(cfg.epoch_size, ds_train, callbacks=cb, dataset_sink_mode=True, sink_size=args.sink_size)