Source code for quantizeml.analysis.tools.metrics

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2024 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["SMAPE", "Saturation", "print_metric_table"]

from collections import defaultdict
import warnings
import numpy as np
import keras


[docs] class SMAPE(keras.metrics.Metric): """Compute the Symmetric Mean Absolute Percentage Error (SMAPE) as: >>> mean(abs(x - y) / (abs(x) + abs(y))) Reference: https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error Args: name (str, optional): name of the metric. Defaults to "smape". """ def __init__(self, name="smape", **kwargs): super().__init__(name=name, **kwargs) self.error = self.add_weight(name='error', initializer='zeros', dtype="float64") self.count = self.add_weight(name='count', initializer='zeros', dtype="float64") def update_state(self, y_true, y_pred): assert y_true.shape == y_pred.shape total = y_true.size # Skip values that undefine the metric mask = (y_true == 0) & (y_pred == 0) y_true = y_true[~mask] y_pred = y_pred[~mask] # Compute smape smape = np.sum(np.abs(y_true - y_pred) / (np.abs(y_true) + np.abs(y_pred))) self.error.assign_add(smape) # Update the metric count. # Note here we take into account the set of values y_true = y_pred = 0, # since the error they contribute is zero. self.count.assign_add(total) def result(self): if self.count == 0: return 0.0 return self.error / self.count def reset_states(self): self.error.assign(0.0) self.count.assign(0.0)
[docs] class Saturation(keras.metrics.Metric): """Returns the percentage of saturating values. We consider a value saturated if it is one of {min_value, max_value} Args: min_value (np.ndarray, optional): the minimum of values. If not provided, it is inferred from the values type. Defaults to None. max_value (np.ndarray, optional): the maximum of values. If not provided, it is inferred from the values type. Defaults to None. """ def __init__(self, name="saturation", min_value=None, max_value=None, **kwargs): super().__init__(name=name, **kwargs) self._min_value = min_value self._max_value = max_value self.total = self.add_weight(name='total', initializer='zeros', dtype="int64") self.count = self.add_weight(name='count', initializer='zeros', dtype="int64") @property def min_value(self): if self._min_value is None: iinfo = np.iinfo(self.dtype) if np.issubdtype(self.dtype, np.integer) else None assert iinfo, f"Unknown minimum value for data type {self.dtype}" return iinfo.min return self._min_value @property def max_value(self): if self._max_value is None: iinfo = np.iinfo(self.dtype) if np.issubdtype(self.dtype, np.integer) else None assert iinfo, f"Unknown maximum value for data type {self.dtype}" return iinfo.max return self._max_value def update_state(self, values): if np.any(values > self.max_value) or np.any(values < self.min_value): warnings.warn(f"Saturation is not accurate: there are values outside of range " f"[{self.min_value}, {self.max_value}].") sat_values = np.sum((values == self.min_value) | (values == self.max_value)) self.total.assign_add(sat_values) self.count.assign_add(values.size) def result(self): count = self.count if self.count != 0 else 1 return 100 * (self.total / count) def reset_states(self): self.total.assign(0) self.count.assign(0) def get_config(self): config = super().get_config() config.update({"min_value": self._min_value, "max_value": self._max_value}) return config