Source code for akida_models.centernet.centernet_loss

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
This module defines a custom loss function for CenterNet training
"""

__all__ = ["CenternetLoss"]

from keras import backend as K
import tensorflow as tf


[docs] class CenternetLoss(tf.keras.losses.Loss): """ Computes CenterNet loss from a model raw output. The CenterNet loss computation is from https://arxiv.org/abs/1904.07850. Args: alpha (float, optional): alpha parameter in heatmap loss. Defaults to 2.0. gamma (float, optional): gamma parameter in heatmap loss. Defaults to 4.0. eps (float, optional): epsilon parameter in heatmap loss. Defaults to 1e-12. heatmap_loss_weight (float, optional): heatmap loss weight. Defaults to 1.0. wh_loss_weight (float, optional): location loss weight. Defaults to 0.1. offset_loss_weight (float, optional): offset loss weight. Defaults to 1.0. """ def __init__(self, alpha=2.0, gamma=4.0, eps=1e-12, heatmap_loss_weight=1.0, wh_loss_weight=0.1, offset_loss_weight=1.0): super().__init__() # Parameters for the gaussian focal loss for the heatmap branch self._alpha = alpha self._gamma = gamma self._eps = eps # Loss weight parameters self.heatmap_loss_weight = heatmap_loss_weight self.wh_loss_weight = wh_loss_weight self.offset_loss_weight = offset_loss_weight def _transform_netout(self, y_pred_raw): """Transforms the output of the network: - cast to float32 - extracts the // wh, offset and heatmap from fused map if necessary - applies sigmoid to the heatmap prediction Args: y_pred_raw (tf.Tensor): raw network predictions. Returns: tuple of tf.Tensor: Predictions transformed on xy, wh and offset values. """ y_pred_raw = tf.cast(y_pred_raw, dtype=tf.float32) y_pred_xy = K.sigmoid(y_pred_raw[..., :-4]) y_pred_wh = y_pred_raw[..., -4:-2] y_pred_offset = y_pred_raw[..., -2:] return y_pred_xy, y_pred_wh, y_pred_offset def _get_targets(self, y_true): """Extract ground truth for each branch and compute avg_factor, wh_offset_target_weight here so we don't have to pass it through the whole model Args: y_true (tf.Tensor): ground truth. Returns: tuple of tf.Tensor: labels in xy, wh, offset, avg_factor and wh_offset format. """ target_xy = y_true[..., :-4] target_wh = y_true[..., -4:-2] target_offset = y_true[..., -2:] # Extract the average factor counts the number of targets to be learned # max(1, center_heatmap_target.eq(1).sum()) tmp = tf.equal(tf.constant(1.0, dtype=y_true.dtype), target_xy) tmp = tf.cast(tmp, dtype=tf.float32) tmp = tf.reduce_sum(tmp) avg_factor = tf.reduce_max([tf.constant(1.0, dtype=tmp.dtype), tmp]) # Extract the wh offset target weight # wh_offset_target_weight[batch_id, :, cty_int, ctx_int] = 1 # => 1 anywhere there is a target offset and wh tmp = tf.equal(tf.constant(0, dtype=y_true.dtype), target_offset) tmp = tf.logical_not(tmp) wh_offset_target_weight = tf.cast(tmp, dtype=tf.float32) return target_xy, target_wh, target_offset, avg_factor, wh_offset_target_weight def heatmap_loss(self, y_true, y_pred, avg_factor): """Implements `Gaussian Focal loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian distribution. Original source: mmdetection/losses/gaussian_focal_loss Args: y_true (tf.Tensor): tensor of true labels. y_pred (tf.Tensor): tensor of predicted labels. avg_factor (tf.Tensor): average factor. Returns: tf.Tensor: Heatmap loss """ # Compute the loss pos_weights = tf.cast(tf.equal(y_true, 1.0), dtype=tf.float32) neg_weights = tf.math.pow((1 - y_true), self._gamma) pos_loss = -tf.math.log(y_pred + self._eps) * \ tf.math.pow((1 - y_pred), self._alpha) * pos_weights neg_loss = -tf.math.log(1 - y_pred + self._eps) * \ tf.math.pow(y_pred, self._alpha) * neg_weights loss = pos_loss + neg_loss # Compute the average across the matrix loss = tf.reduce_sum(loss) / avg_factor return loss def l1_loss(self, y_true, y_pred, avg_factor, weights=None): """L1 loss, used in location loss Args: y_true (tf.Tensor): tensor of true labels. y_pred (tf.Tensor): tensor of predicted labels. avg_factor (tf.Tensor): average factor. weights (tf.Tensor, optional): factor to multiply the loss. Defaults to None. Returns: tf.Tensor: L1 loss """ difference = y_true - y_pred loss = tf.abs(difference) if weights is not None: loss *= weights loss = tf.reduce_sum(loss) / avg_factor return loss def __call__(self, y_true, y_pred_raw, sample_weight=None): # Get the avg factor and wh / offset weights (target_xy, target_wh, target_offset, avg_factor, wh_offset_target_weight) = self._get_targets(y_true) # Extract the 3 // branches + apply sigmoid y_pred_xy, y_pred_wh, y_pred_offset = self._transform_netout(y_pred_raw) # Heatmap loss center_heatmap_loss = self.heatmap_loss(target_xy, y_pred_xy, avg_factor) center_heatmap_loss *= self.heatmap_loss_weight # Wh loss wh_loss = self.l1_loss(target_wh, y_pred_wh, avg_factor * 2, wh_offset_target_weight) wh_loss *= self.wh_loss_weight # Offset loss offset_loss = self.l1_loss(target_offset, y_pred_offset, avg_factor * 2, wh_offset_target_weight) offset_loss *= self.offset_loss_weight loss = center_heatmap_loss + wh_loss + offset_loss return loss def get_config(self): config = super().get_config() config.update({ "alpha": self._alpha, "gamma": self._gamma, "eps": self._eps, "heatmap_loss_weight": self.heatmap_loss_weight, "wh_loss_weight": self.wh_loss_weight, "offset_loss_weight": self.offset_loss_weight }) return config @classmethod def from_config(cls, config): return cls(**config)