From 12d99e65356c4b9e82b2757684e76db6d1984228 Mon Sep 17 00:00:00 2001 From: hexiangdong2019 <PIpi589632147> Date: Wed, 4 Aug 2021 21:47:02 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=AE=AD=E7=BB=83=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 263811a..421d70d 100644 --- a/main.py +++ b/main.py @@ -28,8 +28,7 @@ train_dataset_generator = GetDatasetGenerator('./datasets', 'train') train_dataset = ds.GeneratorDataset(train_dataset_generator, ["data", "label"], shuffle=True) train_dataset = train_dataset.batch(4, drop_remainder=True) -# lr_iter = exponential_lr(3e-5, 100, 0.98, 500, staircase=True) -lr_iter = cosine_lr(3e-5, 20, 500) +lr_iter = exponential_lr(3e-5, 20, 0.9, 100, staircase=True) net_loss = SoftmaxCrossEntropyLoss(6, 255) net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_iter) @@ -177,7 +176,7 @@ def net_eval(): time_cb = TimeMonitor(data_size=60) loss_cb = LossMonitor() model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) -model.train(10, train_dataset, callbacks=[time_cb, loss_cb, ckpoint], dataset_sink_mode=True) +model.train(100, train_dataset, callbacks=[time_cb, loss_cb, ckpoint], dataset_sink_mode=True) net_eval() -- GitLab