from dataclasses import dataclass
from enum import Enum
from typing import Callable, Dict, List, Tuple

import networkx as nx
import torch.nn as nn
from networkx.algorithms import isomorphism

from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.saitama.framework.common.saitama_definitions import SaitmaBuilding
from hailo_model_optimization.saitama.translators.hailo_translator.base_hailo_translator import LayerMode
from hailo_model_optimization.saitama.translators.model_fuser.matching_tools import (
    MatchType,
    node_matcher,
)

LayerBuilder = Callable[[Dict[str, nn.Module]], nn.Module]


@dataclass
class PatternMatch:
    name: str
    pattern: nx.DiGraph
    builder: LayerBuilder
    match_type: Enum = MatchType.CLASS


class BaseMatchStructure:
    """
    Base class for defining a sub-graph fusion pattern and implementing its fusion logic.

    -------------------------
    Overview and Responsibilities
    -------------------------
    1. **Defining a Fuse Pattern**
       Each subclass must provide a static or class method called `create_structure()` returning
       a `PatternMatch`. This pattern describes:
         - A small directed graph (via NetworkX) with placeholder nodes or class references.
         - A `builder` function that takes the matched submodules and creates a fused PyTorch module.
         - A name and a match type (e.g., `MatchType.REGEX` or `MatchType.CLASS`).

    2. **Finding and Replacing Matches**
       The `fuse()` method orchestrates how the pattern is applied to a `HailoModel` (here,
       represented by `ModelFlow` and associated `layers`):
         - It repeatedly searches the model graph for sub-graphs matching the pattern (using
           `find_pattern_matches()`).
         - For each match, it replaces the matched nodes in the model graph with a single
           fused node (via `swap_grapgh_by_node()`).
         - It then replaces the matched layers in the `layers` dict with the single fused module
           created by `builder` (via `replace_modules()`).

    3. **Extension Points**
       - **`create_structure()`**: Must return a `PatternMatch` describing your new fuse pattern.
       - **`find_pattern_matches()`**: Relies on NetworkX’s `DiGraphMatcher` to locate sub-graph isomorphisms.
         If you need advanced matching logic, you can provide or override the `compare_function` logic.

    -------------------------
    Usage
    -------------------------
    Subclasses typically look like:

    .. code-block:: python

        class MyCustomMatch(BaseMatchStructure):
            @staticmethod
            def create_structure() -> PatternMatch:
                # 1) Build the pattern graph and placeholders
                graph = nx.DiGraph()
                graph.add_edge("{}", "{}")

                # 2) Define a builder function
                def builder(matched_layers: Dict[str, nn.Module]) -> nn.Module:
                    # Combine matched_layers into a single fused module
                    return MyFusedModule(...)

                # 3) Return a PatternMatch
                return PatternMatch(
                    name="my_custom_fusion",
                    pattern=graph,
                    builder=builder,
                    match_type=MatchType.REGEX
                )

    When `fuse()` is called, it will:
      - Search for matches of the above `PatternMatch` in the model's graph.
      - Remove the matched sub-graph.
      - Insert a single new node and a single new fused PyTorch module.
      - Continue until no more matches are found.

    This design simplifies replacing repetitive model sub-graphs with more optimized fused counterparts,
    potentially improving both performance and maintainability of your model.

    -------------------------
    Public Methods
    -------------------------
    - **create_structure()** -> PatternMatch
      Must be implemented by subclasses to define the fuse pattern's graph, builder, name, and match type.

    - **fuse(saitama_info: SaitmaBuilding) -> SaitmaBuilding**
      Applies the defined pattern repeatedly on a given `SaitmaBuilding` (model + layers dict),
      returning an updated `SaitmaBuilding`.

    - **replace_modules(...) -> Dict[str, nn.Module]**
      Replaces matched modules in the layers dict with the newly fused module.

    - **find_pattern_matches(...) -> List[Dict[str, str]]**
      Searches for sub-graphs in the target graph (`target_graph`) that match the pattern graph
      (`pattern_graph`) using a `node_match` comparison function (depending on `MatchType`).

    - **swap_grapgh_by_node(...) -> ModelFlow**
      Removes the matched sub-graph from the model flow and adds a new node representing the fused sub-graph.

    -------------------------
    Notes
    -------------------------
    - The default `match_type` is `MatchType.CLASS`. For string-based matching, set it to `MatchType.REGEX`.
    - `SaitmaBuilding` encapsulates the model flow and layers dict. Subclasses can
      extend or override the `fuse` logic to handle custom behaviors when the sub-graph is replaced.

    """

    @staticmethod
    def create_structure(cls, layer_mode) -> PatternMatch:
        """
        Subclasses must override this method to return a PatternMatch describing
        the structure (pattern graph), builder function, name, and match type
        of the desired fusion pattern.
        """
        raise NotImplementedError

    def fuse(self, saitama_info: SaitmaBuilding, layer_mode: LayerMode) -> SaitmaBuilding:
        # TODO make this better
        index = 0
        model_flow = saitama_info.model_flow
        layers = saitama_info.layers
        input_shapes = saitama_info.input_shapes
        matcher = self.create_structure(layer_mode)
        while True:
            name = f"{matcher.name}_{index}"
            index += 1
            matching = self.find_pattern_matches(
                model_flow, matcher.pattern, node_matcher(self, matcher.match_type, layers)
            )

            if not matching:
                break
            subgraph = model_flow.subgraph(set(matching.keys())).copy()
            model_flow = self.swap_graph_by_node(model_flow, subgraph, name)
            layers = self.replace_modules(layers, input_shapes, matching, matcher.builder, name)

        return SaitmaBuilding(model_flow, layers, input_shapes)

    @staticmethod
    def replace_modules(
        layers: Dict[str, nn.Module],
        input_shapes: Dict[str, Tuple[int, ...]],
        matching: Dict[str, str],
        builder: LayerBuilder,
        name: str,
    ) -> Dict[str, nn.Module]:
        new_mapping = {v: layers[k] for k, v in matching.items()}

        new_layer = builder(new_mapping, input_shapes=input_shapes)
        layers[name] = new_layer
        for old_layer in matching:
            del layers[old_layer]
        return layers

    @staticmethod
    def find_pattern_matches(
        target_graph: nx.DiGraph,
        pattern_graph: nx.DiGraph,
        compare_funtion: Callable = None,
    ) -> List[Dict[str, str]]:
        A = target_graph.copy()
        B = pattern_graph.copy()

        # Annotate each node with a 'label' attribute equal to the node itself.
        # (This allows the isomorphism matcher to access the node "name".)
        for node in A.nodes():
            A.nodes[node]["label"] = node
        for node in B.nodes():
            B.nodes[node]["label"] = node

        # Create a DiGraphMatcher that will try to match pattern graph B within graph A.
        matcher = isomorphism.DiGraphMatcher(A, B, node_match=compare_funtion)

        # Get all matches.
        matches = list(matcher.subgraph_isomorphisms_iter())
        if matches == []:
            return None
        # Obtain a topological order of the target graph.
        # (Assuming target_graph is a ModelFlow or otherwise has a toposort() method.)

        topo_order = {node: i for i, node in enumerate(target_graph.toposort())}

        # Define a sort key for each match.
        # We'll sort based on the tuple of topological indices for the target nodes involved in the match.
        # To ensure consistency, sort the mapping items by the pattern node key.
        def mapping_key(mapping: Dict[str, str]) -> Tuple:
            # Create a tuple of the topological order indices of the target nodes.
            return min((topo_order[val] for val in mapping))

        # Sort all matches by the defined key.
        matches.sort(key=mapping_key)
        return matches[0]

    @staticmethod
    def swap_graph_by_node(hailo_model: ModelFlow, match: ModelFlow, new_node: str) -> ModelFlow:
        nodes_to_replace = set(match.nodes())

        incoming_edges = []
        for node in nodes_to_replace:
            for pred in hailo_model.predecessors_sorted(node):
                if pred not in nodes_to_replace:
                    incoming_edges.append((pred, node))

        outgoing_edges = []
        for node in nodes_to_replace:
            for succ in list(hailo_model.successors_sorted(node)):
                if succ not in nodes_to_replace:
                    outgoing_edges.append((node, succ))

        hailo_model.add_node(new_node)

        for pred, old_target in incoming_edges:
            data = hailo_model.get_edge_data(pred, old_target)
            hailo_model.remove_edge(pred, old_target)
            hailo_model.add_edge(pred, new_node, **data)

        for old_source, succ in outgoing_edges:
            data = hailo_model.get_edge_data(old_source, succ)
            hailo_model.remove_edge(old_source, succ)
            hailo_model.add_edge(new_node, succ, **data)

        for node in nodes_to_replace:
            if hailo_model.has_node(node):
                hailo_model.remove_node(node)
                if node in hailo_model.output_layer_order:
                    hailo_model.output_layer_order[hailo_model.output_layer_order.index(node)] = new_node

        return hailo_model
