Skip to content
Snippets Groups Projects
Commit 366d295c authored by panfengfeng's avatar panfengfeng
Browse files

fix naset performance

parent f0c085ed
No related branches found
No related tags found
No related merge requests found
......@@ -18,9 +18,7 @@ import numpy as np
from mindspore import context
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.nn.loss.loss import LossBase
import mindspore.ops.operations as P
import mindspore.ops.functional as F
import mindspore.ops.composite as C
......@@ -61,9 +59,8 @@ def _clip_grad(clip_type, clip_value, grad):
class CrossEntropy(LossBase):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
def __init__(self, smooth_factor=0, num_classes=1000):
super(CrossEntropy, self).__init__()
self.factor = factor
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
......@@ -71,14 +68,11 @@ class CrossEntropy(LossBase):
self.mean = P.ReduceMean(False)
def construct(self, logits, label):
logit, aux = logits
logit = logits[0]
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss_logit = self.ce(logit, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value)
loss_aux = self.ce(aux, one_hot_label_aux)
loss_aux = self.mean(loss_aux, 0)
return loss_logit + self.factor*loss_aux
return loss_logit
class AuxLogits(nn.Cell):
......@@ -896,7 +890,7 @@ class NASNetAMobileWithLoss(nn.Cell):
super(NASNetAMobileWithLoss, self).__init__()
self.network = NASNetAMobile(config.num_classes, is_training)
self.loss = CrossEntropy(smooth_factor=config.label_smooth_factor,
num_classes=config.num_classes, factor=config.aux_factor)
num_classes=config.num_classes)
self.cast = P.Cast()
def construct(self, data, label):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment