Skip to content
Snippets Groups Projects
Commit baf9f58e authored by chenweitao_295's avatar chenweitao_295
Browse files

fix shufflenetv1 gpu bug

parent 9c321da6
No related branches found
No related tags found
No related merge requests found
...@@ -25,7 +25,7 @@ keep_checkpoint_max: 5 ...@@ -25,7 +25,7 @@ keep_checkpoint_max: 5
save_ckpt_path: "./" save_ckpt_path: "./"
save_checkpoint_epochs: 1 save_checkpoint_epochs: 1
save_checkpoint: True save_checkpoint: True
amp_level: "O2" amp_level: "O0"
is_distributed: False is_distributed: False
train_dataset_path: "" train_dataset_path: ""
resume: "" resume: ""
......
...@@ -30,11 +30,11 @@ export CUDA_VISIBLE_DEVICES=$DEVICE_ID ...@@ -30,11 +30,11 @@ export CUDA_VISIBLE_DEVICES=$DEVICE_ID
BASEPATH=$(cd "`dirname $0`" || exit; pwd) BASEPATH=$(cd "`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASEPATH}/../gpu_default_config.yaml" CONFIG_FILE="${BASEPATH}/../gpu_default_config.yaml"
if [ -d "../eval" ]; then if [ -d "eval" ]; then
rm -rf ../eval rm -rf eval
fi fi
mkdir ../eval mkdir eval
cd ../eval || exit cd eval || exit
python ${BASEPATH}/../eval.py \ python ${BASEPATH}/../eval.py \
--config_path=$CONFIG_FILE \ --config_path=$CONFIG_FILE \
......
...@@ -58,6 +58,8 @@ def train(): ...@@ -58,6 +58,8 @@ def train():
group_size = 1 group_size = 1
context.set_context(device_id=config.device_id) context.set_context(device_id=config.device_id)
if config.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
# define network # define network
net = ShuffleNetV1(model_size=config.model_size, n_class=config.num_classes) net = ShuffleNetV1(model_size=config.model_size, n_class=config.num_classes)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment