from itertools import chain
from typing import List, Tuple

from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerMatmulDecompositionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerFeaturePolicy, PrecisionMode
from hailo_model_optimization.algorithms.matmul_decompose.matmul_decompose_blocks import (
    BaseDecomposeBlock,
    MatmulDecompose168,
    MatmulDecompose1616,
)
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector

# supported layer must include neg_weights function
SUPPORTED_LAYERS = [HailoMatmul]


class DecomposeMatmul(OptimizationAlgorithm):
    """
    switch 16bits layer to 16bit decompose
    """

    BLOCKS = {
        PrecisionMode.a16_w16: MatmulDecompose1616,
        PrecisionMode.a16_w8: MatmulDecompose168,
    }

    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        dataset,
        **kwargs,
    ):
        super().__init__(model, model_config, logger_level=logger_level, name="MatmulDecompose", **kwargs)
        self._unbatched_dataset = dataset

    def _setup(self): ...

    def should_skip_algo(self):
        return len(self.layers_to_create()) == 0

    def get_algo_config(self):
        return self._model_config.matmul_decomposition

    def _run_int(self):
        # Step 1: Create decompose blocks
        decompose = []
        for layer, block in self.layers_to_create():
            decomp = block(layer.full_name, self._model.flow, self._logger)
            decomp.add_correction_block(self._model)
            decomp.update_mo_config(self._model_config)
            decompose.append(decomp)

        # Step 2: Add precision to layers
        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level,
            logger=self._logger,
        )
        algo.run()

        # Step 3: Collect stats
        layers_to_handle = list(chain.from_iterable(x.collect_stats_layers() for x in decompose))
        stats_collector = StatsCollector(
            self._model,
            self._model_config,
            self._logger_level,
            self._unbatched_dataset,
            layers_to_handle=layers_to_handle,
            logger=self._logger,
        )
        stats_collector.run()

        # Step 4: Create the precision splits.
        for comp in decompose:
            comp.create_splits(self._model)

    def layers_to_create(self) -> List[Tuple[HailoMatmul, BaseDecomposeBlock]]:
        res = []
        config = self.get_algo_config()
        for layer in map(self._model.layers.get, self._model.flow.toposort()):
            if self._check_suppoerted_layer(layer):
                l_config = config.layers[layer.full_name]
                if l_config.policy in [LayerFeaturePolicy.enabled]:
                    res.append((layer, self.BLOCKS[l_config.precision_mode]))

        return res

    def finalize_global_cfg(self, algo_config):
        for layer in self._model.layers.values():
            if self._check_suppoerted_layer(layer):
                algo_config.layers.setdefault(layer.full_name, LayerMatmulDecompositionConfig.get_default())

    def _get_valid_layer_cfg(self, lname, cfg):
        if not self._check_suppoerted_layer(self._model.layers[lname]):
            cfg = {}
        return cfg

    def _check_suppoerted_layer(self, layer: HailoMatmul) -> bool:
        return type(layer) in SUPPORTED_LAYERS and not layer.transpose_matmul_input
