from typing import List, Tuple

import networkx as nx
import numpy as np
from disjoint_set import DisjointSet

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer


class EquivFlow(nx.DiGraph):
    """
    Divide the graph into equiv sets that each one of them have the same scale/zp

    """

    @classmethod
    def from_hailo_model(cls, hailo_model):
        """
        Creates equiv object from hailo_model
        Args:
            hailo_model

        Returns: initialized ``graph of components``

        """
        graph = cls()
        names = [layer_name for layer_name in hailo_model.flow.toposort()]
        # first iteration on the graph defines which layers should act as continuers and which are stoppers.
        # build a graph in which all the edges of the original model graph are nodes
        for model_edge in hailo_model.flow.edges():
            v_index = names.index(model_edge[1])
            u_index = names.index(model_edge[0])

            graph.add_node(model_edge, u_index=u_index, v_index=v_index)
        graph._add_edges(hailo_model)

        return graph

    def _add_edges(self, hailo_model):
        """
        add edges to the new graph.
        Iteratively go over nodes (which are edges from the hailo_model flow graph)
        and decide if there is an edge between.

        Args:
            hailo_model

        return the equiv sets graph

        """
        for u in self.nodes:
            # u is an edges of the model flow
            dest_layer_name = u[1]
            # check if the outgoing node of the model flow edges preserves encoding
            layer = hailo_model.layers[dest_layer_name]

            if self._is_full_preserver(layer):
                for v in hailo_model.flow.out_edges(dest_layer_name):
                    self.add_edge(u, v)

    @staticmethod
    def source_layers(matching_component):
        """
        Returns: all source layers of the matching_component
        """
        return [edge[0] for edge, degree in matching_component.in_degree if degree == 0]

    @staticmethod
    def consumer_layers(matching_component):
        """
        Returns: all consumer layers of the matching_component
        """
        return [edge[1] for edge, degree in matching_component.out_degree if degree == 0]

    @classmethod
    def consumer_layers_groups(cls, matching_component_group):
        consumer_layers = set()
        for matching_component in matching_component_group:
            consumer_layers = consumer_layers.union(set(cls.consumer_layers(matching_component)))

        return list(consumer_layers)

    @classmethod
    def source_layers_group(cls, matching_component_group):
        """
        Returns: all source layers of the matching_component that are in matching_component_group
        """
        source_layers = set()
        for matching_component in matching_component_group:
            source_layers = source_layers.union(set(cls.source_layers(matching_component)))

        return list(source_layers)

    @classmethod
    def layers_group(cls, matching_component_group):
        """
        Returns: all source layers of the matching_component that are in matching_component_group
        """
        edges = set()
        for matching_component in matching_component_group:
            edges = edges.union(matching_component.nodes)
        layers = [layer for edge in edges for layer in edge]
        return list(set(layers))

    @staticmethod
    def get_comp_order_from_source_consumer(matching_component):
        """
        matching_component is a subgraph of the edges.
        Returns: min_u_index - the minimum of all sources indexes of the "source edges"
                 max_v_index - the maximum of all sources indexes of the "cosumer edges"
        """
        min_u_index = np.max(
            [
                nx.get_node_attributes(matching_component, "u_index")[node]
                for node, degree in matching_component.in_degree
                if degree == 0
            ],
        )
        max_v_index = np.min(
            [
                nx.get_node_attributes(matching_component, "v_index")[node]
                for node, degree in matching_component.out_degree
                if degree == 0
            ],
        )
        return min_u_index, max_v_index

    @staticmethod
    def _is_full_preserver(layer):
        """
        indicates if the layer inputs scale/zp and output scale/zp are the "same" up to order change.

        Notes
            1. layers that are up_to_scalar preservers are not full preservers (like avgpool, conv and add)
            2. input_layers and output layers are preservers

        """
        if isinstance(layer, BaseHailoNonNNCoreLayer):
            return False
        consumer_input_scale = layer.consumer_input_scale
        has_activation = layer.has_activation
        homogeneous = layer.homogeneous

        if not consumer_input_scale and not has_activation and homogeneous:
            # like all non arithmetic ops layer -
            return True
        return False

    def get_toposorted_components(self):
        """
        Returns the list of equiv-set sorted by toposort by their sources.

        Returns

        """
        matching_components = [self.subgraph(comp).copy() for comp in nx.weakly_connected_components(self)]
        matching_components.sort(
            key=lambda matching_component: self.get_comp_order_from_source_consumer(matching_component),
        )
        return matching_components

    def get_groups_components(self):
        """
        this function will group together all the components_matching that have the same source.
        in particular, we will create a group of matching component that follow the following:
        1. the union of groups is all the matching components
        2. for each group1 and group2 sources(group1).intersection(sources(group2))=set()
        3. for every m1 mathcing component in group1 there muse be m2 (mathcing componen)
            sources(m1).intersection(sources(m1))!=set()


        for example:
        1) two components with conv1, order is important
        matching_component_1: conv1 -->> conv4
        matching_component_2: conv1 -->> concat1, conv2 -->> concat1, concat1 -->> conv3

        in this case we will want to do zp matching on matching_component_2 and only then update matching_component_2.

        2) two components with conv1, must work together!
        matching_component_1: conv1 -->> concat2, conv4 -->> concat2, concat1 -->> conv5
        matching_component_2: conv1 -->> concat1, conv2 -->> concat1, concat1 -->> conv3

        in this case we will want to do zp matching on both components and hence must run the matching together.


        Returns: return list of components that need to be scales matching together.

        """
        disjoint_matching_indices = DisjointSet()

        matching_components_sorted = self.get_toposorted_components()
        number_of_indices = len(matching_components_sorted)
        source_layers = [
            set(self.source_layers(matching_component)) for matching_component in matching_components_sorted
        ]
        for index_1 in range(number_of_indices):
            for index_2 in range(index_1, number_of_indices):
                source_layers_1 = source_layers[index_1]
                source_layers_2 = source_layers[index_2]

                if source_layers_1.intersection(source_layers_2) != set():
                    disjoint_matching_indices.union(index_1, index_2)

        list_of_components = []
        for disjoint_set in list(disjoint_matching_indices.itersets()):
            list_ = []
            for i in disjoint_set:
                list_.append(matching_components_sorted[i])

            list_of_components.append(list_)
        return list_of_components

    def get_sorted_u_nodes_in_componenets_group(self, matching_component_group):
        nodes_indices = dict()
        for componenet in matching_component_group:
            model_edges_with_u_index = nx.get_node_attributes(componenet, "u_index")
            curr_component_nodes_indices = {k[0]: v for k, v in model_edges_with_u_index.items()}
            nodes_indices.update(curr_component_nodes_indices)
        toposorted_nodes = sorted(nodes_indices.keys(), key=lambda x: nodes_indices[x])
        return toposorted_nodes

    @classmethod
    def replace_component_group_source(
        cls,
        matching_component_group: List["EquivFlow"],
        old_source: str,
        new_source: str,
    ):
        """Replaces the sources on the equiv set group with new_source"""
        res = []
        for matching_component in matching_component_group:
            if old_source in cls.source_layers(matching_component):
                copy_component = matching_component.copy()
                copy_component.replace_source(old_source, new_source)
                res.append(copy_component)
            else:
                res.append(matching_component.copy())
        return res

    def replace_source(self, old_source: str, new_source: str):
        """
        replace the source of the graph with new source
        """
        for old_node in self.get_nodes_by_source(old_source):
            new_node = (new_source, old_node[1])
            self.replace_node(old_node, new_node)

    def replace_node(self, old_node, new_node, new_attrs=None):
        # Step 1: Determine the attributes for the new node
        # If new_attrs is None, use old node's attributes; otherwise, use new_attrs
        final_attrs = self.nodes[old_node] if new_attrs is None else new_attrs

        # Step 2: Add the new node with the determined attributes
        # If the new node already exists, update its attributes
        if new_node in self:
            self.nodes[new_node].update(final_attrs)
        else:
            self.add_node(new_node, **final_attrs)

        # Step 3: Transfer all incoming and outgoing edges to/from the new node
        # Note: This step does not involve copying edge attributes
        for predecessor in list(self.predecessors(old_node)):
            self.add_edge(predecessor, new_node)
        for successor in list(self.successors(old_node)):
            self.add_edge(new_node, successor)

        # Step 4: Remove the old node
        self.remove_node(old_node)

    def get_nodes_by_source(self, source: str) -> List[Tuple[str, str]]:
        # TODO Need to check if the source is unique
        source_nodes = [node for node, degree in self.in_degree if (degree == 0 and node[0] == source)]
        if source_nodes == []:
            raise ValueError(f"source {source} does not exist in the graph")
        else:
            return source_nodes
