diff --git a/research/audio/wavenet/evaluate.py b/research/audio/wavenet/evaluate.py index bc11752f0fa1e0ad81a0891222131cc4712ebeae..6b869d954923af0d56d75f4aa55ce02dd4ee88a6 100644 --- a/research/audio/wavenet/evaluate.py +++ b/research/audio/wavenet/evaluate.py @@ -195,10 +195,10 @@ def save_ref_audio(hparam, ref, length, target_wav_path_): if __name__ == '__main__': - device_id = int(os.getenv("DEVICE_ID")) - if args.platform == 'CPU': + if args.platform != 'Ascend': context.set_context(mode=0, device_target=args.platform, save_graphs=False) else: + device_id = int(os.getenv("DEVICE_ID")) context.set_context(mode=1, device_target=args.platform, device_id=device_id) speaker_id = int(args.speaker_id) if args.speaker_id != '' else None diff --git a/research/audio/wavenet/scripts/run_eval_cpu.sh b/research/audio/wavenet/scripts/run_eval_cpu.sh index 02b2ba5c20e9a01acaa49f90006e320beebdf4b9..5c43192a7f2fa612963dd5f6e52e43e851036bec 100644 --- a/research/audio/wavenet/scripts/run_eval_cpu.sh +++ b/research/audio/wavenet/scripts/run_eval_cpu.sh @@ -15,7 +15,8 @@ # ============================================================================ if [ $# == 5 ] then - python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --is_numpy --output_path=$5 > eval.log 2>&1 & + python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --is_numpy --output_path=$5 --platform=CPU \ + > eval.log 2>&1 & else - python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --output_path=$4 > eval.log 2>&1 & + python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --output_path=$4 --platform=CPU > eval.log 2>&1 & fi diff --git a/research/audio/wavenet/train.py b/research/audio/wavenet/train.py index 1c2d02b3bd82675bd7308e6503001e9c3415bf0c..80b13acfee58519cfd0d7b5abac51c02753da835 100644 --- a/research/audio/wavenet/train.py +++ b/research/audio/wavenet/train.py @@ -66,10 +66,15 @@ if __name__ == '__main__': else: context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False) - rank_id = int(os.getenv('RANK_ID')) - group_size = int(os.getenv('RANK_SIZE')) - device_id = int(os.getenv("DEVICE_ID")) - context.set_context(device_id=device_id) + if target == 'Ascend': + rank_id = int(os.getenv('RANK_ID')) + group_size = int(os.getenv('RANK_SIZE')) + device_id = int(os.getenv("DEVICE_ID")) + context.set_context(device_id=device_id) + else: + rank_id = 0 + group_size = 1 + device_id = 0 if args.is_distributed: context.reset_auto_parallel_context() @@ -154,7 +159,7 @@ if __name__ == '__main__': else: lr = get_lr(hparams.optimizer_params["lr"], hparams.nepochs, step_size_per_epoch) lr = Tensor(lr) - if arg.checkpoint != '': + if args.checkpoint != '': param_dict = load_checkpoint(args.checkpoint) load_param_into_net(model, param_dict) print('Successfully loading the pre-trained model')