diff --git a/official/recommend/tbnet/README.md b/official/recommend/tbnet/README.md
index d650a02a28645389c6fc7fbd5ec95504bb293975..dbae5154353cea8a690ccba633419e1853c0b829 100644
--- a/official/recommend/tbnet/README.md
+++ b/official/recommend/tbnet/README.md
@@ -54,10 +54,17 @@ Note that the \<item\> needs to traverse candidate items (all items by default)
 #format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item  # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
 ```
 
+We have to download the data package and put it underneath the current project path銆�
+
+```bash
+wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
+tar -xf tbnet_data.tar.gz
+```
+
 # [Environment Requirements](#contents)
 
-- Hardware锛圙PU锛�
-    - Prepare hardware environment with GPU processor.
+- Hardware锛圢VIDIA GPU or Ascend NPU锛�
+    - Prepare hardware environment with NVIDIA GPU or Ascend NPU processor.
 - Framework
     - [MindSpore](https://www.mindspore.cn/install/en)
 - For more information, please check the resources below锛�
@@ -70,51 +77,57 @@ After installing MindSpore via the official website, you can start training and
 
 - Data preprocessing
 
-Process the data to the format in chapter [Dataset](#Dataset) (e.g. 'steam' dataset), and then run code as follows.
+Download the data package(e.g. 'steam' dataset) and put it underneath the current project path.
+
+```bash
+wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
+tar -xf tbnet_data.tar.gz
+cd scripts
+```
+
+and then run code as follows.
 
 - Training
 
 ```bash
-python train.py \
-  --dataset [DATASET] \
-  --epochs [EPOCHS]
+bash run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
 ```
 
 Example:
 
 ```bash
-python train.py \
-  --dataset steam \
-  --epochs 20
+bash run_standalone_train.sh steam 0 Ascend
 ```
 
 - Evaluation
 
+Evaluation model on test dataset.
+
 ```bash
-python eval.py \
-  --dataset [DATASET] \
-  --checkpoint_id [CHECKPOINT_ID]
+bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
 ```
 
-Argument `--checkpoint_id` is required.
+Argument `[CHECKPOINT_ID]` is required.
 
 Example:
 
 ```bash
-python eval.py \
-  --dataset steam \
-  --checkpoint_id 8
+bash run_eval.sh 19 steam 0 Ascend
 ```
 
 - Inference and Explanation
 
+Recommende items to user acrodding to `user`, the number of items is determined by `items`.
+
 ```bash
 python infer.py \
   --dataset [DATASET] \
   --checkpoint_id [CHECKPOINT_ID] \
   --user [USER] \
   --items [ITEMS] \
-  --explanations [EXPLANATIONS]
+  --explanations [EXPLANATIONS] \
+  --csv [CSV] \
+  --device_target [DEVICE_TARGET]
 ```
 
 Arguments `--checkpoint_id` and `--user` are required.
@@ -124,10 +137,12 @@ Example:
 ```bash
 python infer.py \
   --dataset steam \
-  --checkpoint_id 8 \
-  --user 1 \
+  --checkpoint_id 19 \
+  --user 2 \
   --items 1 \
-  --explanations 3
+  --explanations 3 \
+  --csv test.csv \
+  --device_target Ascend
 ```
 
 # [Script Description](#contents)
@@ -139,14 +154,16 @@ python infer.py \
 鈹斺攢tbnet
   鈹溾攢README.md
   鈹溾攢鈹€ scripts
-  鈹�   鈹斺攢run_infer_310.sh    # Ascend310 inference script
+      鈹溾攢run_infer_310.sh                  # Ascend310 inference script
+      鈹溾攢run_standalone_train.sh           # NVIDIA GPU or Ascend NPU training script
+      鈹斺攢run_eval.sh                       # NVIDIA GPU or Ascend NPU evaluation script
   鈹溾攢data
     鈹溾攢steam
         鈹溾攢config.json               # data and training parameter configuration
-        鈹溾攢infer.csv                 # inference and explanation dataset
-        鈹溾攢test.csv                  # evaluation dataset
-        鈹溾攢train.csv                 # training dataset
-        鈹斺攢trainslate.json           # explanation configuration
+        鈹溾攢src_infer.csv             # inference and explanation dataset
+        鈹溾攢src_test.csv              # evaluation dataset
+        鈹溾攢src_train.csv             # training dataset
+        鈹斺攢id_maps.json              # explanation configuration
   鈹溾攢src
     鈹溾攢aggregator.py                 # inference result aggregation
     鈹溾攢config.py                     # parsing parameter configuration
@@ -156,6 +173,7 @@ python infer.py \
     鈹溾攢steam.py                      # 'steam' dataset text explainer
     鈹斺攢tbnet.py                      # TB-Net model
   鈹溾攢export.py                         # export mindir script
+  鈹溾攢preprocess_dataset.py           # dataset preprocess script
   鈹溾攢preprocess.py                         # inference data preprocess script
   鈹溾攢postprocess.py                         # inference result calculation script
   鈹溾攢eval.py                         # evaluation
@@ -165,6 +183,14 @@ python infer.py \
 
 ## [Script Parameters](#contents)
 
+- preprocess_dataset.py parameters
+
+```text
+--dataset         'steam' dataset is supported currently
+--device_target   run code on GPU or Ascend NPU
+--same_relation   only generate paths that relation1 is same as relation2
+```
+
 - train.py parameters
 
 ```text
@@ -173,7 +199,7 @@ python infer.py \
 --test_csv        the test csv datafile inside the dataset folder
 --device_id       device id
 --epochs          number of training epochs
---device_target   run code on GPU
+--device_target   run code on GPU or Ascend NPU
 --run_mode        run code by GRAPH mode or PYNATIVE mode
 ```
 
@@ -184,7 +210,7 @@ python infer.py \
 --csv             the csv datafile inside the dataset folder (e.g. test.csv)
 --checkpoint_id   use which checkpoint(.ckpt) file to eval
 --device_id       device id
---device_target   run code on GPU
+--device_target   run code on GPU or Ascend NPU
 --run_mode        run code by GRAPH mode or PYNATIVE mode
 ```
 
@@ -198,7 +224,7 @@ python infer.py \
 --items           no. of items to be recommended
 --reasons         no. of recommendation reasons to be shown
 --device_id       device id
---device_target   run code on GPU
+--device_target   run code on GPU or Ascend NPU
 --run_mode        run code by GRAPH mode or PYNATIVE mode
 ```
 
@@ -215,6 +241,17 @@ python export.py --config_path [CONFIG_PATH] --checkpoint_path [CKPT_PATH] --dev
 - `DEVICE` should be in ['Ascend', 'GPU'].
 - `FILE_FORMAT` should be in ['MINDIR', 'AIR'].
 
+Example锛�
+
+```bash
+python export.py \
+  --config_path ./data/steam/config.json \
+  --checkpoint_path ./checkpoints/tbnet_epoch19.ckpt \
+  --device_target Ascend \
+  --file_name model \
+  --file_format MINDIR
+```
+
 ### [Infer on Ascend310](#contents)
 
 Before performing inference, the mindir file must be exported by `export.py` script. We only provide an example of inference using MINDIR model.
@@ -228,6 +265,12 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
 - `DATA_PATH` specifies path of test.csv.
 - `DEVICE_ID` is optional, default value is 0.
 
+Example锛�
+
+```bash
+bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
+```
+
 ### [Result](#contents)
 
 Inference result is saved in current path, you can find result like this in acc.log file.
@@ -242,35 +285,35 @@ auc: 0.8251359368836292
 
 ### Training Performance
 
-| Parameters                 | GPU                                                         |
-| -------------------------- | ----------------------------------------------------------- |
-| Model Version              | TB-Net                                                      |
-| Resource                   |Tesla V100-SXM2-32GB                                         |
-| Uploaded Date              | 2021-08-01                                                  |
-| MindSpore Version          | 1.3.0                                                       |
-| Dataset                    | steam                                                       |
-| Training Parameter         | epoch=20, batch_size=1024, lr=0.001                         |
-| Optimizer                  | Adam                                                        |
-| Loss Function              | Sigmoid Cross Entropy                                       |
-| Outputs                    | AUC=0.8596锛孉ccuracy=0.7761                                 |
-| Loss                       | 0.57                                                        |
-| Speed                      | 1pc: 90ms/step                                              |
-| Total Time                 | 1pc: 297s                                                   |
-| Checkpoint for Fine Tuning | 104.66M (.ckpt file)                                        |
-| Scripts                    | [TB-Net scripts](https://gitee.com/mindspore/models/tree/master/official/recommend/tbnet) |
+| Parameters                 | GPU                                                                                        | Ascend NPU                                   |
+| -------------------------- |--------------------------------------------------------------------------------------------| ---------------------------------------------|
+| Model Version              | TB-Net                                                                                     | TB-Net                                       |
+| Resource                   | NVIDIA RTX 3090                                                                            | Ascend 910                                   |
+| Uploaded Date              | 2022-07-14                                                                                 | 2022-06-30                                   |
+| MindSpore Version          | 1.6.1                                                                                      | 1.6.1                                        |
+| Dataset                    | steam                                                                                      | steam                                        |
+| Training Parameter         | epoch=20, batch_size=1024, lr=0.001                                                        | epoch=20, batch_size=1024, lr=0.001          |
+| Optimizer                  | Adam                                                                                       | Adam                                         |
+| Loss Function              | Sigmoid Cross Entropy                                                                      | Sigmoid Cross Entropy                        |
+| Outputs                    | AUC=0.8573锛孉ccuracy=0.7733                                                                 | AUC=0.8592锛屽噯纭巼=0.7741                      |
+| Loss                       | 0.57                                                                                       | 0.59                                         |
+| Speed                      | 1pc: 90ms/step                                                                             | 鍗曞崱锛�80姣/姝�                                |
+| Total Time                 | 1pc: 297s                                                                                  | 鍗曞崱锛�336绉�                                    |
+| Checkpoint for Fine Tuning | 686.3K (.ckpt file)                                                                       | 671K (.ckpt 鏂囦欢)                             |
+| Scripts                    | [TB-Net scripts](https://gitee.com/mindspore/models/tree/master/official/recommend/tbnet)  |
 
 ### Evaluation Performance
 
-| Parameters                | GPU                           |
-| ------------------------- | ----------------------------- |
-| Model Version             | TB-Net                        |
-| Resource                  | Tesla V100-SXM2-32GB          |
-| Uploaded Date             | 2021-08-01                    |
-| MindSpore Version         | 1.3.0                         |
-| Dataset                   | steam                         |
-| Batch Size                | 1024                          |
-| Outputs                   | AUC=0.8252锛孉ccuracy=0.7503   |
-| Total Time                | 1pc: 5.7s                     |
+| Parameters                | GPU                        | Ascend NPU                    |
+| ------------------------- |----------------------------| ----------------------------- |
+| Model Version             | TB-Net                     | TB-Net                        |
+| Resource                  | NVIDIA RTX 3090            | Ascend 910                    |
+| Uploaded Date             | 2022-07-14                 | 2022-06-30                    |
+| MindSpore Version         | 1.3.0                      | 1.5.1                         |
+| Dataset                   | steam                      | steam                         |
+| Batch Size                | 1024                       | 1024                          |
+| Outputs                   | AUC=0.8487锛孉ccuracy=0.7699 | AUC=0.8486锛孉ccuracy=0.7704    |
+| Total Time                | 1pc: 5.7s                  | 1pc: 1.1绉�                    |
 
 ### Inference and Explanation Performance
 
diff --git a/official/recommend/tbnet/README_CN.md b/official/recommend/tbnet/README_CN.md
index 49342bad853ffede77a606e70b5d0aa355f250dc..fca40f2c00856d4fb3bb9ff721c8de3b8fd9bef5 100644
--- a/official/recommend/tbnet/README_CN.md
+++ b/official/recommend/tbnet/README_CN.md
@@ -58,8 +58,8 @@ TB-Net灏嗙敤鎴峰拰鐗╁搧鐨勪氦浜掍俊鎭互鍙婄墿鍝佺殑灞炴€т俊鎭湪鐭ヨ瘑鍥�
 
 # [鐜瑕佹眰](#鐩綍)
 
-- 纭欢锛圙PU锛�
-    - 浣跨敤GPU澶勭悊鍣ㄥ噯澶囩‖浠剁幆澧冦€�
+- 纭欢锛圢VIDIA GPU or Ascend NPU锛�
+    - 浣跨敤NVIDIA GPU澶勭悊鍣ㄦ垨鑰匒scend NPU澶勭悊鍣ㄥ噯澶囩‖浠剁幆澧冦€�
 - 妗嗘灦
     - [MindSpore](https://www.mindspore.cn/install)
 - 濡傞渶鏌ョ湅璇︽儏锛岃鍙傝濡備笅璧勬簮锛�
@@ -72,51 +72,57 @@ TB-Net灏嗙敤鎴峰拰鐗╁搧鐨勪氦浜掍俊鎭互鍙婄墿鍝佺殑灞炴€т俊鎭湪鐭ヨ瘑鍥�
 
 - 鏁版嵁鍑嗗
 
-灏嗘暟鎹鐞嗘垚涓婁竴鑺俒鏁版嵁闆哴(#鏁版嵁闆�)涓殑鏍煎紡锛堜互'steam'鏁版嵁闆嗕负渚嬶級锛岀劧鍚庢寜鐓т互涓嬫楠よ繍琛屼唬鐮併€�
+涓嬭浇鐢ㄤ緥鏁版嵁闆嗗寘锛堜互'steam'鏁版嵁闆嗕负渚嬶級锛岃В鍘嬪埌褰撳墠椤圭洰璺緞銆�
+
+```bash
+wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
+tar -xf tbnet_data.tar.gz
+cd scripts
+```
+
+鐒跺悗鎸夌収浠ヤ笅姝ラ杩愯浠g爜銆�
 
 - 璁粌
 
 ```bash
-python train.py \
-  --dataset [DATASET] \
-  --epochs [EPOCHS]
+bash run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
 ```
 
 绀轰緥锛�
 
 ```bash
-python train.py \
-  --dataset steam \
-  --epochs 20
+bash run_standalone_train.sh steam 0 Ascend
 ```
 
 - 璇勪及
 
+璇勪及妯″瀷鍦ㄦ祴璇曢泦涓婄殑鎸囨爣銆�
+
 ```bash
-python eval.py \
-  --dataset [DATASET] \
-  --checkpoint_id [CHECKPOINT_ID]
+bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
 ```
 
-鍙傛暟`--checkpoint_id`鏄繀濉」銆�
+鍙傛暟`[CHECKPOINT_ID]`鏄繀濉」銆�
 
 绀轰緥锛�
 
 ```bash
-python eval.py \
-  --dataset steam \
-  --checkpoint_id 8
+bash run_eval.sh 19 steam 0 Ascend
 ```
 
 - 鎺ㄧ悊鍜岃В閲�
 
+鏍规嵁`user`鎺ㄨ崘涓€瀹氭暟閲忕殑鐗╁搧锛屾暟閲忕敱`items`鍐冲畾銆�
+
 ```bash
 python infer.py \
   --dataset [DATASET] \
   --checkpoint_id [CHECKPOINT_ID] \
   --user [USER] \
   --items [ITEMS] \
-  --explanations [EXPLANATIONS]
+  --explanations [EXPLANATIONS] \
+  --csv [CSV] \
+  --device_target [DEVICE_TARGET]
 ```
 
 鍙傛暟`--checkpoint_id`鍜宍--user`鏄繀濉」銆�
@@ -126,10 +132,12 @@ python infer.py \
 ```bash
 python infer.py \
   --dataset steam \
-  --checkpoint_id 8 \
-  --user 1 \
+  --checkpoint_id 19 \
+  --user 2 \
   --items 1 \
-  --explanations 3
+  --explanations 3 \
+  --csv test.csv \
+  --device_target Ascend
 ```
 
 # [鑴氭湰璇存槑](#鐩綍)
@@ -141,14 +149,16 @@ python infer.py \
 鈹斺攢tbnet
   鈹溾攢README.md
   鈹溾攢鈹€ scripts
-  鈹�   鈹斺攢run_infer_310.sh    # 鐢ㄤ簬Ascend310鎺ㄧ悊鐨勮剼鏈�
+      鈹溾攢run_infer_310.sh                  # 鐢ㄤ簬Ascend310鎺ㄧ悊鐨勮剼鏈�
+      鈹溾攢run_standalone_train.sh           # 鐢ㄤ簬NVIDIA GPU鎴栬€匒scend NPU璁粌鐨勮剼鏈�
+      鈹斺攢run_eval.sh                       # 鐢ㄤ簬NVIDIA GPU鎴栬€匒scend NPU璇勪及鐨勮剼鏈�
   鈹溾攢data
     鈹溾攢steam
         鈹溾攢config.json               # 鏁版嵁鍜岃缁冨弬鏁伴厤缃�
-        鈹溾攢infer.csv                 # 鎺ㄧ悊鍜岃В閲婃暟鎹泦
-        鈹溾攢test.csv                  # 娴嬭瘯鏁版嵁闆�
-        鈹溾攢train.csv                 # 璁粌鏁版嵁闆�
-        鈹斺攢trainslate.json           # 杈撳嚭瑙i噴鐩稿叧閰嶇疆
+        鈹溾攢src_infer.csv             # 鎺ㄧ悊鍜岃В閲婃暟鎹泦
+        鈹溾攢src_test.csv              # 娴嬭瘯鏁版嵁闆�
+        鈹溾攢src_train.csv             # 璁粌鏁版嵁闆�
+        鈹斺攢id_maps.json              # 杈撳嚭瑙i噴鐩稿叧閰嶇疆
   鈹溾攢src
     鈹溾攢aggregator.py                 # 鎺ㄧ悊缁撴灉鑱氬悎
     鈹溾攢config.py                     # 鍙傛暟閰嶇疆瑙f瀽
@@ -157,9 +167,10 @@ python infer.py \
     鈹溾攢metrics.py                    # 妯″瀷搴﹂噺
     鈹溾攢steam.py                      # 'steam'鏁版嵁闆嗘枃鏈В鏋�
     鈹斺攢tbnet.py                      # TB-Net缃戠粶
-  鈹溾攢export.py                         # 瀵煎嚭MINDIR鑴氭湰
-  鈹溾攢preprocess.py                         # 鎺ㄧ悊鏁版嵁棰勫鐞嗚剼鏈�
-  鈹溾攢postprocess.py                         # 鎺ㄧ悊缁撴灉璁$畻鑴氭湰
+  鈹溾攢export.py                       # 瀵煎嚭MINDIR鑴氭湰
+  鈹溾攢preprocess_dataset.py           # 鏁版嵁闆嗛澶勭悊鑴氭湰
+  鈹溾攢preprocess.py                   # 鎺ㄧ悊鏁版嵁棰勫鐞嗚剼鏈�
+  鈹溾攢postprocess.py                  # 鎺ㄧ悊缁撴灉璁$畻鑴氭湰
   鈹溾攢eval.py                         # 璇勪及缃戠粶
   鈹溾攢infer.py                        # 鎺ㄧ悊鍜岃В閲�
   鈹斺攢train.py                        # 璁粌缃戠粶
@@ -167,6 +178,14 @@ python infer.py \
 
 ## [鑴氭湰鍙傛暟](#鐩綍)
 
+- preprocess_dataset.py鍙傛暟
+
+```text
+--dataset         'steam' dataset is supported currently
+--device_target   run code on GPU or Ascend NPU
+--same_relation   only generate paths that relation1 is same as relation2
+```
+
 - train.py鍙傛暟
 
 ```text
@@ -175,7 +194,7 @@ python infer.py \
 --test_csv        the test csv datafile inside the dataset folder
 --device_id       device id
 --epochs          number of training epochs
---device_target   run code on GPU
+--device_target   run code on GPU or Ascend NPU
 --run_mode        run code by GRAPH mode or PYNATIVE mode
 ```
 
@@ -186,7 +205,7 @@ python infer.py \
 --csv             the csv datafile inside the dataset folder (e.g. test.csv)
 --checkpoint_id   use which checkpoint(.ckpt) file to eval
 --device_id       device id
---device_target   run code on GPU
+--device_target   run code on GPU or Ascend NPU
 --run_mode        run code by GRAPH mode or PYNATIVE mode
 ```
 
@@ -200,7 +219,7 @@ python infer.py \
 --items           no. of items to be recommended
 --reasons         no. of recommendation reasons to be shown
 --device_id       device id
---device_target   run code on GPU
+--device_target   run code on GPU or Ascend NPU
 --run_mode        run code by GRAPH mode or PYNATIVE mode
 ```
 
@@ -209,7 +228,12 @@ python infer.py \
 ### 瀵煎嚭MindIR
 
 ```shell
-python export.py --config_path [CONFIG_PATH] --checkpoint_path [CKPT_PATH] --device_target [DEVICE] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
+python export.py \
+  --config_path [CONFIG_PATH] \
+  --checkpoint_path [CKPT_PATH] \
+  --device_target [DEVICE] \
+  --file_name [FILE_NAME] \
+  --file_format [FILE_FORMAT]
 ```
 
 - `CKPT_PATH` 涓哄繀濉」銆�
@@ -217,6 +241,17 @@ python export.py --config_path [CONFIG_PATH] --checkpoint_path [CKPT_PATH] --dev
 - `DEVICE` 鍙€夐」涓� ['Ascend', 'GPU']銆�
 - `FILE_FORMAT` 鍙€夐」涓� ['MINDIR', 'AIR']銆�
 
+绀轰緥锛�
+
+```bash
+python export.py \
+  --config_path ./data/steam/config.json \
+  --checkpoint_path ./checkpoints/tbnet_epoch19.ckpt \
+  --device_target Ascend \
+  --file_name model \
+  --file_format MINDIR
+```
+
 ### 鍦ˋscend310鎵ц鎺ㄧ悊
 
 鍦ㄦ墽琛屾帹鐞嗗墠锛宮indir鏂囦欢蹇呴』閫氳繃`export.py`鑴氭湰瀵煎嚭銆備互涓嬪睍绀轰簡浣跨敤minir妯″瀷鎵ц鎺ㄧ悊鐨勭ず渚嬨€�
@@ -230,6 +265,12 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
 - `DATA_PATH` 鎺ㄧ悊鏁版嵁闆唗est.csv璺緞
 - `DEVICE_ID` 鍙€夛紝榛樿鍊间负0銆�
 
+绀轰緥锛�
+
+```bash
+bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
+```
+
 ### 缁撴灉
 
 鎺ㄧ悊缁撴灉淇濆瓨鍦ㄨ剼鏈墽琛岀殑褰撳墠璺緞锛屼綘鍙互鍦╝cc.log涓湅鍒颁互涓嬬簿搴﹁绠楃粨鏋溿€�
@@ -244,35 +285,35 @@ auc: 0.8251359368836292
 
 ### [璁粌鎬ц兘](#鐩綍)
 
-| 鍙傛暟                  | GPU                                                |
-| -------------------  | --------------------------------------------------- |
-| 妯″瀷鐗堟湰              | TB-Net                                              |
-| 璧勬簮                  |Tesla V100-SXM2-32GB                                 |
-| 涓婁紶鏃ユ湡              | 2021-08-01                                          |
-| MindSpore鐗堟湰         | 1.3.0                                               |
-| 鏁版嵁闆�                | steam                                               |
-| 璁粌鍙傛暟              | epoch=20, batch_size=1024, lr=0.001                 |
-| 浼樺寲鍣�                | Adam                                                |
-| 鎹熷け鍑芥暟              | Sigmoid浜ゅ弶鐔�                                        |
-| 杈撳嚭                  | AUC=0.8596锛屽噯纭巼=0.7761                            |
-| 鎹熷け                  | 0.57                                               |
-| 閫熷害                  | 鍗曞崱锛�90姣/姝�                                      |
-| 鎬绘椂闀�                | 鍗曞崱锛�297绉�                                          |
-| 寰皟妫€鏌ョ偣             | 104.66M (.ckpt 鏂囦欢)                                |
+| 鍙傛暟                  | GPU                                                                                 | Ascend NPU                          |
+| -------------------  |-------------------------------------------------------------------------------------|-------------------------------------|
+| 妯″瀷鐗堟湰              | TB-Net                                                                              | TB-Net                              |
+| 璧勬簮                  | NVIDIA RTX 3090                                                                     | Ascend 910                          |
+| 涓婁紶鏃ユ湡              | 2022-07-14                                                                          | 2022-06-30                          |
+| MindSpore鐗堟湰         | 1.6.1                                                                               | 1.6.1                               |
+| 鏁版嵁闆�                | steam                                                                               | steam                               |
+| 璁粌鍙傛暟              | epoch=20, batch_size=1024, lr=0.001                                                 | epoch=20, batch_size=1024, lr=0.001 |
+| 浼樺寲鍣�                | Adam                                                                                | Adam                                |
+| 鎹熷け鍑芥暟              | Sigmoid浜ゅ弶鐔�                                                                          | Sigmoid浜ゅ弶鐔�                          |
+| 杈撳嚭                  | AUC=0.8573锛屽噯纭巼=0.7733                                                               | AUC=0.8592锛屽噯纭巼=0.7741               |
+| 鎹熷け                  | 0.57                                                                                | 0.59                                |
+| 閫熷害                  | 鍗曞崱锛�90姣/姝�                                                                           | 鍗曞崱锛�80姣/姝�                           |
+| 鎬绘椂闀�                | 鍗曞崱锛�297绉�                                                                             | 鍗曞崱锛�336绉�                             |
+| 寰皟妫€鏌ョ偣             | 686.3K (.ckpt 鏂囦欢)                                                                   | 671K (.ckpt 鏂囦欢)                     |
 | 鑴氭湰                  | [TB-Net鑴氭湰](https://gitee.com/mindspore/models/tree/master/official/recommend/tbnet) |
 
 ### [璇勪及鎬ц兘](#鐩綍)
 
-| 鍙傛暟                        | GPU                          |
-| -------------------------- | ----------------------------- |
-| 妯″瀷鐗堟湰                    | TB-Net                        |
-| 璧勬簮                        | Tesla V100-SXM2-32GB         |
-| 涓婁紶鏃ユ湡                    | 2021-08-01                    |
-| MindSpore鐗堟湰               | 1.3.0                         |
-| 鏁版嵁闆�                      | steam                         |
-| 鎵规澶у皬                    | 1024                          |
-| 杈撳嚭                        | AUC=0.8252锛屽噯纭巼=0.7503      |
-| 鎬绘椂闀�                      | 鍗曞崱锛�5.7绉�                    |
+| 鍙傛暟                        | GPU                   | Ascend NPU                      |
+| -------------------------- |-----------------------| ----------------------------- |
+| 妯″瀷鐗堟湰                    | TB-Net                | TB-Net                        |
+| 璧勬簮                        | NVIDIA RTX 3090       | Ascend 910                    |
+| 涓婁紶鏃ユ湡                    | 2022-07-14            | 2022-06-30                    |
+| MindSpore鐗堟湰               | 1.6.1                 | 1.6.1                         |
+| 鏁版嵁闆�                      | steam                 | steam                         |
+| 鎵规澶у皬                    | 1024                  | 1024                          |
+| 杈撳嚭                        | AUC=0.8487锛屽噯纭巼=0.7699 | AUC=0.8486锛屽噯纭巼=0.7704       |
+| 鎬绘椂闀�                      | 鍗曞崱锛�5.7绉�               | 鍗曞崱锛�1.1绉�                     |
 
 ### [鎺ㄧ悊鍜岃В閲婃€ц兘](#鐩綍)
 
diff --git a/official/recommend/tbnet/data/steam/config.json b/official/recommend/tbnet/data/steam/config.json
deleted file mode 100644
index dcaa1740e7f60a10a7116d4cf2830617dbf4fe0b..0000000000000000000000000000000000000000
--- a/official/recommend/tbnet/data/steam/config.json
+++ /dev/null
@@ -1,12 +0,0 @@
-{
-  "num_item": 3005,
-  "num_relation": 5,
-  "num_entity": 5138,
-  "per_item_num_paths": 39,
-  "embedding_dim": 26,
-  "batch_size": 1024,
-  "lr": 0.001,
-  "kge_weight": 0.05,
-  "node_weight": 0.002,
-  "l2_weight": 1e-6
-}
\ No newline at end of file
diff --git a/official/recommend/tbnet/data/steam/infer.csv b/official/recommend/tbnet/data/steam/infer.csv
deleted file mode 100644
index 6579425f1b27f857ef753dd4803a0cd4db8a0a0c..0000000000000000000000000000000000000000
--- a/official/recommend/tbnet/data/steam/infer.csv
+++ /dev/null
@@ -1 +0,0 @@
-#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item  # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
diff --git a/official/recommend/tbnet/data/steam/test.csv b/official/recommend/tbnet/data/steam/test.csv
deleted file mode 100644
index 0b9eec9ea07d80cb24e9057de5adca1787db3dd8..0000000000000000000000000000000000000000
--- a/official/recommend/tbnet/data/steam/test.csv
+++ /dev/null
@@ -1 +0,0 @@
-#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item  # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
\ No newline at end of file
diff --git a/official/recommend/tbnet/data/steam/train.csv b/official/recommend/tbnet/data/steam/train.csv
deleted file mode 100644
index 0b9eec9ea07d80cb24e9057de5adca1787db3dd8..0000000000000000000000000000000000000000
--- a/official/recommend/tbnet/data/steam/train.csv
+++ /dev/null
@@ -1 +0,0 @@
-#format:user,item,rating,relation1,entity,relation2,hist_item,relation1,entity,relation2,hist_item,...,relation1,entity,relation2,hist_item  # module [relation1,entity,relation2,hist_item] repeats PER_ITEM_NUM_PATHS times
\ No newline at end of file
diff --git a/official/recommend/tbnet/data/steam/translate.json b/official/recommend/tbnet/data/steam/translate.json
deleted file mode 100644
index a348188cf5b90ab7a48af0a39d0493c398f90213..0000000000000000000000000000000000000000
--- a/official/recommend/tbnet/data/steam/translate.json
+++ /dev/null
@@ -1,14 +0,0 @@
-{
-  "item": {
-    "0": "Star Wars",
-    "1": "Battlefield 1"
-  },
-  "relation": {
-    "0": "Developer",
-    "1": "Genre"
-  },
-  "entity": {
-    "425": "EA Games",
-    "426": "Shooting"
-  }
-}
\ No newline at end of file
diff --git a/official/recommend/tbnet/eval.py b/official/recommend/tbnet/eval.py
index 893ee60d071e7075a9a2734632b66fe37384c9c5..7f97cdecc1258bf3509fc7e27c02dbca0b50c777 100644
--- a/official/recommend/tbnet/eval.py
+++ b/official/recommend/tbnet/eval.py
@@ -16,8 +16,10 @@
 
 import os
 import argparse
+import math
 
 from mindspore import context, Model, load_checkpoint, load_param_into_net
+import mindspore.common.dtype as mstype
 
 from src import tbnet, config, metrics, dataset
 
@@ -62,8 +64,8 @@ def get_args():
         type=str,
         required=False,
         default='GPU',
-        choices=['GPU'],
-        help="run code on GPU"
+        choices=['GPU', 'Ascend'],
+        help="run code on GPU or Ascend NPU"
     )
 
     parser.add_argument(
@@ -95,10 +97,15 @@ def eval_tbnet():
 
     print(f"creating dataset from {test_csv_path}...")
     net_config = config.TBNetConfig(config_path)
-    eval_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
+    if args.device_target == 'Ascend':
+        net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
+        net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
+    eval_ds = dataset.create(test_csv_path, net_config.per_item_paths, train=True).batch(net_config.batch_size)
 
     print(f"creating TBNet from checkpoint {args.checkpoint_id} for evaluation...")
     network = tbnet.TBNet(net_config)
+    if args.device_target == 'Ascend':
+        network.to_float(mstype.float16)
     param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
     load_param_into_net(network, param_dict)
 
diff --git a/official/recommend/tbnet/export.py b/official/recommend/tbnet/export.py
index dd6275cfe33d2c8a751135b4a675320868a63c26..4110028d448167457e1d0a02a11000ff33b73dab 100644
--- a/official/recommend/tbnet/export.py
+++ b/official/recommend/tbnet/export.py
@@ -16,6 +16,7 @@
 
 import os
 import argparse
+import math
 import numpy as np
 
 from mindspore import context, load_checkpoint, load_param_into_net, Tensor, export
@@ -103,20 +104,23 @@ def export_tbnet():
         context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
 
     net_config = config.TBNetConfig(config_path)
-
+    if args.device_target == 'Ascend':
+        net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
+        net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
     network = tbnet.TBNet(net_config)
     param_dict = load_checkpoint(ckpt_path)
     load_param_into_net(network, param_dict)
     eval_net = tbnet.PredictWithSigmoid(network)
 
     item = Tensor(np.ones((1,)).astype(np.int))
-    rl1 = Tensor(np.ones((1, 39)).astype(np.int))
-    ety = Tensor(np.ones((1, 39)).astype(np.int))
-    rl2 = Tensor(np.ones((1, 39)).astype(np.int))
-    his = Tensor(np.ones((1, 39)).astype(np.int))
+    rl1 = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
+    ety = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
+    rl2 = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
+    his = Tensor(np.ones((1, net_config.per_item_paths)).astype(np.int))
     rate = Tensor(np.ones((1,)).astype(np.float32))
     inputs = [item, rl1, ety, rl2, his, rate]
     export(eval_net, *inputs, file_name=args.file_name, file_format=args.file_format)
 
+
 if __name__ == '__main__':
     export_tbnet()
diff --git a/official/recommend/tbnet/infer.py b/official/recommend/tbnet/infer.py
index e630e0f9a225506cfcf9bf69c6a66ef5b9908f50..1e48b5b6c6c30376088fa4d668beeac41fc1405c 100644
--- a/official/recommend/tbnet/infer.py
+++ b/official/recommend/tbnet/infer.py
@@ -16,8 +16,10 @@
 
 import os
 import argparse
+import math
 
 from mindspore import load_checkpoint, load_param_into_net, context
+import mindspore.common.dtype as mstype
 from src.config import TBNetConfig
 from src.tbnet import TBNet
 from src.aggregator import InferenceAggregator
@@ -88,8 +90,8 @@ def get_args():
         type=str,
         required=False,
         default='GPU',
-        choices=['GPU'],
-        help="run code on GPU"
+        choices=['GPU', 'Ascend'],
+        help="run code on GPU or Ascend NPU"
     )
 
     parser.add_argument(
@@ -121,12 +123,17 @@ def infer_tbnet():
 
     print(f"creating TBNet from checkpoint {args.checkpoint_id}...")
     config = TBNetConfig(config_path)
+    if args.device_target == 'Ascend':
+        config.per_item_paths = math.ceil(config.per_item_paths / 16) * 16
+        config.embedding_dim = math.ceil(config.embedding_dim / 16) * 16
     network = TBNet(config)
+    if args.device_target == 'Ascend':
+        network.to_float(mstype.float16)
     param_dict = load_checkpoint(os.path.join(ckpt_path, f'tbnet_epoch{args.checkpoint_id}.ckpt'))
     load_param_into_net(network, param_dict)
 
     print(f"creating dataset from {data_path}...")
-    infer_ds = dataset.create(data_path, config.per_item_num_paths, train=False, users=args.user)
+    infer_ds = dataset.create(data_path, config.per_item_paths, train=False, users=args.user)
     infer_ds = infer_ds.batch(config.batch_size)
 
     print("inferring...")
diff --git a/official/recommend/tbnet/preprocess.py b/official/recommend/tbnet/preprocess.py
index 94c24dfe3f25bfde79880b1f011c4cbb936c57d9..f22d5c6d49c28a3478036bb18ea902cb2484dfbe 100644
--- a/official/recommend/tbnet/preprocess.py
+++ b/official/recommend/tbnet/preprocess.py
@@ -17,6 +17,7 @@
 import os
 import argparse
 import shutil
+import math
 import numpy as np
 
 from mindspore import context
@@ -44,7 +45,6 @@ def get_args():
         help="the csv datafile inside the dataset folder (e.g. test.csv)"
     )
 
-
     parser.add_argument(
         '--device_id',
         type=int,
@@ -58,8 +58,8 @@ def get_args():
         type=str,
         required=False,
         default='Ascend',
-        choices=['Ascend'],
-        help="run code on GPU"
+        choices=['Ascend', 'GPU'],
+        help="run code on GPU or Ascend NPU"
     )
 
     parser.add_argument(
@@ -90,7 +90,9 @@ def preprocess_tbnet():
 
     print(f"creating dataset from {test_csv_path}...")
     net_config = config.TBNetConfig(config_path)
-    eval_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(1)
+    if args.device_target == 'Ascend':
+        net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
+    eval_ds = dataset.create(test_csv_path, net_config.per_item_paths, train=True).batch(1)
     item_path = os.path.join('./preprocess_Result/', '00_item')
     rl1_path = os.path.join('./preprocess_Result/', '01_rl1')
     ety_path = os.path.join('./preprocess_Result/', '02_ety')
@@ -134,5 +136,7 @@ def preprocess_tbnet():
         rate_rst.tofile(rate_real_path)
 
         idx += 1
+
+
 if __name__ == '__main__':
     preprocess_tbnet()
diff --git a/official/recommend/tbnet/preprocess_dataset.py b/official/recommend/tbnet/preprocess_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e32b525bf072ae1995eebe2ead01819611217761
--- /dev/null
+++ b/official/recommend/tbnet/preprocess_dataset.py
@@ -0,0 +1,117 @@
+# Copyright 2021-2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Data Preprocessing app."""
+
+# This script should be run directly with 'python <script> <args>'.
+
+import os
+import io
+import argparse
+import json
+import math
+
+from src.path_gen import PathGen
+from src.config import TBNetConfig
+
+
+def get_args():
+    """Parse commandline arguments."""
+    parser = argparse.ArgumentParser(description='Preprocess TB-Net data.')
+
+    parser.add_argument(
+        '--dataset',
+        type=str,
+        required=False,
+        default='steam',
+        help="'steam' dataset is supported currently"
+    )
+
+    parser.add_argument(
+        '--device_target',
+        type=str,
+        required=False,
+        default='GPU',
+        choices=['GPU', 'Ascend'],
+        help="run code on GPU or Ascend NPU"
+    )
+
+    parser.add_argument(
+        '--same_relation',
+        required=False,
+        action='store_true',
+        default=False,
+        help="only generate paths that relation1 is same as relation2"
+    )
+
+    return parser.parse_args()
+
+
+def preprocess_csv(path_gen, data_home, src_name, out_name):
+    """Pre-process a csv file."""
+    src_path = os.path.join(data_home, src_name)
+    out_path = os.path.join(data_home, out_name)
+    print(f'converting {src_path} to {out_path} ...')
+    rows = path_gen.generate(src_path, out_path)
+    print(f'{rows} rows of path data generated.')
+
+
+def preprocess_data():
+    """Pre-process the dataset."""
+    args = get_args()
+    home = os.path.dirname(os.path.realpath(__file__))
+
+    data_home = os.path.join(home, 'data', args.dataset)
+    config_path = os.path.join(data_home, 'config.json')
+    id_maps_path = os.path.join(data_home, 'id_maps.json')
+
+    cfg = TBNetConfig(config_path)
+    if args.device_target == 'Ascend':
+        cfg.per_item_paths = math.ceil(cfg.per_item_paths / 16) * 16
+    path_gen = PathGen(per_item_paths=cfg.per_item_paths, same_relation=args.same_relation)
+
+    preprocess_csv(path_gen, data_home, 'src_train.csv', 'train.csv')
+
+    # save id maps for the later use by Recommender in infer.py
+    with io.open(id_maps_path, mode="w", encoding="utf-8") as f:
+        json.dump(path_gen.id_maps(), f, indent=4)
+
+    # count distinct objects from the training set
+    cfg.num_items = path_gen.num_items
+    cfg.num_references = path_gen.num_references
+    cfg.num_relations = path_gen.num_relations
+    cfg.save(config_path)
+
+    print(f'{config_path} updated.')
+    print(f'num_items: {cfg.num_items}')
+    print(f'num_references: {cfg.num_references}')
+    print(f'num_relations: {cfg.num_relations}')
+
+    # treat new items and references in test and infer set as unseen entities
+    # dummy internal id 0 will be assigned to them
+    path_gen.grow_id_maps = False
+
+    preprocess_csv(path_gen, data_home, 'src_test.csv', 'test.csv')
+
+    # for inference, only take interacted('c') and other('x') items as candidate items,
+    # the purchased('p') items won't be recommended.
+    # assume there is only one user in src_infer.csv
+    path_gen.subject_ratings = "cx"
+    preprocess_csv(path_gen, data_home, 'src_infer.csv', 'infer.csv')
+
+    print(f'Dataset {data_home} processed.')
+
+
+if __name__ == '__main__':
+    preprocess_data()
diff --git a/official/recommend/tbnet/scripts/run_eval.sh b/official/recommend/tbnet/scripts/run_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..054e70139fd46ebae89ee49782b742ed33a6434a
--- /dev/null
+++ b/official/recommend/tbnet/scripts/run_eval.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [[ $# -lt 3 || $# -gt 4 ]]; then
+    echo "Usage: bash run_train.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
+    CHECKPOINT_ID means model checkpoint id.
+    DATA_NAME means dataset name, it's value is 'steam'.
+    DEVICE_ID means device id, it can be set by environment variable DEVICE_ID.
+    DEVICE_TARGET is optional, it's value is ['GPU', 'Ascend'], default 'GPU'."
+exit 1
+fi
+
+CHECKPOINT_ID=$1
+DATA_NAME=$2
+DEVICE_ID=$3
+
+DEVICE_TARGET='GPU'
+if [ $# == 4 ]; then
+    DEVICE_TARGET=$4
+fi
+
+python ../eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
+       --device_id $DEVICE_ID  &> eval_standalone_gpu_log &
\ No newline at end of file
diff --git a/official/recommend/tbnet/scripts/run_standalone_train.sh b/official/recommend/tbnet/scripts/run_standalone_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9333260dcb1717afaa9788b6b79a5609dbad9946
--- /dev/null
+++ b/official/recommend/tbnet/scripts/run_standalone_train.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+if [[ $# -lt 2 || $# -gt 3 ]]; then
+    echo "Usage: bash run_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
+    DATA_NAME means dataset name, it's value is 'steam'.
+    DEVICE_ID means device id, it can be set by environment variable DEVICE_ID.
+    DEVICE_TARGET is optional, it's value is ['GPU', 'Ascend'], default 'GPU'."
+exit 1
+fi
+
+DATA_NAME=$1
+DEVICE_ID=$2
+
+DEVICE_TARGET='GPU'
+if [ $# == 3 ]; then
+    DEVICE_TARGET=$3
+fi
+
+python ../preprocess_dataset.py --dataset $DATA_NAME --device_target $DEVICE_TARGET &> train_standalone_log &&
+python ../train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id $DEVICE_ID &>> train_standalone_log &
\ No newline at end of file
diff --git a/official/recommend/tbnet/src/config.py b/official/recommend/tbnet/src/config.py
index bf1c0ecd66d57162de1744dc7f1f0072c469a461..0ad2a8cc5e4404b73bb5b224f6b3b0454ebebace 100644
--- a/official/recommend/tbnet/src/config.py
+++ b/official/recommend/tbnet/src/config.py
@@ -27,13 +27,17 @@ class TBNetConfig:
     def __init__(self, config_path):
         with open(config_path) as f:
             json_dict = json.load(f)
-        self.num_item = int(json_dict['num_item'])
-        self.num_relation = int(json_dict['num_relation'])
-        self.num_entity = int(json_dict['num_entity'])
-        self.per_item_num_paths = int(json_dict['per_item_num_paths'])
+        self.num_items = int(json_dict['num_items'])
+        self.num_relations = int(json_dict['num_relations'])
+        self.num_references = int(json_dict['num_references'])
+        self.per_item_paths = int(json_dict['per_item_paths'])
         self.embedding_dim = int(json_dict['embedding_dim'])
         self.batch_size = int(json_dict['batch_size'])
         self.lr = float(json_dict['lr'])
         self.kge_weight = float(json_dict['kge_weight'])
         self.node_weight = float(json_dict['node_weight'])
         self.l2_weight = float(json_dict['l2_weight'])
+
+    def save(self, config_path):
+        with open(config_path, 'w') as f:
+            json.dump(self.__dict__, f, indent=4)
diff --git a/official/recommend/tbnet/src/dataset.py b/official/recommend/tbnet/src/dataset.py
index 2dfaea47cc067ef7efc4899fe1aadc20193ac7ec..34f4ac6d5d6ffe520b18bf9ba8940775b272397b 100644
--- a/official/recommend/tbnet/src/dataset.py
+++ b/official/recommend/tbnet/src/dataset.py
@@ -14,10 +14,12 @@
 # ============================================================================
 """Dataset loader."""
 
+import os
 from functools import partial
 
 import numpy as np
-from mindspore.dataset import GeneratorDataset
+import mindspore.dataset as ds
+import mindspore.mindrecord as record
 
 
 def create(data_path, per_item_num_paths, train, users=None, **kwargs):
@@ -39,13 +41,71 @@ def create(data_path, per_item_num_paths, train, users=None, **kwargs):
     """
     if isinstance(users, int):
         users = (users,)
-    kwargs['source'] = partial(csv_generator, data_path, per_item_num_paths, users, train)
 
     if train:
-        kwargs['column_names'] = ['item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
+        kwargs['columns_list'] = ['item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
     else:
-        kwargs['column_names'] = ['user', 'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
-    return GeneratorDataset(**kwargs)
+        kwargs['columns_list'] = ['user', 'item', 'relation1', 'entity', 'relation2', 'hist_item', 'rating']
+    mindrecord_file_path = csv_dataset(partial(csv_generator, data_path, per_item_num_paths, users, train), data_path,
+                                       train)
+    return ds.MindDataset(mindrecord_file_path, **kwargs)
+
+
+def csv_dataset(generator, csv_path, train):
+    """Dataset for csv datafile."""
+    file_name = os.path.basename(csv_path)
+    mindrecord_file_path = os.path.join(os.path.dirname(csv_path), file_name[0:file_name.rfind('.')] + '.mindrecord')
+
+    if os.path.exists(mindrecord_file_path):
+        os.remove(mindrecord_file_path)
+
+    if os.path.exists(mindrecord_file_path + ".db"):
+        os.remove(mindrecord_file_path + ".db")
+
+    data_schema = {
+        "item": {"type": "int32", "shape": []},
+        "relation1": {"type": "int32", "shape": [-1]},
+        "entity": {"type": "int32", "shape": [-1]},
+        "relation2": {"type": "int32", "shape": [-1]},
+        "hist_item": {"type": "int32", "shape": [-1]},
+        "rating": {"type": "float32", "shape": []},
+    }
+    if not train:
+        data_schema["user"] = {"type": "int32", "shape": []}
+
+    writer = record.FileWriter(file_name=mindrecord_file_path, shard_num=1)
+    writer.add_schema(data_schema, "Preprocessed dataset.")
+
+    data = []
+    for i, row in enumerate(generator()):
+        if train:
+            sample = {
+                "item": row[0],
+                "relation1": row[1],
+                "entity": row[2],
+                "relation2": row[3],
+                "hist_item": row[4],
+                "rating": row[5],
+            }
+        else:
+            sample = {
+                "user": row[0],
+                "item": row[1],
+                "relation1": row[2],
+                "entity": row[3],
+                "relation2": row[4],
+                "hist_item": row[5],
+                "rating": row[6],
+            }
+        data.append(sample)
+
+        if i % 10 == 0:
+            writer.write_raw_data(data)
+            data = []
+    if data:
+        writer.write_raw_data(data)
+    writer.commit()
+    return mindrecord_file_path
 
 
 def csv_generator(csv_path, per_item_num_paths, users, train):
@@ -81,8 +141,8 @@ def csv_generator(csv_path, per_item_num_paths, users, train):
         if train:
             # item, relation1, entity, relation2, hist_item, rating
             yield np.array(item, dtype=np.int), relation1, entity, relation2, hist_item, \
-                np.array(rating, dtype=np.float32)
+                  np.array(rating, dtype=np.float32)
         else:
             # user, item, relation1, entity, relation2, hist_item, rating
-            yield np.array(user, dtype=np.int), np.array(item, dtype=np.int),\
-                relation1, entity, relation2, hist_item, np.array(rating, dtype=np.float32)
+            yield np.array(user, dtype=np.int), np.array(item, dtype=np.int), \
+                  relation1, entity, relation2, hist_item, np.array(rating, dtype=np.float32)
diff --git a/official/recommend/tbnet/src/path_gen.py b/official/recommend/tbnet/src/path_gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..c53378ccad7b599997ddfa0b834c08850658daf9
--- /dev/null
+++ b/official/recommend/tbnet/src/path_gen.py
@@ -0,0 +1,435 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Relation path data generator."""
+import io
+import random
+import csv
+import warnings
+
+
+class _UserRec:
+    """User record, helper class for path generation."""
+
+    def __init__(self, src_id, intern_id):
+        self.src_id = src_id
+        self.intern_id = intern_id
+        self.positive_items = dict()
+        self.interact_items = dict()
+        self.other_items = dict()
+        self.has_unseen_ref = False
+
+    def add_item(self, item_rec, rating):
+        """Add an item."""
+        if rating == 'p':
+            item_dict = self.positive_items
+        elif rating == 'c':
+            item_dict = self.interact_items
+        else:
+            item_dict = self.other_items
+        item_dict[item_rec.intern_id] = item_rec
+
+
+class _ItemRec:
+    """Item record, helper class for path generation."""
+
+    def __init__(self, src_id, intern_id, ref_src_ids, ref_ids):
+        self.src_id = src_id
+        self.intern_id = intern_id
+        self.ref_src_ids = ref_src_ids
+        self.ref_ids = ref_ids
+
+
+class PathGen:
+    """
+    Generate relation path csv from the source csv table.
+
+    Args:
+        per_item_paths (int): Number of relation paths per subject item, must be positive.
+        same_relation (bool): True to only generate paths that relation1 is same as relation2, usually faster.
+        id_maps (dict[str, Union[dict[str, int], int]], Optional): Object id maps, the internal id baseline, new user,
+            item and entity IDs will be based on that. If Which is None or empty, grow_id_maps will be True by
+            default.
+    """
+
+    def __init__(self, per_item_paths, same_relation=False, id_maps=None):
+
+        self._per_item_paths = per_item_paths
+        self._same_relation = same_relation
+
+        self._user_id_counter = 1
+        self._entity_id_counter = 1
+        self._num_relations = 0
+        self._rows_generated = 0
+        self._user_rec = None
+
+        if id_maps:
+            self._item_id_map = id_maps.get('item', dict())
+            self._ref_id_map = id_maps.get('reference', dict())
+            self._rl_id_map = id_maps.get('relation', None)
+            self._user_id_counter = id_maps.get('_user_id_counter', self._user_id_counter)
+            max_item_id = max(self._item_id_map.values()) if self._item_id_map else 0
+            max_ref_id = max(self._ref_id_map.values()) if self._ref_id_map else 0
+            self._entity_id_counter = max(max_item_id, max_ref_id) + 1
+        else:
+            self._item_id_map = dict()
+            self._ref_id_map = dict()
+            self._rl_id_map = None
+
+        self.grow_id_maps = not (bool(self._item_id_map) and bool(self._ref_id_map))
+        self.subject_ratings = ""
+
+        self._unseen_items = 0
+        self._unseen_refs = 0
+
+    @property
+    def num_users(self):
+        """int, the number of distinct users."""
+        return self._user_id_counter - 1
+
+    @property
+    def num_references(self):
+        """int, the number of distinct references."""
+        return len(self._ref_id_map)
+
+    @property
+    def num_items(self):
+        """int, the number of distinct items."""
+        return len(self._item_id_map)
+
+    @property
+    def num_relations(self):
+        """int, the number of distinct relations."""
+        return self._num_relations
+
+    @property
+    def rows_generated(self):
+        """int, total number of rows generated to the output CSVs."""
+        return self._rows_generated
+
+    @property
+    def per_item_paths(self):
+        """int, the number of path per subject item."""
+        return self._per_item_paths
+
+    @property
+    def same_relation(self):
+        """bool, only generate paths with the same relation on both sides."""
+        return self._same_relation
+
+    @property
+    def unseen_items(self):
+        """int, total number of unseen items has encountered."""
+        return self._unseen_items
+
+    @property
+    def unseen_refs(self):
+        """int, total number of unseen references has encountered."""
+        return self._unseen_refs
+
+    def id_maps(self):
+        """dict, object ID maps."""
+        maps = {
+            "item": dict(self._item_id_map),
+            "reference": dict(self._ref_id_map),
+            "_user_id_counter": self._user_id_counter
+        }
+        if self._rl_id_map is not None:
+            maps["relation"] = dict(self._rl_id_map)
+        return maps
+
+    def generate(self, in_csv, out_csv, in_sep=',', in_mv_sep=';', in_encoding='utf-8'):
+        """
+        Generate paths csv from the source CSV files.
+
+        args:
+            in_csv (Union[str, TextIOBase]): The input source csv path or stream.
+            out_csv (Union[str, TextIOBase]): The output source csv path or stream.
+            in_sep (str): Separator of the input csv.
+            in_mv_sep (str): Multi-value separator of the input csv in a single column.
+            in_encoding (str): Encoding of the input source csv, ignored if in_csv is a text stream already.
+
+        Returns:
+            int, the number of rows that generated to the output csv in this call.
+        """
+        if not isinstance(in_csv, (str, io.TextIOBase)):
+            raise TypeError(f"Unexpected in_csv type:{type(in_csv)}")
+        if not isinstance(out_csv, (str, io.TextIOBase)):
+            raise TypeError(f"Unexpected out_csv type:{type(out_csv)}")
+
+        opened_files = []
+        try:
+            if isinstance(in_csv, str):
+                in_csv = io.open(in_csv, mode="r", encoding=in_encoding)
+                opened_files.append(in_csv)
+            in_csv = csv.reader(in_csv, delimiter=in_sep)
+            col_indices = self._pre_generate(in_csv, None)
+
+            if isinstance(out_csv, str):
+                out_csv = io.open(out_csv, mode="w", encoding="ascii")
+                opened_files.append(out_csv)
+            rows_generated = self._do_generate(in_csv, out_csv, in_mv_sep, col_indices)
+
+        except (IOError, ValueError, RuntimeError, PermissionError, KeyError) as e:
+            raise e
+        finally:
+            for f in opened_files:
+                f.close()
+        return rows_generated
+
+    def _pre_generate(self, in_csv, in_col_map):
+        """Prepare for the path generation."""
+        if in_col_map is not None:
+            expected_cols = self._default_abstract_header(len(in_col_map) - 3)
+            map_values = list(in_col_map.values())
+            for col in expected_cols:
+                if col not in map_values:
+                    raise ValueError("col_map has no '{col}' value.")
+
+        header = self._read_header(in_csv)
+        if len(header) < 4:
+            raise IOError(f"No. of in_csv columns:{len(header)} is less than 4.")
+        num_relations = len(header) - 3
+        if self._num_relations > 0:
+            if num_relations != self._num_relations:
+                raise IOError(f"Inconsistent no. of in_csv relations.")
+        else:
+            self._num_relations = num_relations
+
+        col_indices = self._get_col_indices(header, in_col_map)
+        rl_id_map = self._to_relation_id_map(header, col_indices)
+
+        if not self._rl_id_map:
+            self._rl_id_map = rl_id_map
+        elif rl_id_map != self._rl_id_map:
+            raise IOError(f"Inconsistent in_csv relations.")
+
+        return col_indices
+
+    def _do_generate(self, in_csv, out_csv, in_mv_sep, col_indices):
+        """Do generate the paths."""
+        old_rows_generated = self._rows_generated
+        old_unseen_items = self._unseen_items
+        old_unseen_refs = self._unseen_refs
+
+        col_count = len(col_indices)
+        self._user_rec = None
+        for line in in_csv:
+            values = list(map(lambda x: x.strip(), line))
+            if len(values) != col_count:
+                raise IOError(f"No. of in_csv columns:{len(values)} is not {col_count}.")
+            self._process_line(values, in_mv_sep, col_indices, out_csv)
+
+        if self._user_rec is not None:
+            self._process_user_rec(self._user_rec, out_csv)
+            self._user_rec = None
+
+        delta_unseen_items = self._unseen_items - old_unseen_items
+        delta_unseen_refs = self._unseen_refs - old_unseen_refs
+        if delta_unseen_items > 0:
+            warnings.warn(f"{delta_unseen_items} unseen items' internal IDs were set to 0, "
+                          f"set grow_id_maps to True for adding new internal IDs.", RuntimeWarning)
+        if delta_unseen_refs > 0:
+            warnings.warn(f"{delta_unseen_refs} unseen references' internal IDs were set to 0, "
+                          f"set grow_id_maps to True for adding new internal IDs.", RuntimeWarning)
+
+        return self._rows_generated - old_rows_generated
+
+    def _process_line(self, values, in_mv_sep, col_indices, out_csv):
+        """Process a line from the input CSV."""
+        user_src = values[col_indices[0]]
+        item_src = values[col_indices[1]]
+        rating = values[col_indices[2]].lower()
+        if rating not in ('p', 'c', 'x'):
+            raise IOError(f"Unrecognized rating:'{rating}', must be one of 'p', 'c' or 'x'.")
+        ref_srcs = [values[col_indices[i]] for i in range(3, len(col_indices))]
+
+        if in_mv_sep:
+            ref_srcs = list(map(lambda x: list(map(lambda y: y.strip(), x.split(in_mv_sep))), ref_srcs))
+        else:
+            ref_srcs = list(map(lambda x: [x], ref_srcs))
+
+        if self._user_rec is not None and user_src != self._user_rec.src_id:
+            # user changed
+            self._process_user_rec(self._user_rec, out_csv)
+            self._user_rec = None
+
+        if self._user_rec is None:
+            self._user_rec = _UserRec(user_src, self._user_id_counter)
+            self._user_id_counter += 1
+
+        item_rec, has_unseen_ref = self._to_item_rec(item_src, ref_srcs)
+        self._user_rec.add_item(item_rec, rating)
+        self._user_rec.has_unseen_ref |= has_unseen_ref
+
+    def _process_user_rec(self, user_rec, out_csv):
+        """Generate paths for an user."""
+        positive_count = 0
+
+        subject_items = []
+
+        if self.subject_ratings == "":
+            subject_items.extend(user_rec.positive_items.values())
+            subject_items.extend(user_rec.other_items.values())
+            positive_count = len(user_rec.positive_items)
+        else:
+            if 'p' in self.subject_ratings:
+                subject_items.extend(user_rec.positive_items.values())
+                positive_count = len(user_rec.positive_items)
+            if 'c' in self.subject_ratings:
+                subject_items.extend(user_rec.interact_items.values())
+            if 'x' in self.subject_ratings:
+                subject_items.extend(user_rec.other_items.values())
+
+        hist_items = []
+        hist_items.extend(user_rec.positive_items.values())
+        hist_items.extend(user_rec.interact_items.values())
+
+        for i, subject in enumerate(subject_items):
+
+            paths = []
+            for hist in hist_items:
+                if hist.src_id == subject.src_id:
+                    continue
+                self._find_paths(not user_rec.has_unseen_ref, subject, hist, paths)
+
+            if not paths:
+                continue
+
+            paths = random.sample(paths, min(len(paths), self._per_item_paths))
+
+            row = [0] * (3 + self._per_item_paths * 4)
+            row[0] = user_rec.src_id
+            row[1] = subject.intern_id  # subject item
+            row[2] = 1 if i < positive_count else 0  # label
+            for p, path in enumerate(paths):
+                offset = 3 + p * 4
+                for j in range(4):
+                    row[offset + j] = path[j]
+            out_csv.write(','.join(map(str, row)))
+            out_csv.write('\n')
+            self._rows_generated += 1
+
+    def _find_paths(self, by_intern_id, subject_item, hist_item, paths):
+        """Find paths between the subject and historical item."""
+        if by_intern_id:
+            for i, ref_list in enumerate(subject_item.ref_ids):
+                for ref in ref_list:
+                    self._find_paths_by_intern_id(i, ref, hist_item, paths)
+        else:
+            for i, (ref_src_list, ref_list) in enumerate(zip(subject_item.ref_src_ids,
+                                                             subject_item.ref_ids)):
+                for src_ref, ref in zip(ref_src_list, ref_list):
+                    self._find_paths_by_src(i, src_ref, ref, hist_item, paths)
+
+    def _find_paths_by_intern_id(self, subject_ridx, ref_id, hist_item, paths):
+        """Find paths by internal reference ID, a bit faster."""
+        if self._same_relation:
+            if ref_id in hist_item.ref_ids[subject_ridx]:
+                relation_id = self._ridx_to_relation_id(subject_ridx)
+                paths.append((relation_id,
+                              ref_id,
+                              relation_id,
+                              hist_item.intern_id))
+        else:
+            for hist_ridx, hist_ref_list in enumerate(hist_item.ref_ids):
+                if ref_id in hist_ref_list:
+                    paths.append((self._ridx_to_relation_id(subject_ridx),
+                                  ref_id,
+                                  self._ridx_to_relation_id(hist_ridx),
+                                  hist_item.intern_id))
+
+    def _find_paths_by_src(self, subject_ridx, ref_src_id, ref_id, hist_item, paths):
+        """Find paths by source reference ID."""
+        if self._same_relation:
+            if ref_src_id in hist_item.ref_src_ids[subject_ridx]:
+                relation_id = self._ridx_to_relation_id(subject_ridx)
+                paths.append((relation_id,
+                              ref_id,
+                              relation_id,
+                              hist_item.intern_id))
+        else:
+            for hist_ridx, hist_ref_src_list in enumerate(hist_item.ref_src_ids):
+                if ref_src_id in hist_ref_src_list:
+                    paths.append((self._ridx_to_relation_id(subject_ridx),
+                                  ref_id,
+                                  self._ridx_to_relation_id(hist_ridx),
+                                  hist_item.intern_id))
+
+    def _ridx_to_relation_id(self, idx):
+        """Relation index to id."""
+        return idx
+
+    def _to_relation_id_map(self, header, col_indices):
+        """Convert input csv header to a relation id map."""
+        id_map = {}
+        id_counter = 0
+        for i in range(3, len(col_indices)):
+            id_map[header[col_indices[i]]] = id_counter
+            id_counter += 1
+        if len(id_map) < len(header) - 3:
+            raise IOError("Duplicated column!")
+        return id_map
+
+    def _to_item_rec(self, item_src, ref_srcs):
+        """Convert the item src id and the source reference to an item record."""
+        item_id = self._item_id_map.get(item_src, -1)
+        if item_id == -1:
+            if not self.grow_id_maps:
+                item_id = 0
+                self._unseen_items += 1
+            else:
+                item_id = self._entity_id_counter
+                self._item_id_map[item_src] = item_id
+                self._entity_id_counter += 1
+
+        has_unseen_ref = False
+        ref_ids = [[] for _ in range(len(ref_srcs))]
+        for i, ref_src_list in enumerate(ref_srcs):
+            for ref_src in ref_src_list:
+                ref_id = self._ref_id_map.get(ref_src, -1)
+                if ref_id == -1:
+                    if not self.grow_id_maps:
+                        ref_id = 0
+                        self._unseen_refs += 1
+                        has_unseen_ref = True
+                    else:
+                        ref_id = self._entity_id_counter
+                        self._ref_id_map[ref_src] = ref_id
+                        self._entity_id_counter += 1
+                ref_ids[i].append(ref_id)
+
+        return _ItemRec(item_src, item_id, ref_srcs, ref_ids), has_unseen_ref
+
+    def _get_col_indices(self, header, col_map):
+        """Find the column indices base on the mapping."""
+        if col_map:
+            mapped = [col_map[col] for col in header]
+            default_header = self._default_abstract_header(len(header) - 3)
+            return [mapped.index(col) for col in default_header]
+        return range(len(header))
+
+    @staticmethod
+    def _read_header(in_csv):
+        """Read the CSV header."""
+        line = next(in_csv)
+        splited = list(map(lambda x: x.strip(), line))
+        return splited
+
+    @staticmethod
+    def _default_abstract_header(num_relation):
+        """Get the default abstract header."""
+        abstract_header = ["user", "item", "rating"]
+        abstract_header.extend([f"r{i + 1}" for i in num_relation])
+        return abstract_header
diff --git a/official/recommend/tbnet/src/tbnet.py b/official/recommend/tbnet/src/tbnet.py
index c06166073f384300b3d3c45ceb187baf43ef7123..bc67715b8c6ed425dc57f3366b642e13fc88232d 100644
--- a/official/recommend/tbnet/src/tbnet.py
+++ b/official/recommend/tbnet/src/tbnet.py
@@ -14,16 +14,25 @@
 # ============================================================================
 """TB-Net Model."""
 
-from mindspore import nn
+from mindspore import nn, Tensor
 from mindspore import ParameterTuple
 from mindspore.ops import operations as P
+from mindspore.ops import functional as F
 from mindspore.ops import composite as C
 from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
 from mindspore.context import ParallelMode
 from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
+import mindspore.common.dtype as mstype
 
 from src.embedding import EmbeddingMatrix
 
+grad_scale = C.MultitypeFuncGraph("grad_scale")
+
+
+@grad_scale.register("Tensor", "Tensor")
+def gradient_scale(scale, grad):
+    return grad * F.cast(scale, F.dtype(grad))
+
 
 class TBNet(nn.Cell):
     """
@@ -68,8 +77,8 @@ class TBNet(nn.Cell):
     def _parse_config(self, config):
         """Argument parsing."""
 
-        self.num_entity = config.num_entity
-        self.num_relation = config.num_relation
+        self.num_entity = config.num_items + config.num_references + 1
+        self.num_relation = config.num_relations
         self.dim = config.embedding_dim
         self.kge_weight = config.kge_weight
         self.node_weight = config.node_weight
@@ -279,7 +288,7 @@ class NetWithLossClass(nn.Cell):
 class TrainStepWrap(nn.Cell):
     """TrainStepWrap definition."""
 
-    def __init__(self, network, lr, sens=1):
+    def __init__(self, network, lr, sens=1, loss_scale=False):
         super(TrainStepWrap, self).__init__(auto_prefix=False)
         self.network = network
         self.network.set_train()
@@ -294,11 +303,13 @@ class TrainStepWrap(nn.Cell):
                                  loss_scale=sens)
 
         self.hyper_map = C.HyperMap()
+        self.reciprocal_sense = Tensor(1 / sens, mstype.float32)
         self.grad = C.GradOperation(get_by_list=True, sens_param=True)
         self.sens = sens
 
         self.reducer_flag = False
         self.grad_reducer = None
+        self.loss_scale = loss_scale
         parallel_mode = _get_parallel_mode()
         if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
             self.reducer_flag = True
@@ -307,6 +318,10 @@ class TrainStepWrap(nn.Cell):
             degree = _get_device_num()
             self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
 
+    def scale_grad(self, gradients):
+        gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_sense), gradients)
+        return gradients
+
     def construct(self, items, relation1, mid_entity, relation2, hist_item, labels):
         """
         Args:
@@ -325,11 +340,14 @@ class TrainStepWrap(nn.Cell):
         sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
         grads = self.grad(self.network, weights)(items, relation1, mid_entity, relation2, hist_item, labels, sens)
 
+        if self.loss_scale:
+            grads = self.scale_grad(grads)
+
         if self.reducer_flag:
             # apply grad reducer on grads
             grads = self.grad_reducer(grads)
-        self.optimizer(grads)
-        return loss
+
+        return F.depend(loss, self.optimizer(grads))
 
 
 class PredictWithSigmoid(nn.Cell):
diff --git a/official/recommend/tbnet/train.py b/official/recommend/tbnet/train.py
index 10d6d2f4eb40b9aec3ebd4fff316f129c7d5948b..eda71b1b62ae5b67fff3938be39d19ee1343bb7b 100644
--- a/official/recommend/tbnet/train.py
+++ b/official/recommend/tbnet/train.py
@@ -16,11 +16,13 @@
 
 import os
 import argparse
+import math
 
 import numpy as np
 from mindspore import context, Model, Tensor
 from mindspore.train.serialization import save_checkpoint
 from mindspore.train.callback import Callback, TimeMonitor
+import mindspore.common.dtype as mstype
 
 from src import tbnet, config, metrics, dataset
 
@@ -104,8 +106,8 @@ def get_args():
         type=str,
         required=False,
         default='GPU',
-        choices=['GPU'],
-        help="run code on GPU"
+        choices=['GPU', 'Ascend'],
+        help="run code on GPU or Ascend NPU"
     )
 
     parser.add_argument(
@@ -141,13 +143,21 @@ def train_tbnet():
 
     print(f"creating dataset from {train_csv_path}...")
     net_config = config.TBNetConfig(config_path)
-    train_ds = dataset.create(train_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
-    test_ds = dataset.create(test_csv_path, net_config.per_item_num_paths, train=True).batch(net_config.batch_size)
+    if args.device_target == 'Ascend':
+        net_config.per_item_paths = math.ceil(net_config.per_item_paths / 16) * 16
+        net_config.embedding_dim = math.ceil(net_config.embedding_dim / 16) * 16
+    train_ds = dataset.create(train_csv_path, net_config.per_item_paths, train=True).batch(net_config.batch_size)
+    test_ds = dataset.create(test_csv_path, net_config.per_item_paths, train=True).batch(net_config.batch_size)
 
     print("creating TBNet for training...")
     network = tbnet.TBNet(net_config)
     loss_net = tbnet.NetWithLossClass(network, net_config)
-    train_net = tbnet.TrainStepWrap(loss_net, net_config.lr)
+    if args.device_target == 'Ascend':
+        loss_net.to_float(mstype.float16)
+        train_net = tbnet.TrainStepWrap(loss_net, net_config.lr, loss_scale=True)
+    else:
+        train_net = tbnet.TrainStepWrap(loss_net, net_config.lr)
+
     train_net.set_train()
     eval_net = tbnet.PredictWithSigmoid(network)
     time_callback = TimeMonitor(data_size=train_ds.get_dataset_size())
@@ -161,7 +171,8 @@ def train_tbnet():
         test_out = model.eval(test_ds, dataset_sink_mode=False)
         print(f'Train AUC:{train_out["auc"]} ACC:{train_out["acc"]}  Test AUC:{test_out["auc"]} ACC:{test_out["acc"]}')
 
-        save_checkpoint(network, os.path.join(ckpt_path, f'tbnet_epoch{i}.ckpt'))
+        if i >= args.epochs-5:
+            save_checkpoint(network, os.path.join(ckpt_path, f'tbnet_epoch{i}.ckpt'))
 
 
 if __name__ == '__main__':