Source code for akida_models.centernet.centernet_processing

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Processing tools for CenterNet data handling.
"""

__all__ = ["decode_output"]

import numpy as np
from tensorflow.nn import max_pool2d

from ..detection.processing import BoundingBox


[docs]def decode_output(output, nb_classes, obj_threshold=0.1, max_detections=100, kernel=5): """ Decodes a CenterNet model. Args: output (tf.Tensor): model output to decode. nb_classes (int): number of classes. obj_threshold (float, optional): confidence threshold for a box. Defaults to 0.1. max_detection (int, optional): maximum number of boxes the model is allowed to produce. Defaults to 100. kernel (int, optional): max pool kernel size. Defaults to 5. Returns: List: `BoundingBox` objects """ def _sigmoid(x): return 1. / (1. + np.exp(-x)) grid_h, grid_w = output.shape[:2] # Decode the output of the network center_heatmap_pred = _sigmoid(output[..., :nb_classes]) wh_pred = output[..., nb_classes:nb_classes + 2] offset_pred = output[..., nb_classes + 2:nb_classes + 4] # Get local maximum hmax = max_pool2d(center_heatmap_pred[None, ...], ksize=[kernel, kernel], strides=1, padding='SAME', data_format='NHWC') center_heatmap_pred[hmax[0] != center_heatmap_pred] = 0 # Get top k from the heatmap perm_center_heatmap = np.transpose(center_heatmap_pred, (2, 0, 1)) flattened_heatmap = np.reshape(perm_center_heatmap, (-1)) topk_scores = np.partition(flattened_heatmap, -max_detections)[-max_detections:] topk_scores = np.flip(np.sort(topk_scores)) topk_inds = np.argpartition(flattened_heatmap, -max_detections)[-max_detections:] topk_inds = topk_inds[np.argsort(flattened_heatmap[topk_inds])][::-1] topk_labels = topk_inds // (grid_h * grid_w) topk_inds = topk_inds % (grid_h * grid_w) topk_ys = topk_inds // grid_h topk_xs = topk_inds % grid_w # Transpose and gather features for the WH and OFFSET. # Removed the transpose as we don't do it above either wh_pred = np.reshape(wh_pred, [-1, wh_pred.shape[-1]]) wh = wh_pred[topk_inds, ...] offset_pred = np.reshape(offset_pred, [-1, offset_pred.shape[-1]]) offset = offset_pred[topk_inds, ...] # The output should be x,y,w,h topk_xs = topk_xs + offset[..., 0] topk_ys = topk_ys + offset[..., 1] tl_x = np.clip((topk_xs - wh[..., 0] / 2) / grid_w, a_min=0, a_max=grid_w) tl_y = np.clip((topk_ys - wh[..., 1] / 2) / grid_h, a_min=0, a_max=grid_h) br_x = np.clip((topk_xs + wh[..., 0] / 2) / grid_w, a_min=0, a_max=grid_w) br_y = np.clip((topk_ys + wh[..., 1] / 2) / grid_h, a_min=0, a_max=grid_h) boxes = [] for i in range(max_detections): score = topk_scores[i] if score > obj_threshold: label = topk_labels[i] box = BoundingBox(tl_x[i], tl_y[i], br_x[i], br_y[i], score=score) box.label = label boxes.append(box) else: break return boxes