from typing import Dict

from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.algorithms.matmul_decompose.matmul_decompose_blocks import (
    BaseDecomposeBlock,
    search_matmul_decomp_blocks,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class DecomposeMatmulFix(OptimizationAlgorithm):
    """
    switch 16bits layer to 16bit decompose
    """

    components: Dict[str, BaseDecomposeBlock]

    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        **kwargs,
    ):
        super().__init__(model, model_config, logger_level=logger_level, name="MatmulDecomposeFix", **kwargs)

    def should_skip_algo(self):
        return False

    def _setup(self):
        self.components = search_matmul_decomp_blocks(self._model)

    def get_algo_config(self):
        return self._model_config.matmul_decomposition

    def _run_int(self):
        for lname, comp in self.components.items():
            self._logger.debug(f"Applying {comp} to {lname}")
            comp.fix_matmuls(self._model)

    def finalize_config(self): ...

    def finalize_global_cfg(self, algo_config): ...

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg
