#!/usr/bin/env python
# ******************************************************************************
# Copyright 2025 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["VariableRegistry", "FifoOp", "btc_function", "dwbtc_function",
"QuantizedBufferTempConv", "QuantizedDepthwiseBufferTempConv",
"get_qbtc", "get_qdbtc"]
import onnx
import numpy as np
from onnxruntime_extensions import onnx_op, PyCustomOpDef
from onnx import AttributeProto as AP, TensorProto as TP, NodeProto
from onnx.helper import make_node
from .base_layer import BRN_OPSET, ONNX_OPSET as op, OnnxLayer, register_node_format
from .subgraph_ops import cast_tensors_to, get_scale_out_ops
from .subgraph_ops.activation import get_activation_ops
from .set_weights import set_weights_on_qnode, set_range_max_on_qnode
from ..graph_tools import TENSOR_SHAPE, get_field, get_activation
from ..layers.compute_shapes import compute_onnx_btc_output
from ..quantization.core import quantize_to_qfloat, aligned_quantize, align_to, downscale
from .register import register_new_subgraph
class VariableRegistry():
"""A registry for storing and managing variables globally.
It is used to store the fifo of BufferTempConv nodes.
"""
variables = {}
@staticmethod
def get_variable(model_id, variable_id, default_value):
models_vars = VariableRegistry.variables.get(model_id, {})
return models_vars.get(variable_id, default_value)
@staticmethod
def set_variable(model_id, variable_id, value):
if model_id not in VariableRegistry.variables:
VariableRegistry.variables[model_id] = {}
VariableRegistry.variables[model_id][variable_id] = value
return value
@staticmethod
def clear(model_ids=None):
""" Clear model ids from the registry.
If ids are provided, only those models will be removed.
Args:
model_ids (list, optional): list of strings ids to clear. Defaults to None.
"""
if model_ids is not None:
for id in model_ids:
VariableRegistry.variables.pop(id, None)
else:
VariableRegistry.variables.clear()
@onnx_op(op_type=f"{BRN_OPSET.domain}::FifoOp", inputs=[PyCustomOpDef.dt_float],
attrs={"model_id": PyCustomOpDef.dt_string, "variable_id": PyCustomOpDef.dt_string,
"fifo_size": PyCustomOpDef.dt_int64})
def FifoOp(x, model_id, variable_id, fifo_size):
# The FifoOp operates on (B, C, H, W) tensors and uses the VariableRegistry to retrieve
# buffers and storing them after a roll operation.
if x.ndim != 4:
raise RuntimeError("FifoOp only supports 4D tensors (B, C, H, W)")
default_fifo = np.zeros((*x.shape[:2], fifo_size, *x.shape[2:]), dtype=x.dtype)
fifo = VariableRegistry.get_variable(model_id, variable_id, default_fifo)
# Check dimensions match with existing fifo
if fifo.shape[:2] != x.shape[:2] or fifo.shape[3:] != x.shape[2:]:
raise RuntimeError(f"Input dimensions {x.shape} do not match fifo dimensions {fifo.shape}")
fifo = np.concatenate((fifo[:, :, 1:, :, :], np.expand_dims(x, 2)), axis=2)
return VariableRegistry.set_variable(model_id, variable_id, fifo)
# The BufferTempConv operation is made of a FifoOp and a convolution:
# - Reshape FIFO to perform a 2-D convolution: (B, C, T, H, W) -> (B, T * C, H, W)
# - Perform convolution: (B, T * C, H, W) * (F, T * C, 1, 1) -> (B, F, H, W)
btc_function = onnx.parser.parse_function("""
<opset_import: ["": {onnx_opset}, "{domain}": 1], domain:"{domain}">
BufferTempConv <model_id: string, fifo_name: string, fifo_size: int>(X, W, B) => (Y)
{{
fifo = {domain}.FifoOp<model_id: string=@model_id, variable_id: string=@fifo_name,
fifo_size: int=@fifo_size>(X)
fifo_f32 = Cast<to=1>(fifo)
fifo_transposed = Transpose<perm=[0, 3, 4, 1, 2]>(fifo_f32)
target_shape = Constant<value_ints=[0, 0, 0, -1]>()
fifo_reshaped = Reshape(fifo_transposed, target_shape)
fifo_transposed_back = Transpose<perm=[0, 3, 1, 2]>(fifo_reshaped)
Y = Conv(fifo_transposed_back, W, B)
}}
""".format(onnx_opset=op.version, domain=BRN_OPSET.domain))
onnx.checker.check_function(btc_function)
# The DepthwiseBufferTempConv operation is made of a FifoOp and a multiplication:
# - Retrieve the FIFO
# - Element-wise multiplication: (B, C, T, H, W) * (C, T, H=1, W=1) -> (B, C, T, H, W)
# - Reduction sum along T dim: (B, C, T, H, W) -> (B, C, H, W)
dwbtc_function = onnx.parser.parse_function("""
<opset_import: ["": {onnx_opset}, "{domain}": 1], domain:"{domain}">
DepthwiseBufferTempConv <model_id: string, fifo_name: string, fifo_size: int>(X, W, B) => (Y)
{{
fifo = {domain}.FifoOp<model_id: string=@model_id, variable_id: string=@fifo_name,
fifo_size: int=@fifo_size>(X)
fifo_f32 = Cast<to=1>(fifo)
Y1 = Mul(fifo_f32, W)
axis = Constant<value_ints=[2]>()
Y2 = ReduceSum<keepdims=0>(Y1, axis)
Y = Add(Y2, B)
}}
""".format(onnx_opset=op.version, domain=BRN_OPSET.domain))
onnx.checker.check_function(dwbtc_function)
[docs]
def reset_buffers(model):
""" Resets all FIFO-buffer of (Depthwise)BufferTempConv layers within the model.
Args:
model (ONNXModel): the model to reset
"""
model_ids = set()
for node in model.nodes():
if node.op_type in ["FifoOp", "BufferTempConv", "DepthwiseBufferTempConv"]:
model_id = onnx.helper.get_node_attr_value(node, "model_id").decode('utf-8')
model_ids.add(model_id)
VariableRegistry.clear(model_ids)
[docs]
@register_node_format(requires_downscale=True)
class QuantizedBufferTempConv(OnnxLayer):
def __init__(self,
fifo_name,
fifo_size,
model_id,
activation="",
name=''):
super().__init__("QuantizedBufferTempConv",
fifo_name=fifo_name,
fifo_size=fifo_size,
model_id=model_id,
name=name)
# Save properties need to serialize operation name
self.serialize_attr["activation"] = activation
self.serialize_attr["scale"] = True
# Declare weights
self._add_weight("kernel")
self._add_weight("bias")
self._add_weight("max_value")
self._add_weight("range_max", 1.0)
self._add_weight("act_range_max", 1.0)
def __build__(self, input_ts, downscale=True):
assert input_ts.dtype == np.int8
assert downscale, f"{self.name} ({self.base_name}) does not support 32-bit output"
# The chain of operations is modified if downscale is needed
self.serialize_attr["scale"] = downscale
# Compute output shape
conv_output_shape = compute_onnx_btc_output(self, input_ts.shape)
output_ts = TENSOR_SHAPE(conv_output_shape, np.dtype("int8"))
return output_ts
def __quantize__(self, qinput, force_fp=False):
i_scale = qinput.weights["scale"]
# Perform cross-layer equalization, i.e.: rescale weights with input scale.
# To do that first reshape i_scale to put it into axis = 1 and be capable of broadcasting.
assert i_scale.ndim <= 1
kernel = self.weights["kernel"]
if i_scale.size > 1:
kernel = kernel / align_to(
np.repeat(i_scale, kernel.shape[1] // i_scale.shape[0]), kernel.ndim, axis=1)
else:
kernel = kernel / align_to(np.repeat(i_scale, kernel.shape[1]), kernel.ndim, axis=1)
# Quantize and set weights
qweights, i_scale = quantize_to_qfloat(kernel)
qweights = qweights.astype("int8")
# Prepare tensors list with unique names
conv_name = self.name
prefix = conv_name + "_"
weights_dict = {}
weights_dict[prefix + "Wi"] = qweights
bias = self.weights["bias"]
if "Biased" in self.op_type:
qbias = aligned_quantize(bias, i_scale)
weights_dict[prefix + "B"] = qbias
# Now consider calibrated output range
range_max = self.weights["range_max"]
scale, s_out, ocalib_scale = downscale(range_max, i_scale, force_fp=force_fp)
weights_dict.update({prefix + "M": align_to(scale.astype("uint8"), qweights.ndim),
prefix + "S_out": align_to(s_out, qweights.ndim)})
# Return quantized weights and ouput scale
return weights_dict, ocalib_scale
@staticmethod
def build_subgraph(op_type):
# Cast input, weights (and bias) into float.
t_names = ["X", "W", ""]
if "Biased" in op_type:
t_names[-1] = "bias"
nodes, t_names = cast_tensors_to(t_names)
nodes.append(make_node("BufferTempConv", inputs=t_names, outputs=["Yi"],
domain=BRN_OPSET.domain))
nodes[-1].attribute.append(AP(name="fifo_name", ref_attr_name="fifo_name", type=AP.STRING))
nodes[-1].attribute.append(AP(name="fifo_size", ref_attr_name="fifo_size", type=AP.INTS))
nodes[-1].attribute.append(AP(name="model_id", ref_attr_name="model_id", type=AP.STRING))
# Activation (optional)
if "ReLU" in op_type:
# Replace previous output as relu input
nodes[-1].output.__setitem__(0, nodes[-1].op_type)
nodes += get_activation_ops(nodes[-1].output[0], "Yi", "ReLUClipped" in op_type)
# Scale out (with saturation) in float domain
nodes += get_scale_out_ops("Yi", "Yscaled")
# Cast output to expect type
nodes.append(make_node("Cast", ["Yscaled"], ["Y"], to=TP.INT8))
return nodes
def make_node(self, inputs, outputs):
node = super().make_node(inputs, outputs, use_custom_op=True)
register_new_subgraph(btc_function)
return node
def get_qbtc(nodes, graph, tensor_ranges):
btc_node = nodes[0]
assert btc_node.op_type == 'BufferTempConv'
fifo_name = get_field(btc_node, "fifo_name")
fifo_size = get_field(btc_node, "fifo_size")
model_id = get_field(btc_node, "model_id") + "_quantized"
act_node = get_activation(nodes) or NodeProto()
activation = act_node.op_type
qconv = QuantizedBufferTempConv(fifo_name=fifo_name,
fifo_size=fifo_size,
model_id=model_id,
activation=activation)
set_weights_on_qnode(qconv, btc_node, graph)
# Set calibration ranges
set_range_max_on_qnode(qconv, tensor_ranges[nodes[-1].output[0]])
if act_node.op_type == "activation":
act_range_max = tensor_ranges[act_node.input[0]]
set_range_max_on_qnode(qconv, act_range_max, name="act_range_max", reduce=True)
return qconv
[docs]
@register_node_format(requires_downscale=True)
class QuantizedDepthwiseBufferTempConv(OnnxLayer):
def __init__(self,
fifo_name,
fifo_size,
model_id,
activation="",
name=''):
super().__init__("QuantizedDepthwiseBufferTempConv",
fifo_name=fifo_name,
fifo_size=fifo_size,
model_id=model_id,
name=name)
# Save properties need to serialize operation name
self.serialize_attr["activation"] = activation
self.serialize_attr["scale"] = True
# Declare weights
self._add_weight("kernel")
self._add_weight("bias")
self._add_weight("max_value")
self._add_weight("range_max", 1.0)
self._add_weight("act_range_max", 1.0)
def __build__(self, input_ts, downscale=True):
assert input_ts.dtype == np.int8
assert downscale, f"{self.name} ({self.base_name}) does not support 32-bit output"
assert self.weights["kernel"].ndim == 4
if self.weights["bias"].size == 0:
self.set_weight(
"bias", np.zeros(shape=(self.weights["kernel"].shape[0], 1, 1), dtype=np.float32))
self.serialize_attr["scale"] = downscale
# Compute output shape
conv_output_shape = compute_onnx_btc_output(self, input_ts.shape)
output_ts = TENSOR_SHAPE(conv_output_shape, np.dtype("int8"))
return output_ts
def __quantize__(self, qinput, force_fp=False):
i_scale = qinput.weights["scale"]
# Perform cross-layer equalization, i.e.: rescale weights with input scale.
# To do that first reshape i_scale to put it into axis = 1 and be capable of broadcasting.
assert i_scale.ndim <= 1
kernel = self.weights["kernel"]
kernel = kernel / align_to(i_scale, kernel.ndim, axis=0)
# Quantize and set weights
qweights, i_scale = quantize_to_qfloat(kernel)
qweights = qweights.astype("int8")
# Prepare tensors list with unique names
conv_name = self.name
prefix = conv_name + "_"
weights_dict = {}
weights_dict[prefix + "Wi"] = qweights
bias = self.weights["bias"]
qbias = aligned_quantize(bias, align_to(i_scale, ndims=bias.ndim, axis=0))
weights_dict[prefix + "B"] = qbias
# Now consider calibrated output range
range_max = self.weights["range_max"]
scale, s_out, ocalib_scale = downscale(range_max, i_scale, force_fp=force_fp)
weights_dict.update({prefix + "M": align_to(scale.astype("uint8"), qweights.ndim),
prefix + "S_out": align_to(s_out, qweights.ndim)})
# Return quantized weights and ouput scale
return weights_dict, ocalib_scale
@staticmethod
def build_subgraph(op_type):
# Cast input, weights (and bias) into float.
t_names = ["X", "W", "bias"]
nodes, t_names = cast_tensors_to(t_names)
nodes.append(make_node("DepthwiseBufferTempConv", inputs=t_names, outputs=["Yi"],
domain=BRN_OPSET.domain))
nodes[-1].attribute.append(AP(name="fifo_name", ref_attr_name="fifo_name", type=AP.STRING))
nodes[-1].attribute.append(AP(name="fifo_size", ref_attr_name="fifo_size", type=AP.INTS))
nodes[-1].attribute.append(AP(name="model_id", ref_attr_name="model_id", type=AP.STRING))
# Activation (optional)
if "ReLU" in op_type:
# Replace previous output as relu input
nodes[-1].output.__setitem__(0, nodes[-1].op_type)
nodes += get_activation_ops(nodes[-1].output[0], "Yi", "ReLUClipped" in op_type)
# Scale out (with saturation) in float domain
nodes += get_scale_out_ops("Yi", "Yscaled")
# Cast output to expect type
nodes.append(make_node("Cast", ["Yscaled"], ["Y"], to=TP.INT8))
return nodes
def make_node(self, inputs, outputs):
node = super().make_node(inputs, outputs, use_custom_op=True)
register_new_subgraph(dwbtc_function)
return node
def get_qdbtc(nodes, graph, tensor_ranges):
btc_node = nodes[0]
assert btc_node.op_type == 'DepthwiseBufferTempConv'
fifo_name = get_field(btc_node, "fifo_name")
fifo_size = get_field(btc_node, "fifo_size")
model_id = get_field(btc_node, "model_id") + "_quantized"
act_node = get_activation(nodes) or NodeProto()
activation = act_node.op_type
qconv = QuantizedDepthwiseBufferTempConv(fifo_name=fifo_name,
fifo_size=fifo_size,
model_id=model_id,
activation=activation)
# Sets the weights to configure the operation chain
set_weights_on_qnode(qconv, btc_node, graph)
# Set calibration ranges
set_range_max_on_qnode(qconv, tensor_ranges[nodes[-1].output[0]])
if act_node.op_type == "activation":
act_range_max = tensor_ranges[act_node.input[0]]
set_range_max_on_qnode(qconv, act_range_max, name="act_range_max", reduce=True)
return qconv