import re
from enum import Enum
from typing import Dict

import torch.nn as nn

from hailo_model_optimization.saitama.framework.fused_modules.fused_base import SubClusterModule


class MatchType(Enum):
    """
    Enumeration of match types for node label matching.

    - REGEX: Match using a regex pattern.
    - CLASS: Match using the class type of the node's corresponding layer.
    """

    REGEX = 1
    CLASS = 2


def _pattern_to_regex(pattern: str) -> str:
    r"""
    Convert a pattern string (which may contain wildcards) into a regular expression.

    Supported wildcards:
      - "{}"  : Matches any sequence of characters (equivalent to ".*")
      - "{n}" : Matches one or more digits (equivalent to "\d+")
      - "{s}" : Matches one or more word characters (equivalent to "\w+")

    Examples:
      - "{}_reduce_max"  -> "^.*_reduce_max$"
      - "conv{n}"        -> "^conv\d+$"
      - "layer_{s}"      -> "^layer_\w+$"
      - "pool"           -> "^pool$" (i.e., an exact match)

    Args:
        pattern (str): The pattern string with optional wildcards.

    Returns:
        str: The converted regular expression pattern.
    """
    token_map = {
        "{}": ".*",
        "{n}": r"\d+",
        "{s}": r"\w+",
    }

    token_regex = re.compile(r"(\{\}|\{n\}|\{s\})")
    parts = token_regex.split(pattern)

    regex_parts = [token_map[part] if part in token_map else re.escape(part) for part in parts]

    return "^" + "".join(regex_parts) + "$"


def node_matcher(fuse_algo, match_type: MatchType, layers: Dict[str, nn.Module]) -> bool:
    def is_class_match(node_original: Dict[str, str], node_pattern: Dict[str, str]) -> bool:
        """
        Match nodes based on the class type of their corresponding layers.

        Args:
            node_original (Dict[str, str]): The original node with a "label" key.
            node_pattern (Dict[str, str]): The pattern node with a "label" key.

        Returns:
            bool: True if the class of the original node's layer matches the class specified in the pattern node.
        """
        from hailo_model_optimization.saitama.translators.model_fuser.matching_structures import NegExponentMatch

        original_node = layers[node_original["label"]]
        pattern_node = node_pattern["label"]
        if isinstance(original_node, SubClusterModule) and isinstance(pattern_node, SubClusterModule):
            if isinstance(fuse_algo, NegExponentMatch):
                # in neg exponent match we want to match sub cluster layer without limiting the exact mac and apu
                # but the activation only
                return original_node.is_activation_only == pattern_node.is_activation_only
            return (
                original_node.mac.__class__ == pattern_node.mac.__class__
                and original_node.apu.__class__ == pattern_node.apu.__class__
            )
        return layers[node_original["label"]].__class__ == node_pattern["label"].__class__

    def is_regex_match(node_original: Dict[str, str], node_pattern: Dict[str, str]) -> bool:
        """
        Match nodes based on a regex pattern derived from the pattern node's label.

        Args:
            node_original (Dict[str, str]): The original node with a "label" key.
            node_pattern (Dict[str, str]): The pattern node with a "label" key.

        Returns:
            bool: True if the original node's label matches the regex pattern from the pattern node's label.
        """
        regex = _pattern_to_regex(node_pattern["label"])
        candidate = node_original["label"]
        return re.fullmatch(regex, candidate) is not None

    return is_class_match if match_type == MatchType.CLASS else is_regex_match
