#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["StatefulRecurrent", "QuantizedStatefulRecurrent", "reset_states",
"StatefulProjection", "QuantizedStatefulProjection", "update_batch_size",
"PicoPostProcessing", "QuantizedPicoPostProcessing"]
import tf_keras as keras
import tensorflow as tf
from .recorders import (NonTrackVariable, TensorRecorder, NonTrackFixedPointVariable,
FixedPointRecorder)
from .layers_base import (register_quantize_target, tensor_inputs, apply_buffer_bitwidth,
register_aligned_inputs, QuantizedLayer, neural_layer_init,
rescale_outputs)
from .quantizers import WeightQuantizer, OutputQuantizer
from ..tensors import FixedPoint, QTensor, QFloat
[docs]
@keras.saving.register_keras_serializable()
class StatefulRecurrent(keras.layers.Layer):
""" A recurrent layer with an internal state. """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._internal_state_real = NonTrackVariable("internal_state_real")
self._internal_state_imag = NonTrackVariable("internal_state_imag")
def build(self, input_shape):
assert input_shape[0] is not None, f"{self.name} must be built with a known batch size."
assert input_shape[-1] is not None, \
f"{self.name} must be built with a known input channels."
with tf.name_scope(self.name + '/'):
super().build(input_shape)
# 'A' weight is a complex64 tensor stored as two float32 tensor to ease quantization
self.A_real = self.add_weight(name='A_real', shape=(input_shape[-1],))
self.A_imag = self.add_weight(name='A_imag', shape=(input_shape[-1],))
# Initialize the internal state variables, drop the timesteps dimension
self._internal_state_real.init_var(tf.zeros((input_shape[0], input_shape[-1])))
self._internal_state_imag.init_var(tf.zeros((input_shape[0], input_shape[-1])))
def call(self, inputs):
""" For every input step, the internal state is updated using the inputs which should be the
updated state from the previous layer.
"""
# Build output tensors that will contain all updates, initialize them with internal state
assert inputs.shape[1] is not None, f"{self.name} requires a known number of timesteps."
multiples = [1, inputs.shape[1], 1]
state_real = tf.tile(tf.expand_dims(self._internal_state_real.var, axis=1), multiples)
state_imag = tf.tile(tf.expand_dims(self._internal_state_imag.var, axis=1), multiples)
# Loop over timesteps
for i in range(tf.shape(inputs)[1]):
# Compute real and imaginary part separately
updated_real = state_real[:, i - 1] * self.A_real - \
state_imag[:, i - 1] * self.A_imag + inputs[:, i]
updated_imag = self.A_imag * state_real[:, i - 1] + \
self.A_real * state_imag[:, i - 1]
indices = tf.stack([tf.range(state_real.shape[0]), tf.repeat(i, state_real.shape[0])],
axis=1)
state_real = tf.tensor_scatter_nd_update(state_real, indices, updated_real)
state_imag = tf.tensor_scatter_nd_update(state_imag, indices, updated_imag)
# Update internal state for next call
self._internal_state_real.set_var(state_real[:, -1])
self._internal_state_imag.set_var(state_imag[:, -1])
return tf.stack([state_real, state_imag], -1)
def reset_layer_states(self):
""" Resets internal state (real and imaginary part)."""
self._internal_state_real.reset_var()
self._internal_state_imag.reset_var()
[docs]
@register_quantize_target([StatefulRecurrent], has_weights=True)
@keras.saving.register_keras_serializable()
class QuantizedStatefulRecurrent(QuantizedLayer, StatefulRecurrent):
""" A quantized version of the StatefulRecurrent layer that operates on quantized inputs,
weights and internal state.
Note that internal state is quantized to 16-bits for accuracy reasons, inputs and outputs of
this layer are then also 16-bits.
Args:
quant_config (dict, optional): the serialized quantization configuration. Defaults to None.
"""
def __init__(self, *args, quant_config=None, **kwargs):
super().__init__(*args, quant_config=quant_config, **kwargs)
self._internal_state_real = NonTrackFixedPointVariable("internal_state_real")
self._internal_state_imag = NonTrackFixedPointVariable("internal_state_imag")
# Build weight quantizer for A_real and A_imag (sharing the same quantizer)
if "a_quantizer" not in self.quant_config:
# Forcing to:
# - per-tensor to ensure alignment in the call operations
# - 16-bits for accuracy reasons
# - FixedPoint quantization to prevent scale_out operations on internal_state
self.quant_config["a_quantizer"] = {"bitwidth": 16, "axis": None, "fp_quantizer": True}
a_quantizer_cfg = self.quant_config["a_quantizer"]
self.a_quantizer = WeightQuantizer(name="a_quantizer", **a_quantizer_cfg)
# Finalize output quantizer, add one with default configuration if there is None in the
# config as state must be quantized
if "output_quantizer" not in self.quant_config:
self.quant_config["output_quantizer"] = {"bitwidth": 16, "axis": "per-tensor"}
out_quant_cfg = self.quant_config["output_quantizer"]
self.out_quantizer = OutputQuantizer(name="output_quantizer", **out_quant_cfg)
self.buffer_bitwidth = apply_buffer_bitwidth(self.quant_config, signed=True)
# Prepare the variable that should be recorded
self.new_state_shift = TensorRecorder(name=self.name + "/new_state_shift")
def build(self, input_shape):
assert input_shape[0] is not None, f"{self.name} must be built with a known batch size."
assert input_shape[-1] is not None, \
f"{self.name} must be built with a known input channels."
with tf.name_scope(self.name + '/'):
# Explicitly build the Keras.layer and not StatefulRecurrent because state is not of
# the same type
keras.layers.Layer.build(self, input_shape)
# 'A' weight is a complex64 tensor stored as two float32 tensor to ease quantization
self.A_real = self.add_weight(name='A_real', shape=(input_shape[-1],))
self.A_imag = self.add_weight(name='A_imag', shape=(input_shape[-1],))
# Explicitly build the OutputQuantizer so that output frac_bits can be computed
self.out_quantizer.build((input_shape[0], input_shape[-1]))
# Initialize the internal state variables
zeros = FixedPoint(tf.zeros((input_shape[0], input_shape[-1])),
self.out_quantizer.value_bits, self.out_quantizer.frac_bits)
self._internal_state_real.init_var(zeros)
self._internal_state_imag.init_var(zeros)
@tensor_inputs([QTensor])
def call(self, inputs):
if isinstance(inputs, QFloat):
# Handle QFloat inputs by quantizing the scale: reuse output quantizer to get scale_bits
# because the Stateful layer should be homogeneous
inputs, qscales = inputs.to_fixed_point()
if getattr(self, 'new_state_scale', None) is None:
# from tf_keras documentation, any variable creation taking place in call
# should be wrapped with tf.init_scope
with tf.init_scope():
self.new_state_scale = FixedPointRecorder(self.name + "/new_state_scale")
self.new_state_scale(qscales)
# Quantize A matrices
A_real = self.a_quantizer(self.A_real)
A_imag = self.a_quantizer(self.A_imag)
# Align inputs with {A * state}, which is out_quantizer.frac_bits + A.frac_bits
inputs, shift = inputs.rescale(self.out_quantizer.frac_bits + A_real.frac_bits,
inputs.value_bits)
self.new_state_shift(shift)
# Set the appropriate frac_bit in state variable
self._internal_state_real._frac_bits.assign(self.out_quantizer.frac_bits)
self._internal_state_imag._frac_bits.assign(self.out_quantizer.frac_bits)
# Define state update step
internal_state_real_step = self._internal_state_real.var
internal_state_imag_step = self._internal_state_imag.var
# Build output tensors that will contain all updates, initialize them with internal state
zero_fp = FixedPoint(tf.zeros(inputs.shape),
self._internal_state_real.var.value_bits,
self._internal_state_real.var.frac_bits)
next_internal_state_real, next_internal_state_imag = zero_fp, zero_fp
# Define a zero padding value
padding_value = FixedPoint(0, self.out_quantizer.value_bits,
self._internal_state_real._frac_bits)
# Loop over timesteps
timesteps = tf.shape(inputs)[1]
for i in range(timesteps):
# Promote internal_state_step
internal_state_real_step = internal_state_real_step.promote(self.buffer_bitwidth)
internal_state_imag_step = internal_state_imag_step.promote(self.buffer_bitwidth)
# Update internal state: compute real and imaginary part separately using current step
updated_real = tf.multiply(internal_state_real_step, A_real) - \
tf.multiply(internal_state_imag_step, A_imag)
# Get the inputs for this step
input_step = inputs[:, i]
# At this point addition is possible
updated_real = updated_real + input_step
# Same for imaginary part
updated_imag = tf.multiply(internal_state_real_step, A_imag) + \
tf.multiply(internal_state_imag_step, A_real)
# Quantize down the update for next step using the layer output quantizer
internal_state_real_step = self.out_quantizer(updated_real)
internal_state_imag_step = self.out_quantizer(updated_imag)
# Store the updates in next_internal_state. To do so, pad the updated state with zeroes
# to left and right in order to rebuild a 'full' state where i-th timestep is equal to
# computed update. Then add it to next_internal_state that was initialized to zeroes:
# this effectively mimics an inplace update (working around Tensorflow logic).
paddings = [[0, 0], [i, timesteps - 1 - i], [0, 0]]
padded_real = tf.pad(tf.expand_dims(internal_state_real_step, axis=1),
paddings=paddings, constant_values=padding_value)
padded_imag = tf.pad(tf.expand_dims(internal_state_imag_step, axis=1),
paddings=paddings, constant_values=padding_value)
# Ensure the shape is as expected because TensorFlow will not tolerate shape changes
# within the loop
padded_real.values.set_shape(next_internal_state_real.shape)
padded_imag.values.set_shape(next_internal_state_imag.shape)
next_internal_state_real += padded_real
next_internal_state_imag += padded_imag
# Update internal state members
self._internal_state_real.set_var(internal_state_real_step)
self._internal_state_imag.set_var(internal_state_imag_step)
# Return the concatenated states
return tf.stack([next_internal_state_real, next_internal_state_imag], -1)
[docs]
def reset_states(model):
""" Resets all StatefulRecurrent layers internal states in the model.
Args:
model (keras.Model): the model to reset
"""
for layer in model.layers:
if isinstance(layer, StatefulRecurrent):
layer.reset_layer_states()
[docs]
def update_batch_size(model, batch_size):
""" Updates the batch size in a model.
Similar to keras RNN/Cell behavior where batch size must be known, the StatefulRecurrent layer
state is built with a (None, ...) shape that must be defined runtime. This helper allow to set
or update a model batch size.
Args:
model (keras.Model): the model to update
batch_size (int): batch size to set
Returns:
keras.Model: an updated model (the original model when batch size is unchanged)
"""
# Update batch size in config
config = model.get_config()
input_shape = config['layers'][0]['config']['batch_input_shape']
if input_shape[0] == batch_size:
return model
config['layers'][0]['config']['batch_input_shape'] = (batch_size, *input_shape[1:])
# Force layers shapes to be recomputed
for layer_config in config["layers"]:
layer_config.pop('build_config', None)
# Rebuild model and transfer weights
updated_model = model.from_config(config)
updated_model.set_weights(model.get_weights())
return updated_model
[docs]
@keras.saving.register_keras_serializable()
class StatefulProjection(keras.layers.Dense):
""" Same as a Dense layer but with optional subsampling.
Args:
subsample (int, optional): subsampling factor applied to the timesteps. Defaults to 1.
"""
def __init__(self, *args, subsample=1, **kwargs):
super().__init__(*args, **kwargs)
self.subsample = subsample
def build(self, input_shape):
with tf.name_scope(self.name + '/'):
super().build(input_shape)
def call(self, inputs):
# Apply the optional subsampling
if self.subsample > 1:
assert inputs.shape[1] % self.subsample == 0, \
f"Number of timesteps: {inputs.shape[1]} must be a multiple of subsample " \
f"ratio: {self.subsample}."
inputs = inputs[:, (self.subsample - 1)::self.subsample, :]
# Standard Dense operation
outputs = super().call(inputs)
return outputs
def get_config(self):
config = super().get_config()
config["subsample"] = self.subsample
return config
[docs]
@register_quantize_target([StatefulProjection], has_weights=True)
@register_aligned_inputs
@keras.saving.register_keras_serializable()
class QuantizedStatefulProjection(QuantizedLayer, StatefulProjection):
""" A quantized version of the StatefulProjection layer that operates on quantized inputs.
"""
@neural_layer_init(False)
def __init__(self, *args, **kwargs):
# Limit buffer bitwidth to 27 for HW constraint
self.quant_config['buffer_bitwidth'] = min(28, self.quant_config['buffer_bitwidth'])
self.buffer_bitwidth = self.quant_config['buffer_bitwidth'] - 1
@tensor_inputs([QTensor, tf.Tensor])
@rescale_outputs
def call(self, inputs):
if self.subsample > 1:
assert inputs.shape[1] % self.subsample == 0, \
f"Number of timesteps: {inputs.shape[1]} must be a multiple of subsample " \
f"ratio: {self.subsample}."
inputs = inputs[:, (self.subsample - 1)::self.subsample, :]
# Quantize the weights
kernel = self.weight_quantizer(self.kernel)
outputs = tf.matmul(inputs, kernel)
if self.use_bias:
# Quantize and align biases
bias = self.bias_quantizer(self.bias, outputs)
outputs = tf.add(outputs, bias)
return outputs
[docs]
@keras.saving.register_keras_serializable()
class PicoPostProcessing(keras.layers.Layer):
""" Post-processing layer that computes the mean absolute difference between predictions and
targets along a given axis and binarizes the result using a threshold.
This layer is useful for light-weight evaluation where a continuous error metric is
converted into a binary decision per sample (e.g. anomaly detection or simple pass/fail).
Args:
threshold (float): threshold applied to the mean absolute difference. Values greater
than or equal to it become 1.0, otherwise 0.0.
Example:
>>> layer = PicoPostProcessing(threshold=0.5)
>>> out = layer(y_pred, y_true)
"""
def __init__(self, *args, threshold, **kwargs):
super().__init__(*args, **kwargs)
self.threshold = threshold
def _check_input_constraints(self, y_pred, y_true):
# Check input constraints.
if not isinstance(y_true, tf.Tensor):
raise TypeError(f"{self.__class__.__name__} only accepts {tf.Tensor} for second input. "
f"Receives '{type(y_true)}'.")
if not (y_true_dtype := tf.as_dtype(y_true.dtype)).is_integer:
raise TypeError(f"{self.__class__.__name__} only accepts integer for second input. "
f"Receives '{y_true_dtype}'.")
def call(self, y_pred, y_true):
self._check_input_constraints(y_pred, y_true)
# Compare inputs.
y_true = tf.cast(y_true, y_pred.dtype)
outputs = tf.reduce_mean(tf.abs(y_pred - y_true), axis=1)
# Binarize the output.
outputs = tf.where(outputs >= self.threshold, 1.0, 0.0)
return outputs
def get_config(self):
config = super().get_config()
config["threshold"] = self.threshold
return config
[docs]
@register_quantize_target([PicoPostProcessing])
@keras.saving.register_keras_serializable()
class QuantizedPicoPostProcessing(QuantizedLayer, PicoPostProcessing):
"""Quantized PicoPostProcessing layer.
A quantized variant of PicoPostProcessing that operates on quantized predictions
(QTensor) and integer targets.
Args:
threshold (float): threshold applied to the mean absolute difference. Values greater
than or equal to it become 1.0, otherwise 0.0.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Promote buffer bitwidth to 44.
self.quant_config['buffer_bitwidth'] = max(44, self.quant_config.get('buffer_bitwidth', 0))
self.buffer_bitwidth = self.quant_config['buffer_bitwidth'] - 1
# Initialize prediction quantizer with default parameters.
self.pred_quantizer = OutputQuantizer(bitwidth=16,
axis="per-axis",
scale_bits=16,
buffer_bitwidth=self.quant_config['buffer_bitwidth'],
name="prediction_quantizer")
def build(self, input_shape):
super().build(input_shape)
with tf.name_scope(self.name + '/'):
# Explicitly build the OutputQuantizers so that output frac_bits can be computed.
self.pred_quantizer.build(input_shape)
def _check_input_constraints(self, y_pred, y_true):
super()._check_input_constraints(y_pred, y_true)
if not isinstance(y_pred, QTensor):
raise TypeError(f"{self.__class__.__name__} only accepts {QTensor} for first input. "
f"Receives '{type(y_pred)}'.")
def call(self, y_pred, y_true):
self._check_input_constraints(y_pred, y_true)
# Force quantizer to rescale the y_pred to '0' frac_bits, as this is the input scale.
# This is achieved by setting the maximum range to the limit.
max_range = 2.0**self.pred_quantizer.value_bits
self.pred_quantizer.range_max.assign(
max_range * tf.ones_like(self.pred_quantizer.range_max))
# Convert both inputs into a FixedPoint.
y_true = FixedPoint(y_true, value_bits=self.pred_quantizer.value_bits, frac_bits=0)
y_pred = self.pred_quantizer(y_pred)
# Compare inputs.
outputs = tf.abs(y_pred - y_true)
outputs = outputs.promote(self.buffer_bitwidth)
outputs = tf.reduce_mean(outputs, axis=1)
# Binarize the output.
outputs = tf.where(outputs.values >= self.threshold, 1.0, 0.0)
return FixedPoint(outputs, value_bits=1, frac_bits=0)