#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
TENN kernelized model conversion to stateful.
"""
__all__ = ['convert_to_stateful']
import numpy as np
from copy import deepcopy
from tf_keras import layers
from tf_keras.models import Sequential, Model
from tf_keras.saving import serialize_keras_object
from quantizeml.models.transforms.transforms_utils import (get_layers_by_type, get_layer_index,
inbound_node_generator, get_layers,
update_inbound)
from quantizeml.models.transforms import sanitize
from quantizeml.models.transforms.insert_layer import insert_in_config
from quantizeml.layers import (StatefulRecurrent, ExtractToken, StatefulProjection,
PicoPostProcessing)
from ..custom_layers import Kernelized, Stride, zoh_discretize
def _clean_build_config(config):
for layer_config in config["layers"]:
layer_config.pop('build_config', None)
def _replace_input_first_dim(model, config, timesteps=1):
""" Replace input first dimension with 1.
Args:
model (keras.Model): original model
config (dict): model config being updated
timesteps (int, optional): number of timesteps. Defaults to 1.
"""
input = get_layers_by_type(model, layers.InputLayer)
if len(input) != 1:
raise RuntimeError(f'Detected {len(input)} InputLayer layers while expecting 1.')
input_index = get_layer_index(config['layers'], input[0].name)
input_config = config['layers'][input_index]
shape = input_config['config']['batch_input_shape']
# Force batch_size to 1, it can always be updated later
input_config['config']['batch_input_shape'] = (1, timesteps, *shape[2:])
def _find_outbounds(config, target):
outbounds = []
# Parse all layer configs
for layer in config['layers']:
# Parse all inbounds for each layer
for inbound_node in inbound_node_generator(layer):
# Look for 'target' in the inbounds list
if isinstance(inbound_node, dict):
inbound_node = inbound_node.values()
for connection_info in inbound_node:
# When there is a match, add it to outbounds. Nested loops to parse the nested
# lists of the tfmot config, ie: [[[['conv1', 0, 0, {} ]]]]
if connection_info[0] == target:
outbounds.append(layer['name'])
return outbounds
def _remove_from_config(config, removables):
for layer_to_remove in get_layers(config, [rm.name for rm in removables]):
config['layers'].remove(layer_to_remove)
def _replace_kernelized(model, config):
""" Replaces Kernelized with StafefulRecurrent.
Args:
model (keras.Model): original model
config (dict): model config being updated
Returns:
dict: map of kernelized layer names to their related dense layers (list of input and output
projections names).
"""
kernelized_to_dense = {}
# Retrieve Kernelized layers
target_layers = get_layers_by_type(model, Kernelized)
# Replace layers in config
for kernelized in target_layers:
kernelized_index = get_layer_index(config['layers'], kernelized.name)
kernelized_config = config['layers'][kernelized_index]
# Update configuration
out_channels = kernelized_config['config'].pop('out_channels')
for param in ['speed', 'speed_range', 'repeat', 'num_coeffs', 'force_full_conv']:
kernelized_config['config'].pop(param)
# Update layer type
new_config = StatefulRecurrent.from_config(kernelized_config['config'])
kernelized_config.update(serialize_keras_object(new_config))
# Build the input projection, output projection and extractToken layers as independent
# layers
input_proj = StatefulProjection(units=kernelized.num_coeffs * kernelized.repeat,
use_bias=False, name=f"{kernelized.name}_input_proj")
outb = _find_outbounds(config, kernelized.name)
if len(outb) == 1 and isinstance(model.get_layer(outb[0]), Stride):
subsample = model.get_layer(outb[0]).stride
else:
subsample = 1
output_proj = StatefulProjection(units=out_channels, use_bias=False, subsample=subsample,
name=f"{kernelized.name}_output_proj")
extract_token = ExtractToken(token=0, axis=-1,
name=f"{kernelized.name}_extract_internal_state_real")
# Add them to the configuration
kernelized_inbounds = kernelized_config['inbound_nodes'][0]
target_outbounds = kernelized.outbound_nodes
outbound_names = [outbound.layer.name for outbound in target_outbounds]
assert len(kernelized_inbounds) == 1, "Only supporting single inbound kernelized layers."
insert_in_config(model, kernelized_inbounds[0][0], input_proj, config)
insert_in_config(model, kernelized.name, extract_token, config)
insert_in_config(model, extract_token.name, output_proj, config, outbound_names)
# Store a map linking the kernelized name to the two dense projection layers to allow future
# weight loading
kernelized_to_dense[kernelized.name] = [input_proj.name, output_proj.name]
return kernelized_to_dense
def _remove_non_recurrent(model, config):
""" Edits configuration to remove GlobalAveragePooling1D, SpatialDropout1D and Stride layers.
Args:
model (keras.Model): original model
config (dict): model config being updated
"""
# Retrieve GlobalAveragePooling1D, SpatialDropout1D and Stride layers
non_recurrent = (layers.GlobalAveragePooling1D, layers.SpatialDropout1D, Stride)
removables = get_layers_by_type(model, non_recurrent)
# For sequential model, the 'removable' layers will simply be removed in the following step. For
# other models, the layers inbounds/outbounds must be rebuilt.
if not isinstance(model, Sequential):
for removable in removables:
# Retrieve outbound from the configuration and not from the model itself because when
# replacing Kernelized layers, inbounds/outbounds have been updated in the configuration
# but the model was not rebuild yet.
outbounds = _find_outbounds(config, removable.name)
# Limit support to single inbound/outbound
if len(removable.inbound_nodes) != 1 or len(outbounds) > 1:
continue
# Retrieve the 'removable' input layer, assuming it has only 1 inbound
removable_index = get_layer_index(config['layers'], removable.name)
# tfmot code: 'inbound_nodes' is a nested list where first element is the inbound
# layername, e.g: [[['conv1', 0, 0, {} ]]]
updated_inbound = config['layers'][removable_index]['inbound_nodes'][0][0][0]
# Update 'removable' outbounds layers: their current inbound is the 'removable' layer
# that will be removed so it must be replaced with the 'removable' previous layer. This
# results in by-passing the 'removable' layer: inbound > removable > outbounds becomes
# inbound > outbounds.
if len(outbounds) == 1:
next_index = get_layer_index(config['layers'], outbounds[0])
update_inbound(config['layers'][next_index], removable.name, updated_inbound)
else:
# no outbounds when 'removable' is the last layer in the model, then update the
# model output_layers
config['output_layers'][0][0] = updated_inbound
_remove_from_config(config, removables)
def _set_stateful_weights(src_model, dst_model, kernelized_to_dense):
""" Sets the given weights in the stateful model, adapting Kernelized weights to
StatefulRecurrent.
Args:
src_model (keras.Model): original model
dst_model (keras.Model): stateful model to set weights into
kernelized_to_dense (dict): map from kernelized layer names to their related dense layers
(list of input and output projections names).
"""
for src_ly in src_model.layers:
weights = src_ly.get_weights()
if len(weights):
dst_ly = dst_model.get_layer(src_ly.name)
if isinstance(dst_ly, StatefulRecurrent):
# Retrieve A, B, C and log_dt
A, B, C, log_dt = weights
# Compute A_hat and B_hat and update weights
A_hat, B_hat, _ = zoh_discretize(A, B, log_dt)
A_hat, B_hat = A_hat.numpy(), B_hat.numpy()
# Store A_hat as two weights (real and imag parts)
A_hat_real = np.real(A_hat)
A_hat_imag = np.imag(A_hat)
# Redefine weights to set in StatefulRecurrent layer
weights = [A_hat_real, A_hat_imag]
# Set B_hat and C weights in respective input/output dense projections
proj_list = kernelized_to_dense[src_ly.name]
dst_model.get_layer(proj_list[0]).set_weights([B_hat])
dst_model.get_layer(proj_list[1]).set_weights([C])
dst_ly.set_weights(weights)
[docs]
def convert_to_stateful(model, timesteps=1, threshold=None):
""" Converts the given Kernelized based model to a StatefulRecurrent based model.
Args:
model (keras.Model): the source model
timesteps (int, optional): number of timesteps. Defaults to 1.
threshold (int, optional): add a PicoPostProcessing layer with a threshold to compare
the input with the reconstructed output. Useful in anomaly detection tasks.
Defaults to None.
Returns:
keras model: stateful model
"""
# Copy configuration before applying modifications
config = deepcopy(model.get_config())
# Replace input shape first dimension with 'timesteps' as data will be streamed to the model
_replace_input_first_dim(model, config, timesteps=timesteps)
# Replace Kernelized layers with StatefulRecurrent
kernelized_to_dense = _replace_kernelized(model, config)
# Remove GlobalAveragePooling1D, SpatialDropout1D and Stride layers
_remove_non_recurrent(model, config)
# Since layers were replaced, the built shapes are dropped to allow for a clean rebuild
_clean_build_config(config)
# Reconstruct model from the config, using the cloned layers
intermediate_model = model.from_config(config)
# Restore model weights
_set_stateful_weights(model, intermediate_model, kernelized_to_dense)
# Apply sanitize to fold BN and get a 'clean' architecture
sanitized_model = sanitize(intermediate_model)
sanitized_config = deepcopy(sanitized_model.get_config())
# Rebuild stateful model from sanitized config
_clean_build_config(sanitized_config)
stateful_model = model.from_config(sanitized_config)
# Set back weights
stateful_model.set_weights(sanitized_model.get_weights())
# Insert post-processing if required.
if threshold is not None and len(get_layers_by_type(stateful_model, PicoPostProcessing)) == 0:
try:
x = stateful_model.input
y = PicoPostProcessing(threshold=threshold)(stateful_model.output, x)
stateful_model = Model(x, y)
except Exception as e:
raise RuntimeError("Impossible to append threshold.") from e
return stateful_model