Skip to content
Snippets Groups Projects
Commit ca69ac96 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1315 Fixes Wavenet GPU training error

Merge pull request !1315 from huangbo/master_1209
parents 1402ba50 f3d8904e
No related branches found
No related tags found
No related merge requests found
......@@ -167,7 +167,7 @@ if __name__ == '__main__':
optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.)
train_net = TrainOneStepCell(loss_net, optimizer)
if target != 'CPU':
if target == 'Ascend':
summary_collector = SummaryCollector(summary_dir='summary_dir/device_{}'.format(device_id), collect_freq=1)
model = Model(train_net)
lr_cb = Monitor(lr)
......@@ -180,7 +180,7 @@ if __name__ == '__main__':
config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=hparams.nepochs)
ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck)
callback_list.append(ckpt_cb)
if target != 'CPU':
if target == 'Ascend':
callback_list.append(summary_collector)
if target == 'Ascend' and resume_epoch is not None:
model.train(hparams.nepochs - resume_epoch, data_loaders, callbacks=callback_list, dataset_sink_mode=False)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment