import inspect
import json
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from typing import List, Tuple

import networkx as nx

from hailo_model_optimization.flows.utils.flow_memento import FlowMemento
from hailo_model_optimization.tools.base_memento import BaseMemento


class SupervisorState(Enum):
    """States for the supervisor"""

    RUNNING = "running"
    IGNORING = "ignoring"
    DISABLED = "disabled"
    LOOPING = "looping"


def looping_list(call_graph: nx.DiGraph, leaf_id: int) -> List[Tuple[int, bool, dict]]:
    """
    Generates a list of tuples representing the path from a specified leaf node to the root in a call graph,
    marking each node on the path as 'True' and adding any branching siblings encountered along the way
    as 'False'. Each tuple includes the node ID, a boolean indicating if it is on the direct path, and
    the node's attribute dictionary.

    Parameters:
    call_graph (nx.DiGraph): The directed graph representing call dependencies.
    leaf_id (int): The ID of the leaf node from which to start tracing back to the root.

    Returns:
    List[Tuple[int, bool, dict]]: A list of tuples where each tuple contains:
                                  - An integer node ID.
                                  - A boolean flag, True if the node is on the direct path to the root, otherwise False.
                                  - A dictionary of the node's attributes.
    """

    def add_pred(call_graph: nx.DiGraph, current: int, vals: list):
        if current == 1:  # First call is the decorated function!
            vals.append((current, True, call_graph.nodes[current]))
            return vals
        predecessors = next(call_graph.predecessors(current))
        for brother in sorted(call_graph.successors(predecessors), reverse=True):
            if brother < current:
                vals.append((brother, False, call_graph.nodes[brother]))
        vals.append((current, True, call_graph.nodes[current]))
        return add_pred(call_graph, predecessors, vals)

    values = []
    values = add_pred(call_graph, leaf_id, values)
    return sorted(values, key=lambda x: x[0])


def build_nested_dict_tree(graph, start_node=None):
    if start_node is None:
        # If no start node provided, use the first node in the graph
        start_node = next(iter(graph.nodes()))

    def recurse(node):
        # Base dictionary from node attributes, could also include specific attributes
        node_dict = {**graph.nodes[node]}
        # Recursively build the dictionary for children
        children = list(graph.successors(node))
        if children:
            node_dict["children"] = {child: recurse(child) for child in children}
        return node_dict

    # Start recursion from the chosen root node
    tree_dict = {start_node: recurse(start_node)}
    return tree_dict


class NodeCheckpoint(BaseMemento):
    node_id: int
    # TODO SDK-48985 This need to be generic
    checkpoint: FlowMemento


class FlowCheckPoint(BaseMemento):
    node_checkpoint: NodeCheckpoint
    graph: dict
    state: SupervisorState


class Supervisor:
    """
    The Supervisor class oversees the execution of methods and
    manages checkpoints in the state of the system.
    It uses a directed graph to track method calls and their relationships.

    Attributes:
        graph (DiGraph): A directed graph where each node represents a method call.
        call_counter (int): A counter to uniquely identify each node (method call).
        current_node (int): The ID of the current node in the graph.
        state (SupervisorState): The current state of the supervisor, which controls how method calls are handled.
        current_checkpoint (NodeCheckpoint): The most recent checkpoint containing a snapshot of the system's state.
        depth (int): Tracks the nesting level of method calls."""

    call_list: List[Tuple[int, str, dict]]

    def __init__(self):
        self.graph = nx.DiGraph()
        self.call_counter = 0
        self.current_node = 0
        self.state = SupervisorState.DISABLED
        self.current_checkpoint: NodeCheckpoint = None
        self.depth = 0
        self.exit_at = ""
        self.tf_safe = True

    def add_exit(self, method_name: str):
        self.exit_at = method_name

    def load_state(self, memento: FlowCheckPoint):
        self.graph = nx.adjacency_graph(memento.graph, directed=True)
        self.current_checkpoint = memento.node_checkpoint
        self.state = memento.state

    def enter_looping(self):
        self.call_list: List[Tuple[int, str, dict]] = looping_list(self.graph, self.current_checkpoint.node_id)
        self.state = SupervisorState.LOOPING

    def dump(self) -> FlowCheckPoint:
        return FlowCheckPoint(
            graph=nx.adjacency_data(self.graph),
            node_checkpoint=self.current_checkpoint,
            state=self.state,
            base_path=self.current_checkpoint.base_path,
        )

    def export(self) -> FlowCheckPoint:
        graph = self.groom_graph()
        return FlowCheckPoint(
            graph=nx.adjacency_data(graph),
            node_checkpoint=self.current_checkpoint,
            state=self.state,
            base_path=self.current_checkpoint.base_path,
        )

    def add_node(self, method_name, is_checkpoint=False):
        self.call_counter += 1
        node_id = self.call_counter
        self.graph.add_node(node_id, method_name=method_name, is_checkpoint=is_checkpoint, depht=self.depth)
        self.graph.add_edge(self.current_node, node_id)
        self.current_node = node_id
        return node_id

    def add_checkpoint(self, memento: BaseMemento) -> None:
        self.current_checkpoint = NodeCheckpoint(
            base_path=memento.base_path, node_id=self.current_node, checkpoint=memento
        )
        self.graph.nodes[self.current_node]["is_checkpoint"] = True

    def groom_graph(self):
        export_g = self.graph.copy()
        export_g.remove_nodes_from([node for node in self.graph.nodes if node > self.current_checkpoint.node_id])
        return export_g

    def __str__(self) -> str:
        results = build_nested_dict_tree(self.graph)
        # Print as a pretty JSON for better readability
        return json.dumps(results, indent=4)

    def check_result(self, result: BaseMemento):
        if self.graph.nodes[self.current_node]["is_checkpoint"]:
            self.current_checkpoint = NodeCheckpoint(
                node_id=self.current_node, checkpoint=result, base_path=result.base_path
            )
        return result

    @contextmanager
    def visit_node(self, method_name):
        previous_node = self.current_node
        try:
            checkpoint = method_name == "save_state" or method_name == self.exit_at
            self.depth += 1
            self.add_node(method_name, is_checkpoint=checkpoint)
            yield
            # TODO  this should be as stop
        finally:
            self.depth -= 1
            self.current_node = previous_node
            if method_name == "stop":
                self.state = SupervisorState.IGNORING


def flow_control_method(func):
    """
    A decorator for the 'run' method in classes derived from BaseSubprocessFlow.
    It manages the automated tracking and state management of all callable methods within the class,
    enabling precise execution control and robust state management. The execution is monitored and
    controlled through a Supervisor instance, which maintains a method call graph.

    Parameters:
    - func (Callable): The original method, typically 'run', intended to be wrapped for enhanced
                       execution control.

    Functionality:
    - Wraps all callable methods of the class dynamically, allowing for execution to be tracked
      and controlled based on the supervisor's state.
    - Manages different states (RUNNING, LOOPING, DISABLED, IGNORING) to determine how methods are
      executed based on the current and past system states.
    - Utilizes context managers to handle the entry and exit of method calls in the graph, automatically
      tracking the depth of call nesting.

    State Behaviors:
    - RUNNING: Executes and tracks methods normally, suitable for normal operations.
    - LOOPING: Repeats method execution from a saved state, useful for resuming or replaying execution.
    - DISABLED: Methods execute without tracking, typically used after recovery from a checkpoint.
    - IGNORING: Skips method execution, used after a stop condition to gracefully end the run.

    Usage:
    To utilize this decorator, apply it to the 'run' method in any subclass of BaseSubprocessFlow:

    Example:

    >>>    class MyProcessFlow(BaseSubprocessFlow):
    >>>        @flow_control_method
    >>>        def run(self, *, memento: Optional[BaseModel] = None):
    >>>            # Implementation of the run logic


    This setup is essential in complex algorithm flows where precise control, state recovery,
    and the ability to pause and resume are crucial.
    """

    @wraps(func)
    def wrapper(self, *args, **kwargs):
        # Save original methods
        original_methods = {}
        sup = Supervisor()
        memento: FlowCheckPoint = kwargs.get("memento", None)
        run_until: str = kwargs.get("run_until", "")
        if not memento and not run_until:
            # No need to ave the state TODO for now :)
            return func(self, *args, **kwargs)

        sup.state = SupervisorState.RUNNING
        if memento:
            sup.load_state(memento)
            sup.enter_looping()
            sup.call_list.pop(0)

        if run_until:
            sup.add_exit(run_until)

        # Method to wrap other methods
        sup.add_node(func.__name__)
        self.supervisor = sup

        def method_wrapper(method, sup: Supervisor):
            @wraps(method)
            def wrapped(*args, **kwargs):
                method_name = method.__name__
                if sup.state == SupervisorState.RUNNING:
                    with sup.visit_node(method_name):
                        result = method(*args, **kwargs)
                        if method_name == run_until:
                            sup.state = SupervisorState.DISABLED
                            result = self.save_state(*args, **kwargs)
                            sup.state = SupervisorState.IGNORING
                        sup.check_result(result)

                elif sup.state == SupervisorState.LOOPING:
                    node_index, policy, _ = sup.call_list.pop(0)

                    if len(sup.call_list) == 0:
                        sup.state = SupervisorState.DISABLED
                        self.load_state(sup.current_checkpoint.checkpoint, tf_safe=sup.tf_safe)
                        sup.call_counter = node_index
                        sup.current_node = node_index
                        sup.state = SupervisorState.RUNNING
                        result = None
                    elif policy:
                        # TODO maybe use the same context manager.
                        temp = sup.current_node
                        sup.current_node = node_index
                        sup.depth += 1
                        result = method(*args, **kwargs)
                        sup.depth -= 1
                        sup.current_node = temp
                    else:
                        result = None

                elif sup.state == SupervisorState.DISABLED:
                    result = method(*args, **kwargs)

                elif sup.state == SupervisorState.IGNORING:
                    result = None

                return result

            return wrapped

        # Wrap methods
        for attr_name in dir(
            self.__class__,
        ):
            attr = getattr(self.__class__, attr_name)
            if (
                inspect.isfunction(attr)
                and not isinstance(attr, property)
                and not attr_name.startswith("__")
                and attr_name != "run"
            ):
                method = getattr(self, attr_name)
                if inspect.ismethod(method) and not isinstance(method, (staticmethod, classmethod)):
                    original_methods[attr_name] = method
                    setattr(self, attr_name, method_wrapper(method, sup))

        try:
            # Execute the original run function
            return func(self, *args, **kwargs)

        finally:
            # Restore original methods after run execution
            self.call_history = sup.export() if sup.current_checkpoint else None

            for name, method in original_methods.items():
                setattr(self, name, method)

    return wrapper
