From 12d99e65356c4b9e82b2757684e76db6d1984228 Mon Sep 17 00:00:00 2001
From: hexiangdong2019 <PIpi589632147>
Date: Wed, 4 Aug 2021 21:47:02 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=AE=AD=E7=BB=83=E5=8F=82?=
 =?UTF-8?q?=E6=95=B0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 main.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/main.py b/main.py
index 263811a..421d70d 100644
--- a/main.py
+++ b/main.py
@@ -28,8 +28,7 @@ train_dataset_generator = GetDatasetGenerator('./datasets', 'train')
 train_dataset = ds.GeneratorDataset(train_dataset_generator, ["data", "label"], shuffle=True)
 train_dataset = train_dataset.batch(4, drop_remainder=True)
 
-# lr_iter = exponential_lr(3e-5, 100, 0.98, 500, staircase=True)
-lr_iter = cosine_lr(3e-5, 20, 500)
+lr_iter = exponential_lr(3e-5, 20, 0.9, 100, staircase=True)
 
 net_loss = SoftmaxCrossEntropyLoss(6, 255)
 net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_iter)
@@ -177,7 +176,7 @@ def net_eval():
 time_cb = TimeMonitor(data_size=60)
 loss_cb = LossMonitor()
 model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
-model.train(10, train_dataset, callbacks=[time_cb, loss_cb, ckpoint], dataset_sink_mode=True)
+model.train(100, train_dataset, callbacks=[time_cb, loss_cb, ckpoint], dataset_sink_mode=True)
 
 net_eval()
 
-- 
GitLab