diff --git a/main.py b/main.py index f4ad80b78211f34ee848debc76a6faf78eca962d..2a920a42c042b137bdf02dc0919718c9da4070fc 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,7 @@ from unet_medical.unet_model import UNetMedical from nets.deeplab_v3 import deeplab_v3 from dataset import GetDatasetGenerator from loss import SoftmaxCrossEntropyLoss -from utils.learning_rates import exponential_lr +import utils.learning_rates as learning_rates context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target='Ascend', device_id=7) @@ -25,7 +25,7 @@ train_dataset_generator = GetDatasetGenerator('./datasets', 'train') train_dataset = ds.GeneratorDataset(train_dataset_generator, ["data", "label"], shuffle=True) train_dataset = train_dataset.batch(4, drop_remainder=True) -lr_iter = exponential_lr(3e-5, 20, 0.98, 500, staircase=True) +lr_iter = learning_rates.exponential_lr(3e-5, 20, 0.98, 500, staircase=True) net_loss = SoftmaxCrossEntropyLoss(6, 255) net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_iter)