Source code for quantizeml.layers.shiftmax

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************

__all__ = ["shiftmax", "Shiftmax", "QuantizedShiftmax"]

import tensorflow as tf
import keras

from .layers_base import (register_quantize_target, register_no_output_quantizer,
                          register_aligned_inputs, apply_buffer_bitwidth, init_quant_config)
from .recorders import TensorRecorder
from ..tensors import FixedPoint, floor_through, round_log2, pow2


[docs]def shiftmax(logits, axis=-1): """Computes softmax-like activations, but using base 2 for the exponential. Used as approximation of the softmax activation. This function performs the equivalent of >>> logits = tf.floor(logits) >>> exp = 2 ** logits >>> sum_exp_shift = tf.round(tf.log2(tf.reduce_sum(exp, axis, keepdims=True))) >>> softmax = exp / 2 ** sum_exp_shift = 2 ** (logits - sum_exp_shift) When 2 ** :attr:`sum_exp_shift` is an approximated of sum_exp as a Power-of-Two (PoT) To avoid a high exponential (and a tf.inf representation by tensorflow), we adopt the following equivalence: Making the variable change :math:`y=logits-x0`, we reach the same result as :math:`p=shiftmax(logits)`, because, .. math:: p' = \\frac{2^y}{sum(2^y)} = \\frac{2^{logits-x0}}{sum(2^{logits-x0})} = \\frac{2^{logits} * 2^{-x0}}{2^{-x0} * sum(2^{logits})} = \\frac{2^{logits}}{sum(2^{logits})} = p We take :math:`x0 = max(logits)`. Args: logits (tf.Tensor): a non-empty `Tensor`. axis (int, list, optional): the dimension shiftmax would be performed on. The default is -1 which indicates the last dimension. Returns: tf.Tensor: value of shiftmax function with the same type and shape as `logits`. Raises: InvalidArgumentError: if `logits` is empty or `axis` is beyond the last dimension of `logits`. Note: We floor the :attr:`logits` to approximate the results to those expected when quantizing the operation. """ logits = floor_through(logits) logits = logits - tf.reduce_max(logits, axis=axis, keepdims=True) exp = tf.cast(2**logits, dtype=logits.dtype) sum_exp = tf.reduce_sum(exp, axis=axis, keepdims=True) sum_exp_shift = round_log2(sum_exp) return 2 ** (logits - sum_exp_shift)
[docs]@tf.keras.utils.register_keras_serializable() class Shiftmax(keras.layers.Layer): """Wrapper class of `shiftmax` function, that calculates a softmax-like activation. Note that shiftmax operation is performed always along the last axis. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def call(self, inputs): output = shiftmax(inputs) return output
[docs]@register_quantize_target(Shiftmax) @register_no_output_quantizer @register_aligned_inputs @tf.keras.utils.register_keras_serializable() class QuantizedShiftmax(keras.layers.Layer): """A quantized layer to do a quantized function similar to the softmax, but using base 2 instead of e. So we replace .. math:: softmax(x_i) = \\frac{e^{x_i}}{sum(e^{x_k})} With this: .. math:: softmax2(x_i) = \\frac{2^{x_i}}{sum(2^{x_k})} This approximation is close enough to the original function. In order to make it more hardware friendly, we also approximated the :math:`sum(2^{x_k})` to the closest power of two: .. math:: shiftmax(x_i) = \\frac{2^{x_i}}{2^{round(log2(sum(2^{x_k})))}} So it can be implemented with a simple shift operation. Implementation is inspired from this paper: Cardarilli, G.C., Di Nunzio, L., Fazzolari, R. et al. A pseudo-softmax function for hardware-based high speed image classification. Sci Rep 11, 15307 (2021). https://doi.org/10.1038/s41598-021-94691-7 Args: quant_config (dict, optional): the serialized quantization configuration. Defaults to None. """ def __init__(self, *args, quant_config=None, **kwargs): super().__init__(*args, **kwargs) self.quant_config = init_quant_config(quant_config) self.buffer_bitwidth = apply_buffer_bitwidth(self.quant_config, signed=True) self.exp_bitwidth = self.quant_config.get("exp_bitwidth", 10) if self.buffer_bitwidth <= 2 * self.exp_bitwidth: raise ValueError(f"exp_bitwidth={self.exp_bitwidth} must be less than " f"half of buffer_size={self.buffer_bitwidth}.") # Add objects that will store the shift values. self.input_shift = TensorRecorder(self.name + "/input_shift") def call(self, x): # Raise an error if the inputs are not FixedPoint if not isinstance(x, FixedPoint): raise TypeError("QuantizedShiftmax only accepts FixedPoint inputs. Receives " f"{type(x)} inputs.") # To avoid overflowing, some modifications are made to the input. # First remove the fractional part of the input (floor(x)). We can do # this because the exponential function will return very big numbers, # so fractional ones can be ignored in the ratio with the sum. x, shift = x.floor() # update shift values if recording is enabled self.input_shift(shift) # Since x has been floored, we can directly use its values x = x.values # The pseudo-softmax is defined as: # # p = 2^x/sum(2^x) # # but we do this instead: # # p' = p = 2^y/sum(2^y) # # where # # y = x - x0 # # because, # # p' = 2^y/sum(2^y) = 2^(x - x0)/sum(2^(x - x0)) = (2^x * 2^-x0)/(2^-x0 * sum(2^x)) # = 2^x/sum(2^x) = p # # On the other hand, we choose x0 to be the maximum of x, minus a positive # integer constant "exp_bitwidth". So now # # y = x - (max(x) - exp_bitwidth) # # This makes sure that y is never higher than exp_bitwidth x_max = tf.reduce_max(x, axis=-1, keepdims=True) x0 = x_max - self.exp_bitwidth y = x - x0 # To evaluate exp = 2^y, we target a maximum precision of exp_bitwidth. # As a consequence, we will neglect all values that are below -exp_bitwidth, # considering: # - that they don't contribute much to the exponential sum, # - that they would be quantized to zero after the division. exp_values = tf.where(y >= -self.exp_bitwidth, pow2(y + self.exp_bitwidth), 0) # Note that we could do the operation directly on the values, but we store # values in a FixedPoint to make sure we don't saturate the underlying buffer exp = FixedPoint(exp_values, self.buffer_bitwidth, self.exp_bitwidth) # To calculate 2^y, hardware can just: # - set exp to zero if y < -exp_bitwidth, # - do a left shift applying a fixed offset of self.exp_bitwidth. # Example: # exp_bitwidth = 4 # y = [-5, 3, -4, -1, 1] # exp = [0, 1 << (4 + 3), 1 << (4 - 4), 1 << (4 - 1), 1 << (4 + 1)] # exp = [0, 128, 1, 8, 32] # Calculate the sum of the exponential (saturation may happen here). sum_exp = tf.reduce_sum(exp, axis=-1, keepdims=True) # Like the float version, instead of dividing by sum_exp, we simply approximate # it to the closest integer log2 to perform a shift instead of a division. # Please refer to the description of round_log2 for a description of the hardware operation. # Note here that we need to substract the frac_bits as the values are scaled up. sum_exp_shift = round_log2(sum_exp.values) - sum_exp.frac_bits outputs = exp.shift(-sum_exp_shift) # Since sum_exp > exp, the results are between [0,1]. # We can therefore rewrite the output as: return FixedPoint(outputs.values, self.exp_bitwidth + 1, self.exp_bitwidth) def get_config(self): config = super().get_config() config["quant_config"] = self.quant_config return config