Source code for akida_models.filter_pruning

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2021 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""Filter pruning tools to automatically prune a model"""

import numpy as np
from tensorflow.keras import Sequential, Input
from tensorflow.keras.layers import (InputLayer, Conv2D, SeparableConv2D, Dense,
                                     BatchNormalization, MaxPool2D,
                                     GlobalAvgPool2D, Dropout, ReLU, Softmax,
                                     Activation, Flatten, Reshape)
from tensorflow.keras.layers.experimental.preprocessing import Rescaling

import cnn2snn


[docs]def delete_filters(model, layer_to_prune, filters_to_prune): """Deletes filters in the given layer and updates weights in it and its subsequent layers. A pruned model is returned. Only linear models are supported. Args: model (tf.keras.Model): the model to prune. layer_to_prune (str): the name of the neural layer where filters will be deleted. filters_to_prune (list): indices of filters to delete in the given layer. Returns: tf.keras.Sequential: the pruned model """ pruned_model = _instantiate_pruned_model(model, layer_to_prune, len(filters_to_prune)) _assign_pruned_weights(model, pruned_model, layer_to_prune, filters_to_prune) return pruned_model
def _instantiate_pruned_model(model, layer_to_prune, num_filters_to_prune): """Returns the pruned model where the layer to prune has less filters than the original layer. Note that the weights are not transferred yet from the original model to the pruned one. Args: model (tf.keras.Model): the model to prune. layer_to_prune (str): the name of the layer where filters will be deleted. num_filters_to_prune (int): number of filters to delete in the given layer. Returns: tf.keras.Sequential: the pruned model (without weights from the original model) """ pruned_model = Sequential() pruned_model.add(Input(model.input_shape[1:])) # Create the pruned model by browsing the layers of the original model and # by deleting filters/channels in the concerned layers: # - for each layer of the original model, create a new layer with the new # configuration # - add the new layer to the pruned Sequential model for layer in model.layers: if isinstance(layer, InputLayer): continue if isinstance(layer, (Conv2D, SeparableConv2D)): new_config = layer.get_config() if layer.name == layer_to_prune: new_config[ 'filters'] = new_config['filters'] - num_filters_to_prune if new_config['filters'] == 0: raise RuntimeError("Impossible to prune all filters of " f"layer {layer_to_prune}.") new_layer = layer.__class__.from_config(new_config) elif isinstance(layer, Dense): new_config = layer.get_config() if layer.name == layer_to_prune: new_config['units'] = new_config['units'] - num_filters_to_prune if new_config['units'] == 0: raise RuntimeError("Impossible to prune all units of " f"layer {layer_to_prune}.") new_layer = layer.__class__.from_config(new_config) elif isinstance(layer, Reshape): new_config = layer.get_config() # If the layer preceding a Reshape layer was pruned, the new Reshape # layer must handle the new target shape. Note that, in cnn2snn, # only Reshape layer (N,)->(1,1,N) is supported. if layer.target_shape[-1] != pruned_model.output_shape[-1]: new_target_shape = layer.target_shape[:-1] + ( pruned_model.output_shape[-1],) new_config['target_shape'] = new_target_shape new_layer = layer.__class__.from_config(new_config) elif isinstance(layer, (BatchNormalization, MaxPool2D, GlobalAvgPool2D, Dropout, Flatten, ReLU, Softmax, Activation, Rescaling, cnn2snn.QuantizedActivation)): new_layer = layer.__class__.from_config(layer.get_config()) else: raise RuntimeError( f"Layer {layer.name} of type {layer.__class__.__name__} " "is not supported for pruning") pruned_model.add(new_layer) return pruned_model def _assign_pruned_weights(model, pruned_model, layer_to_prune, filters_to_prune): """Assigns new weights from the original model to the pruned one. For each layer in the model, two steps are performed: - update the outbound mask. The outbound mask (of length the number of output channels) represents the filters/channels to keep. - set the new pruned weights, according to the inbound/outbound masks. Args: model (tf.keras.Model): the model to prune. pruned_model (tf.keras.Sequential): the pruned model where the weights must be set. layer_to_prune (str): the name of the layer where filters were deleted. filters_to_prune (list): indices of filters to delete in the given layer. """ # Mask representing the filters to keep (False = filters to prune) inbound_mask = [True] * model.input_shape[-1] for layer in model.layers: # Step 1: Update the outbound mask. # Only layers that modifies the number of channels are concerned (i.e. # neural layers and Flatten layers). The mask for the layer to prune has # "False" values for indices to prune. The mask for other neural layers # is "True" for all output channels. if layer.name == layer_to_prune: outbound_mask = [ i not in filters_to_prune for i in range(layer.output_shape[-1]) ] elif isinstance(layer, (Conv2D, SeparableConv2D, Dense)): # All other (i.e. not pruned) neural layers reset the outbound_mask # (all filters are kept) outbound_mask = [True] * layer.output_shape[-1] elif isinstance(layer, Flatten): # A Flatten layer modifies the number of channels: C_out = H*W*C_in. # The outbound mask is updated, depending on the input size: # input_size = H*W = C_out/C_in input_size = int(layer.output_shape[-1] / layer.input_shape[-1]) outbound_mask = inbound_mask * input_size elif isinstance(layer, (InputLayer, BatchNormalization, MaxPool2D, GlobalAvgPool2D, Dropout, ReLU, Softmax, Activation, Reshape, Rescaling, cnn2snn.QuantizedActivation)): # All these layers have no weights and do not change the number of # channels. Reshape could modify the number of channels. However we # only support (N,) -> (1,1,N), thus there is no change in the # number of channels. outbound_mask = inbound_mask else: raise RuntimeError( f"Layer {layer.name} of type {layer.__class__.__name__} " "is not supported for pruning") # Step 2: set new pruned weights. # The concerned layers are all layers between the neural layer to prune # and the next neural layer. Layers with weights are # (Quantized)Conv2D/SeparableConv2D/Dense and BatchNormalization layers. if isinstance(layer, Conv2D): new_weights = layer.get_weights() new_weights[0] = new_weights[0][:, :, inbound_mask, :] new_weights[0] = new_weights[0][..., outbound_mask] if layer.use_bias: new_weights[1] = new_weights[1][outbound_mask] pruned_model.get_layer(layer.name).set_weights(new_weights) elif isinstance(layer, SeparableConv2D): new_weights = layer.get_weights() new_weights[0] = new_weights[0][:, :, inbound_mask, :] new_weights[1] = new_weights[1][:, :, inbound_mask, :] new_weights[1] = new_weights[1][..., outbound_mask] if layer.use_bias: new_weights[2] = new_weights[2][outbound_mask] pruned_model.get_layer(layer.name).set_weights(new_weights) elif isinstance(layer, Dense): new_weights = layer.get_weights() new_weights[0] = new_weights[0][inbound_mask, :] new_weights[0] = new_weights[0][:, outbound_mask] if layer.use_bias: new_weights[1] = new_weights[1][outbound_mask] pruned_model.get_layer(layer.name).set_weights(new_weights) elif isinstance(layer, BatchNormalization): outbound_mask = inbound_mask new_weights = layer.get_weights() for i in range(4): new_weights[i] = new_weights[i][inbound_mask] pruned_model.get_layer(layer.name).set_weights(new_weights) # Update inbound mask for next layer inbound_mask = outbound_mask def neural_layers(model): """ Returns the list of neural layers names, except the last one. Args: model (`tf.keras.Model`): the model to prune. Returns: list: the names of the prunable layers. """ layers = [ l.name for l in model.layers if type(l) in (Conv2D, SeparableConv2D, Dense) ] return layers[:-1] def _get_weights_for_pruning(layer): """Returns the kernel of the layer where filters will be pruned. For SeparableConv2D, it returns the pointwise kernel. Args: layer (`tf.keras.layers.Layer`): the layer to prune. Returns: `:obj:np.ndarray`: the weights used for pruning. """ weight_index_per_layer_type = {Conv2D: 0, SeparableConv2D: 1, Dense: 0} if type(layer) not in weight_index_per_layer_type: raise RuntimeError( f"Layer {layer.name} must be of type Conv2D, SeparableConv2D or " f"Dense. Receives type {type(layer)}.") weight_index = weight_index_per_layer_type[type(layer)] return layer.get_weights()[weight_index] def _filter_magnitudes(layer): """Returns the absolute mean of filters for the given layer. This is called the magnitude of a filter. Args: layer (`tf.keras.layers.Layer`): the layer to prune. Returns: `:obj:np.ndarray`: the magnitudes of the layer's filters. """ weight = _get_weights_for_pruning(layer) return np.mean(np.abs(weight), axis=tuple(range(weight.ndim - 1))) def smallest_filters(layer, pruning_rate): """Returns indices of the smallest filters to prune, based on the rate of filters to prune. Pruning rate must be a value between 0 and 1. For instance, 0.3 means that 30% of the filters will be pruned. The function then returns the indices of the 30% smallest filters in this layer. Args: layer (`tf.keras.layers.Layer`): the layer to prune. pruning_rate (float): the pruning rate (between 0 and 1). Returns: `:obj:np.ndarray`: the indices of the filters to prune. """ l1_norms = _filter_magnitudes(layer) indices_sorted = np.argsort(l1_norms) num_filters_to_prune = int(np.round(l1_norms.size * pruning_rate)) return indices_sorted[:num_filters_to_prune]
[docs]def prune_model(model, acceptance_function, pruning_rates=None, prunable_layers_policy=neural_layers, prunable_filters_policy=smallest_filters): """Prune model automatically based on an acceptance function. The algorithm for filter pruning is as follows: 1. Select the first prunable layer (according to the ``prunable_layers_policy`` function). 2. As long as the ``acceptance_function`` returns True, prune successively the layer with different pruning rates (according to ``pruning_rates`` and ``prunable_filters_policy``). 3. When the current pruned model is not acceptable, the last valid pruning rate is selected for the final pruned model. 4. Repeat steps 1, 2 and 3 for the next prunable layers. Examples: .. code-block:: python acceptable_drop = 0.05 def acceptance_function(base_model, pruned_model): # This function returns True if the pruned_model is acceptable. # Here, the pruned model is acceptable if the accuracy drops # less than 5% from the base model. def evaluate(model): _, accuracy = model.evaluate(data, labels) return accuracy return evaluate(base_model) - evaluate(pruned_model) <= acceptable_drop # Prune model pruned_model, pruning_rates = prune_model(model, acceptance_function) Args: model (:obj:`tf.keras.Model`): a tf.keras model to prune acceptance_function (function): a criterion function that returns True if the pruned model is acceptable. The signature must be `function(base_model, pruned_model)`. pruning_rates (list, optional): a list of pruning rates to test. Default is [0.1, 0.2, ..., 0.9]. prunable_layers_policy (function, optional): a function returning the layers to prune in the model. The signature must be `function(model)`, and must return a list of prunable layer names. By default, all neural layers (Conv2D/SeparableConv2D/Dense) are candidates for pruning. prunable_filters_policy (function, optional): a function that returns the filters to prune in a given layer for a specific pruning rate. The signature must be `function(layer, pruning_rate)` and returns a list of indices to prune. By default, filters with the lowest magnitude are pruned. Returns: tuple: the pruned model and the pruning rates. """ if pruning_rates is None: pruning_rates = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] # Start by creating a copy of the model where batchnorm is folded pruned_model = cnn2snn.utils.invert_batchnorm_pooling(model) pruned_model = cnn2snn.utils.fold_batch_norms(pruned_model) # Get the list of prunable layers layers_to_prune = prunable_layers_policy(pruned_model) # Initialize pruning rates (no pruning) rates = {name: 0.0 for name in layers_to_prune} for name in layers_to_prune: layer = pruned_model.get_layer(name) next_pruned_model = None for pruning_rate in pruning_rates: # Get the list of prunable filters for the specified rate filters = prunable_filters_policy(layer, pruning_rate) candidate_model = delete_filters(pruned_model, layer.name, filters) if acceptance_function(pruned_model, candidate_model): # The candidate model passed the acceptance criteria next_pruned_model = candidate_model rates[name] = pruning_rate else: # Stop at that rate break if next_pruned_model: pruned_model = next_pruned_model return pruned_model, rates