Skip to content
Snippets Groups Projects
Commit 168b810a authored by 何向东's avatar 何向东
Browse files

使用可变学习率

parent 1ed758bf
No related branches found
No related tags found
No related merge requests found
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
class SoftmaxCrossEntropyLoss(nn.Cell):
def __init__(self, num_cls=21, ignore_label=255):
super(SoftmaxCrossEntropyLoss, self).__init__()
self.one_hot = P.OneHot(axis=-1)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.not_equal = P.NotEqual()
self.num_cls = num_cls
self.ignore_label = ignore_label
self.mul = P.Mul()
self.sum = P.ReduceSum(False)
self.div = P.RealDiv()
self.transpose = P.Transpose()
self.reshape = P.Reshape()
def construct(self, logits, labels):
labels_int = self.cast(labels, mstype.int32)
labels_int = self.reshape(labels_int, (-1,))
logits_ = self.transpose(logits, (0, 2, 3, 1))
logits_ = self.reshape(logits_, (-1, self.num_cls))
weights = self.not_equal(labels_int, self.ignore_label)
weights = self.cast(weights, mstype.float32)
one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
loss = self.ce(logits_, one_hot_labels)
loss = self.mul(weights, loss)
loss = self.div(self.sum(loss), self.sum(weights))
return loss
......@@ -14,6 +14,7 @@ from unet_medical.unet_model import UNetMedical
from nets.deeplab_v3 import deeplab_v3
from dataset import GetDatasetGenerator
from loss import SoftmaxCrossEntropyLoss
from utils.learning_rates import exponential_lr
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False,
device_target='Ascend', device_id=7)
......@@ -24,6 +25,8 @@ train_dataset_generator = GetDatasetGenerator('./datasets', 'train')
train_dataset = ds.GeneratorDataset(train_dataset_generator, ["data", "label"], shuffle=True)
train_dataset = train_dataset.batch(4, drop_remainder=True)
lr_iter = exponential_lr(3e-5, 20, 0.98, 500, staircase=True)
net_loss = SoftmaxCrossEntropyLoss(6, 255)
net_opt = nn.Adam(net.trainable_params(), learning_rate=3e-5)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
def cosine_lr(base_lr, decay_steps, total_steps):
for i in range(total_steps):
step_ = min(i, decay_steps)
yield base_lr * 0.5 * (1 + np.cos(np.pi * step_ / decay_steps))
def poly_lr(base_lr, decay_steps, total_steps, end_lr=0.0001, power=0.9):
for i in range(total_steps):
step_ = min(i, decay_steps)
yield (base_lr - end_lr) * ((1.0 - step_ / decay_steps) ** power) + end_lr
def exponential_lr(base_lr, decay_steps, decay_rate, total_steps, staircase=False):
for i in range(total_steps):
if staircase:
power_ = i // decay_steps
else:
power_ = float(i) / decay_steps
yield base_lr * (decay_rate ** power_)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment