diff --git a/official/recommend/tbnet/README.md b/official/recommend/tbnet/README.md
index bd734d3d31b126bad86ba5df3329e79e005561c5..84213e62dfb3a1109b37783a82032db546b0746b 100644
--- a/official/recommend/tbnet/README.md
+++ b/official/recommend/tbnet/README.md
@@ -82,7 +82,6 @@ Download the data package(e.g. 'steam' dataset) and put it underneath the curren
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
-cd scripts
```
and then run code as follows.
@@ -90,13 +89,13 @@ and then run code as follows.
- Training
```bash
-bash run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
+bash scripts/run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
Example:
```bash
-bash run_standalone_train.sh steam 0 Ascend
+bash scripts/run_standalone_train.sh steam 0 Ascend
```
- Evaluation
@@ -104,7 +103,7 @@ bash run_standalone_train.sh steam 0 Ascend
Evaluation model on test dataset.
```bash
-bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
+bash scripts/run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
Argument `[CHECKPOINT_ID]` is required.
@@ -112,7 +111,7 @@ Argument `[CHECKPOINT_ID]` is required.
Example:
```bash
-bash run_eval.sh 19 steam 0 Ascend
+bash scripts/run_eval.sh 19 steam 0 Ascend
```
- Inference and Explanation
@@ -281,6 +280,7 @@ Before performing inference, the mindir file must be exported by `export.py` scr
```shell
# Ascend310 inference
+cd scripts
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
@@ -291,6 +291,7 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
Example:
```bash
+cd scripts
bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
```
diff --git a/official/recommend/tbnet/README_CN.md b/official/recommend/tbnet/README_CN.md
index c7102970795ff04567e97cbc65419baa80e99bc0..dd6919f88e0e2c7af7bc419bf80ca42a5d5b5e70 100644
--- a/official/recommend/tbnet/README_CN.md
+++ b/official/recommend/tbnet/README_CN.md
@@ -77,7 +77,6 @@ TB-Net将用户和物品的交互信息以及物品的属性信息在知识图
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
-cd scripts
```
然后按照以下步骤运行代码。
@@ -85,13 +84,13 @@ cd scripts
- 训练
```bash
-bash run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
+bash scripts/run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
示例:
```bash
-bash run_standalone_train.sh steam 0 Ascend
+bash scripts/run_standalone_train.sh steam 0 Ascend
```
- 评估
@@ -99,7 +98,7 @@ bash run_standalone_train.sh steam 0 Ascend
评估模型在测试集上的指标。
```bash
-bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
+bash scripts/run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
参数`[CHECKPOINT_ID]`是必填项。
@@ -107,7 +106,7 @@ bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
示例:
```bash
-bash run_eval.sh 19 steam 0 Ascend
+bash scripts/run_eval.sh 19 steam 0 Ascend
```
- 推理和解释
@@ -282,6 +281,7 @@ python export.py \
```shell
# Ascend310 inference
+cd scripts
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
@@ -292,6 +292,7 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
示例:
```bash
+cd scripts
bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
```
diff --git a/official/recommend/tbnet/scripts/run_eval.sh b/official/recommend/tbnet/scripts/run_eval.sh
index 28d78e07139d8bea5eca1519277ae6b4c095005e..45d4060f7ea9b21631042c3a6c8ddffe3016533d 100644
--- a/official/recommend/tbnet/scripts/run_eval.sh
+++ b/official/recommend/tbnet/scripts/run_eval.sh
@@ -14,9 +14,10 @@
# limitations under the License.
# ============================================================================
if [[ $# -lt 3 || $# -gt 4 ]]; then
- echo "Usage: bash run_train.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
+ echo "Usage: bash run_train.sh [CHECKPOINT_ID] [DATA_NAME] [CUDA_VISIBLE_DEVICES]/[DEVICE_ID] [DEVICE_TARGET]
CHECKPOINT_ID means model checkpoint id.
DATA_NAME means dataset name, it's value is 'steam'.
+ CUDA_VISIBLE_DEVICES means cuda visible device id.
DEVICE_ID means device id, it can be set by environment variable DEVICE_ID.
DEVICE_TARGET is optional, it's value is ['GPU', 'Ascend'], default 'GPU'."
exit 1
@@ -24,13 +25,22 @@ fi
CHECKPOINT_ID=$1
DATA_NAME=$2
-DEVICE_ID=$3
DEVICE_TARGET='GPU'
if [ $# == 4 ]; then
DEVICE_TARGET=$4
fi
-cd ..
-python eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
- --device_id $DEVICE_ID &> scripts/eval_standalone_gpu_log &
\ No newline at end of file
+if [ "$DEVICE_TARGET" = "GPU" ];
+then
+ export CUDA_VISIBLE_DEVICES=$3
+ python eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
+ --device_id 0 &> scripts/eval_standalone_gpu_log &
+fi
+
+if [ "$DEVICE_TARGET" = "Ascend" ];
+then
+ export DEVICE_ID=$3
+ python eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
+ --device_id $DEVICE_ID &> scripts/eval_standalone_gpu_log &
+fi
\ No newline at end of file
diff --git a/official/recommend/tbnet/scripts/run_standalone_train.sh b/official/recommend/tbnet/scripts/run_standalone_train.sh
index 8699b0b00a36dcc95966cbe11d731797726bf030..aaa8fc6d94dcd59cf77b1cc034632bc9de4cfc8b 100644
--- a/official/recommend/tbnet/scripts/run_standalone_train.sh
+++ b/official/recommend/tbnet/scripts/run_standalone_train.sh
@@ -14,21 +14,31 @@
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
- echo "Usage: bash run_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
+ echo "Usage: bash run_train.sh [DATA_NAME] [CUDA_VISIBLE_DEVICES]/[DEVICE_ID] [DEVICE_TARGET]
DATA_NAME means dataset name, it's value is 'steam'.
+ CUDA_VISIBLE_DEVICES means cuda visible device id.
DEVICE_ID means device id, it can be set by environment variable DEVICE_ID.
DEVICE_TARGET is optional, it's value is ['GPU', 'Ascend'], default 'GPU'."
exit 1
fi
DATA_NAME=$1
-DEVICE_ID=$2
-
DEVICE_TARGET='GPU'
+
if [ $# == 3 ]; then
DEVICE_TARGET=$3
fi
-cd ..
-python preprocess_dataset.py --dataset $DATA_NAME --device_target $DEVICE_TARGET &> scripts/train_standalone_log &&
-python train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id $DEVICE_ID &>> scripts/train_standalone_log &
\ No newline at end of file
+python preprocess_dataset.py --dataset $DATA_NAME --device_target $DEVICE_TARGET &> scripts/train_standalone_log &
+
+if [ "$DEVICE_TARGET" = "GPU" ];
+then
+ export CUDA_VISIBLE_DEVICES=$2
+ python train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id 0 &>> scripts/train_standalone_log &
+fi
+
+if [ "$DEVICE_TARGET" = "Ascend" ];
+then
+ export DEVICE_ID=$2
+ python train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id $DEVICE_ID &>> scripts/train_standalone_log &
+fi
\ No newline at end of file