Skip to content
Snippets Groups Projects
Commit dbe5a1c2 authored by lizhenyu's avatar lizhenyu
Browse files

change wide_and_deep scipt for embedding cache mode

parent c1bb6424
No related branches found
No related tags found
No related merge requests found
...@@ -60,15 +60,6 @@ do ...@@ -60,15 +60,6 @@ do
done done
export MS_ROLE=MS_WORKER export MS_ROLE=MS_WORKER
if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then
rm -rf ${execute_path}/worker/
mkdir ${execute_path}/worker/
cd ${execute_path}/worker/ || exit
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker.log 2>&1 &
else
for((i=0;i<$MS_WORKER_NUM;i++)); for((i=0;i<$MS_WORKER_NUM;i++));
do do
rm -rf ${execute_path}/worker_$i/ rm -rf ${execute_path}/worker_$i/
...@@ -80,5 +71,4 @@ else ...@@ -80,5 +71,4 @@ else
--device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \ --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
--vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker_$i.log 2>&1 & --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker_$i.log 2>&1 &
done done
fi
...@@ -92,7 +92,7 @@ def train_and_eval(config): ...@@ -92,7 +92,7 @@ def train_and_eval(config):
if cache_enable: if cache_enable:
config.full_batch = True config.full_batch = True
print("epochs is {}".format(epochs)) print("epochs is {}".format(epochs))
if config.full_batch: if config.full_batch and os.getenv("MS_ROLE") == "MS_WORKER":
context.set_auto_parallel_context(full_batch=True) context.set_auto_parallel_context(full_batch=True)
ds.config.set_seed(1) ds.config.set_seed(1)
ds_train = create_dataset(data_path, train_mode=True, ds_train = create_dataset(data_path, train_mode=True,
...@@ -160,6 +160,7 @@ def train_wide_and_deep(): ...@@ -160,6 +160,7 @@ def train_wide_and_deep():
context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank())) context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))
if cache_enable: if cache_enable:
if os.getenv("MS_ROLE") == "MS_WORKER":
context.set_auto_parallel_context( context.set_auto_parallel_context(
parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True) parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
else: else:
......
...@@ -21,6 +21,7 @@ from mindspore import Model, context ...@@ -21,6 +21,7 @@ from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore.communication.management import init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset, DataType from src.datasets import create_dataset, DataType
...@@ -125,12 +126,14 @@ cache_enable = cfg.vocab_cache_size > 0 ...@@ -125,12 +126,14 @@ cache_enable = cfg.vocab_cache_size > 0
@moxing_wrapper(pre_process=modelarts_pre_process) @moxing_wrapper(pre_process=modelarts_pre_process)
def train_wide_and_deep(): def train_wide_and_deep():
""" train_wide_and_deep """ """ train_wide_and_deep """
context.set_ps_context(enable_ps=True)
init()
if not cache_enable: if not cache_enable:
cfg.sparse = True cfg.sparse = True
if cfg.device_target == "GPU": if cfg.device_target == "GPU":
context.set_context(enable_graph_kernel=True) context.set_context(enable_graph_kernel=True)
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul") context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
context.set_ps_context(enable_ps=True)
train_and_eval(cfg) train_and_eval(cfg)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment