diff --git a/research/audio/wavenet/train.py b/research/audio/wavenet/train.py index f982c5cbc4486b7415e0c013b7c950d461a5ac06..cda758e3f7fa4ecf9c662ceeab604c70b9d105ff 100644 --- a/research/audio/wavenet/train.py +++ b/research/audio/wavenet/train.py @@ -167,7 +167,7 @@ if __name__ == '__main__': optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.) train_net = TrainOneStepCell(loss_net, optimizer) - if target != 'CPU': + if target == 'Ascend': summary_collector = SummaryCollector(summary_dir='summary_dir/device_{}'.format(device_id), collect_freq=1) model = Model(train_net) lr_cb = Monitor(lr) @@ -180,7 +180,7 @@ if __name__ == '__main__': config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=hparams.nepochs) ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck) callback_list.append(ckpt_cb) - if target != 'CPU': + if target == 'Ascend': callback_list.append(summary_collector) if target == 'Ascend' and resume_epoch is not None: model.train(hparams.nepochs - resume_epoch, data_loaders, callbacks=callback_list, dataset_sink_mode=False)