diff --git a/research/cv/dcgan/train.py b/research/cv/dcgan/train.py index 831830bd5758628601ca45d4780d2f7f109265cf..efd7c1fa042fa2e4247b7f876454082b1adfc589 100644 --- a/research/cv/dcgan/train.py +++ b/research/cv/dcgan/train.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2021-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import numpy as np from mindspore import context from mindspore import nn, Tensor -from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam, ModelCheckpoint +from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam, ModelCheckpoint, RunContext from mindspore.context import ParallelMode from mindspore.communication.management import init, get_rank from src.dataset import create_dataset_imagenet @@ -170,6 +170,8 @@ if __name__ == '__main__': # For each epoch cb_params.cur_epoch_num = 0 cb_params.cur_step_num = 0 + run_context = RunContext(cb_params) + ckpt_cb.begin(run_context) np.random.seed(1) fixed_noise = Tensor(np.random.normal(size=(16, cfg.latent_size, 1, 1)).astype("float32")) @@ -195,7 +197,7 @@ if __name__ == '__main__': cb_params.cur_epoch_num = cb_params.cur_epoch_num + 1 print("================saving model===================") if args.device_id == 0 or not args.run_distribute: - ckpt_cb.save_ckpt(cb_params, True) + ckpt_cb.step_end(run_context) if run_modelart: fake = netG(fixed_noise) print("================saving images===================")