Skip to content
Snippets Groups Projects
Commit dbe5a1c2 authored by lizhenyu's avatar lizhenyu
Browse files

change wide_and_deep scipt for embedding cache mode

parent c1bb6424
No related branches found
No related tags found
No related merge requests found
......@@ -60,25 +60,15 @@ do
done
export MS_ROLE=MS_WORKER
if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then
rm -rf ${execute_path}/worker/
mkdir ${execute_path}/worker/
cd ${execute_path}/worker/ || exit
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker.log 2>&1 &
else
for((i=0;i<$MS_WORKER_NUM;i++));
do
rm -rf ${execute_path}/worker_$i/
mkdir ${execute_path}/worker_$i/
cd ${execute_path}/worker_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker_$i.log 2>&1 &
done
fi
for((i=0;i<$MS_WORKER_NUM;i++));
do
rm -rf ${execute_path}/worker_$i/
mkdir ${execute_path}/worker_$i/
cd ${execute_path}/worker_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker_$i.log 2>&1 &
done
......@@ -92,7 +92,7 @@ def train_and_eval(config):
if cache_enable:
config.full_batch = True
print("epochs is {}".format(epochs))
if config.full_batch:
if config.full_batch and os.getenv("MS_ROLE") == "MS_WORKER":
context.set_auto_parallel_context(full_batch=True)
ds.config.set_seed(1)
ds_train = create_dataset(data_path, train_mode=True,
......@@ -160,8 +160,9 @@ def train_wide_and_deep():
context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))
if cache_enable:
context.set_auto_parallel_context(
parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
if os.getenv("MS_ROLE") == "MS_WORKER":
context.set_auto_parallel_context(
parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
else:
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=get_group_size())
......
......@@ -21,6 +21,7 @@ from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.common import set_seed
from mindspore.communication.management import init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset, DataType
......@@ -125,12 +126,14 @@ cache_enable = cfg.vocab_cache_size > 0
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_wide_and_deep():
""" train_wide_and_deep """
context.set_ps_context(enable_ps=True)
init()
if not cache_enable:
cfg.sparse = True
if cfg.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
context.set_ps_context(enable_ps=True)
train_and_eval(cfg)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment