diff --git a/official/cv/psenet/train.py b/official/cv/psenet/train.py index ab8ce52173003ba60cc9a735634c8800133e7c88..cd2b6c98183bbabd77d09ad58378d602538175e7 100644 --- a/official/cv/psenet/train.py +++ b/official/cv/psenet/train.py @@ -118,9 +118,9 @@ def train(): net = TrainOneStepCell(net, opt) time_cb = TimeMonitor(data_size=step_size) - loss_cb = LossCallBack(per_print_times=10) + loss_cb = LossCallBack(per_print_times=step_size) # set and apply parameters of check point config.TRAIN_MODEL_SAVE_PATH - ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=3) + ckpoint_cf = CheckpointConfig(save_checkpoint_steps=20*step_size, keep_checkpoint_max=10) ckpoint_cb = ModelCheckpoint(prefix="PSENet", config=ckpoint_cf, directory="{}/ckpt_{}".format(config.TRAIN_MODEL_SAVE_PATH,