Source code for quantizeml.onnx_support.quantization.register_patterns

#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["custom_pattern_scope"]

from collections import namedtuple
from inspect import signature
from contextlib import contextmanager
from .. import layers as onnx_qlayers

# Define named tuples for QuantizerPattern
QuantizePattern = namedtuple('QuantizerPattern', ['pattern', 'f'])

# List of supported patterns, together with matching function
CUSTOM_PATTERNS_MAP = []
PATTERNS_MAP = [
    QuantizePattern(("Conv", "Relu", "GlobalAveragePool"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "Relu", "MaxPool"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "GlobalAveragePool"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "Relu"), [onnx_qlayers.get_qdepthwise, onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "Clip", "GlobalAveragePool"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "Clip", "MaxPool"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "Clip"), [onnx_qlayers.get_qdepthwise, onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv",), [onnx_qlayers.get_qdepthwise, onnx_qlayers.get_qconv]),
    QuantizePattern(("Flatten", "Gemm", "Relu"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Flatten", "Gemm", "Clip"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Flatten", "Gemm"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Gemm", "Relu"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Gemm", "Clip"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Gemm",), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Add", "Relu"), [onnx_qlayers.get_qadd]),
    QuantizePattern(("Add",), [onnx_qlayers.get_qadd]),
    QuantizePattern(("Concat",), [onnx_qlayers.get_qconcat]),
    QuantizePattern(("ConvTranspose", "Clip"), [onnx_qlayers.get_qconv_transpose]),
    QuantizePattern(("ConvTranspose", "Relu"), [onnx_qlayers.get_qconv_transpose]),
    QuantizePattern(("ConvTranspose",), [onnx_qlayers.get_qconv_transpose]),
    QuantizePattern(("Transpose", "Mul", "Add"), [onnx_qlayers.get_input_quantizer]),
    QuantizePattern(("Transpose", "Mul"), [onnx_qlayers.get_input_quantizer]),
    QuantizePattern(("Mul", "Add"), [onnx_qlayers.get_input_quantizer]),
    QuantizePattern(("Mul",), [onnx_qlayers.get_input_quantizer]),
]


[docs] @contextmanager def custom_pattern_scope(new_patterns): """Register a custom pattern in the context to be used at quantization time. A pattern is understood as a sequence of continuous operations in the graph, whose representation can converge in an ``OnnxLayer``. Args: new_patterns (dict): a list of sequence of nodes (keys) and their mapper function (values). """ # Use of global parameters global CUSTOM_PATTERNS_MAP # Transform input patterns in a valid format qpatterns = [] for new_pattern, func in new_patterns.items(): qpatterns.append(_custom_pattern_to_qpattern(new_pattern, func)) try: # Extend CUSTOM_PATTERNS_MAP with new qpatterns CUSTOM_PATTERNS_MAP.extend(qpatterns) yield finally: # Restore to previous state CUSTOM_PATTERNS_MAP.clear()
def _custom_pattern_to_qpattern(pattern, func): assert callable(func), f"function has to be a callable. Receives: {func}" if len(signature(func).parameters) != 2: raise RuntimeError("function must have two inputs: sequence_nodes and graph") if isinstance(pattern, str): pattern = (pattern,) if not (isinstance(pattern, tuple) and all(isinstance(x, str) for x in pattern)): raise ValueError(f"Pattern must be a string-tuple. Receives: {pattern}") return QuantizePattern(pattern, [func])