diff --git a/research/audio/wavenet/evaluate.py b/research/audio/wavenet/evaluate.py
index bc11752f0fa1e0ad81a0891222131cc4712ebeae..6b869d954923af0d56d75f4aa55ce02dd4ee88a6 100644
--- a/research/audio/wavenet/evaluate.py
+++ b/research/audio/wavenet/evaluate.py
@@ -195,10 +195,10 @@ def save_ref_audio(hparam, ref, length, target_wav_path_):
 
 if __name__ == '__main__':
 
-    device_id = int(os.getenv("DEVICE_ID"))
-    if args.platform == 'CPU':
+    if args.platform != 'Ascend':
         context.set_context(mode=0, device_target=args.platform, save_graphs=False)
     else:
+        device_id = int(os.getenv("DEVICE_ID"))
         context.set_context(mode=1, device_target=args.platform, device_id=device_id)
 
     speaker_id = int(args.speaker_id) if args.speaker_id != '' else None
diff --git a/research/audio/wavenet/scripts/run_eval_cpu.sh b/research/audio/wavenet/scripts/run_eval_cpu.sh
index 02b2ba5c20e9a01acaa49f90006e320beebdf4b9..5c43192a7f2fa612963dd5f6e52e43e851036bec 100644
--- a/research/audio/wavenet/scripts/run_eval_cpu.sh
+++ b/research/audio/wavenet/scripts/run_eval_cpu.sh
@@ -15,7 +15,8 @@
 # ============================================================================
 if [ $# == 5 ]
 then
-    python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --is_numpy --output_path=$5 > eval.log 2>&1 &
+    python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --is_numpy --output_path=$5 --platform=CPU \
+    > eval.log 2>&1 &
 else
-    python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --output_path=$4 > eval.log 2>&1 &
+    python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --output_path=$4 --platform=CPU > eval.log 2>&1 &
 fi
diff --git a/research/audio/wavenet/train.py b/research/audio/wavenet/train.py
index 1c2d02b3bd82675bd7308e6503001e9c3415bf0c..80b13acfee58519cfd0d7b5abac51c02753da835 100644
--- a/research/audio/wavenet/train.py
+++ b/research/audio/wavenet/train.py
@@ -66,10 +66,15 @@ if __name__ == '__main__':
     else:
         context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False)
 
-    rank_id = int(os.getenv('RANK_ID'))
-    group_size = int(os.getenv('RANK_SIZE'))
-    device_id = int(os.getenv("DEVICE_ID"))
-    context.set_context(device_id=device_id)
+    if target == 'Ascend':
+        rank_id = int(os.getenv('RANK_ID'))
+        group_size = int(os.getenv('RANK_SIZE'))
+        device_id = int(os.getenv("DEVICE_ID"))
+        context.set_context(device_id=device_id)
+    else:
+        rank_id = 0
+        group_size = 1
+        device_id = 0
 
     if args.is_distributed:
         context.reset_auto_parallel_context()
@@ -154,7 +159,7 @@ if __name__ == '__main__':
     else:
         lr = get_lr(hparams.optimizer_params["lr"], hparams.nepochs, step_size_per_epoch)
         lr = Tensor(lr)
-        if arg.checkpoint != '':
+        if args.checkpoint != '':
             param_dict = load_checkpoint(args.checkpoint)
             load_param_into_net(model, param_dict)
             print('Successfully loading the pre-trained model')