Source code for quantizeml.layers.quantizer_layers
#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["InputQuantizer", "Dequantizer"]
import tensorflow as tf
import tf_keras as keras
from tf_keras.layers import Layer
from ..tensors import QTensor, QFloat, FixedPoint
from ..debugging import assert_less_equal
from .recorders import TensorRecorder, QFloatRecorder
from .layers_base import check_arg_constraints, QuantizedLayer
[docs]
@keras.saving.register_keras_serializable()
class InputQuantizer(QuantizedLayer):
"""Quantizer layer for input tensors.
This layer quantizes input tensors to a q-tensor representation using a specified bitwidth,
supporting signed and axis quantization.
Args:
bitwidth (int, optional): the quantization bitwidth. Defaults to 8.
signed (bool, optional): whether the quantizer expects signed values or unsigned.
Defaults to True.
axis (str, optional): the quantization range is a scalar ('per-tensor') or a vector
corresponding to the last axis ('per-axis'). Defaults to 'per-tensor'.
Example:
>>> quantizer = InputQuantizer(bitwidth=8, signed=True)
>>> quantized = quantizer(tf.constant([[1.0, -2.0], [3.0, 4.0]]))
"""
arg_constraints = {'axis': lambda: ["per-tensor", "per-axis"]}
def __init__(self, bitwidth=8, signed=True, axis="per-tensor", **kwargs):
super().__init__(**kwargs)
self.axis = axis
self.bitwidth = bitwidth
self.signed = signed
self.frac_bits = TensorRecorder()
check_arg_constraints(self, self.get_config())
@property
def value_bits(self):
return self.bitwidth - 1 if self.signed else self.bitwidth
[docs]
def build(self, input_shape):
super().build(input_shape)
# Convert axis to a list of int.
if self.axis == "per-axis":
ndims = len(input_shape)
if ndims < 3:
raise ValueError(f"'{self.name}' (InputQuantizer) cannot quantize per-axis "
"tensors with 2 dimensions or less.")
range_axis = list(range(len(input_shape) - 1))
else:
range_axis = None
# Declares the constant/vector that will store the maximum values and zero point.
self.range_min = self.add_weight(
name="range_min",
shape=input_shape[-1] if range_axis is not None else (),
dtype=tf.float32,
initializer="zeros",
synchronization=tf.VariableSynchronization.ON_READ,
trainable=False,
aggregation=tf.VariableAggregation.MEAN,
experimental_autocast=False,
)
self.range_max = self.add_weight(
name="range_max",
shape=input_shape[-1] if range_axis is not None else (),
dtype=tf.float32,
initializer="ones",
synchronization=tf.VariableSynchronization.ON_READ,
trainable=False,
aggregation=tf.VariableAggregation.MEAN,
experimental_autocast=False,
)
[docs]
def call(self, inputs):
"""Quantize the inputs according to the quantizer configuration.
This method calculates the scale and zero point from the calibration ranges provided
to project the input to integer values.
Args:
inputs (tf.Tensor): the input tensor to be quantized.
Returns:
QFloat: the quantized tensor.
"""
if not isinstance(inputs, tf.Tensor):
raise TypeError(f"{self.__class__.__name__} only accepts {tf.Tensor} inputs. "
f"Receives {type(inputs)} inputs.")
if inputs.dtype.is_integer:
raise TypeError(f"{self.__class__.__name__} only accepts float inputs. "
f"Receives {inputs.dtype} inputs.")
inputs = tf.cast(inputs, tf.float32)
# Compute range/zero_point.
assert_less_equal(self.range_min,
self.range_max,
message="range_max must be higher or equal than range_min.")
if self.signed:
# When signed is true, zero point is not required.
range_abs = tf.maximum(tf.abs(self.range_max), tf.abs(self.range_min))
else:
# We cannot handle negative zero point (HW constraint).
# That is why we reject positive range_min.
# In other words, a positive min_range refers to inputs that do not require
# a zero point, since quantization already produces unsigned values.
assert_less_equal(self.range_min, tf.zeros_like(self.range_min),
message="range_min > 0 is not allowed (HW constraint).")
range_abs = self.range_max - self.range_min
range_min = self.range_min
# Build recorder object for zero point.
with tf.init_scope():
self.zero_points = QFloatRecorder(name="zero_point")
# Compute the frac_bits to quantize the inputs.
frac_bits = tf.stop_gradient(FixedPoint.max_frac_bits(self.value_bits, range_abs))
self.frac_bits(frac_bits)
# Quantize the inputs.
q_inputs = FixedPoint.quantize(inputs, self.value_bits, frac_bits)
# Change output signature depending on signed parameter.
if not self.signed:
q_inputs = QFloat(q_inputs, 1.0)
# When the output is unsigned, a zero point is needed to shift the range
# to an entirely positive one.
q_zero_points = QFloat.quantize(-range_min, self.value_bits, q_inputs.scales, frac_bits)
q_inputs = q_inputs + q_zero_points
# Record zero point.
self.zero_points(q_zero_points)
return q_inputs
[docs]
def get_config(self):
"""Get the config of the layer.
Returns:
dict: the config of the layer.
"""
config = super().get_config()
config.update({"bitwidth": self.bitwidth})
config.update({"signed": self.signed})
config.update({"axis": self.axis})
return config
[docs]
@keras.saving.register_keras_serializable()
class Dequantizer(Layer):
""" Layer that allows to dequantize its inputs.
"""
scales: list = None
frac_bits: list = None
def _build_records(self, inputs):
def _build(x):
record_fb = record_scale = None
# from tf_keras documentation, any variable creation taking place
# in call should be wrapped with tf.init_scope
with tf.init_scope():
if isinstance(x, QTensor):
record_fb = TensorRecorder(self.name + "/record_fb")
if isinstance(x, QFloat):
record_scale = TensorRecorder(self.name + "/record_scale")
return record_fb, record_scale
if self.frac_bits is not None:
# Nothing to do
return
if not isinstance(inputs, (tuple, list)):
# Manage single inputs
self.frac_bits, self.scales = _build(inputs)
return
self.frac_bits = []
self.scales = []
with tf.init_scope():
for x in inputs:
frac_bits, scales = _build(x)
self.frac_bits.append(frac_bits)
self.scales.append(scales)
[docs]
def call(self, inputs):
"""Convert QTensor inputs to float.
Args:
inputs (tf.Tensor or :obj:`QTensor`): the inputs tensor(s).
Returns:
tf.Tensor: the dequantized tensor(s).
"""
def dequantize(x, frac_bits_recorder=None, scales_recorder=None):
if isinstance(x, QTensor):
if frac_bits_recorder is not None:
frac_bits_recorder(x.fp.frac_bits if isinstance(x, QFloat) else x.frac_bits)
if scales_recorder is not None:
scales_recorder(x.scales)
return x.to_float()
return x
# Build records
self._build_records(inputs)
# Apply dequantizer
if isinstance(inputs, (list, tuple)):
return [dequantize(x, fb, scales) for x, fb, scales in
zip(inputs, self.frac_bits, self.scales)]
return dequantize(inputs, self.frac_bits, self.scales)