from abc import ABC, abstractmethod
from typing import Generic, List, Type, TypeVar

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer_decompose import BaseHailoLayerDecompose
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_quant_weight_group import HailoConvQuantWeightGroup
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import OptimizationTarget

# To implement a new strategy, you need to:
# 1. Create a new class that inherits from DecompositionStrategy
# 2. Implement the abstract methods should_apply and decompose, that are used to determine if the strategy should be applied and to create a new BaseHailoLayerDecompose, respectively.
# 3. Register the new class using the DecompositionRegistry. The registration priority is used to determine the order in which the strategies are applied. Each layer can be decomposed once, so the order of the strategies is important.
# 4. Optinally, you can implement the get_decomposition_details method to provide additional information about the decomposition for logging purposes.
# +-------------------------+
# | DecompositionRegistry   |
# |-------------------------|
# | +_strategies: List      |
# |-------------------------|
# | +register(priority: int)|
# | +get_strategies()       |
# +-------------------------+
#             |
#             |
#             v
# +-------------------------+
# | DecompositionStrategy   |
# |-------------------------|
# | +PRIORITY: ClassVar[int]|
# |-------------------------|
# | +should_apply(...)      |
# | +decompose(...)         |
# | +get_decomposition_details(...)|
# +-------------------------+
#             ^
#             |
#             |
# +---------------------------------------------+
# | QuantWeightGroupsDecomposition               |
# |---------------------------------------------|
# | +should_apply(...)                          |
# | +decompose(...)                             |
# | +get_decomposition_details(...)             |
# +---------------------------------------------+
#             ^
#             |
#             |
# +-------------------------+
# | LayerDecompose          |
# |-------------------------|
# | -_layers_to_decompose   |
# |-------------------------|
# | +_setup()               |
# | +_run_int()             |
# | +_decompose_layer()     |
# +-------------------------+


T = TypeVar(
    "T", bound=BaseHailoLayerDecompose
)  # This is the layer to decompose from. If we want to generalize the algorithm to decompose general layers, than change this.
K = TypeVar("K", bound=BaseHailoLayer)


# Base Strategy class
class DecompositionStrategy(Generic[T, K], ABC):
    """Base class for layer decomposition strategies."""

    @staticmethod
    @abstractmethod
    def should_apply(
        layer: T,
        optimization_target: OptimizationTarget,
        model_flow: ModelFlow,
        layer_config: LayerPrecisionConfig,
    ) -> bool:
        """Determine if this strategy should be applied to the given layer.
        Receives model flow to check for predecessors and successors."""

    @staticmethod
    @abstractmethod
    def decompose(layer: T, layer_config: LayerPrecisionConfig) -> K:
        """Apply decomposition to create a new layer."""

    @staticmethod
    def get_decomposition_details(new_layer: K) -> str:
        """Get strategy-specific details for logging."""
        return ""


# Concrete strategies with automatic registration
class QuantWeightGroupsDecompositionStrategy(DecompositionStrategy[HailoConv, HailoConvQuantWeightGroup]):
    """Strategy for decomposing layers with quantization weight groups."""

    @staticmethod
    def should_apply(
        layer,
        optimization_target: OptimizationTarget,
        model_flow: ModelFlow,
        layer_config: LayerPrecisionConfig,
    ) -> bool:
        quantization_weight_groups = getattr(layer_config, "quantization_weight_groups", 1) or 1
        return isinstance(layer, HailoConv) and quantization_weight_groups > 1

    @staticmethod
    def decompose(layer, layer_config: LayerPrecisionConfig):
        quantization_weight_groups = getattr(layer_config, "quantization_weight_groups", 1) or 1
        return HailoConvQuantWeightGroup.from_conv(layer, quantization_weight_groups)

    @staticmethod
    def get_decomposition_details(new_layer) -> str:
        if hasattr(new_layer, "quantization_weight_groups"):
            return f" with {new_layer.quantization_weight_groups} quantization weight groups"
        return ""


# region DecompositionRegistry


class DecompositionRegistry:
    """Registry for decomposition strategies with priority-based ordering."""

    _strategies: List[Type["DecompositionStrategy"]] = [
        # HW_arch specific strategies
        # ----- To fill here -----
        # Precision Config specific strategies
        # ----- To fill here -----
        # General strategies
        QuantWeightGroupsDecompositionStrategy
    ]

    @classmethod
    def get_strategies(cls) -> List[Type["DecompositionStrategy"]]:
        """Get all registered strategies in priority order."""
        return cls._strategies
