diff --git a/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh b/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh
index 56ae20a185d2462e5f1e6897b38e35e9c1bd8329..3277b90c9e671f13272aa8a600f91f9a47a504a8 100644
--- a/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh
+++ b/official/recommend/wide_and_deep/script/run_parameter_server_train_distribute.sh
@@ -60,25 +60,15 @@ do
 done
 
 export MS_ROLE=MS_WORKER
-if [[ "X$DEVICE_TARGET" == "XGPU" ]]; then
-  rm -rf ${execute_path}/worker/
-  mkdir ${execute_path}/worker/
-  cd ${execute_path}/worker/ || exit
-  mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout      \
-    python -s ${self_path}/../train_and_eval_parameter_server_distribute.py                           \
-      --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1   \
-      --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker.log 2>&1 &
-else
-  for((i=0;i<$MS_WORKER_NUM;i++));
-  do
-    rm -rf ${execute_path}/worker_$i/
-    mkdir ${execute_path}/worker_$i/
-    cd ${execute_path}/worker_$i/ || exit
-    export RANK_ID=$i
-    export DEVICE_ID=$i
-    python -s ${self_path}/../train_and_eval_parameter_server_distribute.py                         \
-      --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
-      --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker_$i.log 2>&1 &
-  done
-fi
+for((i=0;i<$MS_WORKER_NUM;i++));
+do
+  rm -rf ${execute_path}/worker_$i/
+  mkdir ${execute_path}/worker_$i/
+  cd ${execute_path}/worker_$i/ || exit
+  export RANK_ID=$i
+  export DEVICE_ID=$i
+  python -s ${self_path}/../train_and_eval_parameter_server_distribute.py                         \
+    --device_target=$DEVICE_TARGET --data_path=$DATASET --epochs=$EPOCH_SIZE --parameter_server=1 \
+    --vocab_cache_size=$VOCAB_CACHE_SIZE --sparse=$SPARSE --dropout_flag=True >worker_$i.log 2>&1 &
+done
 
diff --git a/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py b/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py
index 0cffc8d1ae1f60a57473725b1d1d94ea29c7f809..639d29b75126be458444527429df8478a92328fc 100644
--- a/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py
+++ b/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py
@@ -92,7 +92,7 @@ def train_and_eval(config):
     if cache_enable:
         config.full_batch = True
     print("epochs is {}".format(epochs))
-    if config.full_batch:
+    if config.full_batch and os.getenv("MS_ROLE") == "MS_WORKER":
         context.set_auto_parallel_context(full_batch=True)
         ds.config.set_seed(1)
         ds_train = create_dataset(data_path, train_mode=True,
@@ -160,8 +160,9 @@ def train_wide_and_deep():
     context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank()))
 
     if cache_enable:
-        context.set_auto_parallel_context(
-            parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
+        if os.getenv("MS_ROLE") == "MS_WORKER":
+            context.set_auto_parallel_context(
+                parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
     else:
         context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
                                           device_num=get_group_size())
diff --git a/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py b/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py
index f30263f7e522cb63ad46ed69d17d5c97ae58bfb5..ddb6f6c8bf5d41e9cf3aa00d4f51523b6f8dd05c 100644
--- a/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py
+++ b/official/recommend/wide_and_deep/train_and_eval_parameter_server_standalone.py
@@ -21,6 +21,7 @@ from mindspore import Model, context
 from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
 from mindspore.common import set_seed
 
+from mindspore.communication.management import init
 from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
 from src.callbacks import LossCallBack, EvalCallBack
 from src.datasets import create_dataset, DataType
@@ -125,12 +126,14 @@ cache_enable = cfg.vocab_cache_size > 0
 @moxing_wrapper(pre_process=modelarts_pre_process)
 def train_wide_and_deep():
     """ train_wide_and_deep """
+    context.set_ps_context(enable_ps=True)
+    init()
+
     if not cache_enable:
         cfg.sparse = True
     if cfg.device_target == "GPU":
         context.set_context(enable_graph_kernel=True)
         context.set_context(graph_kernel_flags="--enable_cluster_ops=MatMul")
-    context.set_ps_context(enable_ps=True)
 
     train_and_eval(cfg)