diff --git a/official/cv/nasnet/src/nasnet_a_mobile.py b/official/cv/nasnet/src/nasnet_a_mobile.py index 0988a50a8acc07b4403db6b16c14402595dea6f6..10904a52e55fc2ddb6577be7940f6f366e3a4a11 100644 --- a/official/cv/nasnet/src/nasnet_a_mobile.py +++ b/official/cv/nasnet/src/nasnet_a_mobile.py @@ -18,9 +18,7 @@ import numpy as np from mindspore import context from mindspore import Tensor import mindspore.nn as nn - from mindspore.nn.loss.loss import LossBase - import mindspore.ops.operations as P import mindspore.ops.functional as F import mindspore.ops.composite as C @@ -61,9 +59,8 @@ def _clip_grad(clip_type, clip_value, grad): class CrossEntropy(LossBase): """the redefined loss function with SoftmaxCrossEntropyWithLogits""" - def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): + def __init__(self, smooth_factor=0, num_classes=1000): super(CrossEntropy, self).__init__() - self.factor = factor self.onehot = P.OneHot() self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) @@ -71,14 +68,11 @@ class CrossEntropy(LossBase): self.mean = P.ReduceMean(False) def construct(self, logits, label): - logit, aux = logits + logit = logits[0] one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) loss_logit = self.ce(logit, one_hot_label) loss_logit = self.mean(loss_logit, 0) - one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value) - loss_aux = self.ce(aux, one_hot_label_aux) - loss_aux = self.mean(loss_aux, 0) - return loss_logit + self.factor*loss_aux + return loss_logit class AuxLogits(nn.Cell): @@ -896,7 +890,7 @@ class NASNetAMobileWithLoss(nn.Cell): super(NASNetAMobileWithLoss, self).__init__() self.network = NASNetAMobile(config.num_classes, is_training) self.loss = CrossEntropy(smooth_factor=config.label_smooth_factor, - num_classes=config.num_classes, factor=config.aux_factor) + num_classes=config.num_classes) self.cast = P.Cast() def construct(self, data, label):