from typing import List, Tuple, Type

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerTranslationConfig
from hailo_model_optimization.algorithms.layer_decompose.decomposition_strategy import (
    DecompositionRegistry,
    DecompositionStrategy,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


# Main algorithm class
class LayerDecompose(OptimizationAlgorithm):
    """Switch/Decompose Layers based on registered decomposition strategies."""

    def __init__(self, model, model_config, logger_level, **kwargs):
        super().__init__(model, model_config, name="switch hailo layer decompose", logger_level=logger_level, **kwargs)
        self._layers_to_decompose: List[Tuple[str, Type[DecompositionStrategy]]] = []

    def get_algo_config(self):
        return self._model_config.precision_config

    def should_skip_algo(self):
        return False

    def finalize_global_cfg(self, algo_config):
        pass

    def _setup(self):
        """Identify layers that need decomposition based on registered strategies."""
        precision_config = self._model_config.precision_config.layers
        # Get strategies from registry in priority order
        strategies = DecompositionRegistry.get_strategies()
        for lname, layer in self._model.iterate_layers():
            layer_precision_config = precision_config.get(lname, dict())

            # Check strategies in priority order
            for strategy in strategies:
                if strategy.should_apply(
                    layer, self._model.optimization_target, self._model.flow, layer_precision_config
                ):
                    self._layers_to_decompose.append((lname, strategy))
                    # Select first matching strategy
                    break

    def _run_int(self):
        """Apply all identified decomposition operations."""
        for lname, strategy in self._layers_to_decompose:
            self._decompose_layer(self._model.layers[lname], strategy)

        if len(self._layers_to_decompose) > 0:
            self._model.import_config(self._model_config)

    def _decompose_layer(self, layer: BaseHailoLayer, strategy: Type[DecompositionStrategy]):
        """Apply a decomposition strategy to a layer and update model configuration."""
        layer_precision_config = self._model_config.precision_config.layers[layer.full_name]
        new_layer = strategy.decompose(layer, layer_precision_config)

        self._model.replace_layer(new_layer, layer)
        translation_layer_config = self.finalize_layer_cfg(self._model_config.translation_config.layers)
        self._model_config.translation_config.layers[new_layer.name] = translation_layer_config.get(
            layer.full_name, LayerTranslationConfig.get_default()
        ).copy()

        # Get strategy-specific details for logging
        details = strategy.get_decomposition_details(new_layer)

        self._logger.log(
            self._logger_level,
            f"Algorithm {self._name} switched {layer.full_name} to {new_layer.name}{details}",
        )

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg
