Helper that replaces lambdas with their equivalent Keras layer.

__all__ = ["replace_lambda"]

from copy import deepcopy
from keras.layers import Activation, TFOpLambda

from .transforms_utils import get_layers_by_type, get_layer_index, inbound_node_generator

def _update_inbound_nodes(layer_config, target):
    """ Update a Lambda layer inbound node towards their Layer equivalent.

        layer_config (dict): config of the lambda layer
        target (str): class name of the target layer

        str: name of the lambda operation
    name = None
    for inbound_node in inbound_node_generator(layer_config):
        if isinstance(inbound_node, dict):
            inbound_node = inbound_node.values()
        connection_info = inbound_node[0]
        # "connection_info[-1]" holds the lambda config, that is the op name and other lambda
        # specific parameters. Code below will retrieve meaningful parameters that will be used to
        # define the config of the new layer and will then be dropped.
        if target == 'Reshape':
            # Set the 'target_shape' parameter from the 'shape' attribute
            layer_config['config']['target_shape'] = connection_info[-1].get('shape')
        elif target == 'Permute':
            # Set the 'dims' parameter from the 'perm' attribute
            perm = connection_info[-1].get('perm')
            # Permute 'dims' must start at 1 but transpose 'perm' starts at 0
            layer_config['config']['dims'] = [p + 1 for p in perm]
        elif target == "Add":
            # Get the second inbound currently defined as {'y': ['name', 0, 0]}
            other_inbound = connection_info[-1].get('y')
            # Modify it to fit the ['name', 0, 0, {}] convention
            # Add the updated inbound into the inbound_node list and set it in layer_config
            layer_config['inbound_nodes'] = [inbound_node]
        # Retrieve the lambda name as it will be used to set the name for the layer
        name = connection_info[-1].get('name', None)
        # Drop the lambda config
        connection_info[-1] = {}
    return name

[docs]def replace_lambda(model): """ Replaces lambda layers from a model with their equivalent Keras layer. This transform handles the following replacements: - Lambda(relu) or Activation('relu') → ReLU, - Lambda(transpose) → Permute, - Lambda(reshape) → Reshape, - Lambda(add) → Add. Args: model (keras.Model): the model of interest Returns: keras.Model: the original model or a new one with lambda replaced. """ # Map function names to Keras layers lambda_to_layer = { 'nn.relu': 'ReLU', 'math.add': 'Add', 'reshape': 'Reshape', 'transpose': 'Permute', 'compat.v1.transpose': 'Permute' } # Get all Activations and TFOpLambda layers present in the model lambdas = get_layers_by_type(model, (Activation, TFOpLambda)) # When there are no valid candidates, return the original model if not lambdas: return model # Copy configuration before applying modifications config = deepcopy(model.get_config()) for layer in lambdas: layer_index = get_layer_index(config['layers'], layer_config = config['layers'][layer_index] # Replace 'relu' Activations layers with ReLU layers if (layer_config['class_name'] == 'Activation' and layer_config['config']['activation'] == 'relu'): # Drop the 'activation' parameter and update 'class_name' layer_config['config'].pop('activation') layer_config['class_name'] = 'ReLU' # Replace TFOpLambda layers elif layer_config['class_name'] == 'TFOpLambda': # Retrieve the function used in the config and get the equivalent Keras layer name target = lambda_to_layer.get(layer_config['config']['function'], None) if target: # Drop the 'function' parameter and update 'class_name' layer_config['config'].pop('function') layer_config['class_name'] = target # Update the inbound part of the config: the last element of the inbound list of # lambda layers will contain the lambda op parameters that are used to set the # config for the new layer. new_name = _update_inbound_nodes(layer_config, target) # If layer name was updated, use the new name everywhere in the config if new_name: # Serialize the dict into a string str_config = str(config) # Replace name using 'old_name' for an exact match str_config = str_config.replace(f"'{layer_config['name']}'", f"'{new_name}'") # Deserialize the updated string into a dict config = eval(str_config) # Reconstruct model from the config updated_model = model.from_config(config) # Restore model weights updated_model.set_weights(model.get_weights()) return updated_model