#!/usr/bin/env python


from queue import Queue

from hailo_sdk_client.emulator.tf_model import NMS_FIRST_OP, NMS_LAST_OP
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import BackendQuantizationException


def _build_set_of_layer_predecessors(layer_name, head_layer, port, conv_layers_inference):
    """
    This function finds and returns a set of tuples. Each tuple is of the form (layer, port type)
    for all layer/port couples that are predecessors of the (head_layer, input_or_shortcut port).
    """
    prev_l = set()
    src_tensor_queue = Queue()

    tensor_name = port + "_tensor"
    src_tensor = head_layer[tensor_name]
    if not isinstance(src_tensor, list):
        src_tensor = [src_tensor]
    for tensor in src_tensor:
        src_tensor_queue.put(tensor)

    # Currently only used in bbox decoder. This force scale matching on both inputs.
    if ("second_" + tensor_name) in head_layer:
        src_tensor_queue.put(head_layer["second_" + tensor_name])

    while src_tensor_queue.qsize() > 0:
        src_tensor_to_process = src_tensor_queue.get()

        # special treatment for nms - if nms is in the graph we want to skip all it's inner ops (because the crawl down
        # gets stuck in it)
        if NMS_LAST_OP in src_tensor_to_process.name:
            op_name = src_tensor_to_process.name.replace(NMS_LAST_OP, NMS_FIRST_OP)
            src_tensor_to_process = src_tensor_to_process.graph.get_tensor_by_name(op_name)
            set_of_input_conv_layers = set()
        else:
            # find conv/dense layers which output is src_tensor_to_process
            set_of_input_conv_layers = {
                (n, "output")
                for n, layer in conv_layers_inference.items()
                if layer["output_tensor"] == src_tensor_to_process
            }
        prev_l = prev_l.union(set_of_input_conv_layers)

        # If set_of_input_conv_layers is not empty, it means that our work of crawling down
        # the graph with src_tensor_to_process is finished. i.e. we found the conv layer connected to it
        if len(set_of_input_conv_layers) > 0:
            continue

        # If set_of_input_conv_layers is empty, we need to crawl down the graph.
        # We add to tensor_queue all the tensors that are the inputs of the current tensor's base op,
        # and continue the loop.
        tensors_to_add = set(src_tensor_to_process.op.inputs)
        for tns in tensors_to_add:
            src_tensor_queue.put(tns)

    prev_l.add((layer_name, port))
    return prev_l


def unify_connection_sets(connection_sets):
    """
    Unify sets of predecessors if they have a common element (layer/port tuple).
    Repeat it until all the sets are disjoint. This step takes care of layers
    that affect the scale of another layer without sharing data.
    """
    # For every head layer/tuple pair in connection_sets:
    # Go over every other entry cross_head, cross_port in the dict connection_sets.
    # If the connection sets of the head layer/ head port and the cross layer/cross
    # port are not disjoint - add a copy of each to the other.
    for head_layer, port_type in connection_sets:
        for cross_head_layer, cross_port_type in connection_sets:
            if (
                connection_sets[(head_layer, port_type)] is None
                or connection_sets[(cross_head_layer, cross_port_type)] is None
                or connection_sets[(head_layer, port_type)].isdisjoint(
                    connection_sets[(cross_head_layer, cross_port_type)],
                )
            ):
                continue
            else:
                # This double copy allows the entire process to complete in one pass,
                # regardless of the separation steps between two head layers.
                connection_sets[(head_layer, port_type)] = connection_sets[(head_layer, port_type)].union(
                    connection_sets[(cross_head_layer, cross_port_type)],
                )
                connection_sets[(cross_head_layer, cross_port_type)] = connection_sets[(head_layer, port_type)].union(
                    connection_sets[(cross_head_layer, cross_port_type)],
                )

    # Create a list connected layer. This removes double copies from 'connection_sets'
    return {frozenset(equivalence_class) for equivalence_class in connection_sets.values()}


def find_layer_connection_set(layer_name, head_layer, conv_layers_inference):
    """
    For each conv layer (called 'head layer'), find the set of all of its predecessors
    (by direct predecessors we mean conv layers whose output flows to the input of the
    head layer (without passing through other conv layers, but possibly through some
    arithmetic or control layers such as max-pool, concat etc.).
    """
    # This function calls _build_set_of_layer_predecessors() twice - once for the 'input' port and
    # once for the (optional) 'elementwise' port. For each port it creates a dict where the
    # key is a tuple (layer name, port type), and the value is a set of tuples of the shape (layer, port type).
    layer_connection_set_input = {}
    prev_l = _build_set_of_layer_predecessors(layer_name, head_layer, "input", conv_layers_inference)
    layer_connection_set_input[(layer_name, "input")] = prev_l

    if "elementwise_name" in head_layer:
        layer_connection_set_elementwise = {}
        elementwise_prev_l = _build_set_of_layer_predecessors(
            layer_name,
            head_layer,
            "elementwise",
            conv_layers_inference,
        )
        layer_connection_set_elementwise[(layer_name, "elementwise")] = elementwise_prev_l
        layer_connection_set_input.update(layer_connection_set_elementwise)

    if "weights_input_name" in head_layer:
        layer_connection_set_weights_input = {}
        weights_input_prev_l = _build_set_of_layer_predecessors(
            layer_name,
            head_layer,
            "weights_input",
            conv_layers_inference,
        )
        layer_connection_set_weights_input[(layer_name, "weights_input")] = weights_input_prev_l
        layer_connection_set_input.update(layer_connection_set_weights_input)

    return layer_connection_set_input


def match_scales_in_connection_set(conv_layers_inference, connection_sets):
    """
    The input connection_sets is a list of equivalence classes. Each equivalence class is a set of
    tuples of the form (layer, port). The aim of this function is to equalize the scales of all the
    elements in an equivalence class. It is done by finding the maximum and minimum values within
    each equivalence class and setting the new scale to be these  minmax values for all the
    elements in the equivalence class.
    """
    for i, equivalence_class in enumerate(connection_sets):
        min_vals = []
        max_vals = []
        forced_vals = []
        equivalence_class = list(equivalence_class)
        equivalence_class.sort()
        for layer_name, port in equivalence_class:
            layer = conv_layers_inference[layer_name]
            if "dummy_conv" in layer:
                continue
            minmax_name = port + "_minmax"
            is_forced_vals = layer[f"limvals_{port}_forced"] if port in {"input", "output"} else False
            min_vals.append(layer[minmax_name][0])
            max_vals.append(layer[minmax_name][1])
            forced_vals.append(is_forced_vals)

        if len(min_vals) == 0:
            # If we got here there is no conv-like layers and the quantization is unrelevant
            continue

        if sum(forced_vals) == 0:
            total_minmax_for_equivalence_class = min(min_vals), max(max_vals)
        elif sum(forced_vals) == 1:
            forced_index = forced_vals.index(True)
            total_minmax_for_equivalence_class = min_vals[forced_index], max_vals[forced_index]
        else:
            forced_indices = [i for i, v in enumerate(forced_vals) if v]
            forced_min_vals = [min_vals[i] for i in forced_indices]
            forced_max_vals = [max_vals[i] for i in forced_indices]
            all_min_same = all(forced_min_vals[0] == v for v in forced_min_vals)
            all_max_same = all(forced_max_vals[0] == v for v in forced_max_vals)
            if all_min_same and all_max_same:
                total_minmax_for_equivalence_class = forced_min_vals[0], forced_max_vals[0]
            else:
                forced_layers = [equivalence_class[i][0] for i in forced_indices]
                # TODO: raise actual error
                raise BackendQuantizationException(
                    f"Conflicting forced ranges for layers - {forced_layers}; "
                    f"min values - {forced_min_vals}; max values - {forced_max_vals}",
                )

        for layer_name, port in equivalence_class:
            layer = conv_layers_inference[layer_name]
            if "dummy_conv" in layer:
                continue
            minmax_name = port + "_minmax"
            group_name = port + "_group"
            layer[minmax_name] = total_minmax_for_equivalence_class
            layer[group_name] = i
