Skip to content
Snippets Groups Projects
Select Git revision
  • 2d9841917815fb6352172aa8b09ad132f5072de1
  • master default protected
  • r1.8
  • r1.6
  • r1.9
  • r1.5
  • r1.7
  • r1.3
  • r1.4
  • r1.2
  • v1.6.0
  • v1.5.0
12 results

predict.py

Blame
  • predict.py 2.82 KiB
    # Copyright 2021 Huawei Technologies Co., Ltd
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ============================================================================
    """ prediction picture """
    import math
    import numpy as np
    
    from src.utils.transforms import transform_preds
    
    
    def get_max_preds(batch_heatmaps):
        """
        get predictions from score maps
        heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
        """
        assert isinstance(batch_heatmaps, np.ndarray), 'batch_heatmaps should be numpy.ndarray'
        assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
    
        batch_size = batch_heatmaps.shape[0]
        num_joints = batch_heatmaps.shape[1]
        width = batch_heatmaps.shape[3]
        heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
        idx = np.argmax(heatmaps_reshaped, 2)
        maxvals = np.amax(heatmaps_reshaped, 2)
    
        maxvals = maxvals.reshape((batch_size, num_joints, 1))
        idx = idx.reshape((batch_size, num_joints, 1))
    
        preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
    
        preds[:, :, 0] = (preds[:, :, 0]) % width
        preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
    
        pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
        pred_mask = pred_mask.astype(np.float32)
    
        preds *= pred_mask
        return preds, maxvals
    
    
    def get_final_preds(config, batch_heatmaps, center, scale):
        """
        get final predictions from score maps
        """
        coords, maxvals = get_max_preds(batch_heatmaps)
        heatmap_height = batch_heatmaps.shape[2]
        heatmap_width = batch_heatmaps.shape[3]
    
        # post-processing
        if config.TEST.POST_PROCESS:
            for n in range(coords.shape[0]):
                for p in range(coords.shape[1]):
                    hm = batch_heatmaps[n][p]
                    px = int(math.floor(coords[n][p][0] + 0.5))
                    py = int(math.floor(coords[n][p][1] + 0.5))
                    if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
                        diff = np.array([hm[py][px + 1] - hm[py][px - 1],
                                         hm[py + 1][px] - hm[py - 1][px]])
                        coords[n][p] += np.sign(diff) * .25
    
        preds = coords.copy()
    
        # Transform back
        for i in range(coords.shape[0]):
            preds[i] = transform_preds(coords[i], center[i], scale[i],
                                       [heatmap_width, heatmap_height])
    
        return preds, maxvals