from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_decompose import HailoConvDecompose
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_decompose_pluto import HailoConvDecomposePluto
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector

# supported layer must include neg_weights function
SUPPORTED_LAYERS = [HailoConvDecompose, HailoConvDecomposePluto]


class CreateDecompose16Bits(OptimizationAlgorithm):
    """
    switch 16bits layer to 16bit decompose
    """

    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        dataset,
        **kwargs,
    ):
        super().__init__(model, model_config, logger_level=logger_level, name="create_decompose_16bits", **kwargs)
        self._unbatched_dataset = dataset

    def _setup(self):
        pass

    def should_skip_algo(self):
        return len(self.layers_to_create) == 0

    def get_algo_config(self):
        return self._model_config.layer_decomposition

    def _run_int(self):
        for layer in self.layers_to_create:
            layer.create_splits()
        stats_collector = StatsCollector(
            self._model,
            self._model_config,
            self._logger_level,
            self._unbatched_dataset,
            logger=self._logger,
        )
        stats_collector.run()
        algo_cfg = self.get_algo_config()
        for layer in self.layers_to_create:
            layer.create_weight_split(
                algo_cfg.layers[layer.full_name],
                self.optimization_target,
            )
        stats_collector.run()

    @property
    def layers_to_create(self):
        res = []
        for layer in self._model.flow.toposort():
            acceleras_layer = self._model.layers[layer]
            if type(acceleras_layer) in SUPPORTED_LAYERS:
                res.append(acceleras_layer)
        return res

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        if type(self._model.layers[lname]) not in SUPPORTED_LAYERS:
            cfg = {}
        return cfg
