Source code for quantizeml.layers.quantizers.output_quantizer

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["OutputQuantizer"]

import tensorflow as tf

from ...tensors import QTensor, FixedPoint, QFloat
from ..recorders import TensorRecorder, FixedPointRecorder
from .quantizers import Quantizer

[docs]@tf.keras.utils.register_keras_serializable() class OutputQuantizer(Quantizer): """A uniform FixedPoint quantizer that selects the optimal number of fractional bits for the range of its inputs and updates them accordingly. The typical use case is to decrease the bitwidth of the result of a quantized layer operation to avoid a saturation in downstream operations. If the input is a QFloat, it is converted to a FixedPoint before updating its bitwidth. Args: bitwidth (int, optional): the quantization bitwidth. Defaults to 8. signed (bool, optional): whether the quantizer expects signed values or unsigned. Defaults to True. axis (str, optional): the quantization range is a scalar ('per-tensor') or a vector corresponding to the last axis ('per-axis'). Defaults to 'per-tensor'. scale_bits: (int, optional): the bitwidth to use when quantizing output scales. Defaults to 8. """ def __init__(self, bitwidth=8, signed=True, axis="per-tensor", scale_bits=8, **kwargs): super().__init__(bitwidth, signed, **kwargs) if not (isinstance(axis, str) and axis in ["per-tensor", "per-axis"]): raise ValueError(f"Only support reduction 'per-tensor' or 'per-axis'. Given {axis}.") self.axis = axis self.scale_bits = scale_bits # Add object that will store the shift values. self.shift = TensorRecorder()
[docs] def build(self, input_shape): """Build the layer. Args: input_shape (list): the shape of input tensor. """ super().build(input_shape) # Convert axis to a list of int if self.axis == "per-axis": ndims = len(input_shape) if ndims < 3: raise ValueError("OutputQuantizer cannot quantize per-axis tensors " " with 2 dimensions or less.") self._axis = list(range(len(input_shape) - 1)) else: self._axis = None # Declares the constant/vector that will store the maximum values of the input. self.range_max = self.add_weight( name="range_max", shape=input_shape[-1] if self._axis is not None else (), dtype=tf.float32, initializer="ones", synchronization=tf.VariableSynchronization.ON_READ, trainable=False, aggregation=tf.VariableAggregation.MEAN, experimental_autocast=False, ) # Declare a rescaling_rate variable that will be set at calibration and that will hold # cross-layer equalization ideal range_max / calibrated range_max ratio. self.rescaling_rate = self.add_weight( name="rescaling_rate", shape=input_shape[-1] if self._axis is not None else (), dtype=tf.float32, initializer="ones", synchronization=tf.VariableSynchronization.ON_READ, trainable=False )
[docs] def call(self, inputs): """Quantize the QTensor inputs to a lower bitwidth. The quantization happens with the following steps: 1. Evaluate the nearest power(s) of two containing the quantization range(s) 2. Quantize the inputs. Args: inputs (:obj:`QTensor`): the inputs tensor. Returns: :obj:`FixedPoint`: the quantized tensor. """ if not isinstance(inputs, QTensor): raise TypeError("The OutputQuantizer accepts only QTensor inputs." f"Received {type(inputs)} inputs.") if isinstance(inputs, QFloat): if self.scale_bits is None: raise ValueError(f"{} receives QFloat inputs: the scale_bits parameter" " needs to be specified.") # Apply cross layer equalization rescaling rescaled_inputs = QFloat(inputs.fp, inputs.scales * self.rescaling_rate) inputs, qscales = rescaled_inputs.to_fixed_point(self.scale_bits) if getattr(self, 'qscales', None) is None: # From Keras documentation, any variable creation taking place in call # should be wrapped with tf.init_scope with tf.init_scope(): self.qscales = FixedPointRecorder() self.qscales(qscales) if inputs.value_bits <= self.value_bits: msg = f"Quantizing a {inputs.value_bits}-bit QTensor to "\ f"{self.value_bits}-bit is pointless." if inputs.value_bits < self.value_bits: msg += " Use a promotion instead." raise ValueError(msg) # Rescale to center around range_max and compress to a lower bitwidth frac_bits = tf.stop_gradient(FixedPoint.max_frac_bits(self.value_bits, self.range_max)) inputs, shift_value = inputs.rescale(frac_bits, self.value_bits) # update shift values self.shift(shift_value) return inputs
[docs] def get_config(self): """Get the config of the layer. Returns: dict: the config of the layer. """ config = super().get_config() config.update({"scale_bits": self.scale_bits}) config.update({"axis": self.axis}) return config