Source code for akida_models.tenn_recurrent.convert_recurrent

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
TENN kernelized model conversion to stateful.
"""

__all__ = ['convert_to_stateful']

import numpy as np

from copy import deepcopy
from tf_keras import layers
from tf_keras.models import Sequential, Model
from tf_keras.saving import serialize_keras_object

from quantizeml.models.transforms.transforms_utils import (get_layers_by_type, get_layer_index,
                                                           inbound_node_generator, get_layers,
                                                           update_inbound)
from quantizeml.models.transforms import sanitize
from quantizeml.models.transforms.insert_layer import insert_in_config
from quantizeml.layers import (StatefulRecurrent, ExtractToken, StatefulProjection,
                               PicoPostProcessing)


from ..custom_layers import Kernelized, Stride, zoh_discretize


def _clean_build_config(config):
    for layer_config in config["layers"]:
        layer_config.pop('build_config', None)


def _replace_input_first_dim(model, config, timesteps=1):
    """ Replace input first dimension with 1.

    Args:
        model (keras.Model): original model
        config (dict): model config being updated
        timesteps (int, optional): number of timesteps. Defaults to 1.
    """
    input = get_layers_by_type(model, layers.InputLayer)
    if len(input) != 1:
        raise RuntimeError(f'Detected {len(input)} InputLayer layers while expecting 1.')
    input_index = get_layer_index(config['layers'], input[0].name)
    input_config = config['layers'][input_index]
    shape = input_config['config']['batch_input_shape']
    # Force batch_size to 1, it can always be updated later
    input_config['config']['batch_input_shape'] = (1, timesteps, *shape[2:])


def _find_outbounds(config, target):
    outbounds = []
    # Parse all layer configs
    for layer in config['layers']:
        # Parse all inbounds for each layer
        for inbound_node in inbound_node_generator(layer):
            # Look for 'target' in the inbounds list
            if isinstance(inbound_node, dict):
                inbound_node = inbound_node.values()
            for connection_info in inbound_node:
                # When there is a match, add it to outbounds. Nested loops to parse the nested
                # lists of the tfmot config, ie: [[[['conv1', 0, 0, {} ]]]]
                if connection_info[0] == target:
                    outbounds.append(layer['name'])
    return outbounds


def _remove_from_config(config, removables):
    for layer_to_remove in get_layers(config, [rm.name for rm in removables]):
        config['layers'].remove(layer_to_remove)


def _replace_kernelized(model, config):
    """ Replaces Kernelized with StafefulRecurrent.

    Args:
        model (keras.Model): original model
        config (dict): model config being updated

    Returns:
        dict: map of kernelized layer names to their related dense layers (list of input and output
        projections names).
    """
    kernelized_to_dense = {}

    # Retrieve Kernelized layers
    target_layers = get_layers_by_type(model, Kernelized)

    # Replace layers in config
    for kernelized in target_layers:
        kernelized_index = get_layer_index(config['layers'], kernelized.name)
        kernelized_config = config['layers'][kernelized_index]

        # Update configuration
        out_channels = kernelized_config['config'].pop('out_channels')
        for param in ['speed', 'speed_range', 'repeat', 'num_coeffs', 'force_full_conv']:
            kernelized_config['config'].pop(param)

        # Update layer type
        new_config = StatefulRecurrent.from_config(kernelized_config['config'])
        kernelized_config.update(serialize_keras_object(new_config))

        # Build the input projection, output projection and extractToken layers as independent
        # layers
        input_proj = StatefulProjection(units=kernelized.num_coeffs * kernelized.repeat,
                                        use_bias=False, name=f"{kernelized.name}_input_proj")

        outb = _find_outbounds(config, kernelized.name)
        if len(outb) == 1 and isinstance(model.get_layer(outb[0]), Stride):
            subsample = model.get_layer(outb[0]).stride
        else:
            subsample = 1
        output_proj = StatefulProjection(units=out_channels, use_bias=False, subsample=subsample,
                                         name=f"{kernelized.name}_output_proj")
        extract_token = ExtractToken(token=0, axis=-1,
                                     name=f"{kernelized.name}_extract_internal_state_real")

        # Add them to the configuration
        kernelized_inbounds = kernelized_config['inbound_nodes'][0]
        target_outbounds = kernelized.outbound_nodes
        outbound_names = [outbound.layer.name for outbound in target_outbounds]
        assert len(kernelized_inbounds) == 1, "Only supporting single inbound kernelized layers."
        insert_in_config(model, kernelized_inbounds[0][0], input_proj, config)
        insert_in_config(model, kernelized.name, extract_token, config)
        insert_in_config(model, extract_token.name, output_proj, config, outbound_names)

        # Store a map linking the kernelized name to the two dense projection layers to allow future
        # weight loading
        kernelized_to_dense[kernelized.name] = [input_proj.name, output_proj.name]

    return kernelized_to_dense


def _remove_non_recurrent(model, config):
    """ Edits configuration to remove GlobalAveragePooling1D, SpatialDropout1D and Stride layers.

    Args:
        model (keras.Model): original model
        config (dict): model config being updated
    """
    # Retrieve GlobalAveragePooling1D, SpatialDropout1D and Stride layers
    non_recurrent = (layers.GlobalAveragePooling1D, layers.SpatialDropout1D, Stride)
    removables = get_layers_by_type(model, non_recurrent)

    # For sequential model, the 'removable' layers will simply be removed in the following step. For
    # other models, the layers inbounds/outbounds must be rebuilt.
    if not isinstance(model, Sequential):
        for removable in removables:
            # Retrieve outbound from the configuration and not from the model itself because when
            # replacing Kernelized layers, inbounds/outbounds have been updated in the configuration
            # but the model was not rebuild yet.
            outbounds = _find_outbounds(config, removable.name)

            # Limit support to single inbound/outbound
            if len(removable.inbound_nodes) != 1 or len(outbounds) > 1:
                continue

            # Retrieve the 'removable' input layer, assuming it has only 1 inbound
            removable_index = get_layer_index(config['layers'], removable.name)
            # tfmot code: 'inbound_nodes' is a nested list where first element is the inbound
            # layername, e.g: [[['conv1', 0, 0, {} ]]]
            updated_inbound = config['layers'][removable_index]['inbound_nodes'][0][0][0]

            # Update 'removable' outbounds layers: their current inbound is the 'removable' layer
            # that will be removed so it must be replaced with the 'removable' previous layer. This
            # results in by-passing the 'removable' layer: inbound > removable > outbounds becomes
            # inbound > outbounds.
            if len(outbounds) == 1:
                next_index = get_layer_index(config['layers'], outbounds[0])
                update_inbound(config['layers'][next_index], removable.name, updated_inbound)
            else:
                # no outbounds when 'removable' is the last layer in the model, then update the
                # model output_layers
                config['output_layers'][0][0] = updated_inbound

    _remove_from_config(config, removables)


def _set_stateful_weights(src_model, dst_model, kernelized_to_dense):
    """ Sets the given weights in the stateful model, adapting Kernelized weights to
    StatefulRecurrent.

    Args:
        src_model (keras.Model): original model
        dst_model (keras.Model): stateful model to set weights into
        kernelized_to_dense (dict): map from kernelized layer names to their related dense layers
            (list of input and output projections names).
    """
    for src_ly in src_model.layers:
        weights = src_ly.get_weights()
        if len(weights):
            dst_ly = dst_model.get_layer(src_ly.name)
            if isinstance(dst_ly, StatefulRecurrent):
                # Retrieve A, B, C and log_dt
                A, B, C, log_dt = weights

                # Compute A_hat and B_hat and update weights
                A_hat, B_hat, _ = zoh_discretize(A, B, log_dt)
                A_hat, B_hat = A_hat.numpy(), B_hat.numpy()

                # Store A_hat as two weights (real and imag parts)
                A_hat_real = np.real(A_hat)
                A_hat_imag = np.imag(A_hat)

                # Redefine weights to set in StatefulRecurrent layer
                weights = [A_hat_real, A_hat_imag]

                # Set B_hat and C weights in respective input/output dense projections
                proj_list = kernelized_to_dense[src_ly.name]
                dst_model.get_layer(proj_list[0]).set_weights([B_hat])
                dst_model.get_layer(proj_list[1]).set_weights([C])

            dst_ly.set_weights(weights)


[docs] def convert_to_stateful(model, timesteps=1, threshold=None): """ Converts the given Kernelized based model to a StatefulRecurrent based model. Args: model (keras.Model): the source model timesteps (int, optional): number of timesteps. Defaults to 1. threshold (int, optional): add a PicoPostProcessing layer with a threshold to compare the input with the reconstructed output. Useful in anomaly detection tasks. Defaults to None. Returns: keras model: stateful model """ # Copy configuration before applying modifications config = deepcopy(model.get_config()) # Replace input shape first dimension with 'timesteps' as data will be streamed to the model _replace_input_first_dim(model, config, timesteps=timesteps) # Replace Kernelized layers with StatefulRecurrent kernelized_to_dense = _replace_kernelized(model, config) # Remove GlobalAveragePooling1D, SpatialDropout1D and Stride layers _remove_non_recurrent(model, config) # Since layers were replaced, the built shapes are dropped to allow for a clean rebuild _clean_build_config(config) # Reconstruct model from the config, using the cloned layers intermediate_model = model.from_config(config) # Restore model weights _set_stateful_weights(model, intermediate_model, kernelized_to_dense) # Apply sanitize to fold BN and get a 'clean' architecture sanitized_model = sanitize(intermediate_model) sanitized_config = deepcopy(sanitized_model.get_config()) # Rebuild stateful model from sanitized config _clean_build_config(sanitized_config) stateful_model = model.from_config(sanitized_config) # Set back weights stateful_model.set_weights(sanitized_model.get_weights()) # Insert post-processing if required. if threshold is not None and len(get_layers_by_type(stateful_model, PicoPostProcessing)) == 0: try: x = stateful_model.input y = PicoPostProcessing(threshold=threshold)(stateful_model.output, x) stateful_model = Model(x, y) except Exception as e: raise RuntimeError("Impossible to append threshold.") from e return stateful_model