Skip to content
Snippets Groups Projects
Commit a8687103 authored by zhaoting's avatar zhaoting
Browse files

adapt 1p with mobilenetv3

parent f10e2d71
No related branches found
No related tags found
No related merge requests found
......@@ -46,11 +46,20 @@ run_gpu()
cd ../train || exit
export CUDA_VISIBLE_DEVICES="$3"
mpirun -n $2 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python ${BASEPATH}/../train.py \
if [ $2 -eq 1 ] ; then
python ${BASEPATH}/../train.py \
--dataset_path=$4 \
--device_target=$1 \
&> ../train.log & # dataset train folder
--run_distribute=False \
&> ../train.log &
else
mpirun -n $2 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python ${BASEPATH}/../train.py \
--dataset_path=$4 \
--device_target=$1 \
--run_distribute=True \
&> ../train.log & # dataset train folder
fi;
}
run_cpu()
......
......@@ -49,7 +49,7 @@ parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--device_target', type=str, default="GPU", help='run device_target')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
args_opt = parser.parse_args()
if args_opt.device_target == "GPU":
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment