Skip to content
Snippets Groups Projects
Commit 3bc21cf6 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!966 Fixes Wavenet training error on Windows

Merge pull request !966 from huangbo/master_1019
parents 68c514bd b53a7711
No related branches found
No related tags found
No related merge requests found
......@@ -167,7 +167,8 @@ if __name__ == '__main__':
optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.)
train_net = TrainOneStepCell(loss_net, optimizer)
summary_collector = SummaryCollector(summary_dir='summary_dir/device_{}'.format(device_id), collect_freq=1)
if target != 'CPU':
summary_collector = SummaryCollector(summary_dir='summary_dir/device_{}'.format(device_id), collect_freq=1)
model = Model(train_net)
lr_cb = Monitor(lr)
callback_list = [lr_cb]
......@@ -179,7 +180,8 @@ if __name__ == '__main__':
config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=hparams.nepochs)
ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck)
callback_list.append(ckpt_cb)
callback_list.append(summary_collector)
if target != 'CPU':
callback_list.append(summary_collector)
if target == 'Ascend' and resume_epoch is not None:
model.train(hparams.nepochs - resume_epoch, data_loaders, callbacks=callback_list, dataset_sink_mode=False)
else:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment