Source code for quantizeml.models.transforms.remove_zeropadding2d

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Tools to remove ZeroPadding2D layers from a model.
"""

__all__ = ["remove_zeropadding2d"]

from copy import deepcopy

from keras.models import Sequential
from keras.layers import ZeroPadding2D, Conv2D, SeparableConv2D, DepthwiseConv2D

from .transforms_utils import get_layers, get_layer_index, get_layers_by_type, update_inbound


def _find_removable_zeropadding(model):
    """ Retrieves ZeroPadding2D layers that can be removed.

    This is limited to ZeroPadding2D layers that come before supported layer types and that perform
    a 'same' padding.

    Args:
        model (keras.Model): a model

    Returns:
        dict: map between a ZeroPadding2D and the layer that follows
    """
    map_zeropadding_next = {}

    # Define layers that will support ZeroPadding removal
    supported_layers = (Conv2D, SeparableConv2D, DepthwiseConv2D)

    # Get all ZeroPadding2D layers present in the model
    zeropaddings = get_layers_by_type(model, ZeroPadding2D)

    # Sequential models must be built because a check must be done on shapes
    if zeropaddings and isinstance(model, Sequential) and not model.built:
        raise ValueError("This model has not yet been built.")

    # Find the ones that can be removed
    for zeropadding in zeropaddings:
        # Limit support to single inbound/outbound
        outbounds = zeropadding.outbound_nodes
        if len(zeropadding.inbound_nodes) != 1 or len(outbounds) != 1:
            continue

        # Check that the layer that follows is supported and has a 'valid' padding
        following_layer = outbounds[0].layer
        if not isinstance(following_layer, supported_layers) or following_layer.padding != 'valid':
            continue

        # Check that the combination of ZeroPadding2D + following layer performs a 'same' padding:
        # this is done by checking that next_layer.output_shape * strides = zeropadding.input_shape
        out_spatial_dims = following_layer.output_shape[1:3]
        stride = following_layer.strides
        rectified_out_spatial_dims = tuple(dim * s for dim, s in zip(out_spatial_dims, stride))
        if rectified_out_spatial_dims != zeropadding.input_shape[1:3]:
            continue

        # At this point the ZeroPadding2D is a valid candidate
        map_zeropadding_next[zeropadding] = following_layer
    return map_zeropadding_next


def _get_zeropadding_less_model(model, map_zeropadding_next):
    """ Edits the model configuration to remove ZeroPadding2D layers and rebuilds a model.

    Args:
        model (keras.Model): a model
        map_zeropadding_next (dict): map between a ZeroPadding2D and the layer that follows

    Returns:
        keras.Model: an updated model without ZeroPadding2D layers
    """
    # get_config documentation mentions that a copy should be made when planning to modify the
    # config
    config = deepcopy(model.get_config())
    layers = config['layers']

    for zeropadding, next_layer in map_zeropadding_next.items():
        # Set padding='same' in the layer that follows a ZeroPadding that will be removed
        next_index = get_layer_index(layers, next_layer.name)
        layers[next_index]['config']['padding'] = 'same'

        # For sequential model, the changes stop here: the ZeroPadding2D layers will simply be
        # removed in the following step. For other models, the layers inbounds/outbounds must be
        # rebuilt.
        if isinstance(model, Sequential):
            continue

        # Retrieve the ZeroPadding2D input layer, assuming it has only 1 inbound
        zeropadding_index = get_layer_index(layers, zeropadding.name)
        # tfmot code: 'inbound_nodes' is a nested list where first element is the inbound layername,
        # e.g: [[['conv1', 0, 0, {} ]]]
        updated_inbound = layers[zeropadding_index]['inbound_nodes'][0][0][0]

        # Update ZeroPadding2D outbounds layers: their current inbound is the ZeroPadding2D layer
        # that will be removed so it must be replaced with the ZeroPadding2D previous layer. This
        # results in by-passing the ZeroPadding2D layer: inbound > ZeroPadding2D > outbounds becomes
        # inbound > outbounds.
        update_inbound(layers[next_index], zeropadding.name, updated_inbound)

    # Remove ZeroPadding2D layers
    layers_to_remove = get_layers(config, [zp.name for zp in map_zeropadding_next.keys()])
    for layer_to_remove in layers_to_remove:
        layers.remove(layer_to_remove)

    # Reconstruct model from the config, using the cloned layers
    return model.from_config(config)


[docs]def remove_zeropadding2d(model): """ Removes ZeroPadding2D layers from a model. ZeroPadding2D layers will not be supported by quantization so this transform adds support so that when the ZeroPadding2D layers are immediately followed by a convolution layer with 'valid' padding, they are removed and the following convolution is updated with a 'same' padding instead. This can however only happen when the padding specified in ZeroPadding2D actually corresponds to a 'same' padding. Args: model (keras.Model): the model to update Returns: keras.Model: the original model or a new model with ZeroPadding2D removed """ # Find ZeroPadding2D and following layer pairs that are candidates for removal map_zeropadding_next = _find_removable_zeropadding(model) # When there are no valid candidates, return the original model if not map_zeropadding_next: return model # Rebuild a model without ZeroPadding2D by editing the configuration updated_model = _get_zeropadding_less_model(model, map_zeropadding_next) # Restore model weights updated_model.set_weights(model.get_weights()) return updated_model