from typing import List

import networkx as nx

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import DataPath


class LayerFlow(nx.DiGraph):
    """
    The representation of the graph connectivity of the hailo_model

    """

    INPUT_BASE_NAME = "input_placeholder_"
    OUTPUT_BASE_NAME = "output_placeholder_"

    def __init__(self, incoming_graph_data=None, **attr):
        super().__init__(incoming_graph_data, **attr)
        self._input_count = 0
        self._output_count = 0

    @property
    def num_inputs(self):
        return self._input_count

    @property
    def num_outputs(self):
        return self._output_count

    def toposort(self):
        """
        Create lexicographical toposort iterator (to ensure constant order for same input)
        Returns: lexicographical toposort iterator for graph
        """
        return nx.lexicographical_topological_sort(self)

    def toposort_ops(self):
        iterator = self.toposort()
        ops = self._get_op_attribute()
        for node in iterator:
            if node in ops:
                yield node

    def toposort_edges(self):
        for lname in self.toposort():
            for successor_name in self.successors_sorted(lname):
                yield lname, successor_name

    def add_input(self):
        index = self._input_count
        node_name = self.INPUT_BASE_NAME + f"{index}"
        super().add_node(node_name, is_input=True)
        self._input_count = index + 1
        return node_name

    def add_output(self):
        index = self._output_count
        node_name = self.OUTPUT_BASE_NAME + f"{index}"
        super().add_node(node_name, is_output=True)
        self._output_count = index + 1
        return node_name

    def add_node(self, op: BaseAtomicOp):
        super().add_node(op.full_name, op=op)

    def get_op(self, op_name) -> BaseAtomicOp:
        ops_dict = self._get_op_attribute()
        return ops_dict[op_name]

    def add_edge(self, u_node, v_node, data_path: DataPath, input_index=0, output_index=0):
        u_node = self._get_node(u_node)
        v_node = self._get_node(v_node)
        return super().add_edge(u_node, v_node, data_path=data_path, input_index=input_index, output_index=output_index)

    def _get_node(self, node):
        if isinstance(node, BaseAtomicOp):
            node = node.full_name
        return node

    def _get_op_attribute(self):
        return nx.get_node_attributes(self, "op")

    def get_ops(self) -> List[BaseAtomicOp]:
        ops = self._get_op_attribute()
        return [ops[op_name] for op_name in self.toposort_ops()]

    def get_inputs(self):
        attr = nx.get_node_attributes(self, "is_input")
        input_nodes = [node for node, is_input in attr.items() if is_input]
        return sorted(input_nodes, key=self._get_input_index)

    def get_outputs(self):
        attr = nx.get_node_attributes(self, "is_output")
        is_output = [node for node, is_output in attr.items() if is_output]
        return sorted(is_output, key=self._get_output_index)

    def _get_input_index(self, node):
        return int(node[len(self.INPUT_BASE_NAME) :])

    def _get_output_index(self, node):
        return int(node[len(self.OUTPUT_BASE_NAME) :])

    def predecessors_sorted(self, node, key=None):
        """
        Get sorted predecessors of a node, default order is based on input edges indices
        Args:
            node: predecessors will be searched for given node
            key: A custom key function, to customize the sort order

        Returns: sorted predecessor based on key (default key is input edges indices)
        """
        if key is None:

            def key_func(pred_node):
                return self.get_edge_input_index(pred_node, node)

            key = key_func
        return sorted(self.predecessors(node), key=key_func)

    def successors_sorted(self, node, key=None):
        """
        Get sorted successors of a node, default order is based on output edges indices
        Args:
            node: successors will be searched for given node
            key: A custom key function, to customize the sort order

        Returns: sorted successors based on key (default key is output edges indices)

        """
        if key is None:

            def key_func(succ_node):
                return self.get_edge_output_index(node, succ_node)

            key = key_func
        return sorted(self.successors(node), key=key_func)

    def get_edge_data_path(self, node, succ_node):
        return self.get_edge_data(node, succ_node)["data_path"]

    def visualize(self, path):
        import matplotlib.pyplot as plt

        nx.draw(self)
        plt.savefig(path)

    def get_edge_input_index(self, pred_node, node):
        return self.get_edge_data(pred_node, node)["input_index"]

    def get_edge_output_index(self, node, succ_node):
        return self.get_edge_data(node, succ_node)["output_index"]

    def is_placeholder(self, node):
        is_input = self.nodes[node].get("is_input", False)
        is_output = self.nodes[node].get("is_output", False)
        return is_input or is_output
