diff --git a/research/audio/wavenet/train.py b/research/audio/wavenet/train.py index 80b13acfee58519cfd0d7b5abac51c02753da835..f982c5cbc4486b7415e0c013b7c950d461a5ac06 100644 --- a/research/audio/wavenet/train.py +++ b/research/audio/wavenet/train.py @@ -167,7 +167,8 @@ if __name__ == '__main__': optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.) train_net = TrainOneStepCell(loss_net, optimizer) - summary_collector = SummaryCollector(summary_dir='summary_dir/device_{}'.format(device_id), collect_freq=1) + if target != 'CPU': + summary_collector = SummaryCollector(summary_dir='summary_dir/device_{}'.format(device_id), collect_freq=1) model = Model(train_net) lr_cb = Monitor(lr) callback_list = [lr_cb] @@ -179,7 +180,8 @@ if __name__ == '__main__': config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=hparams.nepochs) ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck) callback_list.append(ckpt_cb) - callback_list.append(summary_collector) + if target != 'CPU': + callback_list.append(summary_collector) if target == 'Ascend' and resume_epoch is not None: model.train(hparams.nepochs - resume_epoch, data_loaders, callbacks=callback_list, dataset_sink_mode=False) else: