# Source code for quantizeml.layers.batch_normalization

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# ******************************************************************************

__all__ = ["QuantizedBatchNormalization"]

import tensorflow as tf
import keras

from .layers_base import (register_quantize_target, rescale_outputs,
tensor_inputs, apply_buffer_bitwidth)
from .quantizers import WeightQuantizer, AlignedWeightQuantizer, OutputQuantizer
from ..tensors import QTensor

[docs]@register_quantize_target(keras.layers.BatchNormalization)
@tf.keras.utils.register_keras_serializable()
class QuantizedBatchNormalization(keras.layers.Layer):
"""Layer that normalizes its inputs, on the last axis.

The normalization is applied like this:

.. math::

y = \\frac{(x - \\mu) \\cdot \\gamma}{\\sigma} + \\beta \\
= \\frac{x \\cdot \\gamma}{\\sigma} - \\
\\frac{\\mu\\cdot \\gamma}{\\gamma} + \\beta

if we consider:

.. math:: a = \\frac{\\gamma}{\\sigma}

and

.. math:: b = -\\frac{\\mu\\cdot \\gamma}{\\sigma} + \\beta

The normalization can be re-written as:

.. math:: y = a \\cdot x + b

Note that this layer will hold variables with names gamma, beta, moving_mean (:math:\\mu),
and moving_variance (:math:\\sigma = \\sqrt{moving\_variance + \\epsilon}), so they can be
converted from a BatchNormalization layer. However, it's a and b that are going to be quantized.

Args:
quant_config (dict, optional): the serialized quantization configuration. Defaults to None.
axis (int, optional): The axis that was normalized on the
BatchNormalization layer. The only supported value is the
last dimension.
epsilon (float, optional): Small value to avoid dividing by zero.
Defaults to 1e-3.
"""

ignored_args = ["momentum",
"center",
"scale",
"beta_initializer",
"gamma_initializer",
"moving_mean_initializer",
"moving_variance_initializer",
"beta_regularizer",
"gamma_regularizer",
"beta_constraint",
"gamma_constraint",
"renorm",
"renorm_clipping",
"renorm_momentum",
"fused",
"trainable",
"virtual_batch_size",
]

def __init__(self,
*args,
quant_config=None,
axis=-1,
epsilon=1e-3,
**kwargs):
super().__init__(*args, **kwargs)
self.quant_config = quant_config or dict()
out_quant_cfg = self.quant_config.get("output_quantizer", False)
if out_quant_cfg:
self.out_quantizer = OutputQuantizer(
name="output_quantizer", **out_quant_cfg)
else:
self.out_quantizer = None
if "a_quantizer" not in self.quant_config:
self.quant_config["a_quantizer"] = {"bitwidth": 8}
a_quantizer_cfg = self.quant_config["a_quantizer"]
self.a_quantizer = WeightQuantizer(name="a_quantizer", **a_quantizer_cfg)
b_quantizer_cfg = self.quant_config.get("b_quantizer", {})
self.b_quantizer = AlignedWeightQuantizer(name="b_quantizer", **b_quantizer_cfg)
self.buffer_bitwidth = apply_buffer_bitwidth(self.quant_config, signed=True)

# Define a small float number to avoid dividing by zero.
self.epsilon = epsilon
# Axis on which operation is applied
self.axis = axis

def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
rank = input_shape.rank

if rank not in (3, 4):
raise ValueError(
"QuantizedBatchNormalization only supports 3D or 4D tensors. "

# Normalize axis
self.axis = keras.utils.tf_utils.validate_axis(self.axis, input_shape)
# Check selected axis is valid
if len(self.axis) != 1 and (self.axis != rank - 1):
raise ValueError("QuantizedBatchNormalization only supports axis "
"argument set to the last dimension.")

# Shape for variables is always as if it was applied on the
# last dimension.
param_shape = input_shape[-1]

# Gamma
name="gamma",
shape=param_shape,
dtype=tf.float32,
initializer="ones",
regularizer=None,
constraint=None,
trainable=True,
experimental_autocast=False,
)
# Beta
name="beta",
shape=param_shape,
dtype=tf.float32,
initializer="zeros",
regularizer=None,
constraint=None,
trainable=True,
experimental_autocast=False,
)
# Mu = moving mean
name="moving_mean",
shape=param_shape,
dtype=tf.float32,
initializer="zeros",
regularizer=None,
constraint=None,
trainable=True,
experimental_autocast=False,
)
# Sigma² = moving variance
name="moving_variance",
shape=param_shape,
dtype=tf.float32,
initializer="ones",
regularizer=None,
constraint=None,
trainable=True,
experimental_autocast=False,
)

@property
def sigma_rec(self):
# Sigma reciprocal = 1 / sigma = 1 / sqrt(moving_variance + epsilon)
sigma_rec = tf.math.rsqrt(self.moving_variance + self.epsilon)
return sigma_rec

@property
def a(self):
a_var = self.gamma * self.sigma_rec
q_a = self.a_quantizer(a_var)
return q_a

def b(self, inputs):
sigma_rec = self.sigma_rec
b_var = -self.moving_mean * self.gamma * sigma_rec + self.beta
q_b = self.b_quantizer(b_var, inputs)
return q_b

@tensor_inputs([QTensor])
@rescale_outputs
def call(self, inputs):
# Calculation is equivalent to
# y = (x - mu) * gamma / sigma + beta
#   = x * gamma / sigma - mu * gamma / sigma + beta
#
# So if we consider
# a = gamma / sigma
# b = -mu * gamma / sigma + beta
# Then the evaluation is just y = a * x + b.

# outputs = a * x
outputs = tf.multiply(inputs, self.a)

# quantize and retrieve b, aligned on the outputs to allow sum
b = self.b(outputs)
# y = outputs + b