Skip to content
Snippets Groups Projects
coco.py 4.37 KiB
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""coco"""
from __future__ import division

import json
import os
import pickle
from collections import defaultdict, OrderedDict
import numpy as np

try:
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval

    has_coco = True
except ImportError:
    has_coco = False

from src.utils.nms import oks_nms


def _write_coco_keypoint_results(img_kpts, num_joints, res_file):
    results = []

    for img, items in img_kpts.items():
        item_size = len(items)
        if not items:
            continue
        kpts = np.array([items[k]['keypoints']
                         for k in range(item_size)])
        keypoints = np.zeros((item_size, num_joints * 3), dtype=np.float)
        keypoints[:, 0::3] = kpts[:, :, 0]
        keypoints[:, 1::3] = kpts[:, :, 1]
        keypoints[:, 2::3] = kpts[:, :, 2]

        result = [{'image_id': int(img),
                   'keypoints': list(keypoints[k]),
                   'score': items[k]['score'],
                   'category_id': 1,
                   } for k in range(item_size)]
        results.extend(result)

    with open(res_file, 'w') as f:
        json.dump(results, f, sort_keys=True, indent=4)


def _do_python_keypoint_eval(res_file, res_folder, ann_path):
    coco = COCO(ann_path)
    coco_dt = coco.loadRes(res_file)
    coco_eval = COCOeval(coco, coco_dt, 'keypoints')
    coco_eval.params.useSegm = None
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()
    stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']

    info_str = []
    for ind, name in enumerate(stats_names):
        info_str.append((name, coco_eval.stats[ind]))

    eval_file = os.path.join(
        res_folder, 'keypoints_results.pkl')

    with open(eval_file, 'wb') as f:
        pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
    print('coco eval results saved to %s' % eval_file)

    return info_str


def evaluate(cfg, preds, output_dir, all_boxes, img_id, ann_path):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    res_file = os.path.join(output_dir, 'keypoints_results.json')
    img_kpts_dict = defaultdict(list)
    for idx, file_id in enumerate(img_id):
        img_kpts_dict[file_id].append({
            'keypoints': preds[idx],
            'area': all_boxes[idx][0],
            'score': all_boxes[idx][1],
        })

    # rescoring and oks nms
    num_joints = cfg.MODEL.NUM_JOINTS
    in_vis_thre = cfg.TEST.IN_VIS_THRE
    oks_thre = cfg.TEST.OKS_THRE
    oks_nmsed_kpts = {}
    for img, items in img_kpts_dict.items():
        for item in items:
            kpt_score = 0
            valid_num = 0
            for n_jt in range(num_joints):
                max_jt = item['keypoints'][n_jt][2]
                if max_jt > in_vis_thre:
                    kpt_score = kpt_score + max_jt
                    valid_num = valid_num + 1
            if valid_num != 0:
                kpt_score = kpt_score / valid_num
            item['score'] = kpt_score * item['score']
        keep = oks_nms(items, oks_thre)
        if not keep:
            oks_nmsed_kpts[img] = items
        else:
            oks_nmsed_kpts[img] = [items[kep] for kep in keep]

    # evaluate and save
    image_set = cfg.DATASET.TEST_SET
    _write_coco_keypoint_results(oks_nmsed_kpts, num_joints, res_file)
    if 'test' not in image_set and has_coco:
        ann_path = ann_path if ann_path else os.path.join(cfg.DATASET.ROOT, 'annotations',
                                                          'person_keypoints_' + image_set + '.json')
        info_str = _do_python_keypoint_eval(res_file, output_dir, ann_path)
        name_value = OrderedDict(info_str)
        return name_value, name_value['AP']
    return {'Null': 0}, 0