diff --git a/official/cv/yolov3_darknet53/src/yolo.py b/official/cv/yolov3_darknet53/src/yolo.py index 72bd46e32c132761cd0bda3b30f75eb7345004f9..c6c8b0366196be7c291de593c0cfcef30fb134a2 100644 --- a/official/cv/yolov3_darknet53/src/yolo.py +++ b/official/cv/yolov3_darknet53/src/yolo.py @@ -211,8 +211,6 @@ class DetectionBlock(nn.Cell): box_xy = prediction[:, :, :, :, :2] box_wh = prediction[:, :, :, :, 2:4] - box_confidence = prediction[:, :, :, :, 4:5] - box_probs = prediction[:, :, :, :, 5:] # gridsize1 is x # gridsize0 is y @@ -220,11 +218,13 @@ class DetectionBlock(nn.Cell): grid_size[0])), ms.float32) # box_wh is w->h box_wh = ops.Exp()(box_wh) * self.anchors / input_shape - box_confidence = self.sigmoid(box_confidence) - box_probs = self.sigmoid(box_probs) if self.conf_training: return grid, prediction, box_xy, box_wh + box_confidence = prediction[:, :, :, :, 4:5] + box_probs = prediction[:, :, :, :, 5:] + box_confidence = self.sigmoid(box_confidence) + box_probs = self.sigmoid(box_probs) return self.concat((box_xy, box_wh, box_confidence, box_probs))