diff --git a/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh b/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh
index 56ae20a185d2462e5f1e6897b38e35e9c1bd8329..3277b90c9e671f13272aa8a600f91f9a47a504a8 100644
--- a/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh
+++ b/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh
@@ -60,25 +60,15 @@ do
done
export MS_ROLE=MS_WORKER
-if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then
- rm -rf ${execute_path}/worker/
- mkdir ${execute_path}/worker/
- cd ${execute_path}/worker/ || exit
- mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
- python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
- --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
- --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker.log 2>&1 &
-else
- for((i=0;i<$MS_WORKER_NUM;i++));
- do
- rm -rf ${execute_path}/worker_$i/
- mkdir ${execute_path}/worker_$i/
- cd ${execute_path}/worker_$i/ || exit
- export RANK_ID=$i
- export DEVICE_ID=$i
- python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
- --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
- --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker_$i.log 2>&1 &
- done
-fi
+for((i=0;i<$MS_WORKER_NUM;i++));
+do
+ rm -rf ${execute_path}/worker_$i/
+ mkdir ${execute_path}/worker_$i/
+ cd ${execute_path}/worker_$i/ || exit
+ export RANK_ID=$i
+ export DEVICE_ID=$i
+ python -s ${self_path}/../train_and_eval_parameter_server_distribute.py \
+ --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
+ --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker_$i.log 2>&1 &
+done
diff --git a/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py b/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py
index 0cffc8d1ae1f60a57473725b1d1d94ea29c7f809..639d29b75126be458444527429df8478a92328fc 100644
--- a/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py
+++ b/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py
@@ -92,7 +92,7 @@ def train_and_eval(config):
if cache_enable:
config.full_batch = True
print("epochs is {}".format(epochs))
- if config.full_batch:
+ if config.full_batch and os.getenv("MS_ROLE") == "MS_WORKER":
context.set_auto_parallel_context(full_batch=True)
ds.config.set_seed(1)
ds_train = create_dataset(data_path, train_mode=True,
@@ -160,8 +160,9 @@ def train_wide_and_deep():
context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))
if cache_enable:
- context.set_auto_parallel_context(
- parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
+ if os.getenv("MS_ROLE") == "MS_WORKER":
+ context.set_auto_parallel_context(
+ parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
else:
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=get_group_size())
diff --git a/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py b/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py
index f30263f7e522cb63ad46ed69d17d5c97ae58bfb5..ddb6f6c8bf5d41e9cf3aa00d4f51523b6f8dd05c 100644
--- a/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py
+++ b/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py
@@ -21,6 +21,7 @@ from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.common import set_seed
+from mindspore.communication.management import init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset, DataType
@@ -125,12 +126,14 @@ cache_enable = cfg.vocab_cache_size > 0
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_wide_and_deep():
""" train_wide_and_deep """
+ context.set_ps_context(enable_ps=True)
+ init()
+
if not cache_enable:
cfg.sparse = True
if cfg.device_target == "GPU":
context.set_context(enable_graph_kernel=True)
context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
- context.set_ps_context(enable_ps=True)
train_and_eval(cfg)