from enum import Enum
from typing import Callable, List, Optional, Tuple, Union

import networkx as nx
import tensorflow as tf

from hailo_model_optimization.acceleras.encoding.encoding_constraint import (
    EncodingConstraint,
    default_state,
    default_status,
)
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingNode, EncodingType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasEncodingError


class NodeType(Enum):
    ENCODING = "encoding"
    CONSTRAINT = "constraint"


class InferenceNodeType(Enum):
    INDEPENDENT = "independent"
    DEPENDANT = "dependant"


@tf.custom_gradient
def fake_quant_int(x, min_val, max_val):
    in_range = tf.logical_and(tf.less_equal(min_val, x), tf.less_equal(x, max_val))

    def grad(dy, variables=None):
        return tf.where(in_range, dy, 0.0), 0, 0

    return tf.clip_by_value(tf.round(x), min_val, max_val), grad


def generate_identity_function():
    return lambda x: x


def generate_independent_scalar_function(shape):
    return lambda x: tf.ones(shape) * x


def generate_quant_function(func, quant_min, quant_max):
    return lambda x: fake_quant_int(func(x), quant_min, quant_max)


class EncodingFlowGraph(nx.DiGraph):
    """
    The representation of the network encodings, and their constraints.

    This is a bipartite graph.

    The vertices made of {e | e is encoding in the network} and {c | c is a constraint on the encoding}.

    An edge from encoding to constraint represent that the encoding is an input of the constraint,
    while an edge from constraint to encoding represent that the encoding is an output of the constraint.
    """

    def __init__(self) -> None:
        super().__init__()
        self._dummy_count = 0
        self._constraint_count = 0

    def add_encoding(
        self,
        name: str,
        encoding_type: EncodingType = EncodingType.Dummy,
        scalar: Optional[bool] = None,
        shape: Optional[Tuple[int]] = None,
        initializer: Optional[Callable] = None,
        regularizer: Optional[Callable] = None,
        constraint: Optional[Callable] = None,
        quant: bool = False,
        quant_min: Optional[float] = None,
        quant_max: Optional[float] = None,
        hidden: bool = False,
    ):
        """
        Add encoding node with named: `name`.

        See `EncodingNode` documentation for more info regarding arguments usage.

        Args:
            name (str): The name of the encoding.
            encoding_type (EncodingType, optional): The type of the encoding. Defaults to EncodingType.Dummy.
            scalar (Optional[bool], optional): Whether the represented encoding should be a scalar.
            shape (Optional[Tuple[int]], optional): The shape of the represented encoding.
            initializer (Optional[Callable], optional): Encoding initializer instance (callable).
            regularizer (Optional[Callable], optional): Encoding regularizer instance (callable).
            constraint (Optional[Callable], optional): Encoding constraint instance (callable).
            quant (bool, optional): Whether the encoding should be represented by an integer (will pass str). Defaults to False.
            quant_min (Optional[float], optional): Minimum quant-value of the encoding. should be specified iff quant is True.
            quant_max (Optional[float], optional): Maximum quant-value of the encoding. should be specified iff quant is True.
            hidden (bool, optional): Whether the encoding value should be visible. Defaults to False.

        """
        encoding = EncodingNode(
            encoding_type,
            scalar,
            shape,
            initializer,
            regularizer,
            constraint,
            quant,
            quant_min,
            quant_max,
        )
        self.add_node(name, encoding=encoding, type=NodeType.ENCODING, hidden=hidden)

    def add_dummy_encoding(self):
        """
        Add dummy encoding node.
        """
        dummy = f"dummy_encoding:{self._dummy_count}"
        self._dummy_count += 1
        self.add_encoding(dummy, hidden=True)
        return dummy

    def add_constraint(
        self,
        inputs: Union[List[str], str],
        outputs: Union[List[str], str],
        func: Callable,
        state_function: Callable[[List[EncodingNode], List[EncodingNode]], List[bool]] = default_state,
        status_function: Callable[[List[EncodingNode], List[EncodingNode]], bool] = default_status,
        func_string: Optional[str] = None,
    ):
        """
        Add constraint node s.t `outputs` = `func`(*`inputs`).

        See `EncodingConstraint` documentation for more info regarding arguments usage.

        Args:
            inputs (Union[List[str], str]): The inputs encoding names of the constraint.
            outputs (Union[List[str], str]): The outputs encoding names of the constraint.
            func (Callable): The main function this constraint represent.
            state_function (Callable[[List[EncodingNode], List[EncodingNode]], List[bool]], optional): A function that
            given a list of inputs `EncodingNode` and a list of outputs `EncodingNode` updates the outputs states, and
            return a list with the same length as outputs, specifying wether the state has changed. By default
            `state_function` doesn't change anything.
            status_function (Callable[[List[EncodingNode], List[EncodingNode]], bool], optional): A function that given
            a list of inputs `EncodingNode` and a list of outputs `EncodingNode` returns whether This constraint is
            valid. By default `status_function` return always True
            func_string (Optional[str], optional): A string template representing this constraint function.

        """
        inputs = inputs if isinstance(inputs, list) else [inputs]
        outputs = outputs if isinstance(outputs, list) else [outputs]

        constraint = EncodingConstraint(
            func,
            state_function=state_function,
            status_function=status_function,
            func_string=func_string,
        )
        constraint_name = f"constraint:{self._constraint_count}"
        self._constraint_count += 1
        self.add_node(constraint_name, constraint=constraint, type=NodeType.CONSTRAINT)

        for i, inp in enumerate(inputs):
            self.add_edge(inp, constraint_name, index=i)
        for i, out in enumerate(outputs):
            self.add_edge(constraint_name, out, index=i)

    def update(self, flow):
        def relabel(node):
            name, index = node.split(":")
            if name == "dummy_encoding":
                index = int(index) + self._dummy_count
            if name == "constraint":
                index = int(index) + self._constraint_count
            return f"{name}:{index}"

        super().update(nx.relabel_nodes(flow, relabel))
        self._dummy_count += flow._dummy_count
        self._constraint_count += flow._constraint_count

    def _derived(self, node, derived_set):
        """
        return encoding nodes that could be derived immediately by adding node to derived_set.

        Args:
            node: Node to find all of the derived encoding from.
            derived_set: A set of all the nodes derived so far.

        Returns:
            Tuple[List, List]: A list of immediate derived nodes, and a list of their respected constraints.

        """
        derived_nodes = list()
        derived_constraints = list()
        for constraint in self.successors(node):
            if all(inp in derived_set for inp in self.predecessors(constraint)):
                for succ in self.successors(constraint):
                    if succ not in derived_set and succ not in derived_nodes:
                        derived_nodes.append(succ)
                        derived_constraints.append(constraint)
        return derived_nodes, derived_constraints

    @property
    def encoding_nodes(self):
        nodes = [n for n, t in self.nodes(data="type") if t is NodeType.ENCODING]

        def key(node):
            # TODO: setting the sort most significant key to be constant break test zipped_yolox_model_level_conv45
            # I'm not sure why, this will be investigated in the future SDK-46258.
            # constant = 0 if any(self.in_degree(cons) == 0 for cons in self.predecessors(node)) else 1
            descendants = len(nodes) - len(
                [child for child in nx.descendants(self, node) if self.get_node_type(child) is NodeType.ENCODING],
            )
            scalar = 0 if self.get_scalar(node) else 1
            quant = 0 if self.get_encoding(node).quant else 1
            initializer = 0 if self.get_encoding(node).initializer is not None else 1
            return (descendants, scalar, quant, initializer)

        return sorted(nodes, key=key)

    @property
    def constraint_nodes(self):
        nodes = [n for n, t in self.nodes(data="type") if t is NodeType.CONSTRAINT]

        def key(node):
            return -len(
                [child for child in nx.descendants(self, node) if self.get_node_type(child) is NodeType.ENCODING],
            )

        return sorted(nodes, key=key)

    def get_encoding(self, node):
        if self.get_node_type(node) is not NodeType.ENCODING:
            raise AccelerasEncodingError(
                f"Node {node} must be of type {NodeType.ENCODING} (received type {self.get_node_type(node)}).",
            )
        return self.nodes[node]["encoding"]

    def get_shape(self, node):
        return self.get_encoding(node).shape

    def get_scalar(self, node):
        return self.get_encoding(node).scalar

    def get_encoding_type(self, node):
        return self.get_encoding(node).encoding_type

    def get_constraint(self, node):
        if self.get_node_type(node) is not NodeType.CONSTRAINT:
            raise AccelerasEncodingError(
                f"Node {node} must be of type {NodeType.CONSTRAINT} (received type {self.get_node_type(node)}).",
            )
        return self.nodes[node]["constraint"]

    def get_func(self, node):
        return self.get_constraint(node).func

    def get_func_string(self, node):
        return self.get_constraint(node).func_string

    def get_state_function(self, node):
        return self.get_constraint(node).state_function

    def get_status_function(self, node):
        return self.get_constraint(node).status_function

    def get_node_type(self, node):
        return self.nodes[node]["type"]

    def inputs_sorted(self, constraint_name):
        if self.get_node_type(constraint_name) is not NodeType.CONSTRAINT:
            raise AccelerasEncodingError(
                f"Node {constraint_name} must be of type {NodeType.CONSTRAINT} (received type "
                f"{self.get_node_type(constraint_name)}).",
            )

        def get_sort_key(pred_node):
            return self.get_edge_data(pred_node, constraint_name)["index"]

        return sorted(self.predecessors(constraint_name), key=get_sort_key)

    def outputs_sorted(self, constraint_name):
        if self.get_node_type(constraint_name) is not NodeType.CONSTRAINT:
            raise AccelerasEncodingError(
                f"Node {constraint_name} must be of type {NodeType.CONSTRAINT} (received type "
                f"{self.get_node_type(constraint_name)}).",
            )

        def get_sort_key(succ_node):
            return self.get_edge_data(constraint_name, succ_node)["index"]

        return sorted(self.successors(constraint_name), key=get_sort_key)

    def _resolve_state(self):
        """
        Update the state of each encoding node.
        """
        resolved = set()
        constraints = self.constraint_nodes
        while len(constraints) > 0:
            cons = constraints.pop(0)
            resolved.add(cons)
            state_function = self.get_state_function(cons)
            inputs = [self.get_encoding(inp) for inp in self.inputs_sorted(cons)]
            outputs = [self.get_encoding(out) for out in self.outputs_sorted(cons)]
            changes = state_function(inputs, outputs)
            if len(outputs) != len(changes):
                raise AccelerasEncodingError(
                    f"state_function of constraint {cons} must return a list of length "
                    f"{len(outputs)} (received len {len(changes)}).",
                )
            for output, change in zip(self.outputs_sorted(cons), changes):
                if change:
                    for constraint_name in self.successors(output):
                        if constraint_name in resolved:
                            resolved.remove(constraint_name)
                            constraints.insert(0, constraint_name)

    def _resolve_status(self):
        """
        Remove imposable constraints.
        """
        constraints = self.constraint_nodes
        for constraint_name in constraints:
            status_function = self.get_status_function(constraint_name)
            inputs = [self.get_encoding(inp) for inp in self.inputs_sorted(constraint_name)]
            outputs = [self.get_encoding(out) for out in self.outputs_sorted(constraint_name)]
            if not status_function(inputs, outputs):
                self.remove_node(constraint_name)

    def _add_independent_encoding_node(self, inference_flow, node):
        encoding = self.get_encoding(node)
        const_constraints = [cons for cons in self.predecessors(node) if self.in_degree(cons) == 0]
        if len(const_constraints) > 1:
            raise AccelerasEncodingError(
                f"Node {node} can't have multiple constant constraint defining it ({len(const_constraints)}).",
            )
        elif len(const_constraints) == 0:
            source_name = inference_flow.add_independent(encoding)
            if encoding.scalar and encoding.shape != ():
                func = generate_independent_scalar_function(encoding.shape)
                func_string = f"tf.ones({encoding.shape}) * {{{0}}}"
            else:
                func = generate_identity_function()
                func_string = "{0}"
            if encoding.quant:
                func = generate_quant_function(func, encoding.quant_min, encoding.quant_max)
                func_string = f"ste({func_string})"
            constraint = EncodingConstraint(func, func_string=func_string)
            inference_flow.add_dependant([source_name], node, constraint, hidden=self.nodes[node]["hidden"])
        else:
            self._add_dependant_encoding_node(inference_flow, node, const_constraints[0])

    def _add_dependant_encoding_node(self, inference_flow, encoding_node, constraint_node):
        inputs = self.inputs_sorted(constraint_node)
        func = self.get_func(constraint_node)
        func_string = self.get_func_string(constraint_node)
        constraint = EncodingConstraint(func, func_string=func_string)
        if self.out_degree(constraint_node) > 1:
            # If havn't done that already, add a hidden node called `constraint_node` to inference flow such that
            # `constraint_node` is equal to a tuple with all of the constraint outputs. Then set `encoding_node` to be
            # infered by taking the relevent index of the `constraint_node` tuple.
            if constraint_node not in inference_flow.dependant_nodes:
                inference_flow.add_dependant(inputs, constraint_node, constraint, hidden=True)
            index = self.get_edge_data(constraint_node, encoding_node)["index"]
            index_constraint = EncodingConstraint(lambda x: x[index], func_string=f"{{0}}[{index}]")
            inference_flow.add_dependant(
                [constraint_node],
                encoding_node,
                index_constraint,
                hidden=self.nodes[encoding_node]["hidden"],
            )
        else:
            inference_flow.add_dependant(inputs, encoding_node, constraint, hidden=self.nodes[encoding_node]["hidden"])

    def _create_inference_flow(self):
        """
        Create a solved encoding inference flow.
        This is a greedy, that each iteration add one encoding node to the set of independent nodes, followed by all
        the nodes that are derived from it.
        """
        inference_flow = EncodingInferenceFlowGraph()
        derived = set()
        nodes = self.encoding_nodes
        while len(nodes) > 0:
            node = nodes.pop(0)
            derived.add(node)
            self._add_independent_encoding_node(inference_flow, node)
            derived_nodes, derived_constraints = self._derived(node, derived)
            while len(derived_nodes) > 0:
                output, constraint_name = derived_nodes.pop(0), derived_constraints.pop(0)
                self._add_dependant_encoding_node(inference_flow, output, constraint_name)
                nodes.remove(output)
                derived.add(output)
                for new_node, new_constraint in zip(*self._derived(output, derived)):
                    if new_node not in derived_nodes:
                        derived_nodes.append(new_node)
                        derived_constraints.append(new_constraint)
        return inference_flow

    def solve(self):
        """
        Solve the encoding flow graph and return solved encoding inference flow.
        """
        self._resolve_state()
        self._resolve_status()
        return self._create_inference_flow()

    def print(self, show_hidden=True):
        """
        Print all the encoding's constraints in the graph.
        """
        for constraint in self.constraint_nodes:
            outputs = self.outputs_sorted(constraint)
            if show_hidden or not all(self.nodes[out]["hidden"] for out in outputs):
                format_function_string = self.get_constraint(constraint).format(
                    *[inp for inp in self.inputs_sorted(constraint)],
                )
                print(f'{", ".join(outputs)} = {format_function_string}')


class EncodingInferenceFlowGraph(nx.DiGraph):
    """
    The representation of the encoding connectivity of the hailo_model.
    """

    def __init__(self) -> None:
        super().__init__()
        self._source_count = 0

    def add_independent(self, encoding: EncodingNode):
        """
        Add and return a new independent encoding node.
        """
        shape = () if encoding.scalar else encoding.shape
        independent_encoding = EncodingNode(
            encoding.encoding_type,
            encoding.scalar,
            shape,
            encoding.initializer,
            encoding.regularizer,
            encoding.constraint,
        )
        source_name = f"source_{self._source_count}"
        self._source_count += 1
        self.add_node(source_name, encoding=independent_encoding, type=InferenceNodeType.INDEPENDENT, hidden=True)
        return source_name

    def add_dependant(self, inputs, output, constraint, hidden=False):
        """
        Add a new dependant encoding node.
        """
        self.add_node(output, constraint=constraint, type=InferenceNodeType.DEPENDANT, hidden=hidden)
        for i, inp in enumerate(inputs):
            self.add_edge(inp, output, index=i)

    @property
    def independent_nodes(self):
        return [n for n, t in self.nodes(data="type") if t is InferenceNodeType.INDEPENDENT]

    @property
    def dependant_nodes(self):
        return [n for n, t in self.nodes(data="type") if t is InferenceNodeType.DEPENDANT]

    def get_encoding(self, node):
        if self.get_node_type(node) is not InferenceNodeType.INDEPENDENT:
            raise AccelerasEncodingError(
                f"Node {node} must be of type {InferenceNodeType.INDEPENDENT} (received type "
                f"{self.get_node_type(node)}).",
            )
        return self.nodes[node]["encoding"]

    def get_shape(self, node):
        return self.get_encoding(node).shape

    def get_scalar(self, node):
        return self.get_encoding(node).scalar

    def get_encoding_type(self, node):
        return self.get_encoding(node).encoding_type

    def get_constraint(self, node):
        if self.get_node_type(node) is not InferenceNodeType.DEPENDANT:
            raise AccelerasEncodingError(
                f"Node {node} must be of type {InferenceNodeType.DEPENDANT} (received type "
                f"{self.get_node_type(node)}).",
            )
        return self.nodes[node]["constraint"]

    def get_func(self, node):
        return self.get_constraint(node).func

    def get_func_string(self, node):
        return self.get_constraint(node).func_string

    def get_node_type(self, node):
        return self.nodes[node]["type"]

    def toposort(self):
        return nx.lexicographical_topological_sort(self)

    def inputs_sorted(self, node):
        def get_sort_key(pred):
            return self.get_edge_data(pred, node)["index"]

        return sorted(self.predecessors(node), key=get_sort_key)

    def print(self, show_hidden=True):
        """
        Print all the encoding's constraints in the graph.
        """
        if show_hidden:
            for node in self.dependant_nodes:
                format_function_string = self.get_constraint(node).format(*[inp for inp in self.inputs_sorted(node)])
                print(f"{node} = {format_function_string}")
        else:
            string_dict = dict()
            for node in self.toposort():
                if self.get_node_type(node) is InferenceNodeType.INDEPENDENT:
                    string_dict[node] = node
                elif self.nodes[node]["hidden"]:
                    string_dict[node] = (
                        f"({self.get_constraint(node).format(*[string_dict[inp] for inp in self.inputs_sorted(node)])})"
                    )
                else:
                    string_dict[node] = "/".join(node.split("/")[1:])
                if self.get_node_type(node) is InferenceNodeType.DEPENDANT and not self.nodes[node]["hidden"]:
                    format_function_string = self.get_constraint(node).format(
                        *[string_dict[inp] for inp in self.inputs_sorted(node)],
                    )
                    print(f"{node} = {format_function_string}")
