"""
! Important Read this before you start working on this module
Module: model_fuse_patterns

This module provides multiple examples of fuse structures used to match and fuse
sub-graphs within a PyTorch model. Each class here inherits from `BaseMatchStructure`
and implements the required static (or class) method `create_structure()`.


-----------------------------------------------------
How to Define Your Own Fuse Structure (`BaseMatchStructure`)
-----------------------------------------------------
1. **Create a Subclass**
   Inherit from `BaseMatchStructure` in your new class.

2. **Implement `create_structure()`**
   - Define a small directed graph (using NetworkX) where each node is either a placeholder string
     or a reference to a module class.
   - Add edges (`graph.add_edge(...)`) to reflect the order in which these nodes occur.
   - Create a `builder` function that takes matched modules and returns a fused `nn.Module`.

3. **Return a `PatternMatch`**
   - `PatternMatch` includes:
       - `name`: A unique string name for your pattern.
       - `graph`: The `nx.DiGraph` describing layer connectivity.
       - `builder`: The function that creates the fused module from matched submodules.
       - `match_type`: Defines how nodes are matched against the model (e.g., `MatchType.REGEX`,
         `MatchType.CLASS`, etc.).

4. **(Optional) Register the Class**
   - When you finish your class, decorate your class with
     `@FUSER_REGISTRY` so it’s automatically discovered.

By following these steps, you can quickly implement specialized fusions that combine multiple
layers or operations into a single optimized module for improved performance and readability.
"""

from typing import Dict

import networkx as nx
import torch.nn as nn

from hailo_model_optimization.acceleras.utils.acceleras_definitions import FeatureMultiplierType, LayerType
from hailo_model_optimization.saitama.framework.apu_modules.activation_factory import Exp, InvPos, InvSqrt
from hailo_model_optimization.saitama.framework.fused_modules.fused_base import (
    SubClusterModule,
    SubClusterWithActivationModule,
)
from hailo_model_optimization.saitama.translators.hailo_translator.base_hailo_translator import LayerMode
from hailo_model_optimization.saitama.translators.hailo_translator.native_modules import (
    EWMult,
    EWSub,
    FeatureMultiplier,
    FusedGroupedNormalization,
    FusedLayerNormalization,
    FusedRMSNormalization,
    FusedSoftmax,
    ReduceMax,
    ReduceMean,
    ReduceSum,
    ResizeNN,
)
from hailo_model_optimization.saitama.translators.model_fuser.base_fuse_structure import (
    BaseMatchStructure,
    PatternMatch,
)
from hailo_model_optimization.saitama.translators.model_fuser.matching_tools import MatchType
from hailo_model_optimization.saitama.translators.translator_registry import ClassRegister

FUSER_REGISTRY = ClassRegister(BaseMatchStructure)


@FUSER_REGISTRY
class SoftMaxMatch(BaseMatchStructure):
    @classmethod
    def create_structure(cls, layer_mode) -> PatternMatch:
        """
        Create a softmax pattern match structure.
                             │
                      ┌──────▼──────┐
                      │    input    │
                      └──┬────────┬─┘
               ┌─────────▼──┐     │
               │ reduce max │     │
               └─────────┬──┘     │
                       ┌─▼────────▼───┐
                       │  ew_sub+exp  │
                       └─┬──────┬─────┘
         ┌───────────────▼────┐ │
         │reduce sum + inv_pos│ │
         └───────────────┬────┘ │
                       ┌─▼──────▼───┐
                       │  ew_mult   │
                       └────────────┘

        """
        softmax_nodes = {}
        if layer_mode == LayerMode.NATIVE:
            softmax_nodes[LayerType.REDUCE_MAX.value] = SubClusterModule(ReduceMax(), nn.Identity())
            softmax_nodes[LayerType.ELEMENTWISE_SUB.value] = SubClusterModule(EWSub(), Exp())
            softmax_nodes[LayerType.REDUCE_SUM.value] = SubClusterModule(ReduceSum(), InvPos())
            softmax_nodes[LayerType.ELEMENTWISE_MULT.value] = EWMult()
        else:
            raise NotImplementedError(f"Layer mode {layer_mode} is not supported for SoftMaxMatch")

        graph = nx.DiGraph()
        graph.add_edge(softmax_nodes[LayerType.REDUCE_MAX.value], softmax_nodes[LayerType.ELEMENTWISE_SUB.value])
        graph.add_edge(softmax_nodes[LayerType.ELEMENTWISE_SUB.value], softmax_nodes[LayerType.REDUCE_SUM.value])
        graph.add_edge(softmax_nodes[LayerType.REDUCE_SUM.value], softmax_nodes[LayerType.ELEMENTWISE_MULT.value])
        graph.add_edge(softmax_nodes[LayerType.ELEMENTWISE_SUB.value], softmax_nodes[LayerType.ELEMENTWISE_MULT.value])

        def builder(mapped_layers: Dict[str, nn.Module], **kwargs) -> nn.Module:
            axis = mapped_layers[softmax_nodes[LayerType.REDUCE_MAX.value]].mac.axis
            return FusedSoftmax(axis[0])

        return PatternMatch("softmax", graph, builder, MatchType.CLASS)


@FUSER_REGISTRY
class LayerNormMatch(BaseMatchStructure):
    @classmethod
    def create_structure(cls, layer_mode) -> PatternMatch:
        """
          Create a layer normalization pattern match structure.
                       |          |
        ┌──────────────▼─┐        │
        │   reduce mean  │        │
        └──────────────┬─┘        │
                    ┌──▼──────────▼───┐
                    │      ew sub     │
                    └──┬──────────┬───┘
        ┌──────────────▼──────┐   │
        │  feature multiplier │   │
        └──────────────┬──────┘   │
        ┌──────────────▼──┐       │
        │ conv + inv_sqrt │       │
        └───────────────┬─┘       │
                    ┌───▼─────────▼───┐
                    │     ew mult     │
                    └─────────────────┘
        """
        layer_norm = {}
        if layer_mode == LayerMode.NATIVE:
            layer_norm[LayerType.REDUCE_MEAN.value] = SubClusterModule(ReduceMean(), nn.Identity())
            layer_norm[LayerType.ELEMENTWISE_SUB.value] = SubClusterModule(EWSub(), nn.Identity())
            layer_norm[LayerType.FEATURE_MULTIPLIER.value] = SubClusterModule(
                FeatureMultiplier(FeatureMultiplierType.square, reduce_sum_groups=1), nn.Identity()
            )
            layer_norm[LayerType.CONV.value] = SubClusterModule(nn.Conv2d(1, 1, 1), InvSqrt())
            layer_norm[LayerType.ELEMENTWISE_MULT.value] = EWMult()
        else:
            raise NotImplementedError(f"Layer mode {layer_mode} is not supported for LayerNormMatch")

        graph = nx.DiGraph()
        graph.add_edge(layer_norm[LayerType.REDUCE_MEAN.value], layer_norm[LayerType.ELEMENTWISE_SUB.value])
        graph.add_edge(layer_norm[LayerType.ELEMENTWISE_SUB.value], layer_norm[LayerType.FEATURE_MULTIPLIER.value])
        graph.add_edge(layer_norm[LayerType.ELEMENTWISE_SUB.value], layer_norm[LayerType.ELEMENTWISE_MULT.value])
        graph.add_edge(layer_norm[LayerType.FEATURE_MULTIPLIER.value], layer_norm[LayerType.CONV.value])
        graph.add_edge(layer_norm[LayerType.CONV.value], layer_norm[LayerType.ELEMENTWISE_MULT.value])

        def builder(mapped_layers: Dict[str, nn.Module], input_shapes: Dict[nn.Module, list]) -> nn.Module:
            return FusedLayerNormalization(input_shapes[mapped_layers[layer_norm[LayerType.REDUCE_MEAN.value]]][1])

        return PatternMatch("layer_normalization", graph, builder, MatchType.CLASS)


@FUSER_REGISTRY
class GroupedNormMatch(BaseMatchStructure):
    @classmethod
    def create_structure(cls, layer_mode) -> PatternMatch:
        """
              Create a grouped normalization pattern match structure.
                           |          |
            ┌──────────────▼─┐        │
        2X  │   reduce mean  │        │
            └──────────────┬─┘        │
            ┌──────────────▼─┐        │
            │   resize       │        │
            └──────────────┬─┘        │
                        ┌──▼──────────▼─┐
                        │      ew sub   │
                        └──┬──────────┬─┘
            ┌──────────────▼──────┐   │
            │  feature multiplier │   │
            └──────────────┬──────┘   │
            ┌──────────────▼──┐       │
            │      conv       │       │
            └───────────────┬─┘       │
            ┌───────────────▼──────┐  │
            │reduce mean + inv_sqrt│  │
            └────────────────────┬─┘  │
                 ┌──────────────▼─┐   │
                 │   resize       │   │
                 └──────────┬─────┘   │
                        ┌───▼─────────▼───┐
                        │     ew mult     │
                        └─────────────────┘
        """
        layer_norm = {}
        if layer_mode == LayerMode.NATIVE:
            layer_norm[f"{LayerType.REDUCE_MEAN.value}_1"] = SubClusterModule(ReduceMean(), nn.Identity())
            layer_norm[f"{LayerType.REDUCE_MEAN.value}_2"] = SubClusterModule(ReduceMean(), nn.Identity())
            layer_norm[f"{LayerType.RESIZE.value}_1"] = SubClusterModule(ResizeNN(), nn.Identity())
            layer_norm[LayerType.ELEMENTWISE_SUB.value] = SubClusterModule(EWSub(), nn.Identity())
            layer_norm[LayerType.FEATURE_MULTIPLIER.value] = SubClusterModule(
                FeatureMultiplier(FeatureMultiplierType.square, reduce_sum_groups=1), nn.Identity()
            )
            layer_norm[LayerType.CONV.value] = SubClusterModule(nn.Conv2d(1, 1, 1), nn.Identity())
            layer_norm[f"{LayerType.REDUCE_MEAN.value}_3"] = SubClusterModule(ReduceMean(), InvSqrt())
            layer_norm[f"{LayerType.RESIZE.value}_2"] = SubClusterModule(ResizeNN(), nn.Identity())
            layer_norm[LayerType.ELEMENTWISE_MULT.value] = EWMult()
        else:
            raise NotImplementedError(f"Layer mode {layer_mode} is not supported for GroupedNorm")

        graph = nx.DiGraph()
        graph.add_edge(layer_norm[f"{LayerType.REDUCE_MEAN.value}_1"], layer_norm[f"{LayerType.REDUCE_MEAN.value}_2"])
        graph.add_edge(layer_norm[f"{LayerType.REDUCE_MEAN.value}_2"], layer_norm[f"{LayerType.RESIZE.value}_1"])
        graph.add_edge(layer_norm[f"{LayerType.RESIZE.value}_1"], layer_norm[f"{LayerType.ELEMENTWISE_SUB.value}"])
        graph.add_edge(layer_norm[LayerType.ELEMENTWISE_SUB.value], layer_norm[LayerType.FEATURE_MULTIPLIER.value])
        graph.add_edge(layer_norm[LayerType.ELEMENTWISE_SUB.value], layer_norm[LayerType.ELEMENTWISE_MULT.value])
        graph.add_edge(layer_norm[LayerType.FEATURE_MULTIPLIER.value], layer_norm[LayerType.CONV.value])
        graph.add_edge(layer_norm[LayerType.CONV.value], layer_norm[f"{LayerType.REDUCE_MEAN.value}_3"])
        graph.add_edge(layer_norm[f"{LayerType.REDUCE_MEAN.value}_3"], layer_norm[f"{LayerType.RESIZE.value}_2"])
        graph.add_edge(layer_norm[f"{LayerType.RESIZE.value}_2"], layer_norm[LayerType.ELEMENTWISE_MULT.value])

        def builder(mapped_layers: Dict[str, nn.Module], input_shapes: Dict[nn.Module, list]) -> nn.Module:
            return FusedGroupedNormalization(
                mapped_layers[layer_norm[f"{LayerType.REDUCE_MEAN.value}_1"]].mac.groups[0],
                input_shapes[mapped_layers[layer_norm[f"{LayerType.REDUCE_MEAN.value}_1"]]][1],
            )

        return PatternMatch("grouped_normalization", graph, builder, MatchType.CLASS)


@FUSER_REGISTRY
class RMSNormMatch(BaseMatchStructure):
    @classmethod
    def create_structure(cls, layer_mode) -> PatternMatch:
        """
          Create a RMS normalization pattern match structure.
                       |          |
        ┌──────────────▼──────┐   │
        │  feature multiplier │   │
        └──────────────┬──────┘   │
        ┌──────────────▼──┐       │
        │ conv + inv_sqrt │       │
        └───────────────┬─┘       │
                    ┌───▼─────────▼───┐
                    │     ew mult     │
                    └─────────────────┘
        """
        layer_norm = {}
        if layer_mode == LayerMode.NATIVE:
            layer_norm[LayerType.FEATURE_MULTIPLIER.value] = SubClusterModule(
                FeatureMultiplier(FeatureMultiplierType.square, reduce_sum_groups=1), nn.Identity()
            )
            layer_norm[LayerType.CONV.value] = SubClusterModule(nn.Conv2d(1, 1, 1), InvSqrt())
            layer_norm[LayerType.ELEMENTWISE_MULT.value] = EWMult()
        else:
            raise NotImplementedError(f"Layer mode {layer_mode} is not supported for RMSNormMatch")

        graph = nx.DiGraph()
        graph.add_edge(layer_norm[LayerType.FEATURE_MULTIPLIER.value], layer_norm[LayerType.CONV.value])
        graph.add_edge(layer_norm[LayerType.CONV.value], layer_norm[LayerType.ELEMENTWISE_MULT.value])

        def builder(mapped_layers: Dict[str, nn.Module], input_shapes: Dict[nn.Module, list]) -> nn.Module:
            return FusedRMSNormalization(input_shapes[mapped_layers[layer_norm[LayerType.FEATURE_MULTIPLIER.value]]][1])

        return PatternMatch("rms_normalization", graph, builder, MatchType.CLASS)


@FUSER_REGISTRY
class NegExponentMatch(BaseMatchStructure):
    @staticmethod
    def create_structure(layer_mode) -> PatternMatch:
        """
          Create a Negative exponent pattern match structure.
                       |
        ┌──────────────▼──────┐
        │  sub cluster layer  │
        └──────────────┬──────┘
        ┌──────────────▼────────┐
        │ standalone activation │
        └───────────────────────┘
        """
        neg_exp = {}
        if layer_mode == LayerMode.NATIVE:
            neg_exp[LayerType.CONV.value] = SubClusterModule(nn.Conv2d(1, 1, 1), nn.Identity())
            neg_exp[LayerType.ACTIVATION.value] = SubClusterModule(nn.Identity(), Exp(), is_activation_only=True)

        graph = nx.DiGraph()
        graph.add_edge(neg_exp[LayerType.CONV.value], neg_exp[LayerType.ACTIVATION.value])

        def builder(mapped_layers: Dict[str, nn.Module], **kwargs) -> nn.Module:
            return SubClusterWithActivationModule(
                mapped_layers[neg_exp[LayerType.CONV.value]],
                mapped_layers[neg_exp[LayerType.ACTIVATION.value]],
            )

        return PatternMatch("NegExponent", graph, builder, MatchType.CLASS)


# @FUSER_REGISTRY
# class PrecisionChangeMatch(BaseMatchStructure):
#     @staticmethod
#     def create_structure() -> PatternMatch:
#         source_pattern = "{@name}"
#         neg_exp_pattern = "{}precision_change{}"

#         graph = nx.DiGraph()
#         graph.add_edge(source_pattern, neg_exp_pattern)

#         def builder(mapped_layers: Dict[str, nn.Module]) -> nn.Module:
#             # we can just update the precision of the source layer, and remove the precision change
#             layer = nn.Sequential(
#                 mapped_layers[source_pattern],
#                 mapped_layers[neg_exp_pattern],
#             )
#             return layer

#         return PatternMatch("@name", graph, builder, MatchType.REGEX)
