diff --git a/main.py b/main.py index 10cd0e169fe1a7439bbe73a3db02f6d67e5bc300..f4ad80b78211f34ee848debc76a6faf78eca962d 100644 --- a/main.py +++ b/main.py @@ -28,7 +28,7 @@ train_dataset = train_dataset.batch(4, drop_remainder=True) lr_iter = exponential_lr(3e-5, 20, 0.98, 500, staircase=True) net_loss = SoftmaxCrossEntropyLoss(6, 255) -net_opt = nn.Adam(net.trainable_params(), learning_rate=3e-5) +net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_iter) config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10) ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)