Source code for quantizeml.layers.normalization

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["QuantizedLayerNormalization", "LayerMadNormalization"]

import tensorflow as tf
import keras
from tensorflow.python.framework import ops

from .layers_base import (register_quantize_target, rescale_outputs, tensor_inputs,
                          register_aligned_inputs, apply_buffer_bitwidth)
from .recorders import TensorRecorder
from .quantizers import AlignedWeightQuantizer, WeightQuantizer, OutputQuantizer
from ..tensors import FixedPoint

[docs]@tf.keras.utils.register_keras_serializable() class LayerMadNormalization(keras.layers.LayerNormalization): r"""Approximates the `keras.layers.LayerNormalization` (LN), replacing the computation of the standard deviation by the mean average deviation (mad). Taking into account the complexity of computing the standard deviation, the LayerMadNormalization (LMN) is intended to replace the :math:`std(x)` by :math:`mad(x)` defined as: .. math:: mad(x) = \frac{sum(|x - mean(x)|)}{nb\_channels} To simplify it even more and make it more hardware-friendly, :math:`mean(x)` is zeroed: .. math:: mad(x) = \frac{sum(|x|)}{nb\_channels} Then, the equation of the layer is defined as: .. math:: LMN(x) = \gamma\frac{x}{mad(x)} + \beta .. note:: A tuning step in the switching procedure between the LN to LMN layer will be required to find the :math:`(\gamma, \beta)` parameters that match the standard deviation changes. """ def build(self, input_shape): super().build(input_shape) if len(self.axis) > 1: raise ValueError(f"{self.__class__.__name__} does not support multiple axis.") def _get_mad(self, inputs, axis, *args, name=None, **kwargs): """Compute the mean absolute deviation (mad). Args: inputs (tf.Tensor): input tensor. axis (array of ints): axis along which to compute mad. name (str, optional): name used to scope the operations that compute the mad. Defaults to None. *args, **kwargs (dict, optional): additional arguments of :func:`tf.math.reduce_sum`. Returns: tf.Tensor: mean absolute deviation (mad). """ # Compute the mad with ops.name_scope(name, "_mad", [inputs, axis]): mad = tf.math.reduce_sum(tf.math.abs(inputs), axis, *args, **kwargs, name="mad") nb_channels = tf.cast(inputs.shape[self.axis[0]], tf.float32) mad = mad / nb_channels return mad def call(self, inputs): """Use the same call as the parent, changing tf.nn.moments instead of self._get_mad, and replacing mean by 0""" # Compute the axis along which to reduce the mad input_shape = inputs.shape ndims = len(input_shape) # Broadcasting only necessary for norm when the axis is not just the last dimension broadcast_shape = [1] * ndims for dim in self.axis: broadcast_shape[dim] = input_shape.dims[dim].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]): return tf.reshape(v, broadcast_shape) return v input_dtype = inputs.dtype if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32': # If mixed precision is used, cast inputs to float32 so that this is at # least as numerically stable as the fused version. inputs = tf.cast(inputs, 'float32') # Calculate the mad on the last axis (layer activations). mad = self._get_mad(inputs, self.axis, keepdims=True) scale, offset = _broadcast(self.gamma), _broadcast(self.beta) # Compute layer normalization as layer mad normalization outputs = scale * inputs / mad + offset outputs = tf.cast(outputs, input_dtype) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
[docs]@register_quantize_target(LayerMadNormalization) @register_aligned_inputs @tf.keras.utils.register_keras_serializable() class QuantizedLayerNormalization(LayerMadNormalization): """ A LayerNormalization layer that operates on quantized inputs and weights. Args: quant_config (dict, optional): the serialized quantization configuration. Defaults to None. """ def __init__(self, *args, quant_config=None, **kwargs): super().__init__(*args, **kwargs) self.quant_config = quant_config or dict() out_quant_cfg = self.quant_config.get("output_quantizer", False) if out_quant_cfg: self.out_quantizer = OutputQuantizer( name="output_quantizer", **out_quant_cfg) else: self.out_quantizer = None self.buffer_bitwidth = apply_buffer_bitwidth(self.quant_config, signed=True) if "gamma_quantizer" not in self.quant_config: self.quant_config["gamma_quantizer"] = {"bitwidth": 8} gamma_quantizer_cfg = self.quant_config["gamma_quantizer"] self.gamma_quantizer = WeightQuantizer(name="gamma_quantizer", **gamma_quantizer_cfg) beta_quantizer_cfg = self.quant_config.get("beta_quantizer", {}) self.beta_quantizer = AlignedWeightQuantizer(name="beta_quantizer", **beta_quantizer_cfg) # Add objects that will store the shift values. self.gamma_shift = TensorRecorder() @tensor_inputs([FixedPoint]) @rescale_outputs def call(self, inputs): """The quantized version of the LayerNormalization with Integer-only operations. This normalizes the input tensor then returns a quantized output. Args: inputs (:obj:`FixedPoint`): the inputs tensor. Returns: :obj:`FixedPoint`: output tensor. """ # Compute the sum of Absolute Deviations mad_sum = tf.math.reduce_sum( inputs.abs(), self.axis, keepdims=True, name="mad") # Instead of evaluating MAD, then apply a division and a multiplication: # # mad = mad_sum / nb_channels # y = x / mad * gamma # # We rather directly evaluate the resulting projection p and multiply: # # gamma = gamma * nb_channels # p = gamma / mad_sum # y = x * p # The first step is to quantize the rescaled gamma nb_channels = tf.cast(inputs.shape[self.axis[0]], tf.float32) gamma = self.gamma_quantizer(self.gamma * nb_channels) # We need to upscale gamma before dividing by mad_sum to reduce the error. # For that, we evaluate the maximum number of bits allowed to avoid a saturation. # The final operation corresponds to y = gamma * x / mad_sum # But we know that mad_sum is always > x, so y will never require more bits than gamma. gamma_bits = self.buffer_bitwidth if self.out_quantizer is not None and self.out_quantizer.scale_bits is not None: # We need to leave some space for the multiplication by the output scales gamma_bits -= self.out_quantizer.scale_bits # Upscale rescaled gamma as much as we can to reduce the division error gamma_int_bits = gamma.value_bits - gamma.frac_bits gamma_frac_bits = gamma_bits - gamma_int_bits gamma, shift = gamma.upscale(gamma_frac_bits, self.buffer_bitwidth) self.gamma_shift(shift) # Calculate the projection p. This can be evaluated once per token on hardware. p = tf.math.divide(gamma, mad_sum) # Apply the projection on inputs inputs_rescaled = tf.math.multiply(inputs, p) # Quantize and align beta on the rescaled inputs beta = self.beta_quantizer(self.beta, inputs_rescaled) # Evaluate outputs return tf.add(inputs_rescaled, beta) def get_config(self): config = super().get_config() config["quant_config"] = self.quant_config return config