Skip to content
Snippets Groups Projects
Commit 1dae8756 authored by maijianqiang's avatar maijianqiang
Browse files

add detection function

parent e98cbff3
Branches
Tags
No related merge requests found
...@@ -18,12 +18,6 @@ flip_ratio: 0.5 ...@@ -18,12 +18,6 @@ flip_ratio: 0.5
expand_ratio: 1.0 expand_ratio: 1.0
# anchor # anchor
feature_shapes:
- [192, 320]
- [96, 160]
- [48, 80]
- [24, 40]
- [12, 20]
anchor_scales: [8] anchor_scales: [8]
anchor_ratios: [0.5, 1.0, 2.0] anchor_ratios: [0.5, 1.0, 2.0]
anchor_strides: [4, 8, 16, 32, 64] anchor_strides: [4, 8, 16, 32, 64]
...@@ -52,7 +46,6 @@ rpn_target_stds: [1.0, 1.0, 1.0, 1.0] ...@@ -52,7 +46,6 @@ rpn_target_stds: [1.0, 1.0, 1.0, 1.0]
neg_iou_thr: 0.3 neg_iou_thr: 0.3
pos_iou_thr: 0.7 pos_iou_thr: 0.7
min_pos_iou: 0.3 min_pos_iou: 0.3
num_bboxes: 245520
num_gts: 128 num_gts: 128
num_expected_neg: 256 num_expected_neg: 256
num_expected_pos: 128 num_expected_pos: 128
...@@ -121,7 +114,9 @@ batch_size: 2 ...@@ -121,7 +114,9 @@ batch_size: 2
loss_scale: 256 loss_scale: 256
momentum: 0.91 momentum: 0.91
weight_decay: 0.00001 weight_decay: 0.00001
epoch_size: 12 epoch_size: 20
run_eval: False
interval: 1
save_checkpoint: True save_checkpoint: True
save_checkpoint_epochs: 1 save_checkpoint_epochs: 1
keep_checkpoint_max: 5 keep_checkpoint_max: 5
...@@ -152,6 +147,7 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane ...@@ -152,6 +147,7 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'] 'teddy bear', 'hair drier', 'toothbrush']
num_classes: 81 num_classes: 81
prefix: ""
# annotations file(json format or user defined text format) # annotations file(json format or user defined text format)
anno_path: '' anno_path: ''
...@@ -209,6 +205,7 @@ checkpoint_path: "Checkpoint file path." ...@@ -209,6 +205,7 @@ checkpoint_path: "Checkpoint file path."
ckpt_file: 'fasterrcnn ckpt file.' ckpt_file: 'fasterrcnn ckpt file.'
result_path: "result file path." result_path: "result file path."
backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152" backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152"
interval: "val interval"
--- ---
device_target: ['Ascend', 'GPU', 'CPU'] device_target: ['Ascend', 'GPU', 'CPU']
......
...@@ -19,12 +19,6 @@ flip_ratio: 0.5 ...@@ -19,12 +19,6 @@ flip_ratio: 0.5
expand_ratio: 1.0 expand_ratio: 1.0
# anchor # anchor
feature_shapes:
- [192, 320]
- [96, 160]
- [48, 80]
- [24, 40]
- [12, 20]
anchor_scales: [8] anchor_scales: [8]
anchor_ratios: [0.5, 1.0, 2.0] anchor_ratios: [0.5, 1.0, 2.0]
anchor_strides: [4, 8, 16, 32, 64] anchor_strides: [4, 8, 16, 32, 64]
...@@ -53,7 +47,6 @@ rpn_target_stds: [1.0, 1.0, 1.0, 1.0] ...@@ -53,7 +47,6 @@ rpn_target_stds: [1.0, 1.0, 1.0, 1.0]
neg_iou_thr: 0.3 neg_iou_thr: 0.3
pos_iou_thr: 0.7 pos_iou_thr: 0.7
min_pos_iou: 0.3 min_pos_iou: 0.3
num_bboxes: 245520
num_gts: 128 num_gts: 128
num_expected_neg: 256 num_expected_neg: 256
num_expected_pos: 128 num_expected_pos: 128
...@@ -123,6 +116,8 @@ loss_scale: 256 ...@@ -123,6 +116,8 @@ loss_scale: 256
momentum: 0.91 momentum: 0.91
weight_decay: 0.00001 weight_decay: 0.00001
epoch_size: 20 epoch_size: 20
run_eval: False
interval: 1
save_checkpoint: True save_checkpoint: True
save_checkpoint_epochs: 1 save_checkpoint_epochs: 1
keep_checkpoint_max: 5 keep_checkpoint_max: 5
...@@ -153,6 +148,7 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane ...@@ -153,6 +148,7 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'] 'teddy bear', 'hair drier', 'toothbrush']
num_classes: 81 num_classes: 81
prefix: ""
# annotations file(json format or user defined text format) # annotations file(json format or user defined text format)
anno_path: '' anno_path: ''
...@@ -210,6 +206,7 @@ checkpoint_path: "Checkpoint file path." ...@@ -210,6 +206,7 @@ checkpoint_path: "Checkpoint file path."
ckpt_file: 'fasterrcnn ckpt file.' ckpt_file: 'fasterrcnn ckpt file.'
result_path: "result file path." result_path: "result file path."
backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152" backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152"
interval: "val interval"
--- ---
device_target: ['Ascend', 'GPU', 'CPU'] device_target: ['Ascend', 'GPU', 'CPU']
......
...@@ -19,12 +19,6 @@ flip_ratio: 0.5 ...@@ -19,12 +19,6 @@ flip_ratio: 0.5
expand_ratio: 1.0 expand_ratio: 1.0
# anchor # anchor
feature_shapes:
- [192, 320]
- [96, 160]
- [48, 80]
- [24, 40]
- [12, 20]
anchor_scales: [8] anchor_scales: [8]
anchor_ratios: [0.5, 1.0, 2.0] anchor_ratios: [0.5, 1.0, 2.0]
anchor_strides: [4, 8, 16, 32, 64] anchor_strides: [4, 8, 16, 32, 64]
...@@ -53,7 +47,6 @@ rpn_target_stds: [1.0, 1.0, 1.0, 1.0] ...@@ -53,7 +47,6 @@ rpn_target_stds: [1.0, 1.0, 1.0, 1.0]
neg_iou_thr: 0.3 neg_iou_thr: 0.3
pos_iou_thr: 0.7 pos_iou_thr: 0.7
min_pos_iou: 0.3 min_pos_iou: 0.3
num_bboxes: 245520
num_gts: 128 num_gts: 128
num_expected_neg: 256 num_expected_neg: 256
num_expected_pos: 128 num_expected_pos: 128
...@@ -123,6 +116,8 @@ loss_scale: 256 ...@@ -123,6 +116,8 @@ loss_scale: 256
momentum: 0.91 momentum: 0.91
weight_decay: 0.00001 weight_decay: 0.00001
epoch_size: 20 epoch_size: 20
run_eval: False
interval: 1
save_checkpoint: True save_checkpoint: True
save_checkpoint_epochs: 1 save_checkpoint_epochs: 1
keep_checkpoint_max: 5 keep_checkpoint_max: 5
...@@ -153,6 +148,7 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane ...@@ -153,6 +148,7 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'] 'teddy bear', 'hair drier', 'toothbrush']
num_classes: 81 num_classes: 81
prefix: ""
# annotations file(json format or user defined text format) # annotations file(json format or user defined text format)
anno_path: '' anno_path: ''
...@@ -210,6 +206,7 @@ checkpoint_path: "Checkpoint file path." ...@@ -210,6 +206,7 @@ checkpoint_path: "Checkpoint file path."
ckpt_file: 'fasterrcnn ckpt file.' ckpt_file: 'fasterrcnn ckpt file.'
result_path: "result file path." result_path: "result file path."
backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152" backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152"
interval: "val interval"
--- ---
device_target: ['Ascend', 'GPU', 'CPU'] device_target: ['Ascend', 'GPU', 'CPU']
......
...@@ -78,7 +78,6 @@ def fasterrcnn_eval(dataset_path, ckpt_path, anno_path): ...@@ -78,7 +78,6 @@ def fasterrcnn_eval(dataset_path, ckpt_path, anno_path):
max_num = 128 max_num = 128
for data in ds.create_dict_iterator(num_epochs=1): for data in ds.create_dict_iterator(num_epochs=1):
eval_iter = eval_iter + 1 eval_iter = eval_iter + 1
img_data = data['image'] img_data = data['image']
img_metas = data['image_shape'] img_metas = data['image_shape']
gt_bboxes = data['box'] gt_bboxes = data['box']
...@@ -117,12 +116,12 @@ def fasterrcnn_eval(dataset_path, ckpt_path, anno_path): ...@@ -117,12 +116,12 @@ def fasterrcnn_eval(dataset_path, ckpt_path, anno_path):
eval_types = ["bbox"] eval_types = ["bbox"]
result_files = results2json(dataset_coco, outputs, "./results.pkl") result_files = results2json(dataset_coco, outputs, "./results.pkl")
coco_eval(result_files, eval_types, dataset_coco, single_result=True) coco_eval(config, result_files, eval_types, dataset_coco, single_result=False)
def modelarts_pre_process(): def modelarts_pre_process():
pass pass
# config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path)
@moxing_wrapper(pre_process=modelarts_pre_process) @moxing_wrapper(pre_process=modelarts_pre_process)
def eval_fasterrcnn(): def eval_fasterrcnn():
...@@ -154,5 +153,36 @@ def eval_fasterrcnn(): ...@@ -154,5 +153,36 @@ def eval_fasterrcnn():
print("Start Eval!") print("Start Eval!")
fasterrcnn_eval(mindrecord_file, config.checkpoint_path, config.anno_path) fasterrcnn_eval(mindrecord_file, config.checkpoint_path, config.anno_path)
flags = [0] * 3
config.eval_result_path = os.path.abspath("./eval_result")
if os.path.exists(config.eval_result_path):
result_files = os.listdir(config.eval_result_path)
for file in result_files:
if file == "statistics.csv":
with open(os.path.join(config.eval_result_path, "statistics.csv"), "r") as f:
res = f.readlines()
if len(res) > 1:
if "class_name" in res[3] and "tp_num" in res[3] and len(res[4].strip().split(",")) > 1:
flags[0] = 1
elif file in ("precision_ng_images", "recall_ng_images", "ok_images"):
imgs = os.listdir(os.path.join(config.eval_result_path, file))
if imgs:
flags[1] = 1
elif file == "pr_curve_image":
imgs = os.listdir(os.path.join(config.eval_result_path, "pr_curve_image"))
if imgs:
flags[2] = 1
else:
pass
if sum(flags) == 3:
print("eval success.")
exit(0)
else:
print("eval failed.")
exit(-1)
if __name__ == '__main__': if __name__ == '__main__':
eval_fasterrcnn() eval_fasterrcnn()
...@@ -20,6 +20,7 @@ from mindspore.ops import operations as P ...@@ -20,6 +20,7 @@ from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from src.model_utils.config import config
def bias_init_zeros(shape): def bias_init_zeros(shape):
...@@ -84,9 +85,9 @@ class FeatPyramidNeck(nn.Cell): ...@@ -84,9 +85,9 @@ class FeatPyramidNeck(nn.Cell):
self.fpn_convs_.append(fpn_conv) self.fpn_convs_.append(fpn_conv)
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
self.interpolate1 = P.ResizeNearestNeighbor((48, 80)) self.interpolate1 = P.ResizeNearestNeighbor(config.feature_shapes[2])
self.interpolate2 = P.ResizeNearestNeighbor((96, 160)) self.interpolate2 = P.ResizeNearestNeighbor(config.feature_shapes[1])
self.interpolate3 = P.ResizeNearestNeighbor((192, 320)) self.interpolate3 = P.ResizeNearestNeighbor(config.feature_shapes[0])
self.maxpool = P.MaxPool(kernel_size=1, strides=2, pad_mode="same") self.maxpool = P.MaxPool(kernel_size=1, strides=2, pad_mode="same")
def construct(self, inputs): def construct(self, inputs):
......
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from typing import List
import os
import csv
import warnings
import cv2
import numpy as np
from pycocotools.cocoeval import COCOeval
import matplotlib.pyplot as plt
from matplotlib import gridspec
import seaborn as sns
warnings.filterwarnings("ignore")
COLOR_MAP = [
(0, 255, 255),
(0, 255, 0),
(255, 0, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 128, 128),
(0, 128, 0),
(128, 0, 0),
(0, 0, 128),
(128, 128, 0),
(128, 0, 128),
]
def write_list_to_csv(file_path, data_to_write, append=False):
print('Saving data into file [{}]...'.format(file_path))
if append:
open_mode = 'a'
else:
open_mode = 'w'
with open(file_path, open_mode) as csvfile:
writer = csv.writer(csvfile)
writer.writerow(data_to_write)
def read_image(image_path):
image = cv2.imread(image_path)
if image is None:
return False, None
return True, image
def save_image(image_path, image):
return cv2.imwrite(image_path, image)
def draw_rectangle(image, pt1, pt2, label=None):
if label is not None:
map_index = label % len(COLOR_MAP)
color = COLOR_MAP[map_index]
else:
color = COLOR_MAP[0]
thickness = 5
cv2.rectangle(image, pt1, pt2, color, thickness)
def draw_text(image, text, org, label=None):
if label is not None:
map_index = label % len(COLOR_MAP)
color = COLOR_MAP[map_index]
else:
color = COLOR_MAP[0]
font_face = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 1
cv2.putText(image, text, org, font_face, font_scale, color, thickness)
def draw_one_box(image, label, box, cat_id, line_thickness=None):
tl = line_thickness or round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1
if cat_id is not None:
map_index = cat_id % len(COLOR_MAP)
color = COLOR_MAP[map_index]
else:
color = COLOR_MAP[0]
c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(image, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
tf = max(tl - 1, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tl / 6, thickness=tf // 2)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(image, c1, c2, color, -1, cv2.LINE_AA)
cv2.putText(image, label, (c1[0], c1[1] - 2), 0, tl / 6, [255, 255, 255], thickness=tf // 2, lineType=cv2.LINE_AA)
class DetectEval(COCOeval):
def __init__(self, cocoGt=None, cocoDt=None, iouType="bbox"):
assert iouType == "bbox", "iouType only supported bbox"
super().__init__(cocoGt, cocoDt, iouType)
if not self.cocoGt is None:
cat_infos = cocoGt.loadCats(cocoGt.getCatIds())
self.params.labels = {}
# self.params.labels = ["" for i in range(len(self.params.catIds))]
for cat in cat_infos:
self.params.labels[cat["id"]] = cat["name"]
# add new
def catId_summarize(self, catId, iouThr=None, areaRng="all", maxDets=100):
p = self.params
aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
s = self.eval["recall"]
if iouThr is not None:
iou = np.where(iouThr == p.iouThrs)[0]
s = s[iou]
if isinstance(catId, int):
s = s[:, catId, aind, mind]
else:
s = s[:, :, aind, mind]
not_empty = len(s[s > -1]) == 0
if not_empty:
mean_s = -1
else:
mean_s = np.mean(s[s > -1])
return mean_s
def compute_gt_dt_num(self):
p = self.params
catIds_gt_num = {}
catIds_dt_num = {}
for ids in p.catIds:
gts_cat_id = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(catIds=[ids]))
dts_cat_id = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(catIds=[ids]))
catIds_gt_num[ids] = len(gts_cat_id)
catIds_dt_num[ids] = len(dts_cat_id)
return catIds_gt_num, catIds_dt_num
def evaluate_ok_ng(self, img_id, catIds, iou_threshold=0.5):
"""
evaluate every if this image is ok、precision_ng、recall_ng
img_id: int
cat_ids:list
iou_threshold:int
"""
p = self.params
img_id = int(img_id)
# Save the results of precision_ng and recall_ng for each category on a picture
cat_id_result = {}
for cat_id in catIds:
gt = self._gts[img_id, cat_id]
dt = self._dts[img_id, cat_id]
ious = self.computeIoU(img_id, cat_id)
# Sort dt in descending order, and only take the first 100
inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
dt = [dt[i] for i in inds]
# p.maxDets must be set in ascending order
if len(dt) > p.maxDets[-1]:
dt = dt[0:p.maxDets[-1]]
# The first case: gt, dt are both 0: skip
if not gt and not dt:
cat_id_result[cat_id] = (False, False)
continue
# The second case: gt = 0, dt !=0: precision_ng
if not gt and dt:
cat_id_result[cat_id] = (True, False)
continue
# The third case: gt != 0, dt = 0: recall_ng
if gt and not dt:
cat_id_result[cat_id] = (False, True)
continue
# The fourth case: gt and dt are matched in pairs
gtm = [0] * len(gt)
dtm = [0] * len(dt)
for dind in range(len(dt)):
# dt:[a] gt [b] ious = [a*b]
iou = min([iou_threshold, 1 - 1e-10])
# m records the position of the gt with the best match
m = -1
for gind in range(len(gt)):
# If gt[gind] already matches, skip it.
if gtm[gind] > 0:
continue
# If the iou(dind, gind) is less than the threshold, traverse
if ious[dind, gind] < iou:
continue
iou = ious[dind, gind]
m = gind
if m == -1:
continue
dtm[dind] = 1
gtm[m] = 1
# If gt is all matched, gtm is all 1
precision_ng = sum(dtm) < len(dtm)
recall_ng = sum(gtm) < len(gtm)
cat_id_result[cat_id] = (precision_ng, recall_ng)
# As long as the precision_ng in a class is True, the picture is precision_ng, and recall_ng is the same
# Subsequent development of NG pictures for each category can be saved
precision_result = False
recall_result = False
for ng in cat_id_result.values():
precision_ng = ng[0]
recall_ng = ng[1]
if precision_ng:
precision_result = precision_ng
if recall_ng:
recall_result = recall_ng
return precision_result, recall_result
def evaluate_every_class(self):
"""
compute every class's:
[label, tp_num, gt_num, dt_num, precision, recall]
"""
print("Evaluate every class's predision and recall")
p = self.params
cat_ids = p.catIds
labels = p.labels
result = []
catIds_gt_num, catIds_dt_num = self.compute_gt_dt_num()
sum_gt_num = 0
sum_dt_num = 0
for value in catIds_gt_num.values():
sum_gt_num += value
for value in catIds_dt_num.values():
sum_dt_num += value
sum_tp_num = 0
for i, cat_id in enumerate(cat_ids):
# Here is hard-coded
stats = self.catId_summarize(catId=i)
recall = stats
gt_num = catIds_gt_num[cat_id]
tp_num = recall * gt_num
sum_tp_num += tp_num
dt_num = catIds_dt_num[cat_id]
if dt_num <= 0:
if gt_num == 0:
precision = -1
else:
precision = 0
else:
precision = tp_num / dt_num
label = labels[cat_id]
class_result = [label, int(round(tp_num)), gt_num, int(round(dt_num)), round(precision, 3),
round(recall, 3)]
result.append(class_result)
all_precision = sum_tp_num / sum_dt_num
all_recall = sum_tp_num / sum_gt_num
all_result = ["all", int(round(sum_tp_num)), sum_gt_num, int(round(sum_dt_num)), round(all_precision, 3),
round(all_recall, 3)]
result.append(all_result)
print("Done")
return result
def plot_pr_curve(self, eval_result_path):
"""
precisions[T, R, K, A, M]
T: iou thresholds [0.5 : 0.05 : 0.95], idx from 0 to 9
R: recall thresholds [0 : 0.01 : 1], idx from 0 to 100
K: category, idx from 0 to ...
A: area range, (all, small, medium, large), idx from 0 to 3
M: max dets, (1, 10, 100), idx from 0 to 2
"""
print("Plot pr curve about every class")
precisions = self.eval["precision"]
p = self.params
cat_ids = p.catIds
labels = p.labels
pr_dir = os.path.join(eval_result_path, "./pr_curve_image")
if not os.path.exists(pr_dir):
os.mkdir(pr_dir)
for i, cat_id in enumerate(cat_ids):
pr_array1 = precisions[0, :, i, 0, 2] # iou = 0.5
x = np.arange(0.0, 1.01, 0.01)
# plot PR curve
plt.plot(x, pr_array1, label="iou=0.5," + labels[cat_id])
plt.xlabel("recall")
plt.ylabel("precision")
plt.xlim(0, 1.0)
plt.ylim(0, 1.01)
plt.grid(True)
plt.legend(loc="lower left")
plt_path = os.path.join(pr_dir, "pr_curve_" + labels[cat_id] + ".png")
plt.savefig(plt_path)
plt.close(1)
print("Done")
def save_images(self, config, eval_result_path, iou_threshold=0.5):
"""
save ok_images, precision_ng_images, recall_ng_images
Arguments:
config: dict, config about parameters
eval_result_path: str, path to save images
iou_threshold: int, iou_threshold
"""
print("Saving images of ok ng")
p = self.params
img_ids = p.imgIds
cat_ids = p.catIds if p.useCats else [-1] # list: [0,1,2,3....]
labels = p.labels
dt = self.cocoDt.getAnnIds()
dts = self.cocoDt.loadAnns(dt)
for img_id in img_ids:
img_id = int(img_id)
img_info = self.cocoGt.loadImgs(img_id)
if config.dataset == "coco":
im_path_dir = os.path.join(config.coco_root, config.val_data_type)
elif config.dataset == "voc":
im_path_dir = os.path.join(config.voc_root, 'eval', "JPEGImages")
assert config.dataset in ("coco", "voc")
# Return whether the image is precision_ng or recall_ng
precision_ng, recall_ng = self.evaluate_ok_ng(img_id, cat_ids, iou_threshold)
# Save to ok_images
if not precision_ng and not recall_ng:
# origin image path
im_path = os.path.join(im_path_dir, img_info[0]['file_name'])
# output image path
im_path_out_dir = os.path.join(eval_result_path, 'ok_images')
if not os.path.exists(im_path_out_dir):
os.makedirs(im_path_out_dir)
im_path_out = os.path.join(im_path_out_dir, img_info[0]['file_name'])
success, image = read_image(im_path)
assert success
for obj in dts:
_id = obj["image_id"]
if _id == img_id:
bbox = obj["bbox"]
score = obj["score"]
category_id = obj["category_id"]
label = labels[category_id]
xmin = int(bbox[0])
ymin = int(bbox[1])
width = int(bbox[2])
height = int(bbox[3])
xmax = xmin + width
ymax = ymin + height
label = label + " " + str(round(score, 3))
draw_one_box(image, label, (xmin, ymin, xmax, ymax), category_id)
save_image(im_path_out, image)
else:
# Save to precision_ng_images
if precision_ng:
# origin image path
im_path = os.path.join(im_path_dir, img_info[0]['file_name'])
# output image path
im_path_out_dir = os.path.join(eval_result_path, 'precision_ng_images')
if not os.path.exists(im_path_out_dir):
os.makedirs(im_path_out_dir)
im_path_out = os.path.join(im_path_out_dir, img_info[0]['file_name'])
success, image = read_image(im_path)
assert success
for obj in dts:
_id = obj["image_id"]
if _id == img_id:
bbox = obj["bbox"]
score = obj["score"]
category_id = obj["category_id"]
label = labels[category_id]
xmin = int(bbox[0])
ymin = int(bbox[1])
width = int(bbox[2])
height = int(bbox[3])
xmax = xmin + width
ymax = ymin + height
label = label + " " + str(round(score, 3))
draw_one_box(image, label, (xmin, ymin, xmax, ymax), category_id)
save_image(im_path_out, image)
# Save to recall_ng_images
if recall_ng:
# origin image path
im_path = os.path.join(im_path_dir, img_info[0]['file_name'])
# output image path
im_path_out_dir = os.path.join(eval_result_path, 'recall_ng_images')
if not os.path.exists(im_path_out_dir):
os.makedirs(im_path_out_dir)
im_path_out = os.path.join(im_path_out_dir, img_info[0]['file_name'])
success, image = read_image(im_path)
if not success:
raise Exception('Failed reading image from [{}]'.format(im_path))
for obj in dts:
_id = obj["image_id"]
if _id == img_id:
bbox = obj["bbox"]
score = obj["score"]
category_id = obj["category_id"]
label = labels[category_id]
xmin = int(bbox[0])
ymin = int(bbox[1])
width = int(bbox[2])
height = int(bbox[3])
xmax = xmin + width
ymax = ymin + height
label = label + " " + str(round(score, 3))
draw_one_box(image, label, (xmin, ymin, xmax, ymax), category_id)
save_image(im_path_out, image)
print("Done")
def compute_precison_recall_f1(self, min_score=0.1):
print('Compute precision, recall, f1...')
if not self.evalImgs:
print('Please run evaluate() first')
p = self.params
catIds = p.catIds if p.useCats == 1 else [-1]
labels = p.labels
assert len(p.maxDets) == 1
assert len(p.iouThrs) == 1
assert len(p.areaRng) == 1
# get inds to evaluate
k_list = [n for n, k in enumerate(p.catIds)]
m_list = [m for n, m in enumerate(p.maxDets)]
a_list: List[int] = [n for n, a in enumerate(p.areaRng)]
i_list = [n for n, i in enumerate(p.imgIds)]
I0 = len(p.imgIds)
A0 = len(p.areaRng)
# cat_pr_dict:
# {label1:[precision_li, recall_li, f1_li, score_li], label2:[precision_li, recall_li, f1_li, score_li]}
cat_pr_dict = {}
cat_pr_dict_origin = {}
for k0 in k_list:
Nk = k0 * A0 * I0
# areagRng
for a0 in a_list:
Na = a0 * I0
# maxDet
for maxDet in m_list:
E = [self.evalImgs[Nk + Na + i] for i in i_list]
E = [e for e in E if not e is None]
if not E:
continue
dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E])
# different sorting method generates slightly different results.
# mergesort is used to be consistent as Matlab implementation.
inds = np.argsort(-dtScores, kind='mergesort')
dtScoresSorted = dtScores[inds]
dtm = np.concatenate([e['dtMatches'][:, 0:maxDet] for e in E], axis=1)[:, inds]
dtIg = np.concatenate([e['dtIgnore'][:, 0:maxDet] for e in E], axis=1)[:, inds]
gtIg = np.concatenate([e['gtIgnore'] for e in E])
npig = np.count_nonzero(gtIg == 0)
if npig == 0:
continue
tps = np.logical_and(dtm, np.logical_not(dtIg))
fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg))
# Ensure that iou has only one value
assert (tps.shape[0]) == 1
assert (fps.shape[0]) == 1
tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
ids = catIds[k0]
label = labels[ids]
self.calculate_pr_dict(tp_sum, fp_sum, label, npig, dtScoresSorted, cat_pr_dict, cat_pr_dict_origin,
min_score=min_score)
print("Done")
return cat_pr_dict, cat_pr_dict_origin
def calculate_pr_dict(self, tp_sum, fp_sum, label, npig, dtScoresSorted, cat_pr_dict, cat_pr_dict_origin,
min_score=0.1):
# iou
for (tp, fp) in zip(tp_sum, fp_sum):
tp = np.array(tp)
fp = np.array(fp)
rc = tp / npig
pr = tp / (fp + tp + np.spacing(1))
f1 = np.divide(2 * (rc * pr), pr + rc, out=np.zeros_like(2 * (rc * pr)), where=pr + rc != 0)
conf_thres = [int(i) * 0.01 for i in range(10, 100, 10)]
dtscores_ascend = dtScoresSorted[::-1]
inds = np.searchsorted(dtscores_ascend, conf_thres, side='left')
pr_new = [0.0] * len(conf_thres)
rc_new = [0.0] * len(conf_thres)
f1_new = [0.0] * len(conf_thres)
pr_ascend = pr[::-1]
rc_ascend = rc[::-1]
f1_ascend = f1[::-1]
try:
for i, ind in enumerate(inds):
if conf_thres[i] >= min_score:
pr_new[i] = pr_ascend[ind]
rc_new[i] = rc_ascend[ind]
f1_new[i] = f1_ascend[ind]
else:
pr_new[i] = 0.0
rc_new[i] = 0.0
f1_new[i] = 0.0
except IndexError:
pass
# Ensure that the second, third, and fourth for loops only enter once
if label not in cat_pr_dict.keys():
cat_pr_dict_origin[label] = [pr[::-1], rc[::-1], f1[::-1], dtScoresSorted[::-1]]
cat_pr_dict[label] = [pr_new, rc_new, f1_new, conf_thres]
else:
break
def compute_tp_fp_confidence(self):
print('Compute tp and fp confidences')
if not self.evalImgs:
print('Please run evaluate() first')
p = self.params
catIds = p.catIds if p.useCats == 1 else [-1]
labels = p.labels
assert len(p.maxDets) == 1
assert len(p.iouThrs) == 1
assert len(p.areaRng) == 1
# get inds to evaluate
m_list = [m for n, m in enumerate(p.maxDets)]
k_list = list(range(len(p.catIds)))
a_list = list(range(len(p.areaRng)))
i_list = list(range(len(p.imgIds)))
I0 = len(p.imgIds)
A0 = len(p.areaRng)
# cat_dict
correct_conf_dict = {}
incorrect_conf_dict = {}
for k0 in k_list:
Nk = k0 * A0 * I0
# areagRng
for a0 in a_list:
Na = a0 * I0
# maxDet
for maxDet in m_list:
E = [self.evalImgs[Nk + Na + i] for i in i_list]
E = [e for e in E if not e is None]
if not E:
continue
dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E])
inds = np.argsort(-dtScores, kind='mergesort')
dtScoresSorted = dtScores[inds]
dtm = np.concatenate([e['dtMatches'][:, 0:maxDet] for e in E], axis=1)[:, inds]
dtIg = np.concatenate([e['dtIgnore'][:, 0:maxDet] for e in E], axis=1)[:, inds]
gtIg = np.concatenate([e['gtIgnore'] for e in E])
npig = np.count_nonzero(gtIg == 0)
if npig == 0:
continue
tps = np.logical_and(dtm, np.logical_not(dtIg))
fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg))
# Ensure that iou has only one value
assert (tps.shape[0]) == 1
assert (fps.shape[0]) == 1
tp_inds = np.where(tps)
fp_inds = np.where(fps)
tp_confidence = dtScoresSorted[tp_inds[1]]
fp_confidence = dtScoresSorted[fp_inds[1]]
tp_confidence_li = tp_confidence.tolist()
fp_confidence_li = fp_confidence.tolist()
ids = catIds[k0]
label = labels[ids]
# Ensure that the second and third for loops only enter once
if label not in correct_conf_dict.keys():
correct_conf_dict[label] = tp_confidence_li
else:
print("maxDet:", maxDet, " ", "areagRng:", p.areagRng)
break
if label not in incorrect_conf_dict.keys():
incorrect_conf_dict[label] = fp_confidence_li
else:
print("maxDet:", maxDet, " ", "areagRng:", p.areagRng)
break
print("Done")
return correct_conf_dict, incorrect_conf_dict
def write_best_confidence_threshold(self, cat_pr_dict, cat_pr_dict_origin, eval_result_path):
"""
write best confidence threshold
"""
print("Write best confidence threshold to csv")
result_csv = os.path.join(eval_result_path, "best_threshold.csv")
result = ["cat_name", "best_f1", "best_precision", "best_recall", "best_score"]
write_list_to_csv(result_csv, result, append=False)
return_result = []
for cat_name, cat_info in cat_pr_dict.items():
f1_li = cat_info[2]
score_li = cat_info[3]
max_f1 = [f1 for f1 in f1_li if abs(f1 - max(f1_li)) <= 0.001]
thre_ = [0.003] + [int(i) * 0.001 for i in range(10, 100, 10)] + [0.099]
# Find the best confidence threshold for 10 levels of confidence thresholds
if len(max_f1) == 1:
# max_f1 is on the far right
if f1_li.index(max_f1) == len(f1_li) - 1:
index = f1_li.index(max_f1) - 1
# max_f1 is in the middle
elif f1_li.index(max_f1) != len(f1_li) - 1 and f1_li.index(max_f1) != 0:
index_a = f1_li.index(max_f1) - 1
index_b = f1_li.index(max_f1) + 1
if f1_li[index_a] >= f1_li[index_b]:
index = index_a
else:
index = f1_li.index(max_f1)
# max_f1 is on the far left
elif f1_li.index(max_f1) == 0:
index = f1_li.index(max_f1)
best_thre = score_li[index]
# thre_ = [0.003] + [int(i) * 0.001 for i in range(10, 100, 10)] + [0.099]
second_thre = [best_thre + i for i in thre_]
elif len(max_f1) > 1:
thre_pre = [index for (index, value) in enumerate(f1_li) if abs(value - max(f1_li)) <= 0.001]
best_thre = score_li[thre_pre[int((len(thre_pre) - 1) / 2)]]
# thre_ = [0.003] + [int(i) * 0.001 for i in range(10, 100, 10)] + [0.099]
second_thre = [best_thre + i for i in thre_]
# Reduce the step unit to find the second confidence threshold
cat_info_origin = cat_pr_dict_origin[cat_name]
dtscores_ascend = cat_info_origin[3]
inds = np.searchsorted(dtscores_ascend, second_thre, side='left')
pr_second = [0] * len(second_thre)
rc_second = [0] * len(second_thre)
f1_second = [0] * len(second_thre)
try:
for i, ind in enumerate(inds):
if ind >= len(cat_info_origin[0]):
ind = len(cat_info_origin[0]) - 1
pr_second[i] = cat_info_origin[0][ind]
rc_second[i] = cat_info_origin[1][ind]
f1_second[i] = cat_info_origin[2][ind]
except IndexError:
pass
best_f1 = max(f1_second)
best_index = f1_second.index(best_f1)
best_precision = pr_second[best_index]
best_recall = rc_second[best_index]
best_score = second_thre[best_index]
result = [cat_name, best_f1, best_precision, best_recall, best_score]
return_result.append(result)
write_list_to_csv(result_csv, result, append=True)
return return_result
def plot_mc_curve(self, cat_pr_dict, eval_result_path):
"""
plot matrix-confidence curve
cat_pr_dict:{"label_name":[precision, recall, f1, score]}
"""
print('Plot mc curve')
savefig_path = os.path.join(eval_result_path, 'pr_cofidence_fig')
if not os.path.exists(savefig_path):
os.mkdir(savefig_path)
xlabel = "Confidence"
ylabel = "Metric"
for cat_name, cat_info in cat_pr_dict.items():
precision = [round(p, 3) for p in cat_info[0]]
recall = [round(r, 3) for r in cat_info[1]]
f1 = [round(f, 3) for f in cat_info[2]]
score = [round(s, 3) for s in cat_info[3]]
plt.figure(figsize=(9, 9))
gs = gridspec.GridSpec(4, 1)
plt.subplot(gs[:3, 0])
# 1.precision-confidence
plt.plot(score, precision, linewidth=2, color="deepskyblue", label="precision")
# 2.recall-confidence
plt.plot(score, recall, linewidth=2, color="limegreen", label="recall")
# 3.f1-confidence
plt.plot(score, f1, linewidth=2, color="tomato", label="f1_score")
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(cat_name, fontsize=15)
plt.xlim(0, 1)
plt.xticks((np.arange(0, 1, 0.1)))
plt.ylim(0, 1.10)
plt.legend(loc="lower left")
row_name = ["conf_threshold", "precision", "recall", "f1"]
plt.grid(True)
plt.subplot(gs[3, 0])
plt.axis('off')
colors = ["white", "deepskyblue", "limegreen", "tomato"]
plt.table(cellText=[score, precision, recall, f1], rowLabels=row_name, loc='center', cellLoc='center',
rowLoc='center', rowColours=colors)
plt.subplots_adjust(left=0.2, bottom=0.2)
plt.savefig(os.path.join(savefig_path, cat_name) + '.png', dpi=250)
print("Done")
def plot_hist_curve(self, input_data, eval_result_path):
correct_conf_dict, incorrect_conf_dict = input_data[0], input_data[1]
savefig_path = os.path.join(eval_result_path, 'hist_curve_fig')
if not os.path.exists(savefig_path):
os.mkdir(savefig_path)
for l in correct_conf_dict.keys():
plt.figure(figsize=(7, 7))
if l in incorrect_conf_dict.keys() and correct_conf_dict[l] and incorrect_conf_dict[l]:
gs = gridspec.GridSpec(4, 1)
plt.subplot(gs[:3, 0])
correct_conf_dict[l].sort()
correct_conf_dict[l].reverse()
col_name_correct = ['number', 'mean', 'max', 'min', 'min99%', 'min99.9%']
col_val_correct = [len(correct_conf_dict[l]),
('%.2f' % np.mean(correct_conf_dict[l])),
('%.2f' % max(correct_conf_dict[l])), ('%.2f' % min(correct_conf_dict[l])),
('%.2f' % correct_conf_dict[l][int(len(correct_conf_dict[l]) * 0.99) - 1]),
('%.2f' % correct_conf_dict[l][int(len(correct_conf_dict[l]) * 0.999) - 1])]
sns.set_palette('hls')
sns.distplot(correct_conf_dict[l], bins=50, kde_kws={'color': 'b', 'lw': 3},
hist_kws={'color': 'b', 'alpha': 0.3})
plt.xlim((0, 1))
plt.xlabel(l)
plt.ylabel("numbers")
ax1 = plt.twinx()
incorrect_conf_dict[l].sort()
incorrect_conf_dict[l].reverse()
col_val_incorrect = [len(incorrect_conf_dict[l]),
('%.2f' % np.mean(incorrect_conf_dict[l])),
('%.2f' % max(incorrect_conf_dict[l])), ('%.2f' % min(incorrect_conf_dict[l])),
('%.2f' % incorrect_conf_dict[l][int(len(incorrect_conf_dict[l]) * 0.99) - 1]),
('%.2f' % incorrect_conf_dict[l][int(len(incorrect_conf_dict[l]) * 0.999) - 1])]
sns.distplot(incorrect_conf_dict[l], bins=50, kde_kws={'color': 'r', 'lw': 3},
hist_kws={'color': 'r', 'alpha': 0.3}, ax=ax1)
plt.grid(True)
plt.subplot(gs[3, 0])
plt.axis('off')
row_name = ['', 'correct', 'incorrect']
table = plt.table(cellText=[col_name_correct, col_val_correct, col_val_incorrect], rowLabels=row_name,
loc='center', cellLoc='center', rowLoc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 1.5)
plt.savefig(os.path.join(savefig_path, l) + '.jpg')
elif correct_conf_dict[l]:
gs = gridspec.GridSpec(4, 1)
plt.subplot(gs[:3, 0])
correct_conf_dict[l].sort()
correct_conf_dict[l].reverse()
col_name_correct = ['number', 'mean', 'max', 'min', 'min99%', 'min99.9%']
col_val_correct = [len(correct_conf_dict[l]),
('%.4f' % np.mean(correct_conf_dict[l])),
('%.4f' % max(correct_conf_dict[l])), ('%.2f' % min(correct_conf_dict[l])),
('%.2f' % correct_conf_dict[l][int(len(correct_conf_dict[l]) * 0.99) - 1]),
('%.2f' % correct_conf_dict[l][int(len(correct_conf_dict[l]) * 0.999) - 1])]
sns.set_palette('hls')
sns.distplot(correct_conf_dict[l], bins=50, kde_kws={'color': 'b', 'lw': 3},
hist_kws={'color': 'b', 'alpha': 0.3})
plt.xlim((0, 1))
plt.xlabel(l)
plt.ylabel("numbers")
plt.grid(True)
plt.subplot(gs[3, 0])
plt.axis('off')
row_name = ['', 'correct']
table = plt.table(cellText=[col_name_correct, col_val_correct], rowLabels=row_name,
loc='center', cellLoc='center', rowLoc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 1.5)
plt.savefig(os.path.join(savefig_path, l) + '.jpg')
elif l in incorrect_conf_dict.keys() and incorrect_conf_dict[l]:
gs = gridspec.GridSpec(4, 1)
plt.subplot(gs[:3, 0])
incorrect_conf_dict[l].sort()
incorrect_conf_dict[l].reverse()
col_name_correct = ['number', 'mean', 'max', 'min', 'min99%', 'min99.9%']
col_val_correct = [len(incorrect_conf_dict[l]),
('%.4f' % np.mean(incorrect_conf_dict[l])),
('%.4f' % max(incorrect_conf_dict[l])), ('%.2f' % min(incorrect_conf_dict[l])),
('%.2f' % incorrect_conf_dict[l][int(len(incorrect_conf_dict[l]) * 0.99) - 1]),
('%.2f' % incorrect_conf_dict[l][int(len(incorrect_conf_dict[l]) * 0.999) - 1])]
sns.set_palette('hls')
sns.distplot(incorrect_conf_dict[l], bins=50, kde_kws={'color': 'b', 'lw': 3},
hist_kws={'color': 'b', 'alpha': 0.3})
plt.xlim((0, 1))
plt.xlabel(l)
plt.grid(True)
plt.subplot(gs[3, 0])
plt.axis('off')
row_name = ['', 'incorrect']
table = plt.table(cellText=[col_name_correct, col_val_correct], rowLabels=row_name,
loc='center', cellLoc='center', rowLoc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 1.5)
plt.savefig(os.path.join(savefig_path, l) + '.jpg')
if __name__ == "__main__":
cocoeval = COCOeval_()
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import shutil
from mindspore import save_checkpoint
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, config, net, apply_eval, datasetsize, mindrecord_path, anno_json, checkpoint_path):
super(EvalCallBack, self).__init__()
self.faster_rcnn_eval = apply_eval
self.mindrecord_path = mindrecord_path
self.anno_json = anno_json
self.datasetsize = datasetsize
self.config = config
self.checkpoint_path = checkpoint_path
self.net = net
self.best_epoch = 0
self.best_res = 0
self.best_ckpt_path = os.path.abspath("./best_ckpt")
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
ckpt_file_name = "faster_rcnn-{}_{}.ckpt".format(cur_epoch, self.datasetsize)
checkoint_path = os.path.join(self.checkpoint_path, ckpt_file_name)
self.config.current_epoch = cur_epoch
res = 0
if self.config.current_epoch % self.config.interval == 0 or self.config.current_epoch == self.config.epoch_size:
res = self.faster_rcnn_eval(self.net, self.config, self.mindrecord_path, checkoint_path, self.anno_json)
if res > self.best_res:
self.best_epoch = cur_epoch
self.best_res = res
if os.path.exists(self.best_ckpt_path):
shutil.rmtree(self.best_ckpt_path)
os.mkdir(self.best_ckpt_path)
save_checkpoint(cb_params.train_network, os.path.join(self.best_ckpt_path, "best.ckpt"))
print("update best result: {} in the {} th epoch".format(self.best_res, self.best_epoch), flush=True)
def end(self, run_context):
print("End training the best {0} epoch is {1}".format(self.best_res, self.best_epoch), flush=True)
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Coco metrics utils"""
import os
import json
from collections import defaultdict
import numpy as np
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import Parameter
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset, parse_json_annos_from_txt
from src.util import bbox2result_1image, results2json
def create_eval_mindrecord(config):
""" eval_fasterrcnn """
print("CHECKING MINDRECORD FILES ...")
if not os.path.exists(config.mindrecord_file):
if not os.path.isdir(config.mindrecord_dir):
os.makedirs(config.mindrecord_dir)
if config.dataset == "coco":
if os.path.isdir(config.coco_root):
print("Create Mindrecord. It may take some time.")
data_to_mindrecord_byte_image(config, "coco", False, config.prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(config.mindrecord_dir))
else:
print("coco_root not exits.")
else:
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
print("Create Mindrecord. It may take some time.")
data_to_mindrecord_byte_image(config, "other", False, config.prefix, file_num=1)
print("Create Mindrecord Done, at {}".format(config.mindrecord_dir))
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
def apply_eval(net, config, dataset_path, ckpt_path, anno_path):
"""FasterRcnn evaluation."""
if not os.path.isfile(ckpt_path):
raise RuntimeError("CheckPoint file {} is not valid.".format(ckpt_path))
ds = create_fasterrcnn_dataset(config, dataset_path, batch_size=config.test_batch_size, is_training=False)
param_dict = load_checkpoint(ckpt_path)
if config.device_target == "GPU":
for key, value in param_dict.items():
tensor = value.asnumpy().astype(np.float32)
param_dict[key] = Parameter(tensor, key)
load_param_into_net(net, param_dict)
net.set_train(False)
device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
if device_type == "Ascend":
net.to_float(mstype.float16)
eval_iter = 0
total = ds.get_dataset_size()
outputs = []
if config.dataset != "coco":
dataset_coco = COCO()
dataset_coco.dataset, dataset_coco.anns, dataset_coco.cats, dataset_coco.imgs = dict(), dict(), dict(), dict()
dataset_coco.imgToAnns, dataset_coco.catToImgs = defaultdict(list), defaultdict(list)
dataset_coco.dataset = parse_json_annos_from_txt(anno_path, config)
dataset_coco.createIndex()
else:
dataset_coco = COCO(anno_path)
print("\n========================================\n")
print("total images num: ", total)
print("Processing, please wait a moment.")
max_num = 128
for data in ds.create_dict_iterator(num_epochs=1):
eval_iter = eval_iter + 1
img_data = data['image']
img_metas = data['image_shape']
gt_bboxes = data['box']
gt_labels = data['label']
gt_num = data['valid_num']
# run net
output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num)
# output
all_bbox = output[0]
all_label = output[1]
all_mask = output[2]
for j in range(config.test_batch_size):
all_bbox_squee = np.squeeze(all_bbox.asnumpy()[j, :, :])
all_label_squee = np.squeeze(all_label.asnumpy()[j, :, :])
all_mask_squee = np.squeeze(all_mask.asnumpy()[j, :, :])
all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :]
all_labels_tmp_mask = all_label_squee[all_mask_squee]
if all_bboxes_tmp_mask.shape[0] > max_num:
inds = np.argsort(-all_bboxes_tmp_mask[:, -1])
inds = inds[:max_num]
all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds]
all_labels_tmp_mask = all_labels_tmp_mask[inds]
outputs_tmp = bbox2result_1image(all_bboxes_tmp_mask, all_labels_tmp_mask, config.num_classes)
outputs.append(outputs_tmp)
eval_types = ["bbox"]
reslut_path = "./{}epoch_results.pkl".format(config.current_epoch)
result_files = results2json(dataset_coco, outputs, reslut_path)
return metrics_map(result_files, eval_types, dataset_coco, single_result=False)
def metrics_map(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False):
"""coco eval for fasterrcnn"""
anns = json.load(open(result_files['bbox']))
if not anns:
return 0
if isinstance(coco, str):
coco = COCO(coco)
assert isinstance(coco, COCO)
for res_type in result_types:
result_file = result_files[res_type]
assert result_file.endswith('.json')
coco_dets = coco.loadRes(result_file)
det_img_ids = coco_dets.getImgIds()
gt_img_ids = coco.getImgIds()
iou_type = 'bbox' if res_type == 'proposal' else res_type
cocoEval = COCOeval(coco, coco_dets, iou_type)
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)
tgt_ids = gt_img_ids if not single_result else det_img_ids
if res_type == 'proposal':
cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets)
cocoEval.params.imgIds = tgt_ids
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
return cocoEval.stats[0]
...@@ -120,9 +120,19 @@ def get_config(): ...@@ -120,9 +120,19 @@ def get_config():
path_args, _ = parser.parse_known_args() path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path) default, helper, choices = parse_yaml(path_args.config_path)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default) default = Config(merge(args, default))
pprint(final_config) default.feature_shapes = [
[default.img_height // 4, default.img_width // 4],
[default.img_height // 8, default.img_width // 8],
[default.img_height // 16, default.img_width // 16],
[default.img_height // 32, default.img_width // 32],
[default.img_height // 64, default.img_width // 64],
]
default.num_bboxes = default.num_anchors * sum([lst[0] * lst[1] for lst in default.feature_shapes])
pprint(default)
print("Please check the above information for the configurations", flush=True) print("Please check the above information for the configurations", flush=True)
return Config(final_config)
return default
config = get_config() config = get_config()
...@@ -13,10 +13,15 @@ ...@@ -13,10 +13,15 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""coco eval for fasterrcnn""" """coco eval for fasterrcnn"""
import json import json
import os
import csv
import shutil
import numpy as np import numpy as np
from pycocotools.coco import COCO from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from src.detecteval import DetectEval
_init_value = np.array(0.0) _init_value = np.array(0.0)
summary_init = { summary_init = {
...@@ -35,7 +40,18 @@ summary_init = { ...@@ -35,7 +40,18 @@ summary_init = {
} }
def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False): def write_list_to_csv(file_path, data_to_write, append=False):
# print('Saving data into file [{}]...'.format(file_path))
if append:
open_mode = 'a'
else:
open_mode = 'w'
with open(file_path, open_mode) as csvfile:
writer = csv.writer(csvfile)
writer.writerow(data_to_write)
def coco_eval(config, result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False):
"""coco eval for fasterrcnn""" """coco eval for fasterrcnn"""
anns = json.load(open(result_files['bbox'])) anns = json.load(open(result_files['bbox']))
if not anns: if not anns:
...@@ -53,7 +69,7 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl ...@@ -53,7 +69,7 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl
gt_img_ids = coco.getImgIds() gt_img_ids = coco.getImgIds()
det_img_ids = coco_dets.getImgIds() det_img_ids = coco_dets.getImgIds()
iou_type = 'bbox' if res_type == 'proposal' else res_type iou_type = 'bbox' if res_type == 'proposal' else res_type
cocoEval = COCOeval(coco, coco_dets, iou_type) cocoEval = DetectEval(coco, coco_dets, iou_type)
if res_type == 'proposal': if res_type == 'proposal':
cocoEval.params.useCats = 0 cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets) cocoEval.params.maxDets = list(max_dets)
...@@ -63,7 +79,7 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl ...@@ -63,7 +79,7 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl
if single_result: if single_result:
res_dict = dict() res_dict = dict()
for id_i in tgt_ids: for id_i in tgt_ids:
cocoEval = COCOeval(coco, coco_dets, iou_type) cocoEval = DetectEval(coco, coco_dets, iou_type)
if res_type == 'proposal': if res_type == 'proposal':
cocoEval.params.useCats = 0 cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets) cocoEval.params.maxDets = list(max_dets)
...@@ -74,7 +90,7 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl ...@@ -74,7 +90,7 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl
cocoEval.summarize() cocoEval.summarize()
res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]}) res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]})
cocoEval = COCOeval(coco, coco_dets, iou_type) cocoEval = DetectEval(coco, coco_dets, iou_type)
if res_type == 'proposal': if res_type == 'proposal':
cocoEval.params.useCats = 0 cocoEval.params.useCats = 0
cocoEval.params.maxDets = list(max_dets) cocoEval.params.maxDets = list(max_dets)
...@@ -99,7 +115,74 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl ...@@ -99,7 +115,74 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl
'Recall/AR@100 (large)': cocoEval.stats[11], 'Recall/AR@100 (large)': cocoEval.stats[11],
} }
return summary_metrics print("summary_metrics: ")
print(summary_metrics)
res = calcuate_pr_rc_f1(config, coco, coco_dets, tgt_ids, iou_type)
return res
def calcuate_pr_rc_f1(config, coco, coco_dets, tgt_ids, iou_type):
cocoEval = DetectEval(coco, coco_dets, iou_type)
cocoEval.params.imgIds = tgt_ids
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
stats_all = cocoEval.stats
eval_result_path = os.path.abspath("./eval_result")
if os.path.exists(eval_result_path):
shutil.rmtree(eval_result_path)
os.mkdir(eval_result_path)
result_csv = os.path.join(eval_result_path, "statistics.csv")
eval_item = ["ap@0.5:0.95", "ap@0.5", "ap@0.75", "ar@0.5:0.95", "ar@0.5", "ar@0.75"]
write_list_to_csv(result_csv, eval_item, append=False)
eval_result = [round(stats_all[0], 3), round(stats_all[1], 3), round(stats_all[2], 3), round(stats_all[6], 3),
round(stats_all[7], 3), round(stats_all[8], 3)]
write_list_to_csv(result_csv, eval_result, append=True)
write_list_to_csv(result_csv, [], append=True)
# 1.2 plot_pr_curve
cocoEval.plot_pr_curve(eval_result_path)
# 2
E = DetectEval(coco, coco_dets, iou_type)
E.params.iouThrs = [0.5]
E.params.maxDets = [100]
E.params.areaRng = [[0 ** 2, 1e5 ** 2]]
E.evaluate()
# 2.1 plot hist_curve of every class's tp's confidence and fp's confidence
confidence_dict = E.compute_tp_fp_confidence()
E.plot_hist_curve(confidence_dict, eval_result_path)
# 2.2 write best_threshold and p r to csv and plot
cat_pr_dict, cat_pr_dict_origin = E.compute_precison_recall_f1()
# E.write_best_confidence_threshold(cat_pr_dict, cat_pr_dict_origin, eval_result_path)
best_confidence_thres = E.write_best_confidence_threshold(cat_pr_dict, cat_pr_dict_origin, eval_result_path)
print("best_confidence_thres: ", best_confidence_thres)
E.plot_mc_curve(cat_pr_dict, eval_result_path)
# 3
# 3.1 compute every class's p r and save every class's p and r at iou = 0.5
E = DetectEval(coco, coco_dets, iouType='bbox')
E.params.iouThrs = [0.5]
E.params.maxDets = [100]
E.params.areaRng = [[0 ** 2, 1e5 ** 2]]
E.evaluate()
E.accumulate()
result = E.evaluate_every_class()
print_info = ["class_name", "tp_num", "gt_num", "dt_num", "precision", "recall"]
write_list_to_csv(result_csv, print_info, append=True)
print("class_name", "tp_num", "gt_num", "dt_num", "precision", "recall")
for class_result in result:
print(class_result)
write_list_to_csv(result_csv, class_result, append=True)
# 3.2 save ng / ok images
E.save_images(config, eval_result_path, 0.5)
return stats_all[0]
def xyxy2xywh(bbox): def xyxy2xywh(bbox):
......
...@@ -36,7 +36,6 @@ from src.model_utils.config import config ...@@ -36,7 +36,6 @@ from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
set_seed(1) set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id()) context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
...@@ -44,7 +43,7 @@ if config.backbone in ("resnet_v1.5_50", "resnet_v1_101", "resnet_v1_152"): ...@@ -44,7 +43,7 @@ if config.backbone in ("resnet_v1.5_50", "resnet_v1_101", "resnet_v1_152"):
from src.FasterRcnn.faster_rcnn_resnet import Faster_Rcnn_Resnet from src.FasterRcnn.faster_rcnn_resnet import Faster_Rcnn_Resnet
elif config.backbone == "resnet_v1_50": elif config.backbone == "resnet_v1_50":
from src.FasterRcnn.faster_rcnn_resnet50v1 import Faster_Rcnn_Resnet from src.FasterRcnn.faster_rcnn_resnet50v1 import Faster_Rcnn_Resnet
config.epoch_size = 20 # config.epoch_size = 20
if config.device_target == "GPU": if config.device_target == "GPU":
context.set_context(enable_graph_kernel=True) context.set_context(enable_graph_kernel=True)
...@@ -122,6 +121,7 @@ def train_fasterrcnn_(): ...@@ -122,6 +121,7 @@ def train_fasterrcnn_():
def modelarts_pre_process(): def modelarts_pre_process():
config.save_checkpoint_path = config.output_path config.save_checkpoint_path = config.output_path
@moxing_wrapper(pre_process=modelarts_pre_process) @moxing_wrapper(pre_process=modelarts_pre_process)
def train_fasterrcnn(): def train_fasterrcnn():
""" train_fasterrcnn """ """ train_fasterrcnn """
...@@ -150,7 +150,6 @@ def train_fasterrcnn(): ...@@ -150,7 +150,6 @@ def train_fasterrcnn():
newkey = oldkey.replace(k, v) newkey = oldkey.replace(k, v)
param_dict[newkey] = param_dict.pop(oldkey) param_dict[newkey] = param_dict.pop(oldkey)
break break
for item in list(param_dict.keys()): for item in list(param_dict.keys()):
if not item.startswith('backbone'): if not item.startswith('backbone'):
param_dict.pop(item) param_dict.pop(item)
...@@ -186,8 +185,25 @@ def train_fasterrcnn(): ...@@ -186,8 +185,25 @@ def train_fasterrcnn():
ckpoint_cb = ModelCheckpoint(prefix='faster_rcnn', directory=save_checkpoint_path, config=ckptconfig) ckpoint_cb = ModelCheckpoint(prefix='faster_rcnn', directory=save_checkpoint_path, config=ckptconfig)
cb += [ckpoint_cb] cb += [ckpoint_cb]
if config.run_eval:
from src.eval_callback import EvalCallBack
from src.eval_utils import create_eval_mindrecord, apply_eval
config.prefix = "FasterRcnn_eval.mindrecord"
anno_json = os.path.join(config.coco_root, "annotations/instances_val2017.json")
mindrecord_path = os.path.join(config.coco_root, "FASTERRCNN_MINDRECORD", config.prefix)
config.instance_set = "annotations/instances_val2017.json"
if not os.path.exists(mindrecord_path):
config.mindrecord_file = mindrecord_path
create_eval_mindrecord(config)
eval_net = Faster_Rcnn_Resnet(config)
eval_cb = EvalCallBack(config, eval_net, apply_eval, dataset_size, mindrecord_path, anno_json,
save_checkpoint_path)
cb += [eval_cb]
model = Model(net) model = Model(net)
model.train(config.epoch_size, dataset, callbacks=cb) model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False)
if __name__ == '__main__': if __name__ == '__main__':
train_fasterrcnn() train_fasterrcnn()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment