Skip to content
Snippets Groups Projects
Commit ff84c5a0 authored by anzhengqi's avatar anzhengqi
Browse files

modify tinybert and ctpn scripts

parent a13e11ac
No related branches found
No related tags found
No related merge requests found
...@@ -52,12 +52,11 @@ cp -r ./src ./train ...@@ -52,12 +52,11 @@ cp -r ./src ./train
cd ./train || exit cd ./train || exit
export DEVICE_NUM=1 export DEVICE_NUM=1
export DEVICE_ID=$3
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
echo "start training for device $DEVICE_ID" echo "start training for device $3"
export CUDA_VISIBLE_DEVICES=$DEVICE_ID export CUDA_VISIBLE_DEVICES=$3
env > env.log env > env.log
python train.py --task_type=$TASK_TYPE --pre_trained=$PRETRAINED_PATH --device_target="GPU" &> log & python train.py --task_type=$TASK_TYPE --pre_trained=$PRETRAINED_PATH --device_target="GPU" &> log &
cd .. cd ..
...@@ -98,6 +98,7 @@ class EvalCallBack(Callback): ...@@ -98,6 +98,7 @@ class EvalCallBack(Callback):
input_ids, input_mask, token_type_id, label_ids = input_data input_ids, input_mask, token_type_id, label_ids = input_data
self.network.set_train(False) self.network.set_train(False)
logits = self.network(input_ids, token_type_id, input_mask) logits = self.network(input_ids, token_type_id, input_mask)
self.network.set_train(True)
callback.update(logits, label_ids) callback.update(logits, label_ids)
acc = callback.acc_num / callback.total_num acc = callback.acc_num / callback.total_num
with open("./eval.log", "a+") as f: with open("./eval.log", "a+") as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment