diff --git a/loss/__init__.py b/loss/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/loss/loss.py b/loss/loss.py deleted file mode 100644 index cff0b0da0e7bbfbe1b5760aee968212bb8ba70c5..0000000000000000000000000000000000000000 --- a/loss/loss.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -from mindspore import Tensor -import mindspore.common.dtype as mstype -import mindspore.nn as nn -from mindspore.ops import operations as P - - -class SoftmaxCrossEntropyLoss(nn.Cell): - def __init__(self, num_cls=21, ignore_label=255): - super(SoftmaxCrossEntropyLoss, self).__init__() - self.one_hot = P.OneHot(axis=-1) - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.cast = P.Cast() - self.ce = nn.SoftmaxCrossEntropyWithLogits() - self.not_equal = P.NotEqual() - self.num_cls = num_cls - self.ignore_label = ignore_label - self.mul = P.Mul() - self.sum = P.ReduceSum(False) - self.div = P.RealDiv() - self.transpose = P.Transpose() - self.reshape = P.Reshape() - - def construct(self, logits, labels): - labels_int = self.cast(labels, mstype.int32) - labels_int = self.reshape(labels_int, (-1,)) - logits_ = self.transpose(logits, (0, 2, 3, 1)) - logits_ = self.reshape(logits_, (-1, self.num_cls)) - weights = self.not_equal(labels_int, self.ignore_label) - weights = self.cast(weights, mstype.float32) - one_hot_labels = self.one_hot(labels_int, self.num_cls, self.on_value, self.off_value) - loss = self.ce(logits_, one_hot_labels) - loss = self.mul(weights, loss) - loss = self.div(self.sum(loss), self.sum(weights)) - return loss