Source code for quantizeml.layers.stateful_recurrent

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["StatefulRecurrent", "QuantizedStatefulRecurrent", "reset_states",
           "StatefulProjection", "QuantizedStatefulProjection", "update_batch_size",
           "PicoPostProcessing", "QuantizedPicoPostProcessing"]

import tf_keras as keras
import tensorflow as tf

from .recorders import (NonTrackVariable, TensorRecorder, NonTrackFixedPointVariable,
                        FixedPointRecorder)
from .layers_base import (register_quantize_target, tensor_inputs, apply_buffer_bitwidth,
                          register_aligned_inputs, QuantizedLayer, neural_layer_init,
                          rescale_outputs)
from .quantizers import WeightQuantizer, OutputQuantizer
from ..tensors import FixedPoint, QTensor, QFloat


[docs] @keras.saving.register_keras_serializable() class StatefulRecurrent(keras.layers.Layer): """ A recurrent layer with an internal state. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._internal_state_real = NonTrackVariable("internal_state_real") self._internal_state_imag = NonTrackVariable("internal_state_imag") def build(self, input_shape): assert input_shape[0] is not None, f"{self.name} must be built with a known batch size." assert input_shape[-1] is not None, \ f"{self.name} must be built with a known input channels." with tf.name_scope(self.name + '/'): super().build(input_shape) # 'A' weight is a complex64 tensor stored as two float32 tensor to ease quantization self.A_real = self.add_weight(name='A_real', shape=(input_shape[-1],)) self.A_imag = self.add_weight(name='A_imag', shape=(input_shape[-1],)) # Initialize the internal state variables, drop the timesteps dimension self._internal_state_real.init_var(tf.zeros((input_shape[0], input_shape[-1]))) self._internal_state_imag.init_var(tf.zeros((input_shape[0], input_shape[-1]))) def call(self, inputs): """ For every input step, the internal state is updated using the inputs which should be the updated state from the previous layer. """ # Build output tensors that will contain all updates, initialize them with internal state assert inputs.shape[1] is not None, f"{self.name} requires a known number of timesteps." multiples = [1, inputs.shape[1], 1] state_real = tf.tile(tf.expand_dims(self._internal_state_real.var, axis=1), multiples) state_imag = tf.tile(tf.expand_dims(self._internal_state_imag.var, axis=1), multiples) # Loop over timesteps for i in range(tf.shape(inputs)[1]): # Compute real and imaginary part separately updated_real = state_real[:, i - 1] * self.A_real - \ state_imag[:, i - 1] * self.A_imag + inputs[:, i] updated_imag = self.A_imag * state_real[:, i - 1] + \ self.A_real * state_imag[:, i - 1] indices = tf.stack([tf.range(state_real.shape[0]), tf.repeat(i, state_real.shape[0])], axis=1) state_real = tf.tensor_scatter_nd_update(state_real, indices, updated_real) state_imag = tf.tensor_scatter_nd_update(state_imag, indices, updated_imag) # Update internal state for next call self._internal_state_real.set_var(state_real[:, -1]) self._internal_state_imag.set_var(state_imag[:, -1]) return tf.stack([state_real, state_imag], -1) def reset_layer_states(self): """ Resets internal state (real and imaginary part).""" self._internal_state_real.reset_var() self._internal_state_imag.reset_var()
[docs] @register_quantize_target([StatefulRecurrent], has_weights=True) @keras.saving.register_keras_serializable() class QuantizedStatefulRecurrent(QuantizedLayer, StatefulRecurrent): """ A quantized version of the StatefulRecurrent layer that operates on quantized inputs, weights and internal state. Note that internal state is quantized to 16-bits for accuracy reasons, inputs and outputs of this layer are then also 16-bits. Args: quant_config (dict, optional): the serialized quantization configuration. Defaults to None. """ def __init__(self, *args, quant_config=None, **kwargs): super().__init__(*args, quant_config=quant_config, **kwargs) self._internal_state_real = NonTrackFixedPointVariable("internal_state_real") self._internal_state_imag = NonTrackFixedPointVariable("internal_state_imag") # Build weight quantizer for A_real and A_imag (sharing the same quantizer) if "a_quantizer" not in self.quant_config: # Forcing to: # - per-tensor to ensure alignment in the call operations # - 16-bits for accuracy reasons # - FixedPoint quantization to prevent scale_out operations on internal_state self.quant_config["a_quantizer"] = {"bitwidth": 16, "axis": None, "fp_quantizer": True} a_quantizer_cfg = self.quant_config["a_quantizer"] self.a_quantizer = WeightQuantizer(name="a_quantizer", **a_quantizer_cfg) # Finalize output quantizer, add one with default configuration if there is None in the # config as state must be quantized if "output_quantizer" not in self.quant_config: self.quant_config["output_quantizer"] = {"bitwidth": 16, "axis": "per-tensor"} out_quant_cfg = self.quant_config["output_quantizer"] self.out_quantizer = OutputQuantizer(name="output_quantizer", **out_quant_cfg) self.buffer_bitwidth = apply_buffer_bitwidth(self.quant_config, signed=True) # Prepare the variable that should be recorded self.new_state_shift = TensorRecorder(name=self.name + "/new_state_shift") def build(self, input_shape): assert input_shape[0] is not None, f"{self.name} must be built with a known batch size." assert input_shape[-1] is not None, \ f"{self.name} must be built with a known input channels." with tf.name_scope(self.name + '/'): # Explicitly build the Keras.layer and not StatefulRecurrent because state is not of # the same type keras.layers.Layer.build(self, input_shape) # 'A' weight is a complex64 tensor stored as two float32 tensor to ease quantization self.A_real = self.add_weight(name='A_real', shape=(input_shape[-1],)) self.A_imag = self.add_weight(name='A_imag', shape=(input_shape[-1],)) # Explicitly build the OutputQuantizer so that output frac_bits can be computed self.out_quantizer.build((input_shape[0], input_shape[-1])) # Initialize the internal state variables zeros = FixedPoint(tf.zeros((input_shape[0], input_shape[-1])), self.out_quantizer.value_bits, self.out_quantizer.frac_bits) self._internal_state_real.init_var(zeros) self._internal_state_imag.init_var(zeros) @tensor_inputs([QTensor]) def call(self, inputs): if isinstance(inputs, QFloat): # Handle QFloat inputs by quantizing the scale: reuse output quantizer to get scale_bits # because the Stateful layer should be homogeneous inputs, qscales = inputs.to_fixed_point() if getattr(self, 'new_state_scale', None) is None: # from tf_keras documentation, any variable creation taking place in call # should be wrapped with tf.init_scope with tf.init_scope(): self.new_state_scale = FixedPointRecorder(self.name + "/new_state_scale") self.new_state_scale(qscales) # Quantize A matrices A_real = self.a_quantizer(self.A_real) A_imag = self.a_quantizer(self.A_imag) # Align inputs with {A * state}, which is out_quantizer.frac_bits + A.frac_bits inputs, shift = inputs.rescale(self.out_quantizer.frac_bits + A_real.frac_bits, inputs.value_bits) self.new_state_shift(shift) # Set the appropriate frac_bit in state variable self._internal_state_real._frac_bits.assign(self.out_quantizer.frac_bits) self._internal_state_imag._frac_bits.assign(self.out_quantizer.frac_bits) # Define state update step internal_state_real_step = self._internal_state_real.var internal_state_imag_step = self._internal_state_imag.var # Build output tensors that will contain all updates, initialize them with internal state zero_fp = FixedPoint(tf.zeros(inputs.shape), self._internal_state_real.var.value_bits, self._internal_state_real.var.frac_bits) next_internal_state_real, next_internal_state_imag = zero_fp, zero_fp # Define a zero padding value padding_value = FixedPoint(0, self.out_quantizer.value_bits, self._internal_state_real._frac_bits) # Loop over timesteps timesteps = tf.shape(inputs)[1] for i in range(timesteps): # Promote internal_state_step internal_state_real_step = internal_state_real_step.promote(self.buffer_bitwidth) internal_state_imag_step = internal_state_imag_step.promote(self.buffer_bitwidth) # Update internal state: compute real and imaginary part separately using current step updated_real = tf.multiply(internal_state_real_step, A_real) - \ tf.multiply(internal_state_imag_step, A_imag) # Get the inputs for this step input_step = inputs[:, i] # At this point addition is possible updated_real = updated_real + input_step # Same for imaginary part updated_imag = tf.multiply(internal_state_real_step, A_imag) + \ tf.multiply(internal_state_imag_step, A_real) # Quantize down the update for next step using the layer output quantizer internal_state_real_step = self.out_quantizer(updated_real) internal_state_imag_step = self.out_quantizer(updated_imag) # Store the updates in next_internal_state. To do so, pad the updated state with zeroes # to left and right in order to rebuild a 'full' state where i-th timestep is equal to # computed update. Then add it to next_internal_state that was initialized to zeroes: # this effectively mimics an inplace update (working around Tensorflow logic). paddings = [[0, 0], [i, timesteps - 1 - i], [0, 0]] padded_real = tf.pad(tf.expand_dims(internal_state_real_step, axis=1), paddings=paddings, constant_values=padding_value) padded_imag = tf.pad(tf.expand_dims(internal_state_imag_step, axis=1), paddings=paddings, constant_values=padding_value) # Ensure the shape is as expected because TensorFlow will not tolerate shape changes # within the loop padded_real.values.set_shape(next_internal_state_real.shape) padded_imag.values.set_shape(next_internal_state_imag.shape) next_internal_state_real += padded_real next_internal_state_imag += padded_imag # Update internal state members self._internal_state_real.set_var(internal_state_real_step) self._internal_state_imag.set_var(internal_state_imag_step) # Return the concatenated states return tf.stack([next_internal_state_real, next_internal_state_imag], -1)
[docs] def reset_states(model): """ Resets all StatefulRecurrent layers internal states in the model. Args: model (keras.Model): the model to reset """ for layer in model.layers: if isinstance(layer, StatefulRecurrent): layer.reset_layer_states()
[docs] def update_batch_size(model, batch_size): """ Updates the batch size in a model. Similar to keras RNN/Cell behavior where batch size must be known, the StatefulRecurrent layer state is built with a (None, ...) shape that must be defined runtime. This helper allow to set or update a model batch size. Args: model (keras.Model): the model to update batch_size (int): batch size to set Returns: keras.Model: an updated model (the original model when batch size is unchanged) """ # Update batch size in config config = model.get_config() input_shape = config['layers'][0]['config']['batch_input_shape'] if input_shape[0] == batch_size: return model config['layers'][0]['config']['batch_input_shape'] = (batch_size, *input_shape[1:]) # Force layers shapes to be recomputed for layer_config in config["layers"]: layer_config.pop('build_config', None) # Rebuild model and transfer weights updated_model = model.from_config(config) updated_model.set_weights(model.get_weights()) return updated_model
[docs] @keras.saving.register_keras_serializable() class StatefulProjection(keras.layers.Dense): """ Same as a Dense layer but with optional subsampling. Args: subsample (int, optional): subsampling factor applied to the timesteps. Defaults to 1. """ def __init__(self, *args, subsample=1, **kwargs): super().__init__(*args, **kwargs) self.subsample = subsample def build(self, input_shape): with tf.name_scope(self.name + '/'): super().build(input_shape) def call(self, inputs): # Apply the optional subsampling if self.subsample > 1: assert inputs.shape[1] % self.subsample == 0, \ f"Number of timesteps: {inputs.shape[1]} must be a multiple of subsample " \ f"ratio: {self.subsample}." inputs = inputs[:, (self.subsample - 1)::self.subsample, :] # Standard Dense operation outputs = super().call(inputs) return outputs def get_config(self): config = super().get_config() config["subsample"] = self.subsample return config
[docs] @register_quantize_target([StatefulProjection], has_weights=True) @register_aligned_inputs @keras.saving.register_keras_serializable() class QuantizedStatefulProjection(QuantizedLayer, StatefulProjection): """ A quantized version of the StatefulProjection layer that operates on quantized inputs. """ @neural_layer_init(False) def __init__(self, *args, **kwargs): # Limit buffer bitwidth to 27 for HW constraint self.quant_config['buffer_bitwidth'] = min(28, self.quant_config['buffer_bitwidth']) self.buffer_bitwidth = self.quant_config['buffer_bitwidth'] - 1 @tensor_inputs([QTensor, tf.Tensor]) @rescale_outputs def call(self, inputs): if self.subsample > 1: assert inputs.shape[1] % self.subsample == 0, \ f"Number of timesteps: {inputs.shape[1]} must be a multiple of subsample " \ f"ratio: {self.subsample}." inputs = inputs[:, (self.subsample - 1)::self.subsample, :] # Quantize the weights kernel = self.weight_quantizer(self.kernel) outputs = tf.matmul(inputs, kernel) if self.use_bias: # Quantize and align biases bias = self.bias_quantizer(self.bias, outputs) outputs = tf.add(outputs, bias) return outputs
[docs] @keras.saving.register_keras_serializable() class PicoPostProcessing(keras.layers.Layer): """ Post-processing layer that computes the mean absolute difference between predictions and targets along a given axis and binarizes the result using a threshold. This layer is useful for light-weight evaluation where a continuous error metric is converted into a binary decision per sample (e.g. anomaly detection or simple pass/fail). Args: threshold (float): threshold applied to the mean absolute difference. Values greater than or equal to it become 1.0, otherwise 0.0. Example: >>> layer = PicoPostProcessing(threshold=0.5) >>> out = layer(y_pred, y_true) """ def __init__(self, *args, threshold, **kwargs): super().__init__(*args, **kwargs) self.threshold = threshold def _check_input_constraints(self, y_pred, y_true): # Check input constraints. if not isinstance(y_true, tf.Tensor): raise TypeError(f"{self.__class__.__name__} only accepts {tf.Tensor} for second input. " f"Receives '{type(y_true)}'.") if not (y_true_dtype := tf.as_dtype(y_true.dtype)).is_integer: raise TypeError(f"{self.__class__.__name__} only accepts integer for second input. " f"Receives '{y_true_dtype}'.") def call(self, y_pred, y_true): self._check_input_constraints(y_pred, y_true) # Compare inputs. y_true = tf.cast(y_true, y_pred.dtype) outputs = tf.reduce_mean(tf.abs(y_pred - y_true), axis=1) # Binarize the output. outputs = tf.where(outputs >= self.threshold, 1.0, 0.0) return outputs def get_config(self): config = super().get_config() config["threshold"] = self.threshold return config
[docs] @register_quantize_target([PicoPostProcessing]) @keras.saving.register_keras_serializable() class QuantizedPicoPostProcessing(QuantizedLayer, PicoPostProcessing): """Quantized PicoPostProcessing layer. A quantized variant of PicoPostProcessing that operates on quantized predictions (QTensor) and integer targets. Args: threshold (float): threshold applied to the mean absolute difference. Values greater than or equal to it become 1.0, otherwise 0.0. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Promote buffer bitwidth to 44. self.quant_config['buffer_bitwidth'] = max(44, self.quant_config.get('buffer_bitwidth', 0)) self.buffer_bitwidth = self.quant_config['buffer_bitwidth'] - 1 # Initialize prediction quantizer with default parameters. self.pred_quantizer = OutputQuantizer(bitwidth=16, axis="per-axis", scale_bits=16, buffer_bitwidth=self.quant_config['buffer_bitwidth'], name="prediction_quantizer") def build(self, input_shape): super().build(input_shape) with tf.name_scope(self.name + '/'): # Explicitly build the OutputQuantizers so that output frac_bits can be computed. self.pred_quantizer.build(input_shape) def _check_input_constraints(self, y_pred, y_true): super()._check_input_constraints(y_pred, y_true) if not isinstance(y_pred, QTensor): raise TypeError(f"{self.__class__.__name__} only accepts {QTensor} for first input. " f"Receives '{type(y_pred)}'.") def call(self, y_pred, y_true): self._check_input_constraints(y_pred, y_true) # Force quantizer to rescale the y_pred to '0' frac_bits, as this is the input scale. # This is achieved by setting the maximum range to the limit. max_range = 2.0**self.pred_quantizer.value_bits self.pred_quantizer.range_max.assign( max_range * tf.ones_like(self.pred_quantizer.range_max)) # Convert both inputs into a FixedPoint. y_true = FixedPoint(y_true, value_bits=self.pred_quantizer.value_bits, frac_bits=0) y_pred = self.pred_quantizer(y_pred) # Compare inputs. outputs = tf.abs(y_pred - y_true) outputs = outputs.promote(self.buffer_bitwidth) outputs = tf.reduce_mean(outputs, axis=1) # Binarize the output. outputs = tf.where(outputs.values >= self.threshold, 1.0, 0.0) return FixedPoint(outputs, value_bits=1, frac_bits=0)