From 50f6ed403c396496b0aa1e4d5bb287f85dc51b69 Mon Sep 17 00:00:00 2001 From: deepr <hexiangdong2020@outlook.com> Date: Tue, 20 Jul 2021 23:08:44 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3utils=E5=BC=95=E7=94=A8?= =?UTF-8?q?=E5=86=B2=E7=AA=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 4 ++-- utils/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index aae5d59..3771fd8 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,7 @@ from unet_medical.unet_model import UNetMedical from nets.deeplab_v3 import deeplab_v3 from dataset import GetDatasetGenerator from loss import SoftmaxCrossEntropyLoss -from utils import exponential_lr +import utils context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target='Ascend', device_id=7) @@ -25,7 +25,7 @@ train_dataset_generator = GetDatasetGenerator('./datasets', 'train') train_dataset = ds.GeneratorDataset(train_dataset_generator, ["data", "label"], shuffle=True) train_dataset = train_dataset.batch(4, drop_remainder=True) -lr_iter = exponential_lr(3e-5, 20, 0.98, 500, staircase=True) +lr_iter = utils.exponential_lr(3e-5, 20, 0.98, 500, staircase=True) net_loss = SoftmaxCrossEntropyLoss(6, 255) net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_iter) diff --git a/utils/__init__.py b/utils/__init__.py index c759065..5d1904a 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1 +1 @@ -from learning_rates import cosine_lr, poly_lr, exponential_lr \ No newline at end of file +from learning_rates import cosine_lr, poly_lr, exponential_lr -- GitLab