diff --git a/official/cv/faster_rcnn/default_config.yaml b/official/cv/faster_rcnn/default_config.yaml index 35314cdf52793976e8fc192877a4c8857c2fdf26..ae0d4e1e7e7cd16d624866ec5d32f9582d4857e1 100644 --- a/official/cv/faster_rcnn/default_config.yaml +++ b/official/cv/faster_rcnn/default_config.yaml @@ -18,12 +18,6 @@ flip_ratio: 0.5 expand_ratio: 1.0 # anchor -feature_shapes: -- [192, 320] -- [96, 160] -- [48, 80] -- [24, 40] -- [12, 20] anchor_scales: [8] anchor_ratios: [0.5, 1.0, 2.0] anchor_strides: [4, 8, 16, 32, 64] @@ -52,7 +46,6 @@ rpn_target_stds: [1.0, 1.0, 1.0, 1.0] neg_iou_thr: 0.3 pos_iou_thr: 0.7 min_pos_iou: 0.3 -num_bboxes: 245520 num_gts: 128 num_expected_neg: 256 num_expected_pos: 128 @@ -121,7 +114,9 @@ batch_size: 2 loss_scale: 256 momentum: 0.91 weight_decay: 0.00001 -epoch_size: 12 +epoch_size: 20 +run_eval: False +interval: 1 save_checkpoint: True save_checkpoint_epochs: 1 keep_checkpoint_max: 5 @@ -152,6 +147,7 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] num_classes: 81 +prefix: "" # annotations file(json format or user defined text format) anno_path: '' @@ -209,6 +205,7 @@ checkpoint_path: "Checkpoint file path." ckpt_file: 'fasterrcnn ckpt file.' result_path: "result file path." backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152" +interval: "val interval" --- device_target: ['Ascend', 'GPU', 'CPU'] diff --git a/official/cv/faster_rcnn/default_config_101.yaml b/official/cv/faster_rcnn/default_config_101.yaml index 48ef8396a88f14cf7e39df650e43eecc74ccc47a..4dabdbed681baa06bf860818a3349b6dabc834cd 100644 --- a/official/cv/faster_rcnn/default_config_101.yaml +++ b/official/cv/faster_rcnn/default_config_101.yaml @@ -19,12 +19,6 @@ flip_ratio: 0.5 expand_ratio: 1.0 # anchor -feature_shapes: -- [192, 320] -- [96, 160] -- [48, 80] -- [24, 40] -- [12, 20] anchor_scales: [8] anchor_ratios: [0.5, 1.0, 2.0] anchor_strides: [4, 8, 16, 32, 64] @@ -53,7 +47,6 @@ rpn_target_stds: [1.0, 1.0, 1.0, 1.0] neg_iou_thr: 0.3 pos_iou_thr: 0.7 min_pos_iou: 0.3 -num_bboxes: 245520 num_gts: 128 num_expected_neg: 256 num_expected_pos: 128 @@ -123,6 +116,8 @@ loss_scale: 256 momentum: 0.91 weight_decay: 0.00001 epoch_size: 20 +run_eval: False +interval: 1 save_checkpoint: True save_checkpoint_epochs: 1 keep_checkpoint_max: 5 @@ -153,6 +148,7 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] num_classes: 81 +prefix: "" # annotations file(json format or user defined text format) anno_path: '' @@ -210,6 +206,7 @@ checkpoint_path: "Checkpoint file path." ckpt_file: 'fasterrcnn ckpt file.' result_path: "result file path." backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152" +interval: "val interval" --- device_target: ['Ascend', 'GPU', 'CPU'] diff --git a/official/cv/faster_rcnn/default_config_152.yaml b/official/cv/faster_rcnn/default_config_152.yaml index 65ea34a74857f2e73d6ac92618bc577b68888730..a52360639d608f223d9bb17d12e4e52b38e94b39 100644 --- a/official/cv/faster_rcnn/default_config_152.yaml +++ b/official/cv/faster_rcnn/default_config_152.yaml @@ -19,12 +19,6 @@ flip_ratio: 0.5 expand_ratio: 1.0 # anchor -feature_shapes: -- [192, 320] -- [96, 160] -- [48, 80] -- [24, 40] -- [12, 20] anchor_scales: [8] anchor_ratios: [0.5, 1.0, 2.0] anchor_strides: [4, 8, 16, 32, 64] @@ -53,7 +47,6 @@ rpn_target_stds: [1.0, 1.0, 1.0, 1.0] neg_iou_thr: 0.3 pos_iou_thr: 0.7 min_pos_iou: 0.3 -num_bboxes: 245520 num_gts: 128 num_expected_neg: 256 num_expected_pos: 128 @@ -123,6 +116,8 @@ loss_scale: 256 momentum: 0.91 weight_decay: 0.00001 epoch_size: 20 +run_eval: False +interval: 1 save_checkpoint: True save_checkpoint_epochs: 1 keep_checkpoint_max: 5 @@ -153,6 +148,7 @@ coco_classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] num_classes: 81 +prefix: "" # annotations file(json format or user defined text format) anno_path: '' @@ -210,6 +206,7 @@ checkpoint_path: "Checkpoint file path." ckpt_file: 'fasterrcnn ckpt file.' result_path: "result file path." backbone: "backbone network name, options:resnet_v1_50, resnet_v1.5_50, resnet_v1_101, resnet_v1_152" +interval: "val interval" --- device_target: ['Ascend', 'GPU', 'CPU'] diff --git a/official/cv/faster_rcnn/eval.py b/official/cv/faster_rcnn/eval.py index 5f867621e9e21c19eb38d8487a29bd059dc3fb70..cf501f025190f4ee8344e2458bf0fdd894753a6c 100644 --- a/official/cv/faster_rcnn/eval.py +++ b/official/cv/faster_rcnn/eval.py @@ -78,7 +78,6 @@ def fasterrcnn_eval(dataset_path, ckpt_path, anno_path): max_num = 128 for data in ds.create_dict_iterator(num_epochs=1): eval_iter = eval_iter + 1 - img_data = data['image'] img_metas = data['image_shape'] gt_bboxes = data['box'] @@ -117,12 +116,12 @@ def fasterrcnn_eval(dataset_path, ckpt_path, anno_path): eval_types = ["bbox"] result_files = results2json(dataset_coco, outputs, "./results.pkl") - coco_eval(result_files, eval_types, dataset_coco, single_result=True) + coco_eval(config, result_files, eval_types, dataset_coco, single_result=False) def modelarts_pre_process(): pass - # config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path) + @moxing_wrapper(pre_process=modelarts_pre_process) def eval_fasterrcnn(): @@ -154,5 +153,36 @@ def eval_fasterrcnn(): print("Start Eval!") fasterrcnn_eval(mindrecord_file, config.checkpoint_path, config.anno_path) + flags = [0] * 3 + config.eval_result_path = os.path.abspath("./eval_result") + if os.path.exists(config.eval_result_path): + result_files = os.listdir(config.eval_result_path) + for file in result_files: + if file == "statistics.csv": + with open(os.path.join(config.eval_result_path, "statistics.csv"), "r") as f: + res = f.readlines() + if len(res) > 1: + if "class_name" in res[3] and "tp_num" in res[3] and len(res[4].strip().split(",")) > 1: + flags[0] = 1 + elif file in ("precision_ng_images", "recall_ng_images", "ok_images"): + imgs = os.listdir(os.path.join(config.eval_result_path, file)) + if imgs: + flags[1] = 1 + + elif file == "pr_curve_image": + imgs = os.listdir(os.path.join(config.eval_result_path, "pr_curve_image")) + if imgs: + flags[2] = 1 + else: + pass + + if sum(flags) == 3: + print("eval success.") + exit(0) + else: + print("eval failed.") + exit(-1) + + if __name__ == '__main__': eval_fasterrcnn() diff --git a/official/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py b/official/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py index 991ad33d92788133fc96ebe5636b76131760efdd..2cdaef7873a01d6026e03f2687adf21a5e95e2ca 100644 --- a/official/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py +++ b/official/cv/faster_rcnn/src/FasterRcnn/fpn_neck.py @@ -20,6 +20,7 @@ from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.common import dtype as mstype from mindspore.common.initializer import initializer +from src.model_utils.config import config def bias_init_zeros(shape): @@ -84,9 +85,9 @@ class FeatPyramidNeck(nn.Cell): self.fpn_convs_.append(fpn_conv) self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) - self.interpolate1 = P.ResizeNearestNeighbor((48, 80)) - self.interpolate2 = P.ResizeNearestNeighbor((96, 160)) - self.interpolate3 = P.ResizeNearestNeighbor((192, 320)) + self.interpolate1 = P.ResizeNearestNeighbor(config.feature_shapes[2]) + self.interpolate2 = P.ResizeNearestNeighbor(config.feature_shapes[1]) + self.interpolate3 = P.ResizeNearestNeighbor(config.feature_shapes[0]) self.maxpool = P.MaxPool(kernel_size=1, strides=2, pad_mode="same") def construct(self, inputs): diff --git a/official/cv/faster_rcnn/src/detecteval.py b/official/cv/faster_rcnn/src/detecteval.py new file mode 100644 index 0000000000000000000000000000000000000000..beafc8cecc71a22636d1f7339e519a7d4a0d5edb --- /dev/null +++ b/official/cv/faster_rcnn/src/detecteval.py @@ -0,0 +1,849 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from typing import List +import os +import csv +import warnings +import cv2 +import numpy as np + +from pycocotools.cocoeval import COCOeval +import matplotlib.pyplot as plt +from matplotlib import gridspec +import seaborn as sns + + +warnings.filterwarnings("ignore") +COLOR_MAP = [ + (0, 255, 255), + (0, 255, 0), + (255, 0, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 128, 128), + (0, 128, 0), + (128, 0, 0), + (0, 0, 128), + (128, 128, 0), + (128, 0, 128), +] + + +def write_list_to_csv(file_path, data_to_write, append=False): + print('Saving data into file [{}]...'.format(file_path)) + if append: + open_mode = 'a' + else: + open_mode = 'w' + with open(file_path, open_mode) as csvfile: + writer = csv.writer(csvfile) + writer.writerow(data_to_write) + + +def read_image(image_path): + image = cv2.imread(image_path) + if image is None: + return False, None + return True, image + + +def save_image(image_path, image): + return cv2.imwrite(image_path, image) + + +def draw_rectangle(image, pt1, pt2, label=None): + if label is not None: + map_index = label % len(COLOR_MAP) + color = COLOR_MAP[map_index] + else: + color = COLOR_MAP[0] + thickness = 5 + cv2.rectangle(image, pt1, pt2, color, thickness) + + +def draw_text(image, text, org, label=None): + if label is not None: + map_index = label % len(COLOR_MAP) + color = COLOR_MAP[map_index] + else: + color = COLOR_MAP[0] + font_face = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + thickness = 1 + cv2.putText(image, text, org, font_face, font_scale, color, thickness) + + +def draw_one_box(image, label, box, cat_id, line_thickness=None): + tl = line_thickness or round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1 + if cat_id is not None: + map_index = cat_id % len(COLOR_MAP) + color = COLOR_MAP[map_index] + else: + color = COLOR_MAP[0] + c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(image, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + + tf = max(tl - 1, 1) + t_size = cv2.getTextSize(label, 0, fontScale=tl / 6, thickness=tf // 2)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(image, c1, c2, color, -1, cv2.LINE_AA) + cv2.putText(image, label, (c1[0], c1[1] - 2), 0, tl / 6, [255, 255, 255], thickness=tf // 2, lineType=cv2.LINE_AA) + + +class DetectEval(COCOeval): + def __init__(self, cocoGt=None, cocoDt=None, iouType="bbox"): + assert iouType == "bbox", "iouType only supported bbox" + + super().__init__(cocoGt, cocoDt, iouType) + if not self.cocoGt is None: + cat_infos = cocoGt.loadCats(cocoGt.getCatIds()) + self.params.labels = {} + # self.params.labels = ["" for i in range(len(self.params.catIds))] + for cat in cat_infos: + self.params.labels[cat["id"]] = cat["name"] + + # add new + def catId_summarize(self, catId, iouThr=None, areaRng="all", maxDets=100): + p = self.params + aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + + s = self.eval["recall"] + if iouThr is not None: + iou = np.where(iouThr == p.iouThrs)[0] + s = s[iou] + + if isinstance(catId, int): + s = s[:, catId, aind, mind] + else: + s = s[:, :, aind, mind] + + not_empty = len(s[s > -1]) == 0 + if not_empty: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + return mean_s + + def compute_gt_dt_num(self): + p = self.params + catIds_gt_num = {} + catIds_dt_num = {} + + for ids in p.catIds: + gts_cat_id = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(catIds=[ids])) + dts_cat_id = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(catIds=[ids])) + catIds_gt_num[ids] = len(gts_cat_id) + catIds_dt_num[ids] = len(dts_cat_id) + + return catIds_gt_num, catIds_dt_num + + def evaluate_ok_ng(self, img_id, catIds, iou_threshold=0.5): + """ + evaluate every if this image is ok、precision_ng、recall_ng + img_id: int + cat_ids:list + iou_threshold:int + """ + p = self.params + img_id = int(img_id) + + # Save the results of precision_ng and recall_ng for each category on a picture + cat_id_result = {} + for cat_id in catIds: + gt = self._gts[img_id, cat_id] + dt = self._dts[img_id, cat_id] + ious = self.computeIoU(img_id, cat_id) + + # Sort dt in descending order, and only take the first 100 + inds = np.argsort([-d['score'] for d in dt], kind='mergesort') + dt = [dt[i] for i in inds] + + # p.maxDets must be set in ascending order + if len(dt) > p.maxDets[-1]: + dt = dt[0:p.maxDets[-1]] + + # The first case: gt, dt are both 0: skip + if not gt and not dt: + cat_id_result[cat_id] = (False, False) + continue + # The second case: gt = 0, dt !=0: precision_ng + if not gt and dt: + cat_id_result[cat_id] = (True, False) + continue + # The third case: gt != 0, dt = 0: recall_ng + if gt and not dt: + cat_id_result[cat_id] = (False, True) + continue + # The fourth case: gt and dt are matched in pairs + gtm = [0] * len(gt) + dtm = [0] * len(dt) + + for dind in range(len(dt)): + # dt:[a] gt [b] ious = [a*b] + iou = min([iou_threshold, 1 - 1e-10]) + # m records the position of the gt with the best match + m = -1 + for gind in range(len(gt)): + # If gt[gind] already matches, skip it. + if gtm[gind] > 0: + continue + # If the iou(dind, gind) is less than the threshold, traverse + if ious[dind, gind] < iou: + continue + iou = ious[dind, gind] + m = gind + if m == -1: + continue + dtm[dind] = 1 + gtm[m] = 1 + + # If gt is all matched, gtm is all 1 + precision_ng = sum(dtm) < len(dtm) + recall_ng = sum(gtm) < len(gtm) + cat_id_result[cat_id] = (precision_ng, recall_ng) + + # As long as the precision_ng in a class is True, the picture is precision_ng, and recall_ng is the same + # Subsequent development of NG pictures for each category can be saved + precision_result = False + recall_result = False + for ng in cat_id_result.values(): + precision_ng = ng[0] + recall_ng = ng[1] + if precision_ng: + precision_result = precision_ng + if recall_ng: + recall_result = recall_ng + return precision_result, recall_result + + def evaluate_every_class(self): + """ + compute every class's: + [label, tp_num, gt_num, dt_num, precision, recall] + """ + print("Evaluate every class's predision and recall") + p = self.params + cat_ids = p.catIds + labels = p.labels + result = [] + catIds_gt_num, catIds_dt_num = self.compute_gt_dt_num() + sum_gt_num = 0 + sum_dt_num = 0 + for value in catIds_gt_num.values(): + sum_gt_num += value + for value in catIds_dt_num.values(): + sum_dt_num += value + sum_tp_num = 0 + + for i, cat_id in enumerate(cat_ids): + # Here is hard-coded + stats = self.catId_summarize(catId=i) + recall = stats + gt_num = catIds_gt_num[cat_id] + tp_num = recall * gt_num + sum_tp_num += tp_num + dt_num = catIds_dt_num[cat_id] + if dt_num <= 0: + if gt_num == 0: + precision = -1 + else: + precision = 0 + else: + precision = tp_num / dt_num + label = labels[cat_id] + class_result = [label, int(round(tp_num)), gt_num, int(round(dt_num)), round(precision, 3), + round(recall, 3)] + result.append(class_result) + all_precision = sum_tp_num / sum_dt_num + all_recall = sum_tp_num / sum_gt_num + all_result = ["all", int(round(sum_tp_num)), sum_gt_num, int(round(sum_dt_num)), round(all_precision, 3), + round(all_recall, 3)] + result.append(all_result) + + print("Done") + return result + + def plot_pr_curve(self, eval_result_path): + + """ + precisions[T, R, K, A, M] + T: iou thresholds [0.5 : 0.05 : 0.95], idx from 0 to 9 + R: recall thresholds [0 : 0.01 : 1], idx from 0 to 100 + K: category, idx from 0 to ... + A: area range, (all, small, medium, large), idx from 0 to 3 + M: max dets, (1, 10, 100), idx from 0 to 2 + """ + print("Plot pr curve about every class") + precisions = self.eval["precision"] + p = self.params + cat_ids = p.catIds + labels = p.labels + + pr_dir = os.path.join(eval_result_path, "./pr_curve_image") + if not os.path.exists(pr_dir): + os.mkdir(pr_dir) + + for i, cat_id in enumerate(cat_ids): + pr_array1 = precisions[0, :, i, 0, 2] # iou = 0.5 + x = np.arange(0.0, 1.01, 0.01) + # plot PR curve + plt.plot(x, pr_array1, label="iou=0.5," + labels[cat_id]) + plt.xlabel("recall") + plt.ylabel("precision") + plt.xlim(0, 1.0) + plt.ylim(0, 1.01) + plt.grid(True) + plt.legend(loc="lower left") + plt_path = os.path.join(pr_dir, "pr_curve_" + labels[cat_id] + ".png") + plt.savefig(plt_path) + plt.close(1) + print("Done") + + def save_images(self, config, eval_result_path, iou_threshold=0.5): + """ + save ok_images, precision_ng_images, recall_ng_images + Arguments: + config: dict, config about parameters + eval_result_path: str, path to save images + iou_threshold: int, iou_threshold + """ + print("Saving images of ok ng") + p = self.params + img_ids = p.imgIds + cat_ids = p.catIds if p.useCats else [-1] # list: [0,1,2,3....] + labels = p.labels + + dt = self.cocoDt.getAnnIds() + dts = self.cocoDt.loadAnns(dt) + + for img_id in img_ids: + img_id = int(img_id) + img_info = self.cocoGt.loadImgs(img_id) + + if config.dataset == "coco": + im_path_dir = os.path.join(config.coco_root, config.val_data_type) + elif config.dataset == "voc": + im_path_dir = os.path.join(config.voc_root, 'eval', "JPEGImages") + + assert config.dataset in ("coco", "voc") + + # Return whether the image is precision_ng or recall_ng + precision_ng, recall_ng = self.evaluate_ok_ng(img_id, cat_ids, iou_threshold) + # Save to ok_images + if not precision_ng and not recall_ng: + # origin image path + im_path = os.path.join(im_path_dir, img_info[0]['file_name']) + # output image path + im_path_out_dir = os.path.join(eval_result_path, 'ok_images') + if not os.path.exists(im_path_out_dir): + os.makedirs(im_path_out_dir) + im_path_out = os.path.join(im_path_out_dir, img_info[0]['file_name']) + + success, image = read_image(im_path) + assert success + + for obj in dts: + _id = obj["image_id"] + if _id == img_id: + bbox = obj["bbox"] + score = obj["score"] + category_id = obj["category_id"] + label = labels[category_id] + + xmin = int(bbox[0]) + ymin = int(bbox[1]) + width = int(bbox[2]) + height = int(bbox[3]) + xmax = xmin + width + ymax = ymin + height + + label = label + " " + str(round(score, 3)) + draw_one_box(image, label, (xmin, ymin, xmax, ymax), category_id) + save_image(im_path_out, image) + else: + # Save to precision_ng_images + if precision_ng: + # origin image path + im_path = os.path.join(im_path_dir, img_info[0]['file_name']) + # output image path + im_path_out_dir = os.path.join(eval_result_path, 'precision_ng_images') + if not os.path.exists(im_path_out_dir): + os.makedirs(im_path_out_dir) + im_path_out = os.path.join(im_path_out_dir, img_info[0]['file_name']) + + success, image = read_image(im_path) + assert success + + for obj in dts: + _id = obj["image_id"] + if _id == img_id: + bbox = obj["bbox"] + score = obj["score"] + category_id = obj["category_id"] + label = labels[category_id] + + xmin = int(bbox[0]) + ymin = int(bbox[1]) + width = int(bbox[2]) + height = int(bbox[3]) + xmax = xmin + width + ymax = ymin + height + + label = label + " " + str(round(score, 3)) + draw_one_box(image, label, (xmin, ymin, xmax, ymax), category_id) + save_image(im_path_out, image) + + # Save to recall_ng_images + if recall_ng: + # origin image path + im_path = os.path.join(im_path_dir, img_info[0]['file_name']) + # output image path + im_path_out_dir = os.path.join(eval_result_path, 'recall_ng_images') + if not os.path.exists(im_path_out_dir): + os.makedirs(im_path_out_dir) + + im_path_out = os.path.join(im_path_out_dir, img_info[0]['file_name']) + success, image = read_image(im_path) + if not success: + raise Exception('Failed reading image from [{}]'.format(im_path)) + for obj in dts: + _id = obj["image_id"] + if _id == img_id: + bbox = obj["bbox"] + score = obj["score"] + category_id = obj["category_id"] + label = labels[category_id] + + xmin = int(bbox[0]) + ymin = int(bbox[1]) + width = int(bbox[2]) + height = int(bbox[3]) + xmax = xmin + width + ymax = ymin + height + + label = label + " " + str(round(score, 3)) + draw_one_box(image, label, (xmin, ymin, xmax, ymax), category_id) + save_image(im_path_out, image) + + print("Done") + + def compute_precison_recall_f1(self, min_score=0.1): + print('Compute precision, recall, f1...') + if not self.evalImgs: + print('Please run evaluate() first') + p = self.params + catIds = p.catIds if p.useCats == 1 else [-1] + labels = p.labels + + assert len(p.maxDets) == 1 + assert len(p.iouThrs) == 1 + assert len(p.areaRng) == 1 + + # get inds to evaluate + k_list = [n for n, k in enumerate(p.catIds)] + m_list = [m for n, m in enumerate(p.maxDets)] + a_list: List[int] = [n for n, a in enumerate(p.areaRng)] + i_list = [n for n, i in enumerate(p.imgIds)] + I0 = len(p.imgIds) + A0 = len(p.areaRng) + + # cat_pr_dict: + # {label1:[precision_li, recall_li, f1_li, score_li], label2:[precision_li, recall_li, f1_li, score_li]} + cat_pr_dict = {} + cat_pr_dict_origin = {} + + for k0 in k_list: + Nk = k0 * A0 * I0 + # areagRng + for a0 in a_list: + Na = a0 * I0 + # maxDet + for maxDet in m_list: + E = [self.evalImgs[Nk + Na + i] for i in i_list] + E = [e for e in E if not e is None] + if not E: + continue + dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E]) + + # different sorting method generates slightly different results. + # mergesort is used to be consistent as Matlab implementation. + inds = np.argsort(-dtScores, kind='mergesort') + dtScoresSorted = dtScores[inds] + + dtm = np.concatenate([e['dtMatches'][:, 0:maxDet] for e in E], axis=1)[:, inds] + dtIg = np.concatenate([e['dtIgnore'][:, 0:maxDet] for e in E], axis=1)[:, inds] + gtIg = np.concatenate([e['gtIgnore'] for e in E]) + npig = np.count_nonzero(gtIg == 0) + if npig == 0: + continue + tps = np.logical_and(dtm, np.logical_not(dtIg)) + fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg)) + + # Ensure that iou has only one value + assert (tps.shape[0]) == 1 + assert (fps.shape[0]) == 1 + + tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float) + fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float) + ids = catIds[k0] + label = labels[ids] + + self.calculate_pr_dict(tp_sum, fp_sum, label, npig, dtScoresSorted, cat_pr_dict, cat_pr_dict_origin, + min_score=min_score) + print("Done") + return cat_pr_dict, cat_pr_dict_origin + + def calculate_pr_dict(self, tp_sum, fp_sum, label, npig, dtScoresSorted, cat_pr_dict, cat_pr_dict_origin, + min_score=0.1): + # iou + for (tp, fp) in zip(tp_sum, fp_sum): + tp = np.array(tp) + fp = np.array(fp) + rc = tp / npig + pr = tp / (fp + tp + np.spacing(1)) + + f1 = np.divide(2 * (rc * pr), pr + rc, out=np.zeros_like(2 * (rc * pr)), where=pr + rc != 0) + + conf_thres = [int(i) * 0.01 for i in range(10, 100, 10)] + dtscores_ascend = dtScoresSorted[::-1] + inds = np.searchsorted(dtscores_ascend, conf_thres, side='left') + pr_new = [0.0] * len(conf_thres) + rc_new = [0.0] * len(conf_thres) + f1_new = [0.0] * len(conf_thres) + pr_ascend = pr[::-1] + rc_ascend = rc[::-1] + f1_ascend = f1[::-1] + try: + for i, ind in enumerate(inds): + if conf_thres[i] >= min_score: + pr_new[i] = pr_ascend[ind] + rc_new[i] = rc_ascend[ind] + f1_new[i] = f1_ascend[ind] + else: + pr_new[i] = 0.0 + rc_new[i] = 0.0 + f1_new[i] = 0.0 + except IndexError: + pass + # Ensure that the second, third, and fourth for loops only enter once + if label not in cat_pr_dict.keys(): + cat_pr_dict_origin[label] = [pr[::-1], rc[::-1], f1[::-1], dtScoresSorted[::-1]] + cat_pr_dict[label] = [pr_new, rc_new, f1_new, conf_thres] + else: + break + + def compute_tp_fp_confidence(self): + print('Compute tp and fp confidences') + if not self.evalImgs: + print('Please run evaluate() first') + p = self.params + catIds = p.catIds if p.useCats == 1 else [-1] + labels = p.labels + + assert len(p.maxDets) == 1 + assert len(p.iouThrs) == 1 + assert len(p.areaRng) == 1 + + # get inds to evaluate + m_list = [m for n, m in enumerate(p.maxDets)] + k_list = list(range(len(p.catIds))) + a_list = list(range(len(p.areaRng))) + i_list = list(range(len(p.imgIds))) + + I0 = len(p.imgIds) + A0 = len(p.areaRng) + # cat_dict + correct_conf_dict = {} + incorrect_conf_dict = {} + + for k0 in k_list: + Nk = k0 * A0 * I0 + # areagRng + for a0 in a_list: + Na = a0 * I0 + # maxDet + for maxDet in m_list: + E = [self.evalImgs[Nk + Na + i] for i in i_list] + E = [e for e in E if not e is None] + if not E: + continue + dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E]) + + inds = np.argsort(-dtScores, kind='mergesort') + dtScoresSorted = dtScores[inds] + + dtm = np.concatenate([e['dtMatches'][:, 0:maxDet] for e in E], axis=1)[:, inds] + dtIg = np.concatenate([e['dtIgnore'][:, 0:maxDet] for e in E], axis=1)[:, inds] + gtIg = np.concatenate([e['gtIgnore'] for e in E]) + npig = np.count_nonzero(gtIg == 0) + if npig == 0: + continue + tps = np.logical_and(dtm, np.logical_not(dtIg)) + fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg)) + + # Ensure that iou has only one value + assert (tps.shape[0]) == 1 + assert (fps.shape[0]) == 1 + + tp_inds = np.where(tps) + fp_inds = np.where(fps) + + tp_confidence = dtScoresSorted[tp_inds[1]] + fp_confidence = dtScoresSorted[fp_inds[1]] + tp_confidence_li = tp_confidence.tolist() + fp_confidence_li = fp_confidence.tolist() + ids = catIds[k0] + label = labels[ids] + + # Ensure that the second and third for loops only enter once + if label not in correct_conf_dict.keys(): + correct_conf_dict[label] = tp_confidence_li + else: + print("maxDet:", maxDet, " ", "areagRng:", p.areagRng) + break + + if label not in incorrect_conf_dict.keys(): + incorrect_conf_dict[label] = fp_confidence_li + else: + print("maxDet:", maxDet, " ", "areagRng:", p.areagRng) + break + print("Done") + return correct_conf_dict, incorrect_conf_dict + + def write_best_confidence_threshold(self, cat_pr_dict, cat_pr_dict_origin, eval_result_path): + """ + write best confidence threshold + """ + print("Write best confidence threshold to csv") + result_csv = os.path.join(eval_result_path, "best_threshold.csv") + result = ["cat_name", "best_f1", "best_precision", "best_recall", "best_score"] + write_list_to_csv(result_csv, result, append=False) + return_result = [] + for cat_name, cat_info in cat_pr_dict.items(): + f1_li = cat_info[2] + score_li = cat_info[3] + max_f1 = [f1 for f1 in f1_li if abs(f1 - max(f1_li)) <= 0.001] + thre_ = [0.003] + [int(i) * 0.001 for i in range(10, 100, 10)] + [0.099] + # Find the best confidence threshold for 10 levels of confidence thresholds + if len(max_f1) == 1: + # max_f1 is on the far right + if f1_li.index(max_f1) == len(f1_li) - 1: + index = f1_li.index(max_f1) - 1 + # max_f1 is in the middle + elif f1_li.index(max_f1) != len(f1_li) - 1 and f1_li.index(max_f1) != 0: + index_a = f1_li.index(max_f1) - 1 + index_b = f1_li.index(max_f1) + 1 + if f1_li[index_a] >= f1_li[index_b]: + index = index_a + else: + index = f1_li.index(max_f1) + # max_f1 is on the far left + elif f1_li.index(max_f1) == 0: + index = f1_li.index(max_f1) + + best_thre = score_li[index] + # thre_ = [0.003] + [int(i) * 0.001 for i in range(10, 100, 10)] + [0.099] + second_thre = [best_thre + i for i in thre_] + + elif len(max_f1) > 1: + thre_pre = [index for (index, value) in enumerate(f1_li) if abs(value - max(f1_li)) <= 0.001] + best_thre = score_li[thre_pre[int((len(thre_pre) - 1) / 2)]] + # thre_ = [0.003] + [int(i) * 0.001 for i in range(10, 100, 10)] + [0.099] + second_thre = [best_thre + i for i in thre_] + + # Reduce the step unit to find the second confidence threshold + cat_info_origin = cat_pr_dict_origin[cat_name] + dtscores_ascend = cat_info_origin[3] + inds = np.searchsorted(dtscores_ascend, second_thre, side='left') + + pr_second = [0] * len(second_thre) + rc_second = [0] * len(second_thre) + f1_second = [0] * len(second_thre) + + try: + for i, ind in enumerate(inds): + if ind >= len(cat_info_origin[0]): + ind = len(cat_info_origin[0]) - 1 + pr_second[i] = cat_info_origin[0][ind] + rc_second[i] = cat_info_origin[1][ind] + f1_second[i] = cat_info_origin[2][ind] + except IndexError: + pass + + best_f1 = max(f1_second) + best_index = f1_second.index(best_f1) + best_precision = pr_second[best_index] + best_recall = rc_second[best_index] + best_score = second_thre[best_index] + result = [cat_name, best_f1, best_precision, best_recall, best_score] + return_result.append(result) + write_list_to_csv(result_csv, result, append=True) + return return_result + + def plot_mc_curve(self, cat_pr_dict, eval_result_path): + """ + plot matrix-confidence curve + cat_pr_dict:{"label_name":[precision, recall, f1, score]} + """ + print('Plot mc curve') + savefig_path = os.path.join(eval_result_path, 'pr_cofidence_fig') + if not os.path.exists(savefig_path): + os.mkdir(savefig_path) + + xlabel = "Confidence" + ylabel = "Metric" + for cat_name, cat_info in cat_pr_dict.items(): + precision = [round(p, 3) for p in cat_info[0]] + recall = [round(r, 3) for r in cat_info[1]] + f1 = [round(f, 3) for f in cat_info[2]] + score = [round(s, 3) for s in cat_info[3]] + plt.figure(figsize=(9, 9)) + gs = gridspec.GridSpec(4, 1) + + plt.subplot(gs[:3, 0]) + # 1.precision-confidence + plt.plot(score, precision, linewidth=2, color="deepskyblue", label="precision") + + # 2.recall-confidence + plt.plot(score, recall, linewidth=2, color="limegreen", label="recall") + + # 3.f1-confidence + plt.plot(score, f1, linewidth=2, color="tomato", label="f1_score") + + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(cat_name, fontsize=15) + + plt.xlim(0, 1) + plt.xticks((np.arange(0, 1, 0.1))) + plt.ylim(0, 1.10) + plt.legend(loc="lower left") + + row_name = ["conf_threshold", "precision", "recall", "f1"] + plt.grid(True) + plt.subplot(gs[3, 0]) + plt.axis('off') + + colors = ["white", "deepskyblue", "limegreen", "tomato"] + plt.table(cellText=[score, precision, recall, f1], rowLabels=row_name, loc='center', cellLoc='center', + rowLoc='center', rowColours=colors) + + plt.subplots_adjust(left=0.2, bottom=0.2) + plt.savefig(os.path.join(savefig_path, cat_name) + '.png', dpi=250) + print("Done") + + def plot_hist_curve(self, input_data, eval_result_path): + correct_conf_dict, incorrect_conf_dict = input_data[0], input_data[1] + savefig_path = os.path.join(eval_result_path, 'hist_curve_fig') + if not os.path.exists(savefig_path): + os.mkdir(savefig_path) + for l in correct_conf_dict.keys(): + plt.figure(figsize=(7, 7)) + if l in incorrect_conf_dict.keys() and correct_conf_dict[l] and incorrect_conf_dict[l]: + gs = gridspec.GridSpec(4, 1) + plt.subplot(gs[:3, 0]) + correct_conf_dict[l].sort() + correct_conf_dict[l].reverse() + col_name_correct = ['number', 'mean', 'max', 'min', 'min99%', 'min99.9%'] + col_val_correct = [len(correct_conf_dict[l]), + ('%.2f' % np.mean(correct_conf_dict[l])), + ('%.2f' % max(correct_conf_dict[l])), ('%.2f' % min(correct_conf_dict[l])), + ('%.2f' % correct_conf_dict[l][int(len(correct_conf_dict[l]) * 0.99) - 1]), + ('%.2f' % correct_conf_dict[l][int(len(correct_conf_dict[l]) * 0.999) - 1])] + sns.set_palette('hls') + sns.distplot(correct_conf_dict[l], bins=50, kde_kws={'color': 'b', 'lw': 3}, + hist_kws={'color': 'b', 'alpha': 0.3}) + plt.xlim((0, 1)) + plt.xlabel(l) + plt.ylabel("numbers") + ax1 = plt.twinx() + incorrect_conf_dict[l].sort() + incorrect_conf_dict[l].reverse() + col_val_incorrect = [len(incorrect_conf_dict[l]), + ('%.2f' % np.mean(incorrect_conf_dict[l])), + ('%.2f' % max(incorrect_conf_dict[l])), ('%.2f' % min(incorrect_conf_dict[l])), + ('%.2f' % incorrect_conf_dict[l][int(len(incorrect_conf_dict[l]) * 0.99) - 1]), + ('%.2f' % incorrect_conf_dict[l][int(len(incorrect_conf_dict[l]) * 0.999) - 1])] + sns.distplot(incorrect_conf_dict[l], bins=50, kde_kws={'color': 'r', 'lw': 3}, + hist_kws={'color': 'r', 'alpha': 0.3}, ax=ax1) + plt.grid(True) + plt.subplot(gs[3, 0]) + plt.axis('off') + row_name = ['', 'correct', 'incorrect'] + table = plt.table(cellText=[col_name_correct, col_val_correct, col_val_incorrect], rowLabels=row_name, + loc='center', cellLoc='center', rowLoc='center') + table.auto_set_font_size(False) + table.set_fontsize(10) + table.scale(1, 1.5) + plt.savefig(os.path.join(savefig_path, l) + '.jpg') + elif correct_conf_dict[l]: + gs = gridspec.GridSpec(4, 1) + plt.subplot(gs[:3, 0]) + correct_conf_dict[l].sort() + correct_conf_dict[l].reverse() + col_name_correct = ['number', 'mean', 'max', 'min', 'min99%', 'min99.9%'] + col_val_correct = [len(correct_conf_dict[l]), + ('%.4f' % np.mean(correct_conf_dict[l])), + ('%.4f' % max(correct_conf_dict[l])), ('%.2f' % min(correct_conf_dict[l])), + ('%.2f' % correct_conf_dict[l][int(len(correct_conf_dict[l]) * 0.99) - 1]), + ('%.2f' % correct_conf_dict[l][int(len(correct_conf_dict[l]) * 0.999) - 1])] + sns.set_palette('hls') + sns.distplot(correct_conf_dict[l], bins=50, kde_kws={'color': 'b', 'lw': 3}, + hist_kws={'color': 'b', 'alpha': 0.3}) + plt.xlim((0, 1)) + plt.xlabel(l) + plt.ylabel("numbers") + plt.grid(True) + plt.subplot(gs[3, 0]) + plt.axis('off') + row_name = ['', 'correct'] + table = plt.table(cellText=[col_name_correct, col_val_correct], rowLabels=row_name, + loc='center', cellLoc='center', rowLoc='center') + table.auto_set_font_size(False) + table.set_fontsize(10) + table.scale(1, 1.5) + plt.savefig(os.path.join(savefig_path, l) + '.jpg') + elif l in incorrect_conf_dict.keys() and incorrect_conf_dict[l]: + gs = gridspec.GridSpec(4, 1) + plt.subplot(gs[:3, 0]) + incorrect_conf_dict[l].sort() + incorrect_conf_dict[l].reverse() + col_name_correct = ['number', 'mean', 'max', 'min', 'min99%', 'min99.9%'] + col_val_correct = [len(incorrect_conf_dict[l]), + ('%.4f' % np.mean(incorrect_conf_dict[l])), + ('%.4f' % max(incorrect_conf_dict[l])), ('%.2f' % min(incorrect_conf_dict[l])), + ('%.2f' % incorrect_conf_dict[l][int(len(incorrect_conf_dict[l]) * 0.99) - 1]), + ('%.2f' % incorrect_conf_dict[l][int(len(incorrect_conf_dict[l]) * 0.999) - 1])] + sns.set_palette('hls') + sns.distplot(incorrect_conf_dict[l], bins=50, kde_kws={'color': 'b', 'lw': 3}, + hist_kws={'color': 'b', 'alpha': 0.3}) + plt.xlim((0, 1)) + plt.xlabel(l) + plt.grid(True) + plt.subplot(gs[3, 0]) + plt.axis('off') + row_name = ['', 'incorrect'] + table = plt.table(cellText=[col_name_correct, col_val_correct], rowLabels=row_name, + loc='center', cellLoc='center', rowLoc='center') + table.auto_set_font_size(False) + table.set_fontsize(10) + table.scale(1, 1.5) + plt.savefig(os.path.join(savefig_path, l) + '.jpg') + + +if __name__ == "__main__": + cocoeval = COCOeval_() diff --git a/official/cv/faster_rcnn/src/eval_callback.py b/official/cv/faster_rcnn/src/eval_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c7e89ef55ad8410c015d9f9ed4513eaa301fad --- /dev/null +++ b/official/cv/faster_rcnn/src/eval_callback.py @@ -0,0 +1,79 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import shutil +from mindspore import save_checkpoint +from mindspore.train.callback import Callback + + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + besk_ckpt_name (str): bast checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__(self, config, net, apply_eval, datasetsize, mindrecord_path, anno_json, checkpoint_path): + super(EvalCallBack, self).__init__() + self.faster_rcnn_eval = apply_eval + self.mindrecord_path = mindrecord_path + self.anno_json = anno_json + self.datasetsize = datasetsize + self.config = config + self.checkpoint_path = checkpoint_path + self.net = net + self.best_epoch = 0 + self.best_res = 0 + self.best_ckpt_path = os.path.abspath("./best_ckpt") + + def epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + ckpt_file_name = "faster_rcnn-{}_{}.ckpt".format(cur_epoch, self.datasetsize) + checkoint_path = os.path.join(self.checkpoint_path, ckpt_file_name) + self.config.current_epoch = cur_epoch + res = 0 + if self.config.current_epoch % self.config.interval == 0 or self.config.current_epoch == self.config.epoch_size: + res = self.faster_rcnn_eval(self.net, self.config, self.mindrecord_path, checkoint_path, self.anno_json) + + if res > self.best_res: + self.best_epoch = cur_epoch + self.best_res = res + + if os.path.exists(self.best_ckpt_path): + shutil.rmtree(self.best_ckpt_path) + + os.mkdir(self.best_ckpt_path) + save_checkpoint(cb_params.train_network, os.path.join(self.best_ckpt_path, "best.ckpt")) + + print("update best result: {} in the {} th epoch".format(self.best_res, self.best_epoch), flush=True) + + def end(self, run_context): + print("End training the best {0} epoch is {1}".format(self.best_res, self.best_epoch), flush=True) diff --git a/official/cv/faster_rcnn/src/eval_utils.py b/official/cv/faster_rcnn/src/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7f93b9c7375ce4ee82c0ad08ab737653bae71ce4 --- /dev/null +++ b/official/cv/faster_rcnn/src/eval_utils.py @@ -0,0 +1,166 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Coco metrics utils""" + +import os +import json +from collections import defaultdict +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +import mindspore.common.dtype as mstype +from mindspore import context + +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import Parameter +from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset, parse_json_annos_from_txt +from src.util import bbox2result_1image, results2json + + +def create_eval_mindrecord(config): + """ eval_fasterrcnn """ + print("CHECKING MINDRECORD FILES ...") + if not os.path.exists(config.mindrecord_file): + if not os.path.isdir(config.mindrecord_dir): + os.makedirs(config.mindrecord_dir) + if config.dataset == "coco": + if os.path.isdir(config.coco_root): + print("Create Mindrecord. It may take some time.") + data_to_mindrecord_byte_image(config, "coco", False, config.prefix, file_num=1) + print("Create Mindrecord Done, at {}".format(config.mindrecord_dir)) + else: + print("coco_root not exits.") + else: + if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path): + print("Create Mindrecord. It may take some time.") + data_to_mindrecord_byte_image(config, "other", False, config.prefix, file_num=1) + print("Create Mindrecord Done, at {}".format(config.mindrecord_dir)) + else: + print("IMAGE_DIR or ANNO_PATH not exits.") + + +def apply_eval(net, config, dataset_path, ckpt_path, anno_path): + """FasterRcnn evaluation.""" + if not os.path.isfile(ckpt_path): + raise RuntimeError("CheckPoint file {} is not valid.".format(ckpt_path)) + ds = create_fasterrcnn_dataset(config, dataset_path, batch_size=config.test_batch_size, is_training=False) + + param_dict = load_checkpoint(ckpt_path) + if config.device_target == "GPU": + for key, value in param_dict.items(): + tensor = value.asnumpy().astype(np.float32) + param_dict[key] = Parameter(tensor, key) + load_param_into_net(net, param_dict) + + net.set_train(False) + device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others" + if device_type == "Ascend": + net.to_float(mstype.float16) + + eval_iter = 0 + total = ds.get_dataset_size() + outputs = [] + + if config.dataset != "coco": + dataset_coco = COCO() + dataset_coco.dataset, dataset_coco.anns, dataset_coco.cats, dataset_coco.imgs = dict(), dict(), dict(), dict() + dataset_coco.imgToAnns, dataset_coco.catToImgs = defaultdict(list), defaultdict(list) + dataset_coco.dataset = parse_json_annos_from_txt(anno_path, config) + dataset_coco.createIndex() + else: + dataset_coco = COCO(anno_path) + + print("\n========================================\n") + print("total images num: ", total) + print("Processing, please wait a moment.") + max_num = 128 + for data in ds.create_dict_iterator(num_epochs=1): + eval_iter = eval_iter + 1 + + img_data = data['image'] + img_metas = data['image_shape'] + gt_bboxes = data['box'] + gt_labels = data['label'] + gt_num = data['valid_num'] + + # run net + output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num) + + # output + all_bbox = output[0] + all_label = output[1] + all_mask = output[2] + + for j in range(config.test_batch_size): + all_bbox_squee = np.squeeze(all_bbox.asnumpy()[j, :, :]) + all_label_squee = np.squeeze(all_label.asnumpy()[j, :, :]) + all_mask_squee = np.squeeze(all_mask.asnumpy()[j, :, :]) + + all_bboxes_tmp_mask = all_bbox_squee[all_mask_squee, :] + all_labels_tmp_mask = all_label_squee[all_mask_squee] + + if all_bboxes_tmp_mask.shape[0] > max_num: + inds = np.argsort(-all_bboxes_tmp_mask[:, -1]) + inds = inds[:max_num] + all_bboxes_tmp_mask = all_bboxes_tmp_mask[inds] + all_labels_tmp_mask = all_labels_tmp_mask[inds] + + outputs_tmp = bbox2result_1image(all_bboxes_tmp_mask, all_labels_tmp_mask, config.num_classes) + + outputs.append(outputs_tmp) + + eval_types = ["bbox"] + reslut_path = "./{}epoch_results.pkl".format(config.current_epoch) + result_files = results2json(dataset_coco, outputs, reslut_path) + + return metrics_map(result_files, eval_types, dataset_coco, single_result=False) + + +def metrics_map(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False): + """coco eval for fasterrcnn""" + + anns = json.load(open(result_files['bbox'])) + if not anns: + return 0 + + if isinstance(coco, str): + coco = COCO(coco) + assert isinstance(coco, COCO) + + for res_type in result_types: + result_file = result_files[res_type] + assert result_file.endswith('.json') + + coco_dets = coco.loadRes(result_file) + det_img_ids = coco_dets.getImgIds() + gt_img_ids = coco.getImgIds() + iou_type = 'bbox' if res_type == 'proposal' else res_type + cocoEval = COCOeval(coco, coco_dets, iou_type) + if res_type == 'proposal': + cocoEval.params.useCats = 0 + cocoEval.params.maxDets = list(max_dets) + + tgt_ids = gt_img_ids if not single_result else det_img_ids + + if res_type == 'proposal': + cocoEval.params.useCats = 0 + cocoEval.params.maxDets = list(max_dets) + + cocoEval.params.imgIds = tgt_ids + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + return cocoEval.stats[0] diff --git a/official/cv/faster_rcnn/src/model_utils/config.py b/official/cv/faster_rcnn/src/model_utils/config.py index e3a8262e7473c0c6da7af1d931d986d296f852c9..7b758095bb20f148518645400c14044f84ce9a33 100644 --- a/official/cv/faster_rcnn/src/model_utils/config.py +++ b/official/cv/faster_rcnn/src/model_utils/config.py @@ -120,9 +120,19 @@ def get_config(): path_args, _ = parser.parse_known_args() default, helper, choices = parse_yaml(path_args.config_path) args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) - final_config = merge(args, default) - pprint(final_config) + default = Config(merge(args, default)) + default.feature_shapes = [ + [default.img_height // 4, default.img_width // 4], + [default.img_height // 8, default.img_width // 8], + [default.img_height // 16, default.img_width // 16], + [default.img_height // 32, default.img_width // 32], + [default.img_height // 64, default.img_width // 64], + ] + default.num_bboxes = default.num_anchors * sum([lst[0] * lst[1] for lst in default.feature_shapes]) + pprint(default) print("Please check the above information for the configurations", flush=True) - return Config(final_config) + + return default + config = get_config() diff --git a/official/cv/faster_rcnn/src/util.py b/official/cv/faster_rcnn/src/util.py index 8be037cf7bdaace4d1addf021e1cf35acd64b31c..d18d490788bedf9196267c6cc85192ccf0f5d191 100644 --- a/official/cv/faster_rcnn/src/util.py +++ b/official/cv/faster_rcnn/src/util.py @@ -13,10 +13,15 @@ # limitations under the License. # ============================================================================ """coco eval for fasterrcnn""" + import json +import os +import csv +import shutil import numpy as np from pycocotools.coco import COCO -from pycocotools.cocoeval import COCOeval +from src.detecteval import DetectEval + _init_value = np.array(0.0) summary_init = { @@ -35,7 +40,18 @@ summary_init = { } -def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False): +def write_list_to_csv(file_path, data_to_write, append=False): + # print('Saving data into file [{}]...'.format(file_path)) + if append: + open_mode = 'a' + else: + open_mode = 'w' + with open(file_path, open_mode) as csvfile: + writer = csv.writer(csvfile) + writer.writerow(data_to_write) + + +def coco_eval(config, result_files, result_types, coco, max_dets=(100, 300, 1000), single_result=False): """coco eval for fasterrcnn""" anns = json.load(open(result_files['bbox'])) if not anns: @@ -53,7 +69,7 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl gt_img_ids = coco.getImgIds() det_img_ids = coco_dets.getImgIds() iou_type = 'bbox' if res_type == 'proposal' else res_type - cocoEval = COCOeval(coco, coco_dets, iou_type) + cocoEval = DetectEval(coco, coco_dets, iou_type) if res_type == 'proposal': cocoEval.params.useCats = 0 cocoEval.params.maxDets = list(max_dets) @@ -63,7 +79,7 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl if single_result: res_dict = dict() for id_i in tgt_ids: - cocoEval = COCOeval(coco, coco_dets, iou_type) + cocoEval = DetectEval(coco, coco_dets, iou_type) if res_type == 'proposal': cocoEval.params.useCats = 0 cocoEval.params.maxDets = list(max_dets) @@ -74,7 +90,7 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl cocoEval.summarize() res_dict.update({coco.imgs[id_i]['file_name']: cocoEval.stats[1]}) - cocoEval = COCOeval(coco, coco_dets, iou_type) + cocoEval = DetectEval(coco, coco_dets, iou_type) if res_type == 'proposal': cocoEval.params.useCats = 0 cocoEval.params.maxDets = list(max_dets) @@ -99,7 +115,74 @@ def coco_eval(result_files, result_types, coco, max_dets=(100, 300, 1000), singl 'Recall/AR@100 (large)': cocoEval.stats[11], } - return summary_metrics + print("summary_metrics: ") + print(summary_metrics) + + res = calcuate_pr_rc_f1(config, coco, coco_dets, tgt_ids, iou_type) + + return res + + +def calcuate_pr_rc_f1(config, coco, coco_dets, tgt_ids, iou_type): + cocoEval = DetectEval(coco, coco_dets, iou_type) + cocoEval.params.imgIds = tgt_ids + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + stats_all = cocoEval.stats + + eval_result_path = os.path.abspath("./eval_result") + if os.path.exists(eval_result_path): + shutil.rmtree(eval_result_path) + os.mkdir(eval_result_path) + + result_csv = os.path.join(eval_result_path, "statistics.csv") + eval_item = ["ap@0.5:0.95", "ap@0.5", "ap@0.75", "ar@0.5:0.95", "ar@0.5", "ar@0.75"] + write_list_to_csv(result_csv, eval_item, append=False) + eval_result = [round(stats_all[0], 3), round(stats_all[1], 3), round(stats_all[2], 3), round(stats_all[6], 3), + round(stats_all[7], 3), round(stats_all[8], 3)] + write_list_to_csv(result_csv, eval_result, append=True) + write_list_to_csv(result_csv, [], append=True) + # 1.2 plot_pr_curve + cocoEval.plot_pr_curve(eval_result_path) + + # 2 + E = DetectEval(coco, coco_dets, iou_type) + E.params.iouThrs = [0.5] + E.params.maxDets = [100] + E.params.areaRng = [[0 ** 2, 1e5 ** 2]] + E.evaluate() + # 2.1 plot hist_curve of every class's tp's confidence and fp's confidence + confidence_dict = E.compute_tp_fp_confidence() + E.plot_hist_curve(confidence_dict, eval_result_path) + + # 2.2 write best_threshold and p r to csv and plot + cat_pr_dict, cat_pr_dict_origin = E.compute_precison_recall_f1() + # E.write_best_confidence_threshold(cat_pr_dict, cat_pr_dict_origin, eval_result_path) + best_confidence_thres = E.write_best_confidence_threshold(cat_pr_dict, cat_pr_dict_origin, eval_result_path) + print("best_confidence_thres: ", best_confidence_thres) + E.plot_mc_curve(cat_pr_dict, eval_result_path) + + # 3 + # 3.1 compute every class's p r and save every class's p and r at iou = 0.5 + E = DetectEval(coco, coco_dets, iouType='bbox') + E.params.iouThrs = [0.5] + E.params.maxDets = [100] + E.params.areaRng = [[0 ** 2, 1e5 ** 2]] + E.evaluate() + E.accumulate() + result = E.evaluate_every_class() + print_info = ["class_name", "tp_num", "gt_num", "dt_num", "precision", "recall"] + write_list_to_csv(result_csv, print_info, append=True) + print("class_name", "tp_num", "gt_num", "dt_num", "precision", "recall") + for class_result in result: + print(class_result) + write_list_to_csv(result_csv, class_result, append=True) + + # 3.2 save ng / ok images + E.save_images(config, eval_result_path, 0.5) + + return stats_all[0] def xyxy2xywh(bbox): diff --git a/official/cv/faster_rcnn/train.py b/official/cv/faster_rcnn/train.py index 8ac55dabb05fe7b3bf82fdd9d6104d0d179934b7..ce2ef9def96bd00641ccf88680b754c4ba933620 100644 --- a/official/cv/faster_rcnn/train.py +++ b/official/cv/faster_rcnn/train.py @@ -36,7 +36,6 @@ from src.model_utils.config import config from src.model_utils.moxing_adapter import moxing_wrapper from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id - set_seed(1) context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id()) @@ -44,7 +43,7 @@ if config.backbone in ("resnet_v1.5_50", "resnet_v1_101", "resnet_v1_152"): from src.FasterRcnn.faster_rcnn_resnet import Faster_Rcnn_Resnet elif config.backbone == "resnet_v1_50": from src.FasterRcnn.faster_rcnn_resnet50v1 import Faster_Rcnn_Resnet - config.epoch_size = 20 + # config.epoch_size = 20 if config.device_target == "GPU": context.set_context(enable_graph_kernel=True) @@ -122,6 +121,7 @@ def train_fasterrcnn_(): def modelarts_pre_process(): config.save_checkpoint_path = config.output_path + @moxing_wrapper(pre_process=modelarts_pre_process) def train_fasterrcnn(): """ train_fasterrcnn """ @@ -150,7 +150,6 @@ def train_fasterrcnn(): newkey = oldkey.replace(k, v) param_dict[newkey] = param_dict.pop(oldkey) break - for item in list(param_dict.keys()): if not item.startswith('backbone'): param_dict.pop(item) @@ -186,8 +185,25 @@ def train_fasterrcnn(): ckpoint_cb = ModelCheckpoint(prefix='faster_rcnn', directory=save_checkpoint_path, config=ckptconfig) cb += [ckpoint_cb] + if config.run_eval: + from src.eval_callback import EvalCallBack + from src.eval_utils import create_eval_mindrecord, apply_eval + config.prefix = "FasterRcnn_eval.mindrecord" + anno_json = os.path.join(config.coco_root, "annotations/instances_val2017.json") + mindrecord_path = os.path.join(config.coco_root, "FASTERRCNN_MINDRECORD", config.prefix) + config.instance_set = "annotations/instances_val2017.json" + + if not os.path.exists(mindrecord_path): + config.mindrecord_file = mindrecord_path + create_eval_mindrecord(config) + eval_net = Faster_Rcnn_Resnet(config) + eval_cb = EvalCallBack(config, eval_net, apply_eval, dataset_size, mindrecord_path, anno_json, + save_checkpoint_path) + cb += [eval_cb] + model = Model(net) - model.train(config.epoch_size, dataset, callbacks=cb) + model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=False) + if __name__ == '__main__': train_fasterrcnn()