Skip to content
Snippets Groups Projects
Commit c7456963 authored by lihaoyang's avatar lihaoyang
Browse files

fix bug in tbnet shell script, GPU version (#I5MZ3S)

parent 9f522f37
No related branches found
No related tags found
No related merge requests found
......@@ -82,7 +82,6 @@ Download the data package(e.g. 'steam' dataset) and put it underneath the curren
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
cd scripts
```
and then run code as follows.
......@@ -90,13 +89,13 @@ and then run code as follows.
- Training
```bash
bash run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
bash scripts/run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
Example:
```bash
bash run_standalone_train.sh steam 0 Ascend
bash scripts/run_standalone_train.sh steam 0 Ascend
```
- Evaluation
......@@ -104,7 +103,7 @@ bash run_standalone_train.sh steam 0 Ascend
Evaluation model on test dataset.
```bash
bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
bash scripts/run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
Argument `[CHECKPOINT_ID]` is required.
......@@ -112,7 +111,7 @@ Argument `[CHECKPOINT_ID]` is required.
Example:
```bash
bash run_eval.sh 19 steam 0 Ascend
bash scripts/run_eval.sh 19 steam 0 Ascend
```
- Inference and Explanation
......@@ -281,6 +280,7 @@ Before performing inference, the mindir file must be exported by `export.py` scr
```shell
# Ascend310 inference
cd scripts
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
......@@ -291,6 +291,7 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
Example:
```bash
cd scripts
bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
```
......
......@@ -77,7 +77,6 @@ TB-Net将用户和物品的交互信息以及物品的属性信息在知识图
```bash
wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz
tar -xf tbnet_data.tar.gz
cd scripts
```
然后按照以下步骤运行代码。
......@@ -85,13 +84,13 @@ cd scripts
- 训练
```bash
bash run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
bash scripts/run_standalone_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
示例:
```bash
bash run_standalone_train.sh steam 0 Ascend
bash scripts/run_standalone_train.sh steam 0 Ascend
```
- 评估
......@@ -99,7 +98,7 @@ bash run_standalone_train.sh steam 0 Ascend
评估模型在测试集上的指标。
```bash
bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
bash scripts/run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
```
参数`[CHECKPOINT_ID]`是必填项。
......@@ -107,7 +106,7 @@ bash run_eval.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
示例:
```bash
bash run_eval.sh 19 steam 0 Ascend
bash scripts/run_eval.sh 19 steam 0 Ascend
```
- 推理和解释
......@@ -282,6 +281,7 @@ python export.py \
```shell
# Ascend310 inference
cd scripts
bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
```
......@@ -292,6 +292,7 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID]
示例:
```bash
cd scripts
bash run_infer_310.sh ../model.mindir ../data/steam/test.csv 0
```
......
......@@ -14,9 +14,10 @@
# limitations under the License.
# ============================================================================
if [[ $# -lt 3 || $# -gt 4 ]]; then
echo "Usage: bash run_train.sh [CHECKPOINT_ID] [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
echo "Usage: bash run_train.sh [CHECKPOINT_ID] [DATA_NAME] [CUDA_VISIBLE_DEVICES]/[DEVICE_ID] [DEVICE_TARGET]
CHECKPOINT_ID means model checkpoint id.
DATA_NAME means dataset name, it's value is 'steam'.
CUDA_VISIBLE_DEVICES means cuda visible device id.
DEVICE_ID means device id, it can be set by environment variable DEVICE_ID.
DEVICE_TARGET is optional, it's value is ['GPU', 'Ascend'], default 'GPU'."
exit 1
......@@ -24,13 +25,22 @@ fi
CHECKPOINT_ID=$1
DATA_NAME=$2
DEVICE_ID=$3
DEVICE_TARGET='GPU'
if [ $# == 4 ]; then
DEVICE_TARGET=$4
fi
cd ..
python eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
--device_id $DEVICE_ID &> scripts/eval_standalone_gpu_log &
\ No newline at end of file
if [ "$DEVICE_TARGET" = "GPU" ];
then
export CUDA_VISIBLE_DEVICES=$3
python eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
--device_id 0 &> scripts/eval_standalone_gpu_log &
fi
if [ "$DEVICE_TARGET" = "Ascend" ];
then
export DEVICE_ID=$3
python eval.py --checkpoint_id $CHECKPOINT_ID --dataset $DATA_NAME --device_target $DEVICE_TARGET \
--device_id $DEVICE_ID &> scripts/eval_standalone_gpu_log &
fi
\ No newline at end of file
......@@ -14,21 +14,31 @@
# limitations under the License.
# ============================================================================
if [[ $# -lt 2 || $# -gt 3 ]]; then
echo "Usage: bash run_train.sh [DATA_NAME] [DEVICE_ID] [DEVICE_TARGET]
echo "Usage: bash run_train.sh [DATA_NAME] [CUDA_VISIBLE_DEVICES]/[DEVICE_ID] [DEVICE_TARGET]
DATA_NAME means dataset name, it's value is 'steam'.
CUDA_VISIBLE_DEVICES means cuda visible device id.
DEVICE_ID means device id, it can be set by environment variable DEVICE_ID.
DEVICE_TARGET is optional, it's value is ['GPU', 'Ascend'], default 'GPU'."
exit 1
fi
DATA_NAME=$1
DEVICE_ID=$2
DEVICE_TARGET='GPU'
if [ $# == 3 ]; then
DEVICE_TARGET=$3
fi
cd ..
python preprocess_dataset.py --dataset $DATA_NAME --device_target $DEVICE_TARGET &> scripts/train_standalone_log &&
python train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id $DEVICE_ID &>> scripts/train_standalone_log &
\ No newline at end of file
python preprocess_dataset.py --dataset $DATA_NAME --device_target $DEVICE_TARGET &> scripts/train_standalone_log &
if [ "$DEVICE_TARGET" = "GPU" ];
then
export CUDA_VISIBLE_DEVICES=$2
python train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id 0 &>> scripts/train_standalone_log &
fi
if [ "$DEVICE_TARGET" = "Ascend" ];
then
export DEVICE_ID=$2
python train.py --dataset $DATA_NAME --device_target $DEVICE_TARGET --device_id $DEVICE_ID &>> scripts/train_standalone_log &
fi
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment