from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_decompose import HailoConvDecompose
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_decompose_pluto import HailoConvDecomposePluto
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerFeaturePolicy, OptimizationTarget
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm

# supported layer must include neg_weights function
SUPPORTED_LAYERS = [HailoConv]


class SwitchDecompose16Bits(OptimizationAlgorithm):
    """
    switch 16bits layer to 16bit decompose
    """

    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        **kwargs,
    ):
        super().__init__(model, model_config, logger_level=logger_level, name="Switch 16bits Decompose", **kwargs)

    def _setup(self):
        pass

    def should_skip_algo(self):
        return len(self.layers_to_switch) == 0

    def get_algo_config(self):
        return self._model_config.layer_decomposition

    def _run_int(self):
        for layer in self.layers_to_switch:
            weights = layer.export_weights()
            hn = layer.to_hn()
            if OptimizationTarget.PLUTO == self._model.optimization_target:
                new_layer = HailoConvDecomposePluto.from_hn(layer.full_name, hn, layer._logger)
            else:
                new_layer = HailoConvDecompose.from_hn(layer.full_name, hn, layer._logger)
            new_layer.import_weights(weights)
            self._model.replace_layer(new_layer, layer)

    @property
    def layers_to_switch(self):
        res = []
        algo_cfg = self.get_algo_config()
        for layer in self._model.flow.toposort():
            acceleras_layer = self._model.layers[layer]
            if type(acceleras_layer) in SUPPORTED_LAYERS:
                if layer in algo_cfg.layers and algo_cfg.layers[layer].policy == LayerFeaturePolicy.enabled:
                    res.append(acceleras_layer)
        return res

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        if type(self._model.layers[lname]) not in SUPPORTED_LAYERS:
            cfg = {}
        return cfg
