Source code for cnn2snn.calibration.adaround

# ******************************************************************************
# Copyright 2020 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""Implementation of the Adaround algorithm described in:
   Up or Down? Adaptive Rounding for Post-Training Quantization
   Markus Nagel, Rana Ali Amjad, Mart van Baalen, Christos Louizos, Tijmen
   Blankevoort
   https://arxiv.org/abs/2004.10568
"""
import shutil
from math import pi
import keras
import numpy as np
import tensorflow as tf

from ..transforms.clone import clone_model_with_weights
from ..quantization_layers import QuantizedDense, QuantizedConv2D, QuantizedSeparableConv2D
from ..quantization_ops import WeightQuantizer
from .calibration import QuantizationSampler, is_quantized_neural

# Rectified sigmoid constants
GAMMA = -0.1
ZETA = 1.1
# Rounding loss annealing constants
BETA_START = 20
BETA_END = 2
# Rounding loss regularization
RD_REG = 0.01


class AdaRoundQuantizer(WeightQuantizer):
    """Adaptive Rounding weights quantizer

    Replaces original weights by floored weights + decimal.

    It relies on an internal continuous variable to evaluate the decimals of
    each weights.

    It provides a rounding loss method that allows an optimizer to force the
    decimals to take the values 0 or 1.

    It also implements the WeightQuantizer API to be able to replace the
    original quantizers of a Model, and return the reconstructed weights
    instead of the quantized weights.

    Args:
        quantizer (:obj:`cnn2snn.quantization_ops.WeightQuantizer`): the
           original quantizer.
        weights (:obj:`np.ndarray`): the original weights
    """

    def __init__(self, quantizer, weights, **kwargs):
        # Store original quantizer
        self.quantizer = quantizer
        # Cache scale factor for the original weights
        self.sf = quantizer.scale_factor(weights)
        # Evaluate rescaled weights
        w_scaled = tf.constant(weights) * self.sf
        # Store floored rescaled weights
        self.w_floor = tf.constant(tf.math.floor(w_scaled))
        # Evaluate decimals
        d = w_scaled - self.w_floor
        # Create alpha variable, init so that self.decimals() == d
        initial_alpha = -1 * tf.math.log((ZETA - GAMMA) / (d - GAMMA) - 1)
        self.alpha = tf.Variable(name="alpha",
                                 initial_value=initial_alpha,
                                 dtype=tf.float32,
                                 trainable=True)
        # The quantizer behaviour is different at training or evaluation time
        self.training = True
        super().__init__(quantizer.bitwidth, **kwargs)

    def decimals(self):
        """The projection of the alpha variable to [0, 1]

        Returns:
            :obj:`tf.Tensor` : the reconstructed decimals
        """
        # Evaluate sigmoid(alpha) to evaluate the rounding between [0, 1]
        rounding = tf.math.sigmoid(self.alpha)
        # Apply an affine transformation to project to [GAMMA, ZETA]
        rounding = rounding * (ZETA - GAMMA) + GAMMA
        # Clip and return
        return tf.clip_by_value(rounding, 0, 1)

    def loss(self, beta):
        """The rounding loss used to force the decimals to 0 or 1.

        Args:
            beta (int): the loss annealing parameter

        Returns:
            :obj:`tf.Tensor`: the rounding loss Tensor
        """
        if beta == 0:
            return 0
        # The loss is the sum of the distances to the [0, 1] boundaries powered
        # by the beta annealing parameter
        return tf.reduce_sum(1 - (tf.abs(2 * self.decimals() - 1))**beta)

    def status(self):
        """The current rounding status

        This primarily evaluates the number of roamers, i.e. weights that have
        not be rounded yet.
        This also returns the number of floored and ceiled weights.

        Returns:
            (int, int, int): the number of roamers, floored and ceiled weights
        """
        d = self.decimals().numpy()
        num_weights = d.size
        roamers = num_weights - np.sum(np.isin(d, [0, 1]))
        floor_r = np.sum(d <= 0.5) / num_weights
        ceil_r = np.sum(d > 0.5) / num_weights
        return roamers, floor_r, ceil_r

    def recons_weights(self):
        """Reconstruct weights from original weights using current decimals.

        Returns:
            :obj:`tensorflow.Tensor`: a Tensor of float weights.
        """
        # Reconstruct rescaled weights: floored rescaled weights + decimals
        adaround_w = self.w_floor + self.decimals()
        # Clip and rescale to obtain the rounded weights
        return tf.clip_by_value(adaround_w, -self.kmax_, self.kmax_) / self.sf

    def scale_factor(self, _w):
        """Return the original quantizer scale factor

        Args:
          _w (:obj:`tensorflow.Tensor`): ignored.

        Returns:
          :obj:`tensorflow.Tensor`: a Tensor containing a list of scalar values
                (1 or more).
        """
        return self.sf

    def quantize(self, _w):
        """Return layer weights

        Although it takes the original weights as parameters, it ignores it and
        uses instead the reconstructed weights.

        This method has a different behaviour at evaluation and training time.

        - at training time it simply returns the non-quantized reconstructed
          weights to provide a smooth quantization loss
        - at evaluation time it applies the original quantization to the
          reconstructed weights.

        Note that when all decimals have converged to either zero or one, the
        reconstructed weights are aligned on quantization intervals, so the
        returned weights are equivalent to the reconstructed weights.

        Args:
            _w (:obj:`tensorflow.Tensor`): ignored.

        Returns:
            :obj:`tensorflow.Tensor`: a Tensor of quantized weights.
        """
        if self.training:
            return self.recons_weights()
        return self.quantizer.quantize(self.recons_weights())


def initialize_training(optimizer, model):
    """Prepare target model for Adaround optimization

    The original quantizers are replaced by AdaRound quantizers

    Args:
        optimizer (:obj:`tensorflow.keras.optimizers.Optimizer`): an optimizer
        model (:obj:`keras.models.Model`): a quantized Keras Model

    Returns:
        dict of [`AdaRoundQuantizer`], dict of [`tf.Variable`]: a list of AdaRound quantizers and
        a list of trainable variables, grouped by layer

    """
    adarounds = {}
    trainable_variables = {}
    for layer in model.layers:
        if not is_quantized_neural(layer):
            continue
        # Replace standard quantizers by Adaround quantizers
        adarounds[layer] = []
        trainable_variables[layer] = []
        if isinstance(layer, QuantizedSeparableConv2D):
            w_list = layer.get_weights()
            dw = w_list[0]
            layer.quantizer_dw = AdaRoundQuantizer(layer.quantizer_dw, dw)
            adarounds[layer].append(layer.quantizer_dw)
            trainable_variables[layer].append(layer.quantizer_dw.alpha)
            pw = w_list[1]
            layer.quantizer = AdaRoundQuantizer(layer.quantizer, pw)
            adarounds[layer].append(layer.quantizer)
            trainable_variables[layer].append(layer.quantizer.alpha)
        elif isinstance(layer, (QuantizedDense, QuantizedConv2D)):
            w_list = layer.get_weights()
            w = w_list[0]
            layer.quantizer = AdaRoundQuantizer(layer.quantizer, w)
            adarounds[layer].append(layer.quantizer)
            trainable_variables[layer].append(layer.quantizer.alpha)
    # Create all optimizer variables
    all_trainable_variables = sum(trainable_variables.values(), [])
    # Since TensorFlow 2.12, the optimizer must build the variables before use them.
    if hasattr(optimizer, 'build'):
        optimizer.build(all_trainable_variables)
    return adarounds, trainable_variables


def finalize_training(layer):
    """Restore target layer after Adaround optimization

    The original quantizers are restored and the rounded weights are frozen.

    Args:
        layer (:obj:`keras.layers.Layer`): the layer to optimize

    """
    # Freeze rounded weights and restore original quantizers
    if not is_quantized_neural(layer) or not isinstance(layer.quantizer,
                                                        AdaRoundQuantizer):
        return
    if isinstance(layer, QuantizedSeparableConv2D):
        # Get new rounded weights
        adaround_dw = layer.quantizer_dw.recons_weights()
        adaround_pw = layer.quantizer.recons_weights()
        # Restore original quantizers
        layer.quantizer_dw = layer.quantizer_dw.quantizer
        layer.quantizer = layer.quantizer.quantizer
        # Compare quantized weights: original vs rounded
        w_list = layer.get_weights()
        print_console(
            "Changed dw weights: {} / {}".format(
                np.sum(
                    layer.quantizer_dw.quantize(
                        w_list[0]) != layer.quantizer_dw.quantize(adaround_dw)),
                adaround_dw.numpy().size))
        print_console("Changed pw weights: {} / {}".format(
            np.sum(
                layer.quantizer.quantize(w_list[1]) != layer.quantizer.quantize(
                    adaround_pw)),
            adaround_pw.numpy().size))
        # Replace original weights by rounded weights
        w_list[0] = adaround_dw
        w_list[1] = adaround_pw
        layer.set_weights(w_list)
    elif isinstance(layer, (QuantizedConv2D, QuantizedDense)):
        # Get new rounded weights
        adaround_w = layer.quantizer.recons_weights()
        # Restore original quantizer
        layer.quantizer = layer.quantizer.quantizer
        # Compare quantized weights: original vs rounded
        w_list = layer.get_weights()
        print("Changed weights: {} / {}".format(
            np.sum(
                layer.quantizer.quantize(w_list[0]) != layer.quantizer.quantize(
                    adaround_w)),
            adaround_w.numpy().size))
        # Replace original weights by rounded weights
        w_list[0] = adaround_w
        layer.set_weights(w_list)


def compute_beta(step, max_step):
    """Computes the rounding loss annealing parameter

    Args:
        step (int): the current optimization step
        max_step (int): the maximum optimization step

    Returns:
        int: the annealing parameter
    """
    rel_step = step / max_step
    cosine_decay = 0.5 * (1 + tf.math.cos(rel_step * pi))
    return BETA_END + (BETA_START - BETA_END) * cosine_decay


def print_console(message, new_line=True):
    """Print a message on the console.

    If new_line is True, this is equivalent to a simple print.

    If new_line is False, the message is padded with spaces until the end of
    the line, but no end-of-line is inserted, which means that subsequent calls
    to print will overwrite the last line of console.

    Args:
        message (str): the message to print on the console
        new_line (bool): go to the next line after the message has been printed
    """
    if new_line:
        print(message)
    else:
        col, _ = shutil.get_terminal_size()
        print(f"{message:<{col}}", sep='', end='\r', flush=True)


def optimize_rounding(model, samples, layer, optimizer, epochs, loss,
                      batch_size, include_activation, adarounds, trainable_variables):
    """AdaRound main optimization method

    Optimize the AdaRound alpha variables using gradient descent to minimize
    the quantization error.

    Args:
        model (:obj:`keras.models.Model`): a quantized Keras Model
        samples (:obj:`np.ndarray`): a set of samples used for calibration
        layer (:obj:`keras.layers.Layer`): the layer to optimize
        optimizer (:obj:`tensorflow.keras.optimizers.Optimizer`): an optimizer
        epochs (int): the maximum number of epochs
        loss (:obj:`tensorflow.keras.losses.Loss`): the error loss function
        batch_size (int): the batch size used when evaluating samples
        include_activation (bool): quantization error is evaluated after
          activation.
    """
    # Instantiate a sampler and select the target layer
    sampler = QuantizationSampler(model, samples, batch_size)
    sampler.select_layer(layer, include_activation)
    m = keras.metrics.MeanSquaredError()
    # Evaluate quantization error
    err_before = sampler.quantization_error(m)
    float_outputs = None
    # Training loop
    print_console("Optimizing {}:".format(layer.name))
    for e in range(epochs):
        # Evaluate rounding loss parameter
        beta = compute_beta(e, epochs)
        for i in range(sampler.n_batches):

            # If we have a single batch, we can cache the float outputs
            if sampler.n_batches > 1 or float_outputs is None:
                # Evaluate the float outputs for this batch
                float_outputs = sampler.float_outputs(i)

            # Optimize the adaround alpha variables
            with tf.GradientTape() as tape:

                # Compute layer outputs with adaround
                outputs = sampler.quantized_outputs(i)

                # Evaluate quantization loss
                q_loss = loss(float_outputs, outputs)

                # Evaluate rounding loss for each set of weights
                r_loss = 0
                for adr in adarounds:
                    r_loss += RD_REG * adr.loss(beta)

                total_loss = q_loss + r_loss

            # Get gradients of loss wrt alpha
            gradients = tape.gradient(total_loss, trainable_variables)

            # Update the alpha variables
            optimizer.apply_gradients(zip(gradients, trainable_variables))

        log = f"Epoch {e + 1}"
        log += f" - loss (quantization|rounding) {q_loss:.8f} | {r_loss:.8f}"
        for adr in adarounds:
            roamers, floor_r, ceil_r = adr.status()
            log += f" - roamers: {roamers} - ↓↑ [{floor_r:.2f}, {ceil_r:.2f}]"
        print_console(log, new_line=False)
        if beta > 0 and r_loss == 0:
            # All weights have been rounded
            break
    # Go to the next line in console
    print_console("")
    # Set Adaround quantizer to evaluation mode
    for adr in adarounds:
        adr.training = False
    # Evaluate final quantization error
    err_after = sampler.quantization_error(m)
    print_console(f"quantization error: {err_before:.4f} -> {err_after:.4f}")
    # Freeze optimized weights
    finalize_training(layer)


[docs]def adaround(model, samples, optimizer, epochs, loss=tf.keras.losses.MeanSquaredError(), batch_size=None, include_activation=False): """Optimize the rounding of quantized weights. This implements the Adaround algorithm described in: Up or Down? Adaptive Rounding for Post-Training Quantization Markus Nagel, Rana Ali Amjad, Mart van Baalen, Christos Louizos, Tijmen Blankevoort https://arxiv.org/abs/2004.10568 Instead of rounding weights to the nearest, Adaround introduces a tensor of continuous variables representing the decimals of the float weights, and thus formulates the minimization of the quantization error as a Quadratic Unconstrained Binary Optimization problem, iteratively pushing the decimal variables to a distribution of 0 and 1 minimizing the error. After the optimization, the quantization scales are preserved, but each weight is closer or equal to a quantized value. When optimizing a model, the following must be provided: - a set of samples (typically from the training dataset), - an optimizer, - the maximum number of epochs (the optimization of a layer stops when all weights have been rounded). Args: model (:obj:`keras.Model`): a quantized Keras Model samples (:obj:`np.ndarray`): a set of samples used for calibration optimizer (:obj:`tensorflow.keras.optimizers.Optimizer`): an optimizer epochs (int): the maximum number of epochs loss (:obj:`tensorflow.keras.losses.Loss`): the error loss function batch_size (int): the batch size used when evaluating samples include_activation (bool): quantization error is evaluated after activation. Returns: :obj:`tf.Model`: a quantized Keras model whose weights have been optimized """ adaround_model = clone_model_with_weights(model) # Initialize all trainable variables adarounds, trainable_variables = initialize_training(optimizer, adaround_model) for layer in adaround_model.layers: if is_quantized_neural(layer): # Optimize rounding optimize_rounding( adaround_model, samples, layer, optimizer, epochs, loss, batch_size, include_activation, adarounds[layer], trainable_variables[layer]) return adaround_model