diff --git a/loss/__init__.py b/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/loss/loss.py b/loss/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..cff0b0da0e7bbfbe1b5760aee968212bb8ba70c5
--- /dev/null
+++ b/loss/loss.py
@@ -0,0 +1,50 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from mindspore import Tensor
+import mindspore.common.dtype as mstype
+import mindspore.nn as nn
+from mindspore.ops import operations as P
+
+
+class SoftmaxCrossEntropyLoss(nn.Cell):
+ def __init__(self, num_cls=21, ignore_label=255):
+ super(SoftmaxCrossEntropyLoss, self).__init__()
+ self.one_hot = P.OneHot(axis=-1)
+ self.on_value = Tensor(1.0, mstype.float32)
+ self.off_value = Tensor(0.0, mstype.float32)
+ self.cast = P.Cast()
+ self.ce = nn.SoftmaxCrossEntropyWithLogits()
+ self.not_equal = P.NotEqual()
+ self.num_cls = num_cls
+ self.ignore_label = ignore_label
+ self.mul = P.Mul()
+ self.sum = P.ReduceSum(False)
+ self.div = P.RealDiv()
+ self.transpose = P.Transpose()
+ self.reshape = P.Reshape()
+
+ def construct(self, logits, labels):
+ labels_int = self.cast(labels, mstype.int32)
+ labels_int = self.reshape(labels_int, (-1,))
+ logits_ = self.transpose(logits, (0, 2, 3, 1))
+ logits_ = self.reshape(logits_, (-1, self.num_cls))
+ weights = self.not_equal(labels_int, self.ignore_label)
+ weights = self.cast(weights, mstype.float32)
+ one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value)
+ loss = self.ce(logits_, one_hot_labels)
+ loss = self.mul(weights, loss)
+ loss = self.div(self.sum(loss), self.sum(weights))
+ return loss
diff --git a/main.py b/main.py
index e6b7edc2ea1e50fa63bd76fb8688afd8fc9045ee..10cd0e169fe1a7439bbe73a3db02f6d67e5bc300 100644
--- a/main.py
+++ b/main.py
@@ -14,6 +14,7 @@ from unet_medical.unet_model import UNetMedical
from nets.deeplab_v3 import deeplab_v3
from dataset import GetDatasetGenerator
from loss import SoftmaxCrossEntropyLoss
+from utils.learning_rates import exponential_lr
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False,
device_target='Ascend', device_id=7)
@@ -24,6 +25,8 @@ train_dataset_generator = GetDatasetGenerator('./datasets', 'train')
train_dataset = ds.GeneratorDataset(train_dataset_generator, ["data", "label"], shuffle=True)
train_dataset = train_dataset.batch(4, drop_remainder=True)
+lr_iter = exponential_lr(3e-5, 20, 0.98, 500, staircase=True)
+
net_loss = SoftmaxCrossEntropyLoss(6, 255)
net_opt = nn.Adam(net.trainable_params(), learning_rate=3e-5)
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/learning_rates.py b/utils/learning_rates.py
new file mode 100644
index 0000000000000000000000000000000000000000..2267b1b6a464c7197b157908ed7952819777f412
--- /dev/null
+++ b/utils/learning_rates.py
@@ -0,0 +1,37 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import numpy as np
+
+
+def cosine_lr(base_lr, decay_steps, total_steps):
+ for i in range(total_steps):
+ step_ = min(i, decay_steps)
+ yield base_lr * 0.5 * (1 + np.cos(np.pi * step_ / decay_steps))
+
+
+def poly_lr(base_lr, decay_steps, total_steps, end_lr=0.0001, power=0.9):
+ for i in range(total_steps):
+ step_ = min(i, decay_steps)
+ yield (base_lr - end_lr) * ((1.0 - step_ / decay_steps) ** power) + end_lr
+
+
+def exponential_lr(base_lr, decay_steps, decay_rate, total_steps, staircase=False):
+ for i in range(total_steps):
+ if staircase:
+ power_ = i // decay_steps
+ else:
+ power_ = float(i) / decay_steps
+ yield base_lr * (decay_rate ** power_)