#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
__all__ = ["custom_pattern_scope"]
from collections import namedtuple
from contextlib import contextmanager
from .. import layers as onnx_qlayers
# Define named tuples for QuantizerPattern
QuantizePattern = namedtuple('QuantizerPattern', ['pattern', 'f'])
# List of supported patterns, together with matching function
CUSTOM_PATTERNS_MAP = []
PATTERNS_MAP = [
    QuantizePattern(("Conv", "Relu", "GlobalAveragePool"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "MaxPool", "Relu"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "GlobalAveragePool"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "Relu"), [onnx_qlayers.get_qdepthwise, onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "Clip", "GlobalAveragePool"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "MaxPool", "Clip"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "activation", "GlobalAveragePool"), [onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "Clip"), [onnx_qlayers.get_qdepthwise, onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv", "activation"), [onnx_qlayers.get_qdepthwise, onnx_qlayers.get_qconv]),
    QuantizePattern(("Conv",), [onnx_qlayers.get_qdepthwise, onnx_qlayers.get_qconv]),
    QuantizePattern(("Flatten", "Gemm", "Relu"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Flatten", "Gemm", "Clip"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Flatten", "Gemm"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Gemm", "Relu"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Gemm", "Clip"), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Gemm",), [onnx_qlayers.get_qgemm]),
    QuantizePattern(("Add", "Relu"), [onnx_qlayers.get_qadd]),
    QuantizePattern(("Add",), [onnx_qlayers.get_qadd]),
    QuantizePattern(("Concat",), [onnx_qlayers.get_qconcat]),
    QuantizePattern(("ConvTranspose", "Clip"), [onnx_qlayers.get_qconv_transpose]),
    QuantizePattern(("ConvTranspose", "Relu"), [onnx_qlayers.get_qconv_transpose]),
    QuantizePattern(("ConvTranspose", "activation"), [onnx_qlayers.get_qconv_transpose]),
    QuantizePattern(("ConvTranspose",), [onnx_qlayers.get_qconv_transpose]),
    QuantizePattern(("Transpose", "Mul", "Add"), [onnx_qlayers.get_input_quantizer]),
    QuantizePattern(("Transpose", "Mul"), [onnx_qlayers.get_input_quantizer]),
    QuantizePattern(("Mul", "Add"), [onnx_qlayers.get_input_quantizer]),
    QuantizePattern(("Mul",), [onnx_qlayers.get_input_quantizer]),
    QuantizePattern(("BufferTempConv", "Relu"), [onnx_qlayers.get_qbtc]),
    QuantizePattern(("DepthwiseBufferTempConv", "Relu"), [onnx_qlayers.get_qdbtc]),
    QuantizePattern(("BufferTempConv",), [onnx_qlayers.get_qbtc]),
    QuantizePattern(("DepthwiseBufferTempConv",), [onnx_qlayers.get_qdbtc]),
]
[docs]
@contextmanager
def custom_pattern_scope(new_patterns):
    """Register a custom pattern in the context to be used at quantization time.
    A pattern is understood as a sequence of continuous operations in the graph,
    whose representation can converge in an ``OnnxLayer``.
    Args:
        new_patterns (dict): a list of sequence of nodes (keys) and their mapper function (values).
    """
    # Use of global parameters
    global CUSTOM_PATTERNS_MAP
    # Transform input patterns in a valid format
    qpatterns = []
    for new_pattern, func in new_patterns.items():
        qpatterns.append(_custom_pattern_to_qpattern(new_pattern, func))
    try:
        # Extend CUSTOM_PATTERNS_MAP with new qpatterns
        CUSTOM_PATTERNS_MAP.extend(qpatterns)
        yield
    finally:
        # Restore to previous state
        CUSTOM_PATTERNS_MAP.clear() 
def _custom_pattern_to_qpattern(pattern, func):
    assert callable(func), f"function has to be a callable. Receives: {func}"
    if isinstance(pattern, str):
        pattern = (pattern,)
    if not (isinstance(pattern, tuple) and all(isinstance(x, str) for x in pattern)):
        raise ValueError(f"Pattern must be a string-tuple. Receives: {pattern}")
    return QuantizePattern(pattern, [func])