diff --git a/research/recommend/mmoe/README_CN.md b/research/recommend/mmoe/README_CN.md
index c8511393b8c03c3d3260cb6df69dd9e1497bcfc7..17e6a2d8cacea390ba318ea8dd8f1e9f3d275b3c 100644
--- a/research/recommend/mmoe/README_CN.md
+++ b/research/recommend/mmoe/README_CN.md
@@ -13,11 +13,11 @@
     - [鑴氭湰鍙傛暟](#鑴氭湰鍙傛暟)
     - [璁粌杩囩▼](#璁粌杩囩▼)
         - [鐢ㄦ硶](#鐢ㄦ硶)
-        - [Ascend澶勭悊鍣ㄧ幆澧冭繍琛宂(#ascend澶勭悊鍣ㄧ幆澧冭繍琛�)
+        - [Ascend澶勭悊鍣ㄦ垨GPU鐜杩愯](#Ascend澶勭悊鍣ㄦ垨GPU鐜杩愯)
         - [缁撴灉](#缁撴灉)
 - [璇勪及杩囩▼](#璇勪及杩囩▼)
-    - [鐢ㄦ硶](#鐢ㄦ硶)
-    - [Ascend澶勭悊鍣ㄧ幆澧冭繍琛宂(#ascend澶勭悊鍣ㄧ幆澧冭繍琛�)
+    - [璇勪及鐢ㄦ硶](#璇勪及鐢ㄦ硶)
+    - [Ascend澶勭悊鍣ㄦ垨GPU鐜杩愯璇勪及](#Ascend澶勭悊鍣ㄦ垨GPU鐜杩愯璇勪及)
     - [缁撴灉](#缁撴灉)
 - [Ascend310鎺ㄧ悊杩囩▼](#鎺ㄧ悊杩囩▼)
     - [瀵煎嚭MindIR](#瀵煎嚭MindIR)
@@ -26,7 +26,7 @@
 - [妯″瀷鎻忚堪](#妯″瀷鎻忚堪)
     - [鎬ц兘](#鎬ц兘)
         - [璇勪及鎬ц兘](#璇勪及鎬ц兘)
-            - [cifar10涓婄殑WideResNet](#cifar10涓婄殑wideresnet)
+            - [census-income涓婄殑MMoE](#census-income涓婄殑MMoE)
 - [闅忔満鎯呭喌璇存槑](#闅忔満鎯呭喌璇存槑)
 - [ModelZoo涓婚〉](#modelzoo涓婚〉)
 
@@ -53,8 +53,12 @@ MMoE鐨勬€讳綋缃戠粶鏋舵瀯濡備笅锛�
 
 - 鏁版嵁闆嗗ぇ灏忥細鍏�9.4Mb銆�299,285鏉℃暟鎹�
     - 璁粌闆嗭細鍏�6.3Mb锛�199,523鏉℃暟鎹�
+
     - 娴嬭瘯闆嗭細鍏�3.1Mb锛�99726鏉℃暟鎹�
+
     - 娉細鏁版嵁鍦╠ata.py涓鐞嗘垚mindrecord鏍煎紡銆�
+
+      浣跨敤鍛戒护 python data.py --local_data_path  ./Census-income
 - 涓嬭浇鍘熷鏁版嵁闆嗭紝鐩綍缁撴瀯濡備笅锛�
 
 ```text
@@ -77,29 +81,43 @@ MMoE鐨勬€讳綋缃戠粶鏋舵瀯濡備笅锛�
 
 閫氳繃瀹樻柟缃戠珯瀹夎MindSpore鍚庯紝鎮ㄥ彲浠ユ寜鐓у涓嬫楠よ繘琛岃缁冨拰璇勪及锛�
 
-- Ascend澶勭悊鍣ㄧ幆澧冭繍琛�
+- Ascend澶勭悊鍣ㄦ垨GPU鐜杩愯
 
 ```Shell
-# 鍒嗗竷寮忚缁�
+# 鍒嗗竷寮忚缁冿紙Ascend锛�
 Usage: bash run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [CKPT_PATH] [CONFIG_FILE]
 [RANK_TABLE_FILE]鏄鍗$殑鍏蜂綋淇℃伅銆�
 [DATA_PATH]鏄暟鎹泦鐨勮矾寰勩€�
 [CKPT_PATH]鏄灏哻kpt淇濆瓨鐨勪綅缃€�
 [CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
 
-# 鍗曟満璁粌
+# 鍗曟満璁粌(Ascend)
 Usage: bash run_standalone_train_ascend.sh [DATA_PATH] [DEVICE_ID] [CKPT_PATH] [CONFIG_FILE]
 [DATA_PATH]鏄暟鎹泦鐨勮矾寰勩€�
 [CKPT_PATH]鏄灏哻kpt淇濆瓨鐨勪綅缃€�
 [DEVICE_ID]涓烘墽琛宼rain.py鐨処D鍙枫€�
 [CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
 
-# 杩愯璇勪及绀轰緥
+# 鍗曟満璁粌(GPU)
+Usage: bash run_standalone_train_gpu.sh [DATA_PATH] [DEVICE_ID] [CKPT_PATH] [CONFIG_FILE]
+[DATA_PATH]鏄暟鎹泦鐨勮矾寰�(mindrecord鏂囦欢鎵€鍦ㄧ殑鐩綍)銆�
+[CKPT_PATH]鏄灏哻kpt淇濆瓨鐨勪綅缃€�
+[DEVICE_ID]涓烘墽琛宼rain.py鐨処D鍙枫€�
+[CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
+
+# 杩愯璇勪及绀轰緥锛圓scend锛�
 Usage: bash run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [CONFIG_FILE]
 [DATA_PATH]鏄暟鎹泦鐨勮矾寰勩€�
 [CKPT_PATH]鏄繚瀛榗kpt鐨勪綅缃€�
 [DEVICE_ID]涓烘墽琛宔val.py鐨処D鍙枫€�
 [CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
+
+# 杩愯璇勪及绀轰緥锛圙PU锛�
+Usage: bash run_standalone_eval_gpu.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [CONFIG_FILE]
+[DATA_PATH]鏄暟鎹泦鐨勮矾寰�(mindrecord鏂囦欢鎵€鍦ㄧ殑鐩綍)銆�
+[CKPT_PATH]鏄繚瀛榗kpt鐨勪綅缃€�
+[DEVICE_ID]涓烘墽琛宔val.py鐨処D鍙枫€�
+[CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
 ```
 
 # 鑴氭湰璇存槑
@@ -120,21 +138,25 @@ Usage: bash run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [C
   鈹溾攢鈹€ scripts
     鈹溾攢鈹€ run_distribute_ascend.sh            # 鍚姩Ascend鍒嗗竷寮忚缁冿紙8鍗★級
     鈹溾攢鈹€ run_standalone_eval_ascend.sh       # 鍚姩Ascend910璇勪及
+    鈹溾攢鈹€ run_standalone_eval_gpu.sh          # 鍚姩GPU璇勪及
     鈹溾攢鈹€ run_infer_310.sh                    # 鍚姩Ascend310璇勪及
-    鈹斺攢鈹€ run_standalone_train_ascend.sh      # 鍚姩Ascend鍗曟満璁粌锛堝崟鍗★級
+    鈹溾攢鈹€ run_standalone_train_ascend.sh      # 鍚姩Ascend鍗曟満璁粌锛堝崟鍗★級
+    鈹斺攢鈹€ run_standalone_train_gpu.sh         # 鍚姩GPU鍗曟満璁粌锛堝崟鍗★級
   鈹溾攢鈹€ src
     鈹溾攢鈹€ model_utils
         鈹溾攢鈹€ config.py                        # 鍙傛暟閰嶇疆
         鈹溾攢鈹€ device_adapter.py                # 閫傞厤浜戜笂鎴栫嚎涓�
         鈹溾攢鈹€ local_adapter.py                 # 绾夸笅閰嶇疆
         鈹溾攢鈹€ moxing_adapter.py                # 浜戜笂閰嶇疆
+    鈹溾攢鈹€ callback.py                          # 璁粌杩囩▼涓繘琛岃瘎浼扮殑鍥炶皟  
     鈹溾攢鈹€ data.py                              # 鏁版嵁棰勫鐞�
     鈹溾攢鈹€ load_dataset.py                      # 鍔犺浇澶勭悊濂界殑mindrecord鏍煎紡鏁版嵁
     鈹溾攢鈹€ get_lr.py                            # 鐢熸垚姣忎釜姝ラ鐨勫涔犵巼
     鈹溾攢鈹€ mmoe.py                              # 妯″瀷鏁翠綋鏋舵瀯
     鈹斺攢鈹€ mmoe_utils.py                        # 姣忎竴灞傛灦鏋�
   鈹溾攢鈹€ eval.py                                # 910璇勪及缃戠粶
-  鈹溾攢鈹€ default_config.yaml                    # 鍙傛暟閰嶇疆
+  鈹溾攢鈹€ default_config.yaml                    # 榛樿鐨勫弬鏁伴厤缃�
+  鈹溾攢鈹€ default_config_gpu.yaml                # 閽堝GPU鐜榛樿鐨勫弬鏁伴厤缃�
   鈹溾攢鈹€ export.py                              # 910瀵煎嚭缃戠粶
   鈹溾攢鈹€ postprocess.py                         # 310鎺ㄧ悊绮惧害璁$畻
   鈹溾攢鈹€ preprocess.py                          # 310鎺ㄧ悊鍓嶆暟鎹鐞�
@@ -153,7 +175,7 @@ Usage: bash run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [C
 "units":4,                         # 姣忎竴灞傜殑unit鏁�
 "batch_size":32,                   # 杈撳叆寮犻噺鐨勬壒娆″ぇ灏�
 "epoch_size":100,                  # 璁粌鍛ㄦ湡澶у皬
-"lr":0.001,                        # 鍒濆瀛︿範鐜�
+"learning_rate":0.001,             # 鍒濆瀛︿範鐜�
 "save_checkpoint":True,            # 鏄惁淇濆瓨妫€鏌ョ偣
 "save_checkpoint_epochs":1,        # 涓や釜妫€鏌ョ偣涔嬮棿鐨勫懆鏈熼棿闅旓紱榛樿鎯呭喌涓嬶紝鏈€鍚庝竴涓鏌ョ偣灏嗗湪鏈€鍚庝竴涓懆鏈熷畬鎴愬悗淇濆瓨
 "keep_checkpoint_max":10,          # 鍙繚瀛樻渶鍚庝竴涓猭eep_checkpoint_max妫€鏌ョ偣
@@ -164,29 +186,43 @@ Usage: bash run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [C
 
 ## 鐢ㄦ硶
 
-## Ascend澶勭悊鍣ㄧ幆澧冭繍琛�
+## Ascend澶勭悊鍣ㄦ垨GPU鐜杩愯
 
 ```Shell
-# 鍒嗗竷寮忚缁�
+# 鍒嗗竷寮忚缁冿紙Ascend锛�
 Usage: bash run_distribution_ascend.sh [RANK_TABLE_FILE] [DATA_PATH] [CKPT_PATH] [CONFIG_FILE]
 [RANK_TABLE_FILE]鏄鍗$殑鍏蜂綋淇℃伅銆�
 [DATA_PATH]鏄暟鎹泦鐨勮矾寰勩€�
 [CKPT_PATH]鏄灏哻kpt淇濆瓨鐨勪綅缃€�
 [CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
 
-# 鍗曟満璁粌
+# 鍗曟満璁粌(Ascend)
 Usage: bash run_standalone_train_ascend.sh [DATA_PATH] [DEVICE_ID] [CKPT_PATH] [CONFIG_FILE]
 [DATA_PATH]鏄暟鎹泦鐨勮矾寰勩€�
 [CKPT_PATH]鏄灏哻kpt淇濆瓨鐨勪綅缃€�
 [DEVICE_ID]涓烘墽琛宼rain.py鐨処D鍙枫€�
 [CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
 
-# 杩愯璇勪及绀轰緥
+# 鍗曟満璁粌(GPU)
+Usage: bash run_standalone_train_gpu.sh [DATA_PATH] [DEVICE_ID] [CKPT_PATH] [CONFIG_FILE]
+[DATA_PATH]鏄暟鎹泦鐨勮矾寰�(mindrecord鏂囦欢鎵€鍦ㄧ殑鐩綍)銆�
+[CKPT_PATH]鏄灏哻kpt淇濆瓨鐨勪綅缃€�
+[DEVICE_ID]涓烘墽琛宼rain.py鐨処D鍙枫€�
+[CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
+
+# 杩愯璇勪及绀轰緥锛圓scend锛�
 Usage: bash run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [CONFIG_FILE]
 [DATA_PATH]鏄暟鎹泦鐨勮矾寰勩€�
 [CKPT_PATH]鏄繚瀛榗kpt鐨勪綅缃€�
 [DEVICE_ID]涓烘墽琛宔val.py鐨処D鍙枫€�
 [CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
+
+# 杩愯璇勪及绀轰緥锛圙PU锛�
+Usage: bash run_standalone_eval_gpu.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [CONFIG_FILE]
+[DATA_PATH]鏄暟鎹泦鐨勮矾寰�(mindrecord鏂囦欢鎵€鍦ㄧ殑鐩綍)銆�
+[CKPT_PATH]鏄繚瀛榗kpt鐨勪綅缃€�
+[DEVICE_ID]涓烘墽琛宔val.py鐨処D鍙枫€�
+[CONFIG_FILE]鏄ā鍨嬪強杩愯鐨勬暣浣撳弬鏁般€�
 ```
 
 鍒嗗竷寮忚缁冮渶瑕佹彁鍓嶅垱寤篔SON鏍煎紡鐨凥CCL閰嶇疆鏂囦欢銆�
@@ -200,7 +236,7 @@ Usage: bash run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [C
 - 浣跨敤census-income鏁版嵁闆嗚缁僊MoE
 
 ```text
-# 鍒嗗竷寮忚缁冪粨鏋滐紙8P锛�
+# 鍒嗗竷寮忚缁冪粨鏋滐紙8P Ascend锛�
 epoch: 1 step: 780, loss is 0.5584626
 epoch: 2 step: 780, loss is 0.72126234
 epoch: 3 step: 780, loss is 0.28167123
@@ -212,14 +248,35 @@ epoch: 8 step: 780, loss is 0.15461507
 epoch: 9 step: 780, loss is 0.37079066
 epoch: 10 step: 780, loss is 0.2761521
 
+...
+
+# 鍗曞崱GPU璁粌缁撴灉
+epoch: 1 step: 1558, loss is 0.7738624215126038
+epoch time: 23271.168 ms, per step time: 14.937 ms
+start infer...
+infer data finished, start eval...
+result : income_auc=0.956143804122577, marital_auc=0.8883598309142848, use time 2s
+The best income_auc is 0.956143804122577,             the best marital_auc is 0.8883598309142848,             the best income_marital_auc_avg is 0.9222518175184309
+epoch: 2 step: 1558, loss is 0.4517086148262024
+epoch time: 17804.081 ms, per step time: 11.428 ms
+start infer...
+infer data finished, start eval...
+result : income_auc=0.9856142129882843, marital_auc=0.9194419616798691, use time 1s
+The best income_auc is 0.9856142129882843,             the best marital_auc is 0.9194419616798691,             the best income_marital_auc_avg is 0.9525280873340767
+epoch: 3 step: 1558, loss is 0.41103610396385193
+epoch time: 17853.932 ms, per step time: 11.460 ms
+start infer...
+infer data finished, start eval...
+result : income_auc=0.9876599788311389, marital_auc=0.9663552616198483, use time 1s
+The best income_auc is 0.9876599788311389,             the best marital_auc is 0.9663552616198483,             the best income_marital_auc_avg is 0.9770076202254936
 ...
 ```
 
 # 璇勪及杩囩▼
 
-## 鐢ㄦ硶
+## 璇勪及鐢ㄦ硶
 
-### Ascend澶勭悊鍣ㄧ幆澧冭繍琛�
+### Ascend澶勭悊鍣ㄦ垨GPU鐜杩愯璇勪及
 
 ```Shell
 # 璇勪及
@@ -283,23 +340,24 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
 
 #### census-income涓婄殑MMoE
 
-| 鍙傛暟 | Ascend 910  |
-|---|---|
-| 妯″瀷鐗堟湰  | MMoE  |
-| 璧勬簮  |  Ascend 910锛汣PU锛�2.60GHz锛�192鏍革紱鍐呭瓨锛�755G |
-| 涓婁紶鏃ユ湡  |2021-11-12 ;  |
-| MindSpore鐗堟湰  | 1.3.0 |
-| 鏁版嵁闆�  | census-income |
-| 璁粌鍙傛暟  | epoch=100, batch_size = 32  |
-| 浼樺寲鍣�  | Adam  |
-| 鎹熷け鍑芥暟  | BCELoss |
-| 杈撳嚭  | 姒傜巼 |
-|  鎹熷け | 0.20949207 |
-|閫熷害|0.671姣/姝� |
-|鎬绘椂闀�   |  17鍒嗛挓 |
-|鍙傛暟(M)   | 23.55 |
-|  寰皟妫€鏌ョ偣 | 2.66M锛�.ckpt鏂囦欢锛�  |
-| 鑴氭湰  | [閾炬帴](https://gitee.com/mindspore/models/tree/master/research/recommend/mmoe)  |
+| 鍙傛暟 | Ascend 910  | V100 GPU |
+|---|---|---|
+| 妯″瀷鐗堟湰  | MMoE  |MMoE|
+| 璧勬簮  |  Ascend 910锛汣PU锛�2.60GHz锛�192鏍革紱鍐呭瓨锛�755G |V100 GPU锛汣PU锛�8鏍革紱鍐呭瓨锛�64G|
+| 涓婁紶鏃ユ湡  |2021-11-12 ;  |2022-2-19|
+| MindSpore鐗堟湰  | 1.3.0 |1.6.0|
+| 鏁版嵁闆�  | census-income |census-income|
+| 璁粌鍙傛暟  | epoch=100, batch_size = 32  |epoch=100, batch_size = 128|
+| 浼樺寲鍣�  | Adam  |Adam|
+| 鎹熷け鍑芥暟  | BCELoss |BCELoss|
+| 杈撳嚭  | 姒傜巼 |姒傜巼|
+|  鎹熷け | 0.20949207 |0.21848808228969574|
+|閫熷害|0.671姣/姝� |11.399姣/姝
+|鎬绘椂闀�   |  17鍒嗛挓 |32鍒嗛挓|
+|鍙傛暟   | 23.55KB |23.55KB|
+|绮惧害鎸囨爣   | best income_auc:0.9895    best marital_auc:0.9837 |best income_auc:0.9892    best marital_auc:0.9826|
+|  寰皟妫€鏌ョ偣 | 2.66MB锛�.ckpt鏂囦欢锛�  |893.8KB锛�.ckpt鏂囦欢锛墊
+| 鑴氭湰  | [閾炬帴](https://gitee.com/mindspore/models/tree/master/research/recommend/mmoe)  |[閾炬帴](https://gitee.com/mindspore/models/tree/master/research/recommend/mmoe)|
 
 # 闅忔満鎯呭喌璇存槑
 
diff --git a/research/recommend/mmoe/default_config_gpu.yaml b/research/recommend/mmoe/default_config_gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..680246d721b116b82e3404834f86f61df7bf8f72
--- /dev/null
+++ b/research/recommend/mmoe/default_config_gpu.yaml
@@ -0,0 +1,63 @@
+# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
+enable_modelarts: False
+data_url: ""
+train_url: ""
+checkpoint_url: ""
+data_path: "/cache/data/"
+output_path: "/cache/train"
+load_path: "/cache/checkpoint_path"
+checkpoint_path: './checkpoint/'
+checkpoint_file: './checkpoint/MMoE_train-50_6236.ckpt'
+device_target: GPU
+enable_profiling: False
+
+ckpt_path: "/cache/train"
+ckpt_file: "/cache/train/MMoE_train-50_6236.ckpt"
+# ==============================================================================
+# Training options
+epoch_size: 100
+keep_checkpoint_max: 10
+learning_rate: 0.0005
+batch_size: 128
+num_features: 499
+num_experts: 8
+units: 4
+MINDIR_name: 'MMoE.MINDIR'
+ckpt_file_path: './MMoE_train-50_6236.ckpt'
+
+dataset_name: 'census_income'
+dataset_sink_mode: True
+run_distribute: False
+device_id: 0
+save_checkpoint: True
+save_checkpoint_epochs: 1
+lr: 0.0005
+local_data_path: '../data'
+
+# Model Description
+model_name: MMoE
+file_name: 'MMoE'
+file_format: 'MINDIR'
+
+# 'preprocess.'
+result_path: './preprocess_Result'
+
+# 'postprocess.'
+label1_path: './scripts/preprocess_Result/01_label1'
+label2_path: './scripts/preprocess_Result/02_label2'
+result_bin_path: './scripts/result_Files'
+income_path: './result_Files/income_output'
+marital_path: './result_Files/marital_output'
+---
+# Config description for each option
+enable_modelarts: 'Whether training on modelarts, default: False'
+data_url: 'Dataset url for obs'
+train_url: 'Training output url for obs'
+data_path: 'Dataset path for local'
+output_path: 'Training output path for local'
+
+device_target: 'Target device type'
+enable_profiling: 'Whether enable profiling while training, default: False'
+
+---
+device_target: ['Ascend', 'GPU', 'CPU']
diff --git a/research/recommend/mmoe/eval.py b/research/recommend/mmoe/eval.py
index b55d768329e8f6e15026300987bdb01b35de719f..1700ca58d5178899cef63bebb9e356168f6c18bc 100644
--- a/research/recommend/mmoe/eval.py
+++ b/research/recommend/mmoe/eval.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -41,7 +41,8 @@ def eval_mmoe():
     """MMoE eval"""
     device_num = get_device_num()
     if device_num > 1:
-        context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
+        context.set_context(mode=context.GRAPH_MODE,
+                            device_target='Ascend', save_graphs=False)
         if config.device_target == "Ascend":
             context.set_context(device_id=get_device_id())
             init()
@@ -50,44 +51,50 @@ def eval_mmoe():
 
     ds_eval = create_dataset(data_path=config.data_path, batch_size=config.batch_size,
                              training=False, target=config.device_target)
+
     eval_dataloader = ds_eval.create_tuple_iterator()
     if ds_eval.get_dataset_size() == 0:
-        raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
+        raise ValueError(
+            "Please check dataset size > 0 and batch_size <= dataset size")
     print("ds_eval_size", ds_eval.get_dataset_size())
 
-    net = MMoE(num_features=config.num_features, num_experts=config.num_experts, units=config.units)
+    net = MMoE(num_features=config.num_features,
+               num_experts=config.num_experts, units=config.units)
 
     param_dict = load_checkpoint(config.ckpt_path)
     print("load checkpoint from [{}].".format(config.ckpt_path))
     load_param_into_net(net, param_dict)
     net.set_train(False)
 
-    income_auc = 0
-    marital_auc = 0
-    for data, income_label, marital_label in eval_dataloader:
-        output = net(Tensor(data, mstype.float16))
-        income_output = output[0].asnumpy()
-        income_output = income_output.flatten().tolist()
+    income_output_list = []
+    marital_output_list = []
 
-        marital_output = output[1].asnumpy()
-        marital_output = marital_output.flatten().tolist()
+    income_label_list = []
+    marital_label_list = []
 
-        income_label = income_label.asnumpy()
-        income_label = income_label.flatten().tolist()
+    data_type = mstype.float16 if config.device_target == 'Ascend' else mstype.float32
+
+    print('start infer...')
+    for data, income_label, marital_label in eval_dataloader:
+        output = net(Tensor(data, data_type))
 
-        marital_label = marital_label.asnumpy()
-        marital_label = marital_label.flatten().tolist()
+        income_output_list.extend(output[0].asnumpy().flatten().tolist())
+        marital_output_list.extend(output[1].asnumpy().flatten().tolist())
 
-        if len(income_output) != len(income_label):
-            raise RuntimeError('income_output.size() is not equal income_label.size().')
-        if len(marital_output) != len(marital_label):
-            raise RuntimeError('marital_output.size is not equal marital_label.size().')
+        income_label_list.extend(income_label.asnumpy().flatten().tolist())
+        marital_label_list.extend(marital_label.asnumpy().flatten().tolist())
 
-        income_auc = roc_auc_score(income_label, income_output)
-        marital_auc = roc_auc_score(marital_label, marital_output)
+    if len(income_output_list) != len(income_label_list):
+        raise RuntimeError(
+            'income_output.size() is not equal income_label.size().')
+    if len(marital_output_list) != len(marital_label_list):
+        raise RuntimeError(
+            'marital_output.size is not equal marital_label.size().')
+    print('infer data finished, start eval...')
+    income_auc = roc_auc_score(income_label_list, income_output_list)
+    marital_auc = roc_auc_score(marital_label_list, marital_output_list)
 
-    results = [[income_auc], [marital_auc]]
-    print("result : {}".format(results))
+    print(f"result : income_auc={income_auc}, marital_auc={marital_auc}")
 
 
 if __name__ == "__main__":
diff --git a/research/recommend/mmoe/scripts/run_standalone_eval_gpu.sh b/research/recommend/mmoe/scripts/run_standalone_eval_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2fe3dfcad73831ecf18b8aaa3c19455287c28c67
--- /dev/null
+++ b/research/recommend/mmoe/scripts/run_standalone_eval_gpu.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# an simple tutorial as follows, more parameters can be setting
+if [ $# != 4 ]
+then
+    echo "Usage: bash run_standalone_eval_gpu.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [CONFIG_FILE]"
+exit 1
+fi
+
+export DATA_PATH=$1
+export CKPT_PATH=$2
+export DEVICE_ID=$3
+export CONFIG_FILE=$4
+
+python ../eval.py --data_path=$DATA_PATH --ckpt_path=$CKPT_PATH --device_id=$DEVICE_ID \
+--config_path=$CONFIG_FILE  > eval_log.txt 2>&1 &
diff --git a/research/recommend/mmoe/scripts/run_standalone_train_gpu.sh b/research/recommend/mmoe/scripts/run_standalone_train_gpu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2ef2b699db0776e4596224ac012d2127e4253c72
--- /dev/null
+++ b/research/recommend/mmoe/scripts/run_standalone_train_gpu.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# an simple tutorial as follows, more parameters can be setting
+if [ $# != 4 ]
+then
+    echo "Usage: bash run_standalone_train_gpu.sh [DATA_PATH] [DEVICE_ID] [CKPT_PATH] [CONFIG_FILE] "
+exit 1
+fi
+
+ulimit -u unlimited
+export DATA_PATH=$1
+export DEVICE_ID=$2
+export CKPT_PATH=$3
+export CONFIG_FILE=$4
+export DEVICE_NUM=1
+
+cpus=`cat /proc/cpuinfo| grep "processor"| wc -l`
+avg=`expr $cpus \/ $DEVICE_NUM`
+gap=`expr $avg \- 1`
+
+start=`expr 0 \* $avg`
+end=`expr $start \+ $gap`
+cmdopt=$start"-"$end
+
+echo "start training"
+taskset -c $cmdopt python ../train.py --data_path=$DATA_PATH --device_id=$DEVICE_ID --ckpt_path=./$CKPT_PATH \
+--config_path=$CONFIG_FILE > log.txt 2>&1 &
\ No newline at end of file
diff --git a/research/recommend/mmoe/src/callback.py b/research/recommend/mmoe/src/callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..e35f2bda53eaee1745ea24cd58745ad56a9269d9
--- /dev/null
+++ b/research/recommend/mmoe/src/callback.py
@@ -0,0 +1,117 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the License);
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# httpwww.apache.orglicensesLICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Defined eval callback for mmoe.
+"""
+import os
+import time
+
+from sklearn.metrics import roc_auc_score
+
+from mindspore import Tensor
+from mindspore.common import dtype as mstype
+from mindspore.train.callback import Callback
+from mindspore.train.serialization import save_checkpoint
+
+from src.model_utils.config import config
+
+
+class EvalCallBack(Callback):
+    """
+    Monitor the loss in training.
+    If the loss is NAN or INF terminating training.
+    Note
+        If per_print_times is 0 do not print loss.
+    """
+
+    def __init__(self, net, ds_eval, ckpt_path, rank_id):
+        super(EvalCallBack, self).__init__()
+        self.net = net
+        self.ds_eval = ds_eval
+        self.ckpt_path = ckpt_path
+        self.rank_id = rank_id
+        self.max_income_auc = 0
+        self.max_marital_auc = 0
+        self.max_income_marital_auc_avg = 0
+
+    def epoch_end(self, run_context):
+        start_time = time.time()
+        eval_dataloader = self.ds_eval.create_tuple_iterator()
+
+        income_output_list = []
+        marital_output_list = []
+
+        income_label_list = []
+        marital_label_list = []
+
+        data_type = mstype.float16 if config.device_target == 'Ascend' else mstype.float32
+
+        print('start infer...')
+        self.net.set_train(False)
+        for data, income_label, marital_label in eval_dataloader:
+            output = self.net(Tensor(data, data_type))
+
+            income_output_list.extend(output[0].asnumpy().flatten().tolist())
+            marital_output_list.extend(output[1].asnumpy().flatten().tolist())
+
+            income_label_list.extend(income_label.asnumpy().flatten().tolist())
+            marital_label_list.extend(
+                marital_label.asnumpy().flatten().tolist())
+
+        if len(income_output_list) != len(income_label_list):
+            raise RuntimeError(
+                'income_output.size() is not equal income_label.size().')
+        if len(marital_output_list) != len(marital_label_list):
+            raise RuntimeError(
+                'marital_output.size is not equal marital_label.size().')
+        print('infer data finished, start eval...')
+        income_auc = roc_auc_score(income_label_list, income_output_list)
+        marital_auc = roc_auc_score(marital_label_list, marital_output_list)
+
+        eval_time = int(time.time() - start_time)
+        print(
+            f"result : income_auc={income_auc}, marital_auc={marital_auc}, use time {eval_time}s")
+
+        cb_params = run_context.original_args()
+
+        if income_auc > self.max_income_auc:
+            self.max_income_auc = income_auc
+            ckpt_file_name = 'best_income_auc.ckpt'
+            if not os.path.exists(self.ckpt_path):
+                os.makedirs(self.ckpt_path)
+            save_path = os.path.join(self.ckpt_path, ckpt_file_name)
+            save_checkpoint(cb_params.train_network, save_path)
+
+        if marital_auc > self.max_marital_auc:
+            self.max_marital_auc = marital_auc
+            ckpt_file_name = 'best_marital_auc.ckpt'
+            if not os.path.exists(self.ckpt_path):
+                os.makedirs(self.ckpt_path)
+            save_path = os.path.join(self.ckpt_path, ckpt_file_name)
+            save_checkpoint(cb_params.train_network, save_path)
+
+        income_marital_auc_avg = (income_auc + marital_auc) / 2
+        if income_marital_auc_avg > self.max_income_marital_auc_avg:
+            self.max_income_marital_auc_avg = income_marital_auc_avg
+            ckpt_file_name = 'best_income_marital_auc_avg.ckpt'
+            if not os.path.exists(self.ckpt_path):
+                os.makedirs(self.ckpt_path)
+            save_path = os.path.join(self.ckpt_path, ckpt_file_name)
+            save_checkpoint(cb_params.train_network, save_path)
+
+        print(
+            f'The best income_auc is {self.max_income_auc}, \
+            the best marital_auc is {self.max_marital_auc}, \
+            the best income_marital_auc_avg is {self.max_income_marital_auc_avg}')
diff --git a/research/recommend/mmoe/src/data.py b/research/recommend/mmoe/src/data.py
index 34caf9cbccacf7d3e6662a77bcb666e500b42c9b..03a7b251e5bbe44355bc52e69c056030ef5ecf61 100644
--- a/research/recommend/mmoe/src/data.py
+++ b/research/recommend/mmoe/src/data.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,6 +13,8 @@
 # limitations under the License.
 # ============================================================================
 """Generate data in mindrecord format."""
+import os
+
 import pandas as pd
 import numpy as np
 
@@ -52,28 +54,36 @@ def generate_npy(data_path, do_train):
             index_col=None,
             names=column_names
         )
-    ds_transformed = pd.get_dummies(ds.drop(label_columns, axis=1), columns=categorical_columns)
+    ds_transformed = pd.get_dummies(
+        ds.drop(label_columns, axis=1), columns=categorical_columns)
     if not do_train:
         ds_transformed['det_hh_fam_stat_ Grandchild <18 ever marr not in subfamily'] = 0
     data = ds_transformed
-    np.save('data.npy', np.array(data), allow_pickle=False)
+    np.save(data_path + '/data.npy', np.array(data), allow_pickle=False)
 
     ds_raw_labels = ds[label_columns]
     ds_raw_labels['marital_stat'] = ds_raw_labels['marital_stat'].apply(
         lambda x: 'never married' if x == ' Never married' else 'married')
 
     income_labels = pd.get_dummies(ds_raw_labels['income_50k'])
-    np.save('income_labels.npy', np.array(income_labels), allow_pickle=False)
+    np.save(data_path + '/income_labels.npy',
+            np.array(income_labels), allow_pickle=False)
 
     married_labels = pd.get_dummies(ds_raw_labels['marital_stat'])
-    np.save('married_labels.npy', np.array(married_labels), allow_pickle=False)
+    np.save(data_path + '/married_labels.npy',
+            np.array(married_labels), allow_pickle=False)
+
+    data = np.load(data_path + '/data.npy').astype(np.float32)
+    income = np.load(data_path + '/income_labels.npy').astype(np.float32)
+    married = np.load(data_path + '/married_labels.npy').astype(np.float32)
+
+    mindrecord_path = data_path + "/mindrecord"
 
-    data = np.load('data.npy').astype(np.float32)
-    income = np.load('income_labels.npy').astype(np.float32)
-    married = np.load('married_labels.npy').astype(np.float32)
+    if not os.path.exists(mindrecord_path):
+        os.mkdir(mindrecord_path)
 
     if do_train:
-        MINDRECORD_FILE = "../data/train.mindrecord"
+        MINDRECORD_FILE = mindrecord_path + "/train.mindrecord"
         writer = FileWriter(file_name=MINDRECORD_FILE, shard_num=1)
 
         nlp_schema = {"data": {"type": "float32", "shape": [-1]},
@@ -85,10 +95,13 @@ def generate_npy(data_path, do_train):
                       "income_labels": income[i],
                       "married_labels": married[i]}
 
+            if i % 10000 == 0:
+                print(f'write {i} lines.')
+
             writer.write_raw_data([sample])
         writer.commit()
     else:
-        MINDRECORD_FILE = "../data/eval.mindrecord"
+        MINDRECORD_FILE = mindrecord_path + "/eval.mindrecord"
         writer = FileWriter(file_name=MINDRECORD_FILE, shard_num=1)
 
         nlp_schema = {"data": {"type": "float32", "shape": [-1]},
@@ -100,9 +113,13 @@ def generate_npy(data_path, do_train):
                       "income_labels": income[i],
                       "married_labels": married[i]}
 
+            if i % 10000 == 0:
+                print(f'write {i} lines.')
+
             writer.write_raw_data([sample])
         writer.commit()
 
 
 if __name__ == '__main__':
     generate_npy(data_path=config.local_data_path, do_train=True)
+    generate_npy(data_path=config.local_data_path, do_train=False)
diff --git a/research/recommend/mmoe/src/load_dataset.py b/research/recommend/mmoe/src/load_dataset.py
index c80eda3dbf624dbc7a46ec3c24071b689ae40ab4..602a26d726788630bb207dfd5387edbba7b7b7d3 100644
--- a/research/recommend/mmoe/src/load_dataset.py
+++ b/research/recommend/mmoe/src/load_dataset.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ from mindspore.communication.management import get_rank, get_group_size
 import mindspore.dataset.transforms.c_transforms as C
 import mindspore.common.dtype as mstype
 
+
 def _get_rank_info(run_distribute):
     """get rank size and rank id"""
     rank_size = int(os.environ.get("RANK_SIZE", 1))
@@ -49,36 +50,41 @@ def create_dataset(data_path,
     if target != "Ascend" or device_num == 1:
         if training:
             ds = de.MindDataset(input_file,
-                                columns_list=['data', 'income_labels', 'married_labels'],
-                                num_parallel_workers=32,
+                                columns_list=[
+                                    'data', 'income_labels', 'married_labels'],
+                                num_parallel_workers=8,
                                 shuffle=True)
         else:
             ds = de.MindDataset(input_file,
-                                columns_list=['data', 'income_labels', 'married_labels'],
-                                num_parallel_workers=32,
+                                columns_list=[
+                                    'data', 'income_labels', 'married_labels'],
+                                num_parallel_workers=8,
                                 shuffle=False)
     else:
         if training:
             ds = de.MindDataset(input_file,
-                                columns_list=['data', 'income_labels', 'married_labels'],
+                                columns_list=[
+                                    'data', 'income_labels', 'married_labels'],
                                 num_parallel_workers=4,
                                 shuffle=True,
                                 num_shards=device_num,
                                 shard_id=rank_id)
         else:
             ds = de.MindDataset(input_file,
-                                columns_list=['data', 'income_labels', 'married_labels'],
+                                columns_list=[
+                                    'data', 'income_labels', 'married_labels'],
                                 num_parallel_workers=4,
                                 shuffle=False,
                                 num_shards=device_num,
                                 shard_id=rank_id)
-    ds_label = [
-        C.TypeCast(mstype.float16)
-    ]
-    ds = ds.map(operations=ds_label, input_columns=["data"])
-    ds = ds.map(operations=ds_label, input_columns=["income_labels"])
-    ds = ds.map(operations=ds_label, input_columns=["married_labels"])
-    ds = ds.batch(batch_size)
+    if target == 'Ascend':
+        ds_label = [
+            C.TypeCast(mstype.float16)
+        ]
+        ds = ds.map(operations=ds_label, input_columns=["data"])
+        ds = ds.map(operations=ds_label, input_columns=["income_labels"])
+        ds = ds.map(operations=ds_label, input_columns=["married_labels"])
+    ds = ds.batch(batch_size, drop_remainder=True)
     return ds
 
 
diff --git a/research/recommend/mmoe/src/mmoe.py b/research/recommend/mmoe/src/mmoe.py
index d3a3907808bc57f30eeb46ab87be05fc049048ec..6396032f73ac8435009a0c209262c7cd7248e05b 100644
--- a/research/recommend/mmoe/src/mmoe.py
+++ b/research/recommend/mmoe/src/mmoe.py
@@ -1,4 +1,4 @@
-# Copyright 2021-2022 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -26,11 +26,13 @@ from mindspore import Parameter, Tensor
 from mindspore import context
 from mindspore.context import ParallelMode
 
+from src.model_utils.config import config
 from src.mmoe_utils import expert, gate, shared_output, tower, output
 
 
 class MMoE_Layer(nn.Cell):
     """MMoE network"""
+
     def __init__(self, input_size, num_experts, units):
         super(MMoE_Layer, self).__init__()
         self.input_size = input_size
@@ -39,8 +41,10 @@ class MMoE_Layer(nn.Cell):
         self.expert = expert(self.input_size, self.units, self.num_experts)
         self.gate1 = gate(self.input_size, self.num_experts)
         self.gate2 = gate(self.input_size, self.num_experts)
-        self.shared_output1 = shared_output(self.input_size, self.num_experts, self.units)
-        self.shared_output2 = shared_output(self.input_size, self.num_experts, self.units)
+        self.shared_output1 = shared_output(
+            self.input_size, self.num_experts, self.units)
+        self.shared_output2 = shared_output(
+            self.input_size, self.num_experts, self.units)
         self.tower_layer1 = tower(4, 8)
         self.tower_layer2 = tower(4, 8)
         self.output_layer1 = output(8, 2)
@@ -66,13 +70,15 @@ class MMoE_Layer(nn.Cell):
 
 def MMoE(num_features, num_experts, units):
     """MMoE call function"""
-    net = MMoE_Layer(input_size=num_features, num_experts=num_experts, units=units)
+    net = MMoE_Layer(input_size=num_features,
+                     num_experts=num_experts, units=units)
 
     return net
 
 
 class LossForMultiLabel(LossBase):
     """loss for two labels"""
+
     def __init__(self):
         super(LossForMultiLabel, self).__init__()
         self.bceloss = BCELoss()
@@ -89,6 +95,7 @@ class LossForMultiLabel(LossBase):
 
 class NetWithLossClass(nn.Cell):
     """net with loss"""
+
     def __init__(self, model, loss_fn):
         super(NetWithLossClass, self).__init__()
         self.model = model
@@ -184,7 +191,7 @@ def _grad_div(val, grad):
 class TrainStepWrap(nn.Cell):
     """TrainStepWrap definition"""
 
-    def __init__(self, network, optimizer, scale_update_cell):  # 16384.0
+    def __init__(self, network, optimizer, scale_update_cell, device_target):  # 16384.0
         super(TrainStepWrap, self).__init__(auto_prefix=False)
         self.network = network
         self.network.set_grad()
@@ -235,6 +242,9 @@ class TrainStepWrap(nn.Cell):
             mean = _get_gradients_mean()
             self.grad_reducer = DistributedGradReducer(
                 self.weights, mean, self.degree)
+        self.device_target = device_target
+        self.grad_scale_sense_type = mindspore.float16 \
+            if config.device_target == 'Ascend' else mindspore.float32
 
     def construct(self, data, label1, label2):
         """construct"""
@@ -243,11 +253,14 @@ class TrainStepWrap(nn.Cell):
 
         scale_sense = self.loss_scale
 
-        init = self.alloc_status()
-        init = F.depend(init, loss)
+        if self.device_target == 'Ascend':
+            init = self.alloc_status()
+            init = F.depend(init, loss)
 
-        clear_status = self.clear_before_grad(init)
-        scale_sense = F.depend(scale_sense, clear_status)
+            clear_status = self.clear_before_grad(init)
+            scale_sense = F.depend(scale_sense, clear_status)
+        else:
+            init = False
 
         grads = self.grad(
             self.network,
@@ -255,7 +268,7 @@ class TrainStepWrap(nn.Cell):
                 data,
                 label1,
                 label2,
-                scale_sense.astype(mindspore.float16))
+                scale_sense.astype(self.grad_scale_sense_type))
         grads = self.grad_reducer(grads)
         grads = self.hyper_map(
             F.partial(
@@ -270,10 +283,15 @@ class TrainStepWrap(nn.Cell):
                 GRADIENT_CLIP_VALUE),
             grads)
 
-        init = F.depend(init, grads)
-        get_status = self.get_status(init)
-        init = F.depend(init, get_status)
-        flag_sum = self.reduce_sum(init, (0,))
+        if self.device_target == 'Ascend':
+            init = F.depend(init, grads)
+            get_status = self.get_status(init)
+            init = F.depend(init, get_status)
+            flag_sum = self.reduce_sum(init, (0,))
+        else:
+            flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
+            flag_sum = P.AddN()(flag_sum)
+            flag_sum = P.Reshape()(flag_sum, (()))
 
         if self.is_distributed:
             flag_reduce = self.all_reduce(flag_sum)
@@ -283,7 +301,10 @@ class TrainStepWrap(nn.Cell):
 
         overflow = self.loss_scaling_manager(self.loss_scale, cond)
 
-        if not overflow:
-            self.optimizer(grads)
+        if overflow:
+            succ = False
+        else:
+            succ = self.optimizer(grads)
+
         ret = (loss, scale_sense)
-        return ret
+        return F.depend(ret, succ)
diff --git a/research/recommend/mmoe/src/mmoe_utils.py b/research/recommend/mmoe/src/mmoe_utils.py
index ff1571e52492e96710640dc95cef5c63a6067fcf..49971bbb80d6f537860cf5d5d4ba85ec8b995cd2 100644
--- a/research/recommend/mmoe/src/mmoe_utils.py
+++ b/research/recommend/mmoe/src/mmoe_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,9 +19,17 @@ from mindspore import Parameter
 from mindspore.common.initializer import initializer, TruncatedNormal, Zero
 import mindspore.ops as P
 
+from src.model_utils.config import config
+
+if config.device_target == 'Ascend':
+    use_mstype = ms.float16
+else:
+    use_mstype = ms.float32
+
 
 class expert(nn.Cell):
     """expert network"""
+
     def __init__(self,
                  input_size,
                  units,
@@ -37,7 +45,8 @@ class expert(nn.Cell):
         self.bias_add = P.Add()
         self.relu = nn.ReLU()
         self.expert_kernels = Parameter(initializer(TruncatedNormal(),
-                                                    (self.input_size, self.units, self.num_experts),
+                                                    (self.input_size, self.units,
+                                                     self.num_experts),
                                                     ms.float32),
                                         requires_grad=True)
         if self.use_expert_bias:
@@ -52,15 +61,18 @@ class expert(nn.Cell):
     def construct(self, x):
         """construct of expert network"""
         # expert_output = self.mul(x.astype(ms.float16), self.expert_kernels.astype(ms.float16), (1, 0))
-        expert_output = self.mul(x, self.expert_kernels.astype(ms.float16), (1, 0))
+        expert_output = self.mul(
+            x, self.expert_kernels.astype(use_mstype), (1, 0))
         if self.use_expert_bias:
-            expert_output = self.bias_add(expert_output, self.expert_bias.astype(ms.float16))
+            expert_output = self.bias_add(
+                expert_output, self.expert_bias.astype(use_mstype))
         expert_output = self.relu(expert_output)
         return expert_output
 
 
 class gate(nn.Cell):
     """gate network"""
+
     def __init__(self,
                  input_size,
                  num_experts,
@@ -73,7 +85,8 @@ class gate(nn.Cell):
         self.bias_add = P.BiasAdd()
         self.softmax = nn.Softmax()
         self.gate_kernel = Parameter(initializer(TruncatedNormal(),
-                                                 (self.input_size, self.num_experts),
+                                                 (self.input_size,
+                                                  self.num_experts),
                                                  ms.float32),
                                      requires_grad=True)
 
@@ -89,15 +102,17 @@ class gate(nn.Cell):
     def construct(self, x):
         """construct of gate network"""
         # gate_output = P.dot(x1=x.astype(ms.float16), x2=self.gate_kernel.astype(ms.float16))
-        gate_output = P.dot(x1=x, x2=self.gate_kernel.astype(ms.float16))
+        gate_output = P.dot(x1=x, x2=self.gate_kernel.astype(use_mstype))
         if self.use_gate_bias:
-            gate_output = self.bias_add(gate_output, self.gate_bias.astype(ms.float16))
+            gate_output = self.bias_add(
+                gate_output, self.gate_bias.astype(use_mstype))
         gate_output = self.softmax(gate_output)
         return gate_output
 
 
 class shared_output(nn.Cell):
     """Gate controls the weights of different experts for different tasks"""
+
     def __init__(self,
                  input_size,
                  num_experts,
@@ -114,7 +129,9 @@ class shared_output(nn.Cell):
     def construct(self, x, x1):
         """construct of shared output"""
         expanded_gate_output = self.expand_dims(x1, 1)
-        weighted_expert_output = x * self.repeat_elements(x=expanded_gate_output, rep=self.units, axis=1)
+        weighted_expert_output = x * \
+            self.repeat_elements(x=expanded_gate_output,
+                                 rep=self.units, axis=1)
         final_outputs = self.sum(weighted_expert_output, 2)
 
         return final_outputs
@@ -122,6 +139,7 @@ class shared_output(nn.Cell):
 
 class tower(nn.Cell):
     """dense with TRelu activation"""
+
     def __init__(self, in_channels, out_channels):
         super(tower, self).__init__()
         self.relu = nn.ReLU()
@@ -129,7 +147,7 @@ class tower(nn.Cell):
                                     out_channels,
                                     weight_init=TruncatedNormal(),
                                     activation=self.relu)
-        self.tower_layer.to_float(ms.float16)
+        self.tower_layer.to_float(use_mstype)
         self.print = P.Print()
 
     def construct(self, x):
@@ -140,6 +158,7 @@ class tower(nn.Cell):
 
 class output(nn.Cell):
     """dense with TSoftmax activation"""
+
     def __init__(self, in_channels, out_channels):
         super(output, self).__init__()
         self.softmax = nn.Softmax()
@@ -147,7 +166,7 @@ class output(nn.Cell):
                                out_channels,
                                weight_init=TruncatedNormal(),
                                activation=self.softmax)
-        self.output.to_float(ms.float16)
+        self.output.to_float(use_mstype)
         self.print = P.Print()
 
     def construct(self, x):
diff --git a/research/recommend/mmoe/train.py b/research/recommend/mmoe/train.py
index b715f7d4715165a9c79be64704e1942d33dee763..2f2235855c737914e38dc701cb6e093e3b4533dc 100644
--- a/research/recommend/mmoe/train.py
+++ b/research/recommend/mmoe/train.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Huawei Technologies Co., Ltd
+# Copyright 2022 Huawei Technologies Co., Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ from src.model_utils.config import config
 from src.mmoe import LossForMultiLabel, NetWithLossClass
 from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
 from src.get_lr import get_lr
+from src.callback import EvalCallBack
 
 set_seed(1)
 
@@ -51,11 +52,9 @@ def run_train():
     print('job id:', get_job_id())
 
     device_target = config.device_target
-    context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
+    context.set_context(mode=context.GRAPH_MODE,
+                        device_target=config.device_target)
     context.set_context(save_graphs=False)
-    if config.device_target == "GPU":
-        context.set_context(enable_graph_kernel=True)
-        context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
 
     device_num = get_device_num()
 
@@ -72,14 +71,18 @@ def run_train():
         context.set_context(device_id=get_device_id())
     print("init finished.")
 
-    ds_train = create_dataset(config.data_path, config.batch_size, training=True, \
-        target=config.device_target, run_distribute=config.run_distribute)
+    ds_train = create_dataset(config.data_path, config.batch_size, training=True,
+                              target=config.device_target, run_distribute=config.run_distribute)
+    ds_eval = create_dataset(config.data_path, config.batch_size,
+                             training=False, target=config.device_target)
 
     if ds_train.get_dataset_size() == 0:
-        raise ValueError("Please check dataset size > 0 and batch_size <= dataset size.")
+        raise ValueError(
+            "Please check dataset size > 0 and batch_size <= dataset size.")
     print("create dataset finished.")
 
-    net = MMoE_Layer(input_size=config.num_features, num_experts=config.num_experts, units=config.units)
+    net = MMoE_Layer(input_size=config.num_features,
+                     num_experts=config.num_experts, units=config.units)
     print("model created.")
     loss = LossForMultiLabel()
     loss_net = NetWithLossClass(net, loss)
@@ -88,9 +91,13 @@ def run_train():
     print("train dataset size:", step_per_size)
 
     if config.run_distribute:
-        learning_rate = get_lr(0.0005, config.epoch_size, step_per_size, step_per_size * 2)
+        learning_rate = get_lr(config.learning_rate / 2,
+                               config.epoch_size,
+                               step_per_size, step_per_size * 2)
     else:
-        learning_rate = get_lr(0.001, config.epoch_size, step_per_size, step_per_size * 5)
+        learning_rate = get_lr(config.learning_rate,
+                               config.epoch_size,
+                               step_per_size, step_per_size * 5)
     opt = Adam(net.trainable_params(),
                learning_rate=learning_rate,
                beta1=0.9,
@@ -98,19 +105,24 @@ def run_train():
                eps=1e-7,
                weight_decay=0.0,
                loss_scale=1.0)
-    scale_update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 12,
-                                                   scale_factor=2,
-                                                   scale_window=1000)
-    train_net = TrainStepWrap(loss_net, opt, scale_update_cell)
+    scale_update_cell = DynamicLossScaleUpdateCell(
+        loss_scale_value=2 ** 12 if config.device_target == 'Ascend' else 1.0,
+        scale_factor=2,
+        scale_window=1000)
+    train_net = TrainStepWrap(
+        loss_net, opt, scale_update_cell, config.device_target)
     train_net.set_train()
     model = Model(train_net)
 
     time_cb = TimeMonitor()
-    loss_cb = LossMonitor(step_per_size)
-    config_ck = CheckpointConfig(save_checkpoint_steps=step_per_size, keep_checkpoint_max=100)
-    callbacks_list = [time_cb, loss_cb]
+    loss_cb = LossMonitor()
+    eval_cb = EvalCallBack(net, ds_eval, config.ckpt_path, get_rank_id())
+    config_ck = CheckpointConfig(
+        save_checkpoint_steps=step_per_size, keep_checkpoint_max=config.keep_checkpoint_max)
+    callbacks_list = [time_cb, loss_cb, eval_cb]
     if get_rank_id() == 0:
-        ckpoint_cb = ModelCheckpoint(prefix='MMoE_train', directory=config.ckpt_path, config=config_ck)
+        ckpoint_cb = ModelCheckpoint(
+            prefix='MMoE_train', directory=config.ckpt_path, config=config_ck)
         callbacks_list.append(ckpoint_cb)
 
     print("train start!")