#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Processing tools for CenterNet data handling.
"""
__all__ = ["decode_output"]
import numpy as np
from tensorflow.nn import max_pool2d
from ..detection.processing import BoundingBox
[docs]
def decode_output(output,
                  nb_classes,
                  obj_threshold=0.1,
                  max_detections=100,
                  kernel=5):
    """ Decodes a CenterNet model.
    Args:
        output (tf.Tensor): model output to decode.
        nb_classes (int): number of classes.
        obj_threshold (float, optional): confidence threshold for a box. Defaults to 0.1.
        max_detection (int, optional): maximum number of boxes the model is allowed to produce.
            Defaults to 100.
        kernel (int, optional): max pool kernel size. Defaults to 5.
    Returns:
        List: `BoundingBox` objects
    """
    def _sigmoid(x):
        return 1. / (1. + np.exp(-x))
    grid_h, grid_w = output.shape[:2]
    # Decode the output of the network
    center_heatmap_pred = _sigmoid(output[..., :nb_classes])
    wh_pred = output[..., nb_classes:nb_classes + 2]
    offset_pred = output[..., nb_classes + 2:nb_classes + 4]
    # Get local maximum
    hmax = max_pool2d(center_heatmap_pred[None, ...],
                      ksize=[kernel, kernel], strides=1, padding='SAME', data_format='NHWC')
    center_heatmap_pred[hmax[0] != center_heatmap_pred] = 0
    # Get top k from the heatmap
    perm_center_heatmap = np.transpose(center_heatmap_pred, (2, 0, 1))
    flattened_heatmap = np.reshape(perm_center_heatmap, (-1))
    topk_scores = np.partition(flattened_heatmap, -max_detections)[-max_detections:]
    topk_scores = np.flip(np.sort(topk_scores))
    topk_inds = np.argpartition(flattened_heatmap, -max_detections)[-max_detections:]
    topk_inds = topk_inds[np.argsort(flattened_heatmap[topk_inds])][::-1]
    topk_labels = topk_inds // (grid_h * grid_w)
    topk_inds = topk_inds % (grid_h * grid_w)
    topk_ys = topk_inds // grid_h
    topk_xs = topk_inds % grid_w
    # Transpose and gather features for the WH and OFFSET.
    # Removed the transpose as we don't do it above either
    wh_pred = np.reshape(wh_pred, [-1, wh_pred.shape[-1]])
    wh = wh_pred[topk_inds, ...]
    offset_pred = np.reshape(offset_pred, [-1, offset_pred.shape[-1]])
    offset = offset_pred[topk_inds, ...]
    # The output should be x,y,w,h
    topk_xs = topk_xs + offset[..., 0]
    topk_ys = topk_ys + offset[..., 1]
    tl_x = np.clip((topk_xs - wh[..., 0] / 2) / grid_w, a_min=0, a_max=grid_w)
    tl_y = np.clip((topk_ys - wh[..., 1] / 2) / grid_h, a_min=0, a_max=grid_h)
    br_x = np.clip((topk_xs + wh[..., 0] / 2) / grid_w, a_min=0, a_max=grid_w)
    br_y = np.clip((topk_ys + wh[..., 1] / 2) / grid_h, a_min=0, a_max=grid_h)
    boxes = []
    for i in range(max_detections):
        score = topk_scores[i]
        if score > obj_threshold:
            label = topk_labels[i]
            box = BoundingBox(tl_x[i], tl_y[i], br_x[i], br_y[i], score=score)
            box.label = label
            boxes.append(box)
        else:
            break
    return boxes