from collections import defaultdict, namedtuple
from dataclasses import dataclass
from typing import Dict, List, Tuple

import networkx as nx

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError

IndexMap = namedtuple("IndexMap", ["output_index", "input_index"])


@dataclass
class StaticMappings:
    main_flow: List[str]
    data_sources: List[str]
    constant_sources: List[str]
    pred_mapping: Dict[str, Dict[str, int]]
    output_mapping: List[Tuple[str, str]]
    out_degree: Dict[str, int]


class ModelFlow(nx.DiGraph):
    """
    The representation of the graph connectivity of the hailo_model

    """

    def __init__(self):
        super().__init__()
        self._all_components = None
        self._output_layer_order = None

    def toposort(self):
        """
        Create lexicographical toposort iterator (to ensure constant order for same input)
        Returns: lexicographical toposort iterator for graph
        """
        return nx.lexicographical_topological_sort(self)

    def descendants(self, source):
        return nx.descendants(self, source)

    def ancestors(self, source):
        return nx.ancestors(self, source)

    @property
    def input_nodes(self):
        """
        Get input nodes of graph
        Returns: all nodes with in-degree of 0, sorted by name
        """
        input_nodes = [node for node, is_input in nx.get_node_attributes(self, "is_input").items() if is_input]
        input_nodes.sort()
        return input_nodes

    @property
    def output_nodes(self):
        """
        Get output nodes of graph

        Returns: all nodes with out-degree of 0, sorted by name
        """
        output_nodes = [node for node, degree in self.out_degree if degree == 0]
        if len(output_nodes) > 1:

            def get_output_index(output_layer):
                return self.output_layer_order.index(self.predecessors_sorted(output_layer)[0])

            output_nodes.sort(key=get_output_index)
        return output_nodes

    @classmethod
    def from_hn_layers(cls, hn_layers, output_layer_order):
        """
        Creates flow object from hn_layers dict
        Args:
            hn_layers: dict with hn layers
        Returns: initialized ``ModelFlow``
        """
        graph = cls()
        output_layer_order = output_layer_order.copy()
        for layer_name in hn_layers:
            graph.add_node(layer_name, is_input=hn_layers[layer_name]["type"] == LayerType.INPUT_LAYER.value)
        for u, v, input_index, output_index in cls._iter_edges(hn_layers):
            # (u,v) is the edge we want to add
            graph.add_edge(u, v, input_index=input_index, output_index=output_index)
        graph.set_output_order(output_layer_order)
        return graph

    def set_output_order(self, output_layer_order):
        self._output_layer_order = output_layer_order

    @property
    def output_layer_order(self):
        return self._output_layer_order

    @classmethod
    def _iter_edges(cls, hn_layers):
        """
        iterate over all the nodes of the flow graph
        Args:
            hn_layers:

        Returns

        """
        layers_to_handle = cls._get_layers_by_type(hn_layers, "input_layer")
        layers_to_handle.extend(cls._get_layers_by_type(hn_layers, "const_input"))
        handled_layers = set()
        while layers_to_handle:
            layer_name = layers_to_handle.pop(0)
            if layer_name in handled_layers:
                continue
            handled_layers.add(layer_name)
            input_layer_name = layer_name  # we might in special case change the inout layer name
            for output_name in hn_layers[layer_name]["output"]:
                # handle a special case
                input_index = hn_layers[output_name]["input"].index(layer_name)
                output_index = hn_layers[layer_name]["output"].index(output_name)
                yield input_layer_name, output_name, input_index, output_index
                layers_to_handle.append(output_name)

    @classmethod
    def _get_layers_by_type(cls, hn_layers, ltype):
        input_layers = list()
        for layer_name, layer_data in hn_layers.items():
            if layer_data["type"] == ltype:  # TODO: use enum; TODO: use pydantic
                input_layers.append(layer_name)
        return input_layers

    def replace_layer(self, old_layer, new_layer):
        self = nx.relabel_nodes(self, {old_layer: new_layer})

    def replace_layer_manual(self, old_layer, new_layer):
        self.add_node(new_layer)
        for pred in self.predecessors_sorted(old_layer):
            self.add_edge(pred, new_layer, **self.get_edge_data(pred, old_layer))
            self.remove_edge(pred, old_layer)
        for succ in self.successors_sorted(old_layer):
            self.add_edge(new_layer, succ, **self.get_edge_data(old_layer, succ))
            self.remove_edge(old_layer, succ)
        self.remove_node(old_layer)

    def predecessors_sorted(self, node, key=None):
        """
        Get sorted predecessors of a node, default order is based on input edges indices
        Args:
            node: predecessors will be searched for given node
            key: A custom key function, to customize the sort order

        Returns: sorted predecessor based on key (default key is input edges indices)
        """

        def get_sort_key(pred_node):
            return self.get_edge_input_index(pred_node, node)

        if key is None:
            key = get_sort_key
        return sorted(self.predecessors(node), key=key)

    def successors_sorted(self, node, key=None):
        """
        Get sorted successors of a node, default order is based on output edges indices
        Args:
            node: successors will be searched for given node
            key: A custom key function, to customize the sort order

        Returns: sorted successors based on key (default key is output edges indices)

        """

        def get_sort_key(succ_node):
            return self.get_edge_output_index(node, succ_node)

        if key is None:
            key = get_sort_key
        return sorted(self.successors(node), key=key)

    def get_edge_output_index(self, u, v):
        return self.get_edge_data(u, v)["output_index"]

    def get_edge_input_index(self, u, v):
        return self.get_edge_data(u, v)["input_index"]

    def get_components(self):
        """
        return the weakly_connected_components of the graph. Create if needed
        """
        if self._all_components is None:
            # build connected components of the flow
            self._all_components = [self.subgraph(comp) for comp in nx.weakly_connected_components(self)]
        return self._all_components

    def insert_node(self, node, edges, is_input=False):
        """
        pushes node between the edges. The function supports insertion of only layers with
        one input and ont output. Therefore, the predecessors in all the edges must be the same
        Args:
            node: node to push
            edges: list of edges to push the node between
        """
        # check inputs
        predecessor = edges[0][0]
        for pred, succ in edges:
            if pred != predecessor:
                raise AccelerasValueError(f"multiple inputs to {node} where no input_indices was provided")

        # push the node
        for index, edge in enumerate(edges):
            pred, succ = edge
            succ_input_index = self.get_edge_input_index(pred, succ)
            pred_output_index = self.get_edge_output_index(pred, succ)
            self.add_node(node, is_input=is_input)
            self.add_edge(pred, node, input_index=0, output_index=pred_output_index)
            self.add_edge(node, succ, input_index=succ_input_index, output_index=index)

            if self.out_degree[succ] == 0 and pred in self.output_layer_order:
                # note that if succ is of degree 0, then pred must be in output_layer_order, but for complitness we check this as well
                output_layers = self.output_layer_order
                index = output_layers.index(pred)
                output_layers[index] = node
                self.set_output_order(output_layers)

            self.remove_edge(pred, succ)

        # reset components if needed
        self._all_components = None

    def add_node(self, node, is_input=False, **kwargs):
        super().add_node(node, is_input=is_input, **kwargs)

    def add_edge(self, u, v, **kwargs):
        kwargs.setdefault("input_index", 0)
        kwargs.setdefault("output_index", 0)
        super().add_edge(u, v, **kwargs)

    def remove_layer(self, layer, connect_succ_and_pred=True):
        """
        Removes node and connects its predecessor to its successor.
        Supports only non-input nodes with one predecessor.
        """
        preds = self.predecessors_sorted(layer)
        succs = self.successors_sorted(layer)
        if len(preds) != 1:
            if len(succs) == 0:
                for pred in preds:
                    self.remove_edge(pred, layer)
                self.remove_node(layer)
                return
            if len(preds) == 0:
                for succ in succs:
                    self.remove_edge(layer, succ)
                self.remove_node(layer)
                return
            else:
                raise AccelerasValueError(f"Can only remove {layer} with 1 predecessor, but got 0 predecessors.")

        pred_succs = list(self.successors(preds[0]))
        output_index = self.get_edge_output_index(preds[0], layer)
        for i, succ in enumerate(succs):
            input_index = self.get_edge_input_index(layer, succ)
            if i > 0:
                output_index = len(pred_succs) - 1 + i
            if connect_succ_and_pred:
                self.add_edge(preds[0], succ, input_index=input_index, output_index=output_index)
            self.remove_edge(layer, succ)
        self.remove_edge(preds[0], layer)
        self.remove_node(layer)

    def get_sub_flow(self, start_node, end_node):
        """
        get start node and end end and return sub flow between node
        """
        sub_flow = ModelFlow()
        spl = dict(nx.all_pairs_shortest_path_length(self))
        start_spl = spl[start_node]
        for node in self.toposort():
            if start_spl.get(node, None) is None:
                continue
            if spl[node].get(end_node, None) is None:
                continue
            sub_flow.add_node(node, is_input=self.nodes[node].get("is_input", False))
        for node in sub_flow.nodes:
            for succ in self.successors_sorted(node):
                if succ in sub_flow.nodes:
                    input_index = self.get_edge_input_index(node, succ)
                    output_index = self.get_edge_output_index(node, succ)
                    sub_flow.add_edge(node, succ, input_index=input_index, output_index=output_index)
        return sub_flow

    def toposort_edges(self) -> List[Tuple[str, str]]:
        topological_nodes = self.toposort()

        # Create a list to hold the edges in topological order
        topological_edges = []

        # Extract edges in the order of their nodes' topological sort
        for node in topological_nodes:
            # Extend the list with tuples (node, successor) for all successors of the node
            topological_edges.extend((node, successor) for successor in self.successors_sorted(node))
        return topological_edges

    def get_sources(self) -> List[str]:
        return [node for node, degree in self.in_degree() if degree == 0]

    def get_end_nodes(self) -> List[str]:
        return [node for node, degree in self.out_degree() if degree == 0]

    def get_index_data(self, u, v) -> IndexMap:
        data = self._succ[u][v]
        vals = IndexMap(data["output_index"], data["input_index"])
        return vals


def create_predecesor_mapping(flow: ModelFlow) -> Dict[str, Dict[str, int]]:
    predecessor_mapping = defaultdict(dict)
    for lname in flow.toposort():
        preds = flow.predecessors_sorted(lname)
        for pred in preds:
            index_map = flow.get_index_data(pred, lname)
            predecessor_mapping[lname][pred] = {
                "output_index": index_map.output_index,
                "input_index": index_map.input_index,
            }
    return predecessor_mapping


def create_static_mapping(
    flow: ModelFlow,
    custom_outputs: List[str] = None,
    custom_inputs: List[str] = None,
) -> StaticMappings:
    pred_mapping = create_predecesor_mapping(flow)

    nodes = set(flow.nodes())
    potential_sources = set(flow.get_sources())
    if custom_inputs:
        for input_layer in custom_inputs:
            nodes -= flow.ancestors(input_layer)
        potential_sources = potential_sources.union(custom_inputs)

    sources = set()
    for source in potential_sources:
        if source in nodes:
            nodes.remove(source)
            sources.add(source)
    sources = sorted(sources)

    if custom_outputs:
        ansectors = set(custom_outputs)
        for output in custom_outputs:
            ansectors.update(flow.ancestors(output))
        nodes = ansectors.intersection(nodes)
        output_mapping = custom_outputs
    else:
        output_mapping = [flow.successors_sorted(lname)[0] for lname in flow.output_layer_order]

    main_flow = list(flow.subgraph(nodes).toposort())
    out_degree = {lname: flow.out_degree(lname) for lname in flow.toposort()}

    data_sources = [src for src in sources if flow.nodes[src]["is_input"]]
    data_sources.extend(src for src in sources if flow.in_degree(src) != 0)  # add custom inputs that are not sources

    constant_sources = [src for src in sources if not flow.nodes[src]["is_input"] and flow.in_degree(src) == 0]

    return StaticMappings(main_flow, data_sources, constant_sources, pred_mapping, output_mapping, out_degree)
