diff --git a/research/cv/tgcn/README_CN.md b/research/cv/tgcn/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..1f1f9d241c145cedadb71876cebff6c92435adca --- /dev/null +++ b/research/cv/tgcn/README_CN.md @@ -0,0 +1,377 @@ +# 鐩綍 + +- [T-GCN姒傝堪](#T-GCN姒傝堪) +- [妯″瀷鏋舵瀯](#妯″瀷鏋舵瀯) +- [鏁版嵁闆哴(#鏁版嵁闆�) +- [鐜瑕佹眰](#鐜瑕佹眰) +- [蹇€熷紑濮媇(#蹇€熷紑濮�) +- [鑴氭湰璇存槑](#鑴氭湰璇存槑) + - [鑴氭湰鍙婃牱渚嬩唬鐮乚(#鑴氭湰鍙婃牱渚嬩唬鐮�) + - [鑴氭湰鍙傛暟](#鑴氭湰鍙傛暟) + - [璁粌娴佺▼](#璁粌娴佺▼) + - [杩愯](#杩愯) + - [缁撴灉](#缁撴灉) + - [璇勪及娴佺▼](#璇勪及娴佺▼) + - [杩愯](#杩愯) + - [缁撴灉](#缁撴灉) + - [MINDIR妯″瀷瀵煎嚭娴佺▼](#MINDIR妯″瀷瀵煎嚭娴佺▼) + - [杩愯](#杩愯) + - [缁撴灉](#缁撴灉) +- [妯″瀷璇存槑](#妯″瀷璇存槑) + - [璁粌鎬ц兘](#璁粌鎬ц兘) + - [璇勪及鎬ц兘](#璇勪及鎬ц兘) +- [闅忔満鎯呭喌璇存槑](#闅忔満鎯呭喌璇存槑) +- [ModelZoo涓婚〉](#ModelZoo涓婚〉) + +# [T-GCN姒傝堪](#鐩綍) + +鏃堕棿鍥惧嵎绉綉缁滐紙Temporal Graph Convolutional Network锛孴-GCN锛夋ā鍨嬶紝绠€绉癟-GCN妯″瀷锛屾槸Zhao L绛変汉鎻愬嚭鐨勪竴绉嶉€傜敤浜庡煄甯傞亾璺氦閫氶娴嬬殑妯″瀷銆傛墍璋撲氦閫氶娴嬶紝鍗冲熀浜庨亾璺巻鍙蹭氦閫氫俊鎭紝瀵逛竴瀹氭椂鏈熷唴鐨勪氦閫氫俊鎭繘琛岄娴嬶紝鍖呮嫭浣嗕笉闄愪簬浜ら€氶€熷害銆佷氦閫氭祦閲忋€佷氦閫氬瘑搴︾瓑淇℃伅銆� + +[璁烘枃](https://arxiv.org/pdf/1811.05320.pdf)锛歓hao L, Song Y, Zhang C, et al. T-gcn: A temporal graph convolutional network for traffic prediction[J]. IEEE Transactions on Intelligent Transportation Systems, 2019, 21(9): 3848-3858. + +# [妯″瀷鏋舵瀯](#鐩綍) + +T-GCN妯″瀷涓昏鐢变袱澶фā鍧楁瀯鎴愶紝鍒嗗埆涓哄浘鍗风Н缃戠粶锛圙raph Convolutional Network锛孏CN锛変笌闂ㄦ帶寰幆鍗曞厓锛圙ated Recurrent Unit锛孏RU锛夈€� + +妯″瀷鏁翠綋澶勭悊娴佺▼濡備笅锛氳緭鍏缁勫巻鍙叉椂闂村簭鍒楁暟鎹紝鍒╃敤鍥惧嵎绉綉缁滄崟鑾峰煄甯傝矾缃戞嫇鎵戠粨鏋勶紝浠ヨ幏鍙栨暟鎹殑绌洪棿鐗瑰緛銆傚啀灏嗗緱鍒扮殑鍏锋湁绌洪棿鐗瑰緛鐨勬暟鎹緭鍏ラ棬鎺у惊鐜崟鍏冿紝鍒╃敤鍗曞厓闂寸殑淇℃伅浼犻€掓崟鑾锋暟鎹殑鍔ㄦ€佸彉鍖栵紝浠ヨ幏鍙栨暟鎹殑鏃堕棿鐗瑰緛銆傛渶鍚庯紝缁忚繃鍏ㄨ繛鎺ュ眰锛岃緭鍑烘渶缁堥娴嬬粨鏋溿€� + +鍏朵腑锛孏CN妯″潡閫氳繃鍦ㄥ倕閲屽彾鍩熶腑鏋勯€犱竴涓綔鐢ㄤ簬鍥炬暟鎹殑鑺傜偣鍙婂叾涓€闃堕偦鍩熺殑婊ゆ尝鍣ㄦ潵鎹曡幏鑺傜偣闂寸殑绌洪棿鐗瑰緛锛屼箣鍚庡湪鍏朵笂鍙犲姞澶氫釜鍗风Н灞傛潵瀹炵幇銆侴CN妯″潡鍙鍩庡競涓績閬撹矾涓庡叾鍛ㄥ洿閬撹矾闂寸殑鎷撴墤缁撴瀯鍙婇亾璺睘鎬у疄鐜扮紪鐮侊紝鎹曡幏鏁版嵁鐨勭┖闂寸浉鍏虫€с€傝€孏RU妯″潡鍒欐槸浣滀负涓€绉嶇粡鍏哥殑閫掑綊绁炵粡缃戠粶鍙樹綋鏉ユ崟鑾蜂氦閫氭祦閲忔暟鎹腑鐨勬椂闂寸浉鍏虫€с€傝妯″潡浣跨敤闂ㄦ帶鏈哄埗鏉ヨ蹇嗗敖鍙兘澶氱殑闀挎湡淇℃伅锛屼笖缁撴瀯鐩稿绠€鍗曪紝鍙傛暟杈冨皯锛岃缁冮€熷害杈冨揩锛屽彲浠ュ湪鎹曡幏褰撳墠鏃跺埢浜ら€氫俊鎭殑鍚屾椂锛屼粛鐒朵繚鎸佸巻鍙蹭氦閫氫俊鎭殑鍙樺寲瓒嬪娍锛屽叿鏈夋崟鑾锋暟鎹殑鏃堕棿鐩稿叧鎬х殑鑳藉姏銆� + +# [鏁版嵁闆哴(#鐩綍) + +- 鏁版嵁闆嗭細瀹為獙鍩轰簬涓ゅぇ鐢辩幇瀹為噰闆嗙殑[SZ-taxi鏁版嵁闆哴(https://github.com/lehaifeng/T-GCN/tree/master/T-GCN/T-GCN-PyTorch/data)涓嶽Los-loop鏁版嵁闆哴(https://github.com/lehaifeng/T-GCN/tree/master/T-GCN/T-GCN-PyTorch/data)銆� + +锛�1锛塖Z-taxi鏁版嵁闆嗛€夊彇娣卞湷甯傜綏婀栧尯鐨�156鏉′富瑕佸煄甯傞亾璺负鐮旂┒鍖哄煙锛岃褰曚簡2015骞�1鏈�1鏃ヨ嚦1鏈�31鏃ョ殑鍑虹杞﹁繍琛岃建杩广€傝鏁版嵁闆嗕富瑕佸寘鍚袱涓儴鍒嗭紝涓€鏄褰曚簡鍩庡競閬撹矾闂存嫇鎵戝叧绯荤殑涓€涓�156*156澶у皬鐨勯偦鎺ョ煩闃碉紝鍏朵腑姣忚浠h〃涓€鏉¢亾璺紝鐭╅樀涓殑鍊艰〃绀洪亾璺棿鐨勮繛鎺ャ€備簩鏄褰曚簡姣忎竴鏉¢亾璺笂閫熷害鍊奸殢鏃堕棿鍙樺寲鐨勭壒寰佺煩闃碉紝鍏朵腑姣忚浠h〃涓€鏉¢亾璺紝姣忓垪涓轰笉鍚屾椂闂存閬撹矾涓婄殑浜ら€氶€熷害锛屾瘡15鍒嗛挓璁板綍涓€娆°€� + +锛�2锛塋os-loop鏁版嵁闆嗙敱娲涙潐鐭堕珮閫熷叕璺笂鍏辫207涓幆褰㈡帰娴嬪櫒浜�2012骞�3鏈�1鏃ヨ嚦2012骞�3鏈�7鏃ュ疄鏃堕噰闆嗗緱鍒帮紝鏁版嵁姣�5鍒嗛挓璁板綍涓€娆°€備笌SZ-taxi鏁版嵁闆嗙浉浼硷紝璇ユ暟鎹泦涓昏鍖呭惈閭绘帴鐭╅樀涓庣壒寰佺煩闃典袱涓儴鍒嗭紝閭绘帴鐭╅樀涓殑鍊肩敱鎺㈡祴鍣ㄤ箣闂寸殑璺濈璁$畻寰楀埌銆傜敱浜庤鏁版嵁闆嗕腑瀛樺湪鏁版嵁缂哄け锛屽洜姝よ鏂囦綔鑰呴噰鐢ㄧ嚎鎬ф彃鍊肩殑鏂规硶杩涜浜嗙己澶卞€煎~鍏呫€� + +- 鏁版嵁澶勭悊锛氳緭鍏ユ暟鎹褰掍竴鍖栬嚦[0,1]鍖洪棿锛屽苟鍒掑垎鍏朵腑鐨�80%浣滆缁冮泦锛�20%浣滄祴璇曢泦锛屾潵鍒嗗埆棰勬祴鏈潵15鍒嗛挓銆�30鍒嗛挓銆�45鍒嗛挓銆�60鍒嗛挓鐨勪氦閫氶€熷害銆� + +# [鐜瑕佹眰](#鐩綍) + +- 纭欢锛圓scend / GPU锛� + - 闇€瑕佸噯澶囧叿鏈堿scend鎴朑PU澶勭悊鑳藉姏鐨勭‖浠剁幆澧冦€� +- 妗嗘灦 + - [MindSpore](https://www.mindspore.cn/) +- 濡傞渶鑾峰彇鏇村淇℃伅锛岃鏌ョ湅濡備笅閾炬帴锛� + - [MindSpore 鏁欑▼](https://www.mindspore.cn/tutorials/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/api/zh-CN/master/index.html) + +# [蹇€熷紑濮媇(#鐩綍) + +閫氳繃瀹樻柟鎸囧崡[瀹夎MindSpore](https://www.mindspore.cn/install)鍚庯紝涓嬭浇[鏁版嵁闆哴(https://github.com/lehaifeng/T-GCN/tree/master/T-GCN/T-GCN-PyTorch/data)锛屽皢涓嬭浇濂界殑鏁版嵁闆嗘寜濡備笅鐩綍缁撴瀯杩涜缁勭粐锛屼篃鍙寜姝ょ粨鏋勮嚜琛屾坊鍔犳暟鎹泦锛� + +```python +. +鈹斺攢tgcn + 鈹溾攢data + 鈹溾攢SZ-taxi + 鈹溾攢adj.csv # 閭绘帴鐭╅樀 + 鈹斺攢feature.csv # 鐗瑰緛鐭╅樀 + 鈹溾攢Los-loop + 鈹溾攢adj.csv # 閭绘帴鐭╅樀 + 鈹斺攢feature.csv # 鐗瑰緛鐭╅樀 +... +``` + +鍑嗗濂芥暟鎹泦鍚庯紝鍗冲彲鎸夐『搴忎緷娆¤繘琛屾ā鍨嬭缁冧笌璇勪及/瀵煎嚭鎿嶄綔锛� + +- 璁粌锛� + +```python +# 鍗曞崱璁粌 +bash ./scripts/run_standalone_train.sh [DEVICE_ID] + +# Ascend澶氬崱璁粌 +bash ./scripts/run_distributed_train_ascend.sh [RANK_TABLE] [RANK_SIZE] [DEVICE_START] [DATA_PATH] +``` + +绀轰緥锛� + + ```python + # 鍗曞崱璁粌 + bash ./scripts/run_standalone_train.sh 0 + + # Ascend澶氬崱璁粌锛�8鍗★級 + bash ./scripts/run_distributed_train_ascend.sh ./rank_table_8pcs.json 8 0 ./data + ``` + +- 璇勪及锛� + +```python +# 璇勪及 +bash ./scripts/run_eval.sh [DEVICE_ID] +``` + +绀轰緥锛� + + ```python + # 璇勪及 + bash ./scripts/run_eval.sh 0 + ``` + +- MINDIR妯″瀷瀵煎嚭 + +```python +# MINDIR妯″瀷瀵煎嚭 +bash ./scripts/run_export.sh [DEVICE_ID] +``` + +绀轰緥锛� + +```python +# MINDIR妯″瀷瀵煎嚭 +bash ./scripts/run_export.sh 0 +``` + +# [鑴氭湰璇存槑](#鐩綍) + +## [鑴氭湰鍙婃牱渚嬩唬鐮乚(#鐩綍) + +```python +. +鈹斺攢tgcn + 鈹溾攢README_CN.md # 涓枃鎸囧崡 + 鈹溾攢requirements.txt # pip渚濊禆鏂囦欢 + 鈹溾攢scripts + 鈹溾攢run_distributed_train_ascend.sh # Ascend澶氬崱璁粌杩愯鑴氭湰 + 鈹溾攢run_eval.sh # 璇勪及杩愯鑴氭湰 + 鈹溾攢run_export.sh # MINDIR妯″瀷瀵煎嚭杩愯鑴氭湰 + 鈹斺攢run_standalone_train.sh # 鍗曞崱璁粌杩愯鑴氭湰 + 鈹溾攢src + 鈹溾攢model + 鈹溾攢__init__.py + 鈹溾攢graph_conv.py # 鍥惧嵎绉绠� + 鈹溾攢loss.py # 鑷畾涔夋崯澶卞嚱鏁� + 鈹斺攢tgcn.py # T-GCN妯″瀷鏋舵瀯 + 鈹溾攢__init__.py + 鈹溾攢callback.py # 鑷畾涔夊洖璋冨嚱鏁� + 鈹溾攢config.py # 妯″瀷鍙傛暟璁惧畾 + 鈹溾攢dataprocess.py # 鏁版嵁澶勭悊妯″潡 + 鈹溾攢metrics.py # 妯″瀷璇勪及鎸囨爣 + 鈹斺攢task.py # 鐩戠潱棰勬祴浠诲姟 + 鈹溾攢eval.py # 璇勪及 + 鈹溾攢export.py # MINDIR妯″瀷瀵煎嚭 + 鈹斺攢train.py # 璁粌 +``` + +## [鑴氭湰鍙傛暟](#鐩綍) + +- 璁粌銆佽瘎浼般€丮INDIR妯″瀷瀵煎嚭绛夋搷浣滅浉鍏冲弬鏁扮殕鍦╜config.py`鑴氭湰涓瀹氾細 + +```python +class ConfigTGCN: + device = 'Ascend' + seed = 1 + dataset = 'SZ-taxi' + hidden_dim = 100 + seq_len = 4 + pre_len = 1 + train_split_rate = 0.8 + epochs = 3000 + batch_size = 64 + learning_rate = 0.001 + weight_decay = 1.5e-3 + data_sink = True +``` + +濡傞渶鏌ラ槄鐩稿叧鍙傛暟淇℃伅璇存槑锛岃鍙傞槄`config.py`鑴氭湰鍐呭銆� + +## [璁粌娴佺▼](#鐩綍) + +### [杩愯](#鐩綍) + +寮€濮嬭缁冨墠锛岃纭宸插湪`config.py`鑴氭湰涓畬鎴愮浉鍏宠缁冨弬鏁拌瀹氾紝鍦ㄥ悓涓€浠诲姟涓嬶紝鍚庣画璇勪及娴佺▼涓嶮INDIR妯″瀷瀵煎嚭娴佺▼璇蜂繚鎸佸弬鏁颁竴鑷淬€� + +```python +# 鍗曞崱璁粌 +# 鐢ㄦ硶锛� +bash ./scripts/run_standalone_train.sh [DEVICE_ID] +# 绀轰緥锛� +bash ./scripts/run_standalone_train.sh 0 + +# Ascend澶氬崱璁粌 +# 鐢ㄦ硶锛� +bash ./scripts/run_distributed_train_ascend.sh [RANK_TABLE] [RANK_SIZE] [DEVICE_START] [DATA_PATH] +# 绀轰緥锛�8鍗★級锛� +bash ./scripts/run_distributed_train_ascend.sh ./rank_table_8pcs.json 8 0 ./data +``` + +鍗曞崱璁粌涓璥[DEVICE_ID]`涓鸿缁冩墍璋冪敤鍗$殑鍗″彿銆� + +Ascend澶氬崱璁粌涓璥[RANK_TABLE]`涓虹浉搴擱ANK_TABLE_FILE鏂囦欢璺緞锛堝8鍗¤缁冧娇鐢ㄧ殑`./rank_table_8pcs.json`锛夛紝RANK_TABLE_FILE鍙寜[姝ゆ柟娉昡(https://gitee.com/mindspore/models/tree/master/utils/hccl_tools)鐢熸垚銆俙[RANK_SIZE]`涓鸿缁冩墍璋冪敤鍗$殑鏁伴噺锛宍[DEVICE_START]`涓鸿捣濮嬪崱鍙凤紝`[DATA_PATH]`涓烘暟鎹泦瀛樻斁鏍圭洰褰曘€� + +### [缁撴灉](#鐩綍) + +璁粌鏃讹紝褰撳墠璁粌杞鏁帮紝妯″瀷鎹熷け鍊硷紝姣忚疆娆¤繍琛屾椂闂寸瓑鏈夊叧淇℃伅浼氫互濡備笅褰㈠紡灞曠ず锛� + + ```python + ==========Training Start========== + epoch: 1 step: 37, loss is 47.07869 + epoch time: 20385.370 ms, per step time: 550.956 ms + RMSE eval: 8.408103 + Best checkpoint saved! + epoch: 2 step: 37, loss is 26.325077 + epoch time: 607.063 ms, per step time: 16.407 ms + RMSE eval: 6.355909 + Best checkpoint saved! + epoch: 3 step: 37, loss is 24.1607 + epoch time: 606.936 ms, per step time: 16.404 ms + RMSE eval: 6.126811 + Best checkpoint saved! + epoch: 4 step: 37, loss is 23.835127 + epoch time: 606.999 ms, per step time: 16.405 ms + RMSE eval: 6.077283 + Best checkpoint saved! + epoch: 5 step: 37, loss is 23.536343 + epoch time: 606.879 ms, per step time: 16.402 ms + RMSE eval: 6.035936 + Best checkpoint saved! + epoch: 6 step: 37, loss is 23.218105 + epoch time: 606.861 ms, per step time: 16.402 ms + RMSE eval: 5.993234 + Best checkpoint saved! + ... + ``` + +鍗曞崱璁粌灏嗕細鎶婁笂杩颁俊鎭互杩愯鏃ュ織鐨勫舰寮忎繚瀛樿嚦`./logs/train.log`锛屼笖妯″瀷浼氫互瑕嗙洊鐨勫舰寮忚嚜鍔ㄤ繚瀛樻渶浼樻鏌ョ偣锛�.ckpt 鏂囦欢锛変簬`./checkpoints`鐩綍涓嬶紝渚涘悗缁瘎浼颁笌妯″瀷瀵煎嚭娴佺▼鍔犺浇浣跨敤锛堝`./checkpoints/SZ-taxi_1.ckpt`锛夈€� + +Ascend澶氬崱璁粌涓庡崟鍗¤缁冩墍灞曠ず淇℃伅鐨勫舰寮忓熀鏈竴鑷达紝杩愯鏃ュ織鍙婃渶浼樻鏌ョ偣灏嗕繚瀛樺湪浠ュ搴斿崱鍙稩D鍛藉悕鐨刞./device{ID}`鐩綍涓嬶紙濡俙./device0/logs/train.log`涓巂./device0/checkpoints/SZ-taxi_1.ckpt`锛夈€� + +## [璇勪及娴佺▼](#鐩綍) + +### [杩愯](#鐩綍) + +鍦ㄥ畬鎴愯缁冩祦绋嬬殑鍩虹涓婏紝璇勪及娴佺▼灏嗚嚜鍔ㄤ粠`./checkpoints`鐩綍鍔犺浇瀵瑰簲浠诲姟鐨勬渶浼樻鏌ョ偣锛�.ckpt 鏂囦欢锛夌敤浜庢ā鍨嬭瘎浼般€� + +```python +# 璇勪及 +# 鐢ㄦ硶锛� +bash ./scripts/run_eval.sh [DEVICE_ID] +# 绀轰緥锛� +bash ./scripts/run_eval.sh 0 +``` + +### [缁撴灉](#鐩綍) + +璁粌鍚庢ā鍨嬪湪楠岃瘉闆嗕笂鐨勭浉鍏虫寚鏍囪瘎浼扮粨鏋滃皢浠ュ涓嬪舰寮忓睍绀猴紝涓斾互杩愯鏃ュ織鐨勫舰寮忎繚瀛樿嚦`./logs/eval.log`锛� + + ```python + =====Evaluation Results===== + RMSE: 4.083120 + MAE: 2.730229 + Accuracy: 0.715577 + R2: 0.847140 + Var: 0.847583 + ============================ + ``` + +## [MINDIR妯″瀷瀵煎嚭娴佺▼](#鐩綍) + +### [杩愯](#鐩綍) + +鍦ㄥ畬鎴愯缁冩祦绋嬬殑鍩虹涓婏紝MINDIR妯″瀷瀵煎嚭娴佺▼灏嗚嚜鍔ㄤ粠`./checkpoints`鐩綍鍔犺浇瀵瑰簲浠诲姟鐨勬渶浼樻鏌ョ偣锛�.ckpt 鏂囦欢锛夌敤浜庡搴擬INDIR妯″瀷瀵煎嚭銆� + +```python +# MINDIR妯″瀷瀵煎嚭 +# 鐢ㄦ硶锛� +bash ./scripts/run_export.sh [DEVICE_ID] +# 绀轰緥锛� +bash ./scripts/run_export.sh 0 +``` + +### [缁撴灉](#鐩綍) + +鑻ユā鍨嬪鍑烘垚鍔燂紝绋嬪簭灏嗕互濡備笅褰㈠紡灞曠ず锛屼笖浠ヨ繍琛屾棩蹇楃殑褰㈠紡淇濆瓨鑷砢./logs/export.log`锛� + +```python +========================================== +SZ-taxi_1.mindir exported successfully! +========================================== +``` + +鍚屾椂MINDIR妯″瀷鏂囦欢灏嗗鍑鸿嚦`./outputs`鐩綍涓嬶紝渚涘悗缁繘涓€姝ヤ娇鐢紙濡俙./outputs/SZ-taxi_1.mindir`锛夈€� + +# [妯″瀷璇存槑](#鐩綍) + +## [璁粌鎬ц兘](#鐩綍) + +- 涓嬭〃涓缁冩€ц兘鐢盩-GCN妯″瀷鍩轰簬SZ-taxi鏁版嵁闆嗗垎鍒娴嬫湭鏉�15鍒嗛挓銆�30鍒嗛挓銆�45鍒嗛挓銆�60鍒嗛挓锛堝嵆pre_len鍒嗗埆鍙�1銆�2銆�3銆�4锛夌殑浜ら€氶€熷害寰楀埌锛岀浉鍏虫寚鏍囦负4缁勮缁冧换鍔″钩鍧囧€硷細 + +| 鍙傛暟 | Ascend | +| -------------------------- | -----------------------------------------------------------| +| 妯″瀷鍚嶇О | T-GCN | +| 杩愯鐜 | 鎿嶄綔绯荤粺 Euler 2.8锛汚scend 910锛涘鐞嗗櫒 2.60GHz锛�192鏍稿績锛涘唴瀛橈紝755G | +| 涓婁紶鏃ユ湡 | 2021-09-30 | +| MindSpore鐗堟湰 | 1.3.0 | +| 鏁版嵁闆� | SZ-taxi锛坔idden_dim=100锛泂eq_len=4锛� | +| 璁粌鍙傛暟 | seed=1锛沞poch=3000锛沚atch_size = 64锛沴r=0.001锛泃rain_split_rate = 0.8锛泈eight_decay = 1.5e-3 | +| 浼樺寲鍣� | Adam with Weight Decay | +| 鎹熷け鍑芥暟 | 鑷畾涔夋崯澶卞嚱鏁� | +| 杈撳嚭 | 浜ら€氶€熷害棰勬祴鍊� | +| 骞冲潎妫€鏌ョ偣锛�.ckpt 鏂囦欢锛夊ぇ灏� | 839 KB | +| 骞冲潎鎬ц兘 | 鍗曞崱锛�23姣/姝ワ紝871姣/杞紱8鍗★細25姣/姝ワ紝101姣/杞� | +| 骞冲潎鎬昏€楁椂 | 鍗曞崱锛�49鍒�19绉掞紱8鍗★細11鍒�35绉� | +| 鑴氭湰 | [璁粌鑴氭湰](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn/train.py) | + +- 涓嬭〃涓缁冩€ц兘鐢盩-GCN妯″瀷鍩轰簬Los-loop鏁版嵁闆嗗垎鍒娴嬫湭鏉�15鍒嗛挓銆�30鍒嗛挓銆�45鍒嗛挓銆�60鍒嗛挓锛堝嵆pre_len鍒嗗埆鍙�3銆�6銆�9銆�12锛夌殑浜ら€氶€熷害寰楀埌锛岀浉鍏虫寚鏍囦负4缁勮缁冧换鍔″钩鍧囧€硷細 + +| 鍙傛暟 | Ascend | +| -------------------------- | -----------------------------------------------------------| +| 妯″瀷鍚嶇О | T-GCN | +| 杩愯鐜 | 鎿嶄綔绯荤粺 Euler 2.8锛汚scend 910锛涘鐞嗗櫒 2.60GHz锛�192鏍稿績锛涘唴瀛橈紝755G | +| 涓婁紶鏃ユ湡 | 2021-09-30 | +| MindSpore鐗堟湰 | 1.3.0 | +| 鏁版嵁闆� | Los-loop锛坔idden_dim=64锛泂eq_len=12锛� | +| 璁粌鍙傛暟 | seed=1锛沞poch=3000锛沚atch_size = 64锛沴r=0.001锛泃rain_split_rate = 0.8锛泈eight_decay = 1.5e-3 | +| 浼樺寲鍣� | Adam with Weight Decay | +| 鎹熷け鍑芥暟 | 鑷畾涔夋崯澶卞嚱鏁� | +| 杈撳嚭 | 浜ら€氶€熷害棰勬祴鍊� | +| 骞冲潎妫€鏌ョ偣锛�.ckpt 鏂囦欢锛夊ぇ灏� | 993KB | +| 骞冲潎鎬ц兘 | 鍗曞崱锛�44姣/姝ワ紝1066姣/杞紱8鍗★細46姣/姝ワ紝139姣/杞� | +| 骞冲潎鎬昏€楁椂 | 鍗曞崱锛�1鏃�00鍒�40绉掞紱8鍗★細15鍒�05绉� | +| 鑴氭湰 | [璁粌鑴氭湰](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn/train.py) | + +## [璇勪及鎬ц兘](#鐩綍) + +- 涓嬭〃涓瘎浼版€ц兘鐢盩-GCN妯″瀷鍩轰簬SZ-taxi鏁版嵁闆嗗垎鍒娴嬫湭鏉�15鍒嗛挓銆�30鍒嗛挓銆�45鍒嗛挓銆�60鍒嗛挓锛堝嵆pre_len鍒嗗埆鍙�1銆�2銆�3銆�4锛夌殑浜ら€氶€熷害寰楀埌锛岀浉鍏虫寚鏍囦负4缁勮瘎浼颁换鍔″钩鍧囧€硷細 + +| 鍙傛暟 | Ascend| +| ------------------- | ---------------------------| +| 妯″瀷鍚嶇О | T-GCN | +| 杩愯鐜 | 鎿嶄綔绯荤粺 Euler 2.8锛汚scend 910锛涘鐞嗗櫒 2.60GHz锛�192鏍稿績锛涘唴瀛橈紝755G | +| 涓婁紶鏃ユ湡 | 2021-09-30 | +| MindSpore鐗堟湰 | 1.3.0 | +| 鏁版嵁闆� | SZ-taxi锛坔idden_dim=100锛泂eq_len=4锛� | +| 杈撳嚭 | 浜ら€氶€熷害棰勬祴鍊� | +| 鍧囨柟鏍硅宸紙RMSE锛夊钩鍧囧€� | 4.1003 | +| 骞冲潎缁濆璇樊锛圡AE锛夊钩鍧囧€� | 2.7498 | +| 棰勬祴鍑嗙‘鐜囷紙Accuracy锛夊钩鍧囧€� | 0.7144 | +| R骞虫柟锛�$R^2$锛夊钩鍧囧€� | 0.8458 | +| 鍙噴鏂瑰樊锛圗xplained Variance锛夊钩鍧囧€� | 0.8461 | +| 鑴氭湰 | [璇勪及鑴氭湰](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn/eval.py) | + +- 涓嬭〃涓瘎浼版€ц兘鐢盩-GCN妯″瀷鍩轰簬Los-loop鏁版嵁闆嗗垎鍒娴嬫湭鏉�15鍒嗛挓銆�30鍒嗛挓銆�45鍒嗛挓銆�60鍒嗛挓锛堝嵆pre_len鍒嗗埆鍙�3銆�6銆�9銆�12锛夌殑浜ら€氶€熷害寰楀埌锛岀浉鍏虫寚鏍囦负4缁勮瘎浼颁换鍔″钩鍧囧€硷細 + +| 鍙傛暟 | Ascend| +| ------------------- | ---------------------------| +| 妯″瀷鍚嶇О | T-GCN | +| 杩愯鐜 | 鎿嶄綔绯荤粺 Euler 2.8锛汚scend 910锛涘鐞嗗櫒 2.60GHz锛�192鏍稿績锛涘唴瀛橈紝755G | +| 涓婁紶鏃ユ湡 | 2021-09-30 | +| MindSpore鐗堟湰 | 1.3.0 | +| 鏁版嵁闆� | Los-loop锛坔idden_dim=64锛泂eq_len=12锛� | +| 杈撳嚭 | 浜ら€氶€熷害棰勬祴鍊� | +| 鍧囨柟鏍硅宸紙RMSE锛夊钩鍧囧€� | 6.1869 | +| 骞冲潎缁濆璇樊锛圡AE锛夊钩鍧囧€� | 3.8552 | +| 棰勬祴鍑嗙‘鐜囷紙Accuracy锛夊钩鍧囧€� | 0.8946 | +| R骞虫柟锛�$R^2$锛夊钩鍧囧€� | 0.8000 | +| 鍙噴鏂瑰樊锛圗xplained Variance锛夊钩鍧囧€� | 0.8002 | +| 鑴氭湰 | [璇勪及鑴氭湰](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn/eval.py) | + +# [闅忔満鎯呭喌璇存槑](#鐩綍) + +`train.py`鑴氭湰涓娇鐢╜mindspore.set_seed()`瀵瑰叏灞€闅忔満绉嶅瓙杩涜浜嗗浐瀹氾紙榛樿鍊间负1锛夛紝鍙湪`config.py`鑴氭湰涓繘琛屼慨鏀广€� + +# [ModelZoo涓婚〉](#鐩綍) + + [T-GCN](https://gitee.com/mindspore/models/tree/master/research/cv/tgcn) \ No newline at end of file diff --git a/research/cv/tgcn/eval.py b/research/cv/tgcn/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..bdfcecbbac78d05be4230e2d010e1ee80594c891 --- /dev/null +++ b/research/cv/tgcn/eval.py @@ -0,0 +1,48 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Evaluation script +""" +import os +import argparse +from mindspore import context +from mindspore import load_checkpoint, load_param_into_net +from src.config import ConfigTGCN +from src.task import SupervisedForecastTask +from src.dataprocess import load_adj_matrix, load_feat_matrix, generate_dataset_np +from src.metrics import evaluate_network + + +# Set DEVICE_ID +parser = argparse.ArgumentParser() +parser.add_argument('--device_id', help="DEVICE_ID", type=int, default=0) +args = parser.parse_args() + + +if __name__ == '__main__': + # Config initialization + config = ConfigTGCN() + # Runtime + context.set_context(mode=context.GRAPH_MODE, device_target=config.device, device_id=args.device_id) + # Create network + net = SupervisedForecastTask(load_adj_matrix(config.dataset), config.hidden_dim, config.pre_len) + # Load parameters from checkpoint into network + ckpt_file_name = config.dataset + "_" + str(config.pre_len) + ".ckpt" + param_dict = load_checkpoint(os.path.join('checkpoints', ckpt_file_name)) + load_param_into_net(net, param_dict) + # Evaluation + feat, max_val = load_feat_matrix(config.dataset) + _, _, eval_inputs, eval_targets = generate_dataset_np(feat, config.seq_len, config.pre_len, config.train_split_rate) + evaluate_network(net, max_val, eval_inputs, eval_targets) diff --git a/research/cv/tgcn/export.py b/research/cv/tgcn/export.py new file mode 100644 index 0000000000000000000000000000000000000000..86a897c3a6dc8d698bc2ed4d376fe9f8dee5c575 --- /dev/null +++ b/research/cv/tgcn/export.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Export checkpoints into MINDIR model files +""" +import os +import argparse +import numpy as np +from mindspore import export, load_checkpoint, load_param_into_net, Tensor, context +from src.config import ConfigTGCN +from src.task import SupervisedForecastTask +from src.dataprocess import load_adj_matrix + + +# Set DEVICE_ID +parser = argparse.ArgumentParser() +parser.add_argument('--device_id', help="DEVICE_ID", type=int, default=0) +args = parser.parse_args() + + +if __name__ == '__main__': + # Config initialization + config = ConfigTGCN() + # Runtime + context.set_context(mode=context.GRAPH_MODE, device_target=config.device, device_id=args.device_id) + # Create network + adj = (load_adj_matrix(config.dataset)) + net = SupervisedForecastTask(adj, config.hidden_dim, config.pre_len) + # Load parameters from checkpoint into network + file_name = config.dataset + "_" + str(config.pre_len) + ".ckpt" + param_dict = load_checkpoint(os.path.join('checkpoints', file_name)) + load_param_into_net(net, param_dict) + # Initialize dummy inputs + inputs = np.random.uniform(0.0, 1.0, size=[config.batch_size, config.seq_len, adj.shape[0]]).astype(np.float32) + # Export network into MINDIR model file + if not os.path.exists('outputs'): + os.mkdir('outputs') + file_name = config.dataset + "_" + str(config.pre_len) + path = os.path.join('outputs', file_name) + export(net, Tensor(inputs), file_name=path, file_format='MINDIR') + print("==========================================") + print(file_name + ".mindir exported successfully!") + print("==========================================") diff --git a/research/cv/tgcn/requirements.txt b/research/cv/tgcn/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..fb6c7ed7ec60dafcf523d2e12daa17abc92ae384 --- /dev/null +++ b/research/cv/tgcn/requirements.txt @@ -0,0 +1 @@ +pandas diff --git a/research/cv/tgcn/scripts/run_distributed_train_ascend.sh b/research/cv/tgcn/scripts/run_distributed_train_ascend.sh new file mode 100644 index 0000000000000000000000000000000000000000..11f4c830e3584caeaa81df2889a1722ff7a35009 --- /dev/null +++ b/research/cv/tgcn/scripts/run_distributed_train_ascend.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +if [[ $# -ne 4 ]]; then + echo "Usage: bash ./scripts/run_distributed_train_ascend.sh [RANK_TABLE] [RANK_SIZE] [DEVICE_START] [DATA_PATH]" +exit 1 +fi + +ulimit -u unlimited +export RANK_SIZE=$2 +RANK_TABLE_FILE=$(realpath $1) +DATA_PATH=$(realpath $4) +export RANK_TABLE_FILE +echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" +echo "DATA_PATH=${DATA_PATH}" + +device_start=$3 +for((i=0; i<${RANK_SIZE}; i++)) +do + export DEVICE_ID=$((device_start + i)) + export RANK_ID=$i + rm -rf ./device$i + mkdir ./device$i + cp -r ./src ./device$i + cp ./train.py ./device$i + cd ./device$i + mkdir ./logs + env > ./logs/env.log + nohup python -u train.py --device_id=$DEVICE_ID --data_path=$DATA_PATH --distributed True > ./logs/train.log 2>&1 & + echo "Start training for rank $RANK_ID, device $DEVICE_ID. PID: $!" + echo $! > ./logs/train.pid + cd .. +done diff --git a/research/cv/tgcn/scripts/run_eval.sh b/research/cv/tgcn/scripts/run_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..407cccdaba80f18c44341ba916bc8ed72b76aa72 --- /dev/null +++ b/research/cv/tgcn/scripts/run_eval.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +if [[ $# -ne 1 ]]; then + echo "Usage: bash ./scripts/run_eval.sh [DEVICE_ID]" +exit 1 +fi + +if [ ! -d "logs" ]; then + mkdir logs +fi + +nohup python -u eval.py --device_id=$1 > ./logs/eval.log 2>&1 & +echo "Evaluation started on device $1 ! PID: $!" diff --git a/research/cv/tgcn/scripts/run_export.sh b/research/cv/tgcn/scripts/run_export.sh new file mode 100644 index 0000000000000000000000000000000000000000..e0465c75fe638abbdd6db8aa19280488115de069 --- /dev/null +++ b/research/cv/tgcn/scripts/run_export.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +if [[ $# -ne 1 ]]; then + echo "Usage: bash ./scripts/run_export.sh [DEVICE_ID]" +exit 1 +fi + +if [ ! -d "logs" ]; then + mkdir logs +fi + +nohup python -u export.py --device_id=$1 > ./logs/export.log 2>&1 & +echo "Export started on device $1 ! PID: $!" diff --git a/research/cv/tgcn/scripts/run_standalone_train.sh b/research/cv/tgcn/scripts/run_standalone_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..a5f0f34edd130693b8642b9605c2b6da74fde049 --- /dev/null +++ b/research/cv/tgcn/scripts/run_standalone_train.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +if [[ $# -ne 1 ]]; then + echo "Usage: bash ./scripts/run_standalone_train.sh [DEVICE_ID]" +exit 1 +fi + +if [ ! -d "logs" ]; then + mkdir logs +fi + +nohup python -u train.py --device_id=$1 > ./logs/train.log 2>&1 & +echo "Training started on device $1 ! PID: $!" +echo $! > ./logs/train.pid diff --git a/research/cv/tgcn/src/__init__.py b/research/cv/tgcn/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba53a0bbbc3ffb4509342b2240edc8caae4e12 --- /dev/null +++ b/research/cv/tgcn/src/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Module initialization +""" diff --git a/research/cv/tgcn/src/callback.py b/research/cv/tgcn/src/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..e39022d5775953c2b1e45ea37c3e40d356187e6b --- /dev/null +++ b/research/cv/tgcn/src/callback.py @@ -0,0 +1,82 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Custom callback and related RMSE metric +""" +import os +import numpy as np +from mindspore.dataset.core.validator_helpers import INT32_MAX +from mindspore.train.callback import Callback +from mindspore import save_checkpoint +from mindspore.nn import Metric + + +class RMSE(Metric): + """ + RMSE metric for choosing the best checkpoint + """ + + def __init__(self, max_val): + super(RMSE, self).__init__() + self.clear() + self.max_val = max_val + + def clear(self): + """Clears the internal evaluation result""" + self._squared_error_sum = 0 + self._samples_num = 0 + + def update(self, *inputs): + """Calculate and update internal result""" + if len(inputs) != 2: + raise ValueError('RMSE metric need 2 inputs (preds, targets), but got {}'.format(len(inputs))) + preds = self._convert_data(inputs[0]) + targets = self._convert_data(inputs[1]) + targets = targets.reshape((-1, targets.shape[2])) + squared_error_sum = np.power(targets - preds, 2) + self._squared_error_sum += squared_error_sum.sum() + self._samples_num += np.size(targets) + + def eval(self): + """Calculate evaluation result at the end of each epoch""" + if self._samples_num == 0: + raise RuntimeError('The number of input samples must not be 0.') + return np.sqrt(self._squared_error_sum / self._samples_num) * self.max_val + + +class SaveCallback(Callback): + """ + Save the best checkpoint (minimum RMSE) during training + """ + + def __init__(self, eval_model, ds_eval, config): + super(SaveCallback, self).__init__() + self.model = eval_model + self.ds_eval = ds_eval + self.rmse = INT32_MAX + self.config = config + + def epoch_end(self, run_context): + """Evaluate the network and save the best checkpoint (minimum RMSE)""" + cb_params = run_context.original_args() + result = self.model.eval(self.ds_eval) + print('Eval RMSE:', '{:.6f}'.format(result['RMSE'])) + if not os.path.exists('checkpoints'): + os.mkdir('checkpoints') + if result['RMSE'] < self.rmse: + self.rmse = result['RMSE'] + file_name = self.config.dataset + '_' + str(self.config.pre_len) + '.ckpt' + save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=os.path.join('checkpoints', file_name)) + print("Best checkpoint saved!") diff --git a/research/cv/tgcn/src/config.py b/research/cv/tgcn/src/config.py new file mode 100644 index 0000000000000000000000000000000000000000..49194082bd80057a3697018db842320d6b034da4 --- /dev/null +++ b/research/cv/tgcn/src/config.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Configuration of parameters + +For detailed information, please refer to the paper below: +https://arxiv.org/pdf/1811.05320.pdf +""" + + +class ConfigTGCN: + """ + Class of parameters configuration + """ + + # Choose device: ['Ascend', 'GPU'] + device = 'Ascend' + # Global random seed + seed = 1 + + # Choose datasets: ['SZ-taxi', 'Los-loop', etc] + dataset = 'SZ-taxi' + + # hidden_dim: 100 for 'SZ-taxi'; 64 for 'Los-loop' + hidden_dim = 100 + # seq_len: 4 for 'SZ-taxi'; 12 for 'Los-loop' + seq_len = 4 + # pre_len: [1, 2, 3, 4] separately for 'SZ-taxi'; [3, 6, 9, 12] separately for 'Los-loop' + pre_len = 1 + + # Training parameters + train_split_rate = 0.8 + epochs = 3000 + batch_size = 64 + learning_rate = 0.001 + weight_decay = 1.5e-3 + data_sink = True diff --git a/research/cv/tgcn/src/dataprocess.py b/research/cv/tgcn/src/dataprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..095158676006da588b06fc786d0bc276973c0ec3 --- /dev/null +++ b/research/cv/tgcn/src/dataprocess.py @@ -0,0 +1,187 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Process datasets + +Both the 'SZ-taxi' and 'Los-loop' datasets can be downloaded from the link below: +https://github.com/lehaifeng/T-GCN/tree/master/T-GCN/T-GCN-PyTorch/data +""" +import os +import numpy as np +import pandas as pd +import mindspore.dataset as ds +from mindspore.communication import get_rank, get_group_size + + +class TGCNDataset: + """ + Custom T-GCN datasets + """ + + def __init__(self, inputs, targets): + self.inputs = inputs + self.targets = targets + + def __getitem__(self, index): + return self.inputs[index], self.targets[index] + + def __len__(self): + return len(self.inputs) + + +def load_adj_matrix(dataset, abs_path=None, dtype=np.float32): + """ + Load adjacency matrix from corresponding csv file + + Args: + dataset(str): name of dataset (the same as folder name) + abs_path(str): absolute data directory path + dtype(type): data type (Default: np.float32) + + Returns: + adj: adjacency matrix in ndarray + """ + if abs_path is not None: + path = os.path.join(abs_path, dataset, 'adj.csv') + else: + path = os.path.join('data', dataset, 'adj.csv') + adj_df = pd.read_csv(path, header=None) + adj = np.array(adj_df, dtype=dtype) + return adj + + +def load_feat_matrix(dataset, abs_path=None, dtype=np.float32): + """ + Load feature matrix from corresponding csv file + + Args: + dataset(str): name of dataset (the same as folder name) + abs_path(str): absolute data directory path + dtype(type): data type (Default: np.float32) + + Returns: + feat: feature matrix in ndarray + max_val: max value in feature matrix + """ + if abs_path is not None: + path = os.path.join(abs_path, dataset, 'feature.csv') + else: + path = os.path.join('data', dataset, 'feature.csv') + feat_df = pd.read_csv(path) + feat = np.array(feat_df, dtype=dtype) + max_val = np.max(feat) + return feat, max_val + + +def generate_dataset_np(feat, seq_len, pre_len, split_ratio, normalize=True, time_len=None): + """ + Generate ndarrays from matrixes + + Args: + feat(ndarray): feature matrix + seq_len(int): length of the train data sequence + pre_len(int): length of the prediction data sequence + split_ratio(float): proportion of the training set + normalize(bool): scale the data to (0, 1], divide by the maximum value in the data + time_len(int): length of the time series in total + + Returns: + Train set (inputs, targets) and evaluation set (inputs, targets) in ndarrays + """ + if time_len is None: + time_len = feat.shape[0] + if normalize: + max_val = np.max(feat) + feat = feat / max_val + train_size = int(time_len * split_ratio) + train_data = feat[0:train_size] + eval_data = feat[train_size:time_len] + train_inputs, train_targets, eval_inputs, eval_targets = list(), list(), list(), list() + for i in range(len(train_data) - seq_len - pre_len): + train_inputs.append(np.array(train_data[i: i + seq_len])) + train_targets.append(np.array(train_data[i + seq_len: i + seq_len + pre_len])) + for i in range(len(eval_data) - seq_len - pre_len): + eval_inputs.append(np.array(eval_data[i: i + seq_len])) + eval_targets.append(np.array(eval_data[i + seq_len: i + seq_len + pre_len])) + return np.array(train_inputs), np.array(train_targets), np.array(eval_inputs), np.array(eval_targets) + + +def generate_dataset_ms(config, training): + """ + Generate MindSpore dataset from ndarrays + + Args: + config(ConfigTGCN): configuration of parameters + training(bool): generate training dataset or evaluation dataset + + Returns: + dataset: MindSpore dataset for training/evaluation + """ + dataset = config.dataset + seq_len = config.seq_len + pre_len = config.pre_len + split_ratio = config.train_split_rate + batch_size = config.batch_size + + feat, _ = load_feat_matrix(dataset) + train_inputs, train_targets, eval_inputs, eval_targets = generate_dataset_np(feat, seq_len, pre_len, split_ratio) + + if training: + dataset_generator = TGCNDataset(train_inputs, train_targets) + else: + dataset_generator = TGCNDataset(eval_inputs, eval_targets) + + dataset = ds.GeneratorDataset(dataset_generator, ["inputs", "targets"], shuffle=False) + dataset = dataset.batch(batch_size, drop_remainder=True) + return dataset + + +def generate_dataset_ms_distributed(config, training, abs_path=None): + """ + Generate MindSpore dataset from ndarrays in distributed training + + Args: + config(ConfigTGCN): configuration of parameters + training(bool): generate training dataset or evaluation dataset + abs_path(str): absolute data directory path + + Returns: + dataset: MindSpore dataset for training/evaluation (distributed) + """ + dataset = config.dataset + seq_len = config.seq_len + pre_len = config.pre_len + split_ratio = config.train_split_rate + if training: + batch_size = config.batch_size + else: + batch_size = 1 + + # Get rank_id and rank_size + rank_id = get_rank() + rank_size = get_group_size() + + feat, _ = load_feat_matrix(dataset, abs_path) + train_inputs, train_targets, eval_inputs, eval_targets = generate_dataset_np(feat, seq_len, pre_len, split_ratio) + + if training: + dataset_generator = TGCNDataset(train_inputs, train_targets) + else: + dataset_generator = TGCNDataset(eval_inputs, eval_targets) + + dataset = ds.GeneratorDataset(dataset_generator, ["inputs", "targets"], shuffle=False, + num_shards=rank_size, shard_id=rank_id) + dataset = dataset.batch(batch_size, drop_remainder=True) + return dataset diff --git a/research/cv/tgcn/src/metrics.py b/research/cv/tgcn/src/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..72646a33c918c62fa3cbd18fb4027f0eb35f4913 --- /dev/null +++ b/research/cv/tgcn/src/metrics.py @@ -0,0 +1,88 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Evaluation metrics +""" +import numpy as np +import mindspore.nn as nn +import mindspore.ops.operations as P +from mindspore import dtype as mstype +from mindspore import Tensor + + +def accuracy(preds, targets): + """ + Calculate the accuracy between predictions and targets + + Args: + preds(Tensor): predictions + targets(Tensor): ground truth + + Returns: + accuracy: defined as 1 - (norm(targets - preds) / norm(targets)) + """ + return 1 - np.linalg.norm(targets.asnumpy() - preds.asnumpy(), 'fro') / np.linalg.norm(targets.asnumpy(), 'fro') + + +def r2(preds, targets): + """ + Calculate R square between predictions and targets + + Args: + preds(Tensor): predictions + targets(Tensor): ground truth + + Returns: + R square: coefficient of determination + """ + return (1 - P.ReduceSum()((targets - preds) ** 2) / P.ReduceSum()((targets - P.ReduceMean()(preds)) ** 2)).asnumpy() + + +def explained_variance(preds, targets): + """ + Calculate the explained variance between predictions and targets + + Args: + preds(Tensor): predictions + targets(Tensor): ground truth + + Returns: + Var: explained variance + """ + return (1 - (targets - preds).var() / targets.var()).asnumpy() + + +def evaluate_network(net, max_val, eval_inputs, eval_targets): + """ + Evaluate the performance of network + """ + eval_inputs = Tensor(eval_inputs, mstype.float32) + eval_preds = net(eval_inputs) + eval_targets = Tensor(eval_targets, mstype.float32) + eval_targets = eval_targets.reshape((-1, eval_targets.shape[2])) + + rmse = P.Sqrt()(nn.MSELoss()(eval_preds, eval_targets)).asnumpy() + mae = nn.MAELoss()(eval_preds, eval_targets).asnumpy() + acc = accuracy(eval_preds, eval_targets) + r_2 = r2(eval_preds, eval_targets) + var = explained_variance(eval_preds, eval_targets) + + print("=====Evaluation Results=====") + print('RMSE:', '{:.6f}'.format(rmse * max_val)) + print('MAE:', '{:.6f}'.format(mae * max_val)) + print('Accuracy:', '{:.6f}'.format(acc)) + print('R2:', '{:.6f}'.format(r_2)) + print('Var:', '{:.6f}'.format(var)) + print("============================") diff --git a/research/cv/tgcn/src/model/__init__.py b/research/cv/tgcn/src/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba53a0bbbc3ffb4509342b2240edc8caae4e12 --- /dev/null +++ b/research/cv/tgcn/src/model/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Module initialization +""" diff --git a/research/cv/tgcn/src/model/graph_conv.py b/research/cv/tgcn/src/model/graph_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..6604d96e8a381137035f70344dc2b06cd04a10ac --- /dev/null +++ b/research/cv/tgcn/src/model/graph_conv.py @@ -0,0 +1,40 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Graph convolution operation +""" +import mindspore.numpy as np +import mindspore.ops.operations as P +from mindspore import dtype as mstype + + +def calculate_laplacian_with_self_loop(matrix, matmul): + """ + Calculate laplacian matrix with self loop + + Args: + matrix(Tensor): input matrix + matmul(MatMul): the MatMul operator for mixed precision + + Returns: + normalized_laplacian: normalized laplacian matrix + """ + matrix = matrix + P.Eye()(matrix.shape[0], matrix.shape[0], mstype.float32) + row_sum = matrix.sum(1) + d_inv_sqrt = P.Pow()(row_sum, -0.5).flatten() + d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0 + d_mat_inv_sqrt = np.diag(d_inv_sqrt) + normalized_laplacian = matmul(matmul(matrix, d_mat_inv_sqrt).transpose(0, 1), d_mat_inv_sqrt) + return normalized_laplacian diff --git a/research/cv/tgcn/src/model/loss.py b/research/cv/tgcn/src/model/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..df6f31eabe38cc43ea7ed62b3e9617852ab03f2a --- /dev/null +++ b/research/cv/tgcn/src/model/loss.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +T-GCN loss cell +""" +import mindspore.nn as nn +import mindspore.numpy as np + + +class TGCNLoss(nn.Cell): + """ + Custom T-GCN loss cell + """ + + def construct(self, predictions, targets): + """ + Calculate loss + + Args: + predictions(Tensor): predictions from models + targets(Tensor): ground truth + + Returns: + loss: loss value + """ + targets = targets.reshape((-1, targets.shape[2])) + return np.sum((predictions - targets) ** 2) / 2 diff --git a/research/cv/tgcn/src/model/tgcn.py b/research/cv/tgcn/src/model/tgcn.py new file mode 100644 index 0000000000000000000000000000000000000000..1b4cbb6e79b35276fab3568f987debe53c9a2843 --- /dev/null +++ b/research/cv/tgcn/src/model/tgcn.py @@ -0,0 +1,152 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +T-GCN architecture + +For detailed information, please refer to the paper below: +https://arxiv.org/pdf/1811.05320.pdf +""" +import mindspore.nn as nn +import mindspore.ops.operations as P +from mindspore import Tensor, Parameter +from mindspore import dtype as mstype +from mindspore.common.initializer import initializer, XavierUniform, Constant +from .graph_conv import calculate_laplacian_with_self_loop + + +class TGCNGraphConvolution(nn.Cell): + """ + T-GCN graph convolution layer + """ + + def __init__(self, adj, num_gru_units: int, output_dim: int, bias: float = 0.0): + super(TGCNGraphConvolution, self).__init__() + self._num_gru_units = num_gru_units + self._output_dim = output_dim + self._bias_init_value = bias + self.matmul = nn.MatMul() + self.laplacian = Parameter(calculate_laplacian_with_self_loop(Tensor(adj, mstype.float32), self.matmul), + name='laplacian', requires_grad=False) + self.weights = Parameter(initializer(XavierUniform(), [self._num_gru_units + 1, self._output_dim], + mstype.float32), name='weights') + self.biases = Parameter(initializer(Constant(self._bias_init_value), [self._output_dim], + mstype.float32), name='biases') + + def construct(self, inputs, hidden_state): + """ + Calculate graph convolution outputs + + Args: + inputs(Tensor): network inputs + hidden_state(Tensor): hidden state + + Returns: + outputs: TGCNGraphConvolution outputs + """ + batch_size, num_nodes = inputs.shape + # inputs (batch_size, num_nodes) -> (batch_size, num_nodes, 1) + inputs = inputs.reshape((batch_size, num_nodes, 1)) + # hidden_state (batch_size, num_nodes, num_gru_units) + hidden_state = hidden_state.reshape((batch_size, num_nodes, self._num_gru_units)) + # [x, h] (batch_size, num_nodes, num_gru_units + 1) + concatenation = P.Concat(axis=2)((inputs, hidden_state)) + # [x, h] (num_nodes, num_gru_units + 1, batch_size) + concatenation = concatenation.transpose(1, 2, 0) + # [x, h] (num_nodes, (num_gru_units + 1) * batch_size) + concatenation = concatenation.reshape((num_nodes, (self._num_gru_units + 1) * batch_size)) + # A[x, h] (num_nodes, (num_gru_units + 1) * batch_size) + a_times_concat = self.matmul(self.laplacian, concatenation) + # A[x, h] (num_nodes, num_gru_units + 1, batch_size) + a_times_concat = a_times_concat.reshape((num_nodes, self._num_gru_units + 1, batch_size)) + # A[x, h] (batch_size, num_nodes, num_gru_units + 1) + a_times_concat = a_times_concat.transpose(2, 0, 1) + # A[x, h] (batch_size * num_nodes, num_gru_units + 1) + a_times_concat = a_times_concat.reshape((batch_size * num_nodes, self._num_gru_units + 1)) + # A[x, h]W + b (batch_size * num_nodes, output_dim) + outputs = self.matmul(a_times_concat, self.weights) + self.biases + # A[x, h]W + b (batch_size, num_nodes, output_dim) + outputs = outputs.reshape((batch_size, num_nodes, self._output_dim)) + # A[x, h]W + b (batch_size, num_nodes * output_dim) + outputs = outputs.reshape((batch_size, num_nodes * self._output_dim)) + return outputs + + +class TGCNCell(nn.Cell): + """ + T-GCN cell + """ + + def __init__(self, adj, input_dim: int, hidden_dim: int): + super(TGCNCell, self).__init__() + self._input_dim = input_dim + self._hidden_dim = hidden_dim + self.adj = Parameter(Tensor(adj, mstype.float32), name='adj', requires_grad=False) + self.graph_conv1 = TGCNGraphConvolution(self.adj, self._hidden_dim, self._hidden_dim * 2, bias=1.0) + self.graph_conv2 = TGCNGraphConvolution(self.adj, self._hidden_dim, self._hidden_dim) + + def construct(self, inputs, hidden_state): + """ + Calculate hidden states + + Args: + inputs(Tensor): network inputs + hidden_state(Tensor): hidden state + + Returns: + new_hidden_state: new hidden state + """ + # [r, u] = sigmoid(A[x, h]W + b) + # [r, u] (batch_size, num_nodes * (2 * num_gru_units)) + concatenation = P.Sigmoid()(self.graph_conv1(inputs, hidden_state)) + # r (batch_size, num_nodes, num_gru_units), u (batch_size, num_nodes, num_gru_units) + r, u = P.Split(axis=1, output_num=2)(concatenation) + # c = tanh(A[x, (r * h)W + b]) + # c (batch_size, num_nodes * num_gru_units) + c = P.Tanh()(self.graph_conv2(inputs, r * hidden_state)) + # h := u * h + (1 - u) * c + # h (batch_size, num_nodes * num_gru_units) + new_hidden_state = u * hidden_state + (1.0 - u) * c + return new_hidden_state, new_hidden_state + + +class TGCN(nn.Cell): + """ + T-GCN network + """ + + def __init__(self, adj, hidden_dim: int, **kwargs): + super(TGCN, self).__init__() + self._input_dim = adj.shape[0] + self._hidden_dim = hidden_dim + self.adj = Parameter(Tensor(adj, mstype.float32), name='adj', requires_grad=False) + self.tgcn_cell = TGCNCell(self.adj, self._input_dim, self._hidden_dim) + + def construct(self, inputs): + """ + Calculate the final output + + Args: + inputs(Tensor): network inputs + + Returns: + output: TGCN output + """ + batch_size, seq_len, num_nodes = inputs.shape + hidden_state = P.Zeros()((batch_size, num_nodes * self._hidden_dim), mstype.float32) + output = None + for i in range(seq_len): + output, hidden_state = self.tgcn_cell(inputs[:, i, :], hidden_state) + output = output.reshape((batch_size, num_nodes, self._hidden_dim)) + return output diff --git a/research/cv/tgcn/src/task.py b/research/cv/tgcn/src/task.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e7853ad1a29a3176010451c86902f6a915ed3d --- /dev/null +++ b/research/cv/tgcn/src/task.py @@ -0,0 +1,56 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Supervised forecast task +""" +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore import dtype as mstype +from .model.tgcn import TGCN + + +class SupervisedForecastTask(nn.Cell): + """ + T-GCN applied to supervised forecast task + """ + + def __init__(self, adj, hidden_dim: int, pre_len: int): + super(SupervisedForecastTask, self).__init__() + self.adj = Parameter(Tensor(adj, mstype.float32), name='adj', requires_grad=False) + self.tgcn = TGCN(self.adj, hidden_dim) + self.fcn = nn.Dense(hidden_dim, pre_len) + + def construct(self, inputs): + """ + Calculate network predictions for supervised forecast task + + Args: + inputs(Tensor): network inputs + + Returns: + predictions: predictions of supervised forecast task + """ + # (batch_size, seq_len, num_nodes) + batch_size, _, num_nodes = inputs.shape + # (batch_size, num_nodes, hidden_dim) + hidden = self.tgcn(inputs) + # (batch_size * num_nodes, hidden_dim) + hidden = hidden.reshape((-1, hidden.shape[2])) + # (batch_size * num_nodes, pre_len) + predictions = self.fcn(hidden) + predictions = predictions.reshape((batch_size, num_nodes, -1)) + # Change data shape for the following calculation of metrics + predictions = predictions.transpose(0, 2, 1).reshape((-1, num_nodes)) + return predictions diff --git a/research/cv/tgcn/train.py b/research/cv/tgcn/train.py new file mode 100644 index 0000000000000000000000000000000000000000..650be7ac45e2d285e3246c08c2bdcc3f8448f275 --- /dev/null +++ b/research/cv/tgcn/train.py @@ -0,0 +1,122 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Training script +""" +import os +import time +import argparse +from mindspore.communication import init +from mindspore.context import ParallelMode +from mindspore import dtype as mstype +from mindspore import set_seed, nn, context, Model +from mindspore.train.callback import LossMonitor, TimeMonitor +from src.config import ConfigTGCN +from src.dataprocess import load_adj_matrix, load_feat_matrix, generate_dataset_ms, generate_dataset_ms_distributed +from src.task import SupervisedForecastTask +from src.model.loss import TGCNLoss +from src.callback import RMSE, SaveCallback + + +def run_train(args): + """ + Run training + """ + # Config initialization + config = ConfigTGCN() + # Set global seed for MindSpore and NumPy + set_seed(config.seed) + # ModelArts runtime, datasets and network initialization + if args.run_modelarts: + import moxing as mox + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id) + mox.file.copy_parallel(src_url=args.data_url, dst_url='./data') + if args.distributed: + device_num = int(os.getenv('RANK_SIZE')) + init() + context.set_auto_parallel_context(device_num=device_num, + parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) + training_set = generate_dataset_ms_distributed(config, training=True, abs_path=args.data_path) + eval_set = generate_dataset_ms_distributed(config, training=False, abs_path=args.data_path) + _, max_val = load_feat_matrix(config.dataset, args.data_path) + net = SupervisedForecastTask(load_adj_matrix(config.dataset, args.data_path), + config.hidden_dim, config.pre_len) + else: + training_set = generate_dataset_ms(config, training=True) + eval_set = generate_dataset_ms(config, training=False) + _, max_val = load_feat_matrix(config.dataset) + net = SupervisedForecastTask(load_adj_matrix(config.dataset), config.hidden_dim, config.pre_len) + # Offline runtime, datasets and network initialization + else: + if args.distributed: + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=args.device_id) + context.set_context(device_id=device_id) + init() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) + training_set = generate_dataset_ms_distributed(config, training=True, abs_path=args.data_path) + eval_set = generate_dataset_ms_distributed(config, training=False, abs_path=args.data_path) + _, max_val = load_feat_matrix(config.dataset, args.data_path) + net = SupervisedForecastTask(load_adj_matrix(config.dataset, args.data_path), + config.hidden_dim, config.pre_len) + else: + context.set_context(mode=context.GRAPH_MODE, device_target=config.device, device_id=args.device_id) + training_set = generate_dataset_ms(config, training=True) + eval_set = generate_dataset_ms(config, training=False) + _, max_val = load_feat_matrix(config.dataset) + net = SupervisedForecastTask(load_adj_matrix(config.dataset), config.hidden_dim, config.pre_len) + # Mixed precision + net.tgcn.tgcn_cell.graph_conv1.matmul.to_float(mstype.float16) + net.tgcn.tgcn_cell.graph_conv2.matmul.to_float(mstype.float16) + # Loss function + loss_fn = TGCNLoss() + # Optimizer + optimizer = nn.Adam(net.trainable_params(), config.learning_rate, weight_decay=config.weight_decay) + # Create model + model = Model(net, loss_fn, optimizer, {'RMSE': RMSE(max_val)}) + # Training + if args.distributed: + print("==========Distributed Training Start==========") + else: + print("==========Training Start==========") + time_start = time.time() + model.train(config.epochs, training_set, + callbacks=[LossMonitor(), TimeMonitor(), SaveCallback(model, eval_set, config)], + dataset_sink_mode=config.data_sink) + time_end = time.time() + if args.distributed: + print("==========Distributed Training End==========") + else: + print("==========Training End==========") + print("Training time in total:", '{:.6f}'.format(time_end - time_start), "s") + # Save outputs (checkpoints) on ModelArts + if args.run_modelarts: + mox.file.copy_parallel(src_url='./checkpoints', dst_url=args.train_url) + + +if __name__ == '__main__': + # Set universal arguments + parser = argparse.ArgumentParser() + parser.add_argument('--device_id', help="DEVICE_ID", type=int, default=0) + parser.add_argument('--distributed', help="distributed training", type=bool, default=False) + parser.add_argument('--data_path', help="directory of datasets", type=str, default='./data') + # Set ModelArts arguments + parser.add_argument('--run_modelarts', help="ModelArts runtime", type=bool, default=False) + parser.add_argument('--data_url', help='ModelArts location of data', type=str, default=None) + parser.add_argument('--train_url', help='ModelArts location of training outputs', type=str, default=None) + run_args = parser.parse_args() + # Training + run_train(run_args)