diff --git a/official/cv/psenet/src/PSENET/dice_loss.py b/official/cv/psenet/src/PSENET/dice_loss.py index cfb3edca58905917216fcd34a8dec347611565f3..a758d7da3945bdf5e1e76b0aeaeba306ef2ab68f 100644 --- a/official/cv/psenet/src/PSENET/dice_loss.py +++ b/official/cv/psenet/src/PSENET/dice_loss.py @@ -87,7 +87,7 @@ class DiceLoss(Cell): neg_num = self.minimum(3 * pos_num, neg_num) neg_num = self.cast(neg_num, mstype.int32) - neg_num = self.add(neg_num, self.negative_one_int32) + neg_num = neg_num + self.k - 1 neg_mask = self.less_equal(gt_text, self.threshold0) ignore_score = self.fill(mstype.float32, (640, 640), -1e3) neg_score = self.select(neg_mask, score, ignore_score)