from copy import deepcopy
from typing import List, Tuple

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult import HailoElementwiseMult
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult_on_mac import HailoElementwiseMultOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum import HailoReduceSum
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum_a16_pre_act_sum import HailoReduceSumA16PreActSum
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import PrecisionConfig
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class SwitchLayersByPrecision(OptimizationAlgorithm):
    """
    switch 16bits layer to 16bit decompose
    """

    REPLACEMENT_MAP = {
        HailoElementwiseMult: (HailoElementwiseMultOnMac,),
        HailoReduceSum: (HailoReduceSumA16PreActSum,),
    }

    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        **kwargs,
    ):
        super().__init__(model, model_config, logger_level=logger_level, name="Switch Layers", **kwargs)
        self._config: PrecisionConfig = None

    def _setup(self):
        pass

    def should_skip_algo(self):
        return len(self._layers_to_switch()) == 0

    def get_algo_config(self) -> PrecisionConfig:
        if self._config is None:
            # This is important to make sure we are not changing the original config
            self._config = deepcopy(self._model_config.precision_config)
        return self._config

    def _run_int(self):
        layers_to_switch = self._layers_to_switch()
        self._switch_layer(layers_to_switch)

    def _layers_to_switch(self) -> List[str]:
        res = []
        press_config = self.get_algo_config()
        for layer in map(self._model.layers.get, self._model.flow.toposort()):
            if type(layer) in self.REPLACEMENT_MAP:
                precision_config = press_config.layers.get(layer.full_name, layer.get_default_precision_config())
                precision_mode = precision_config.precision_mode
                # Taking Care of the know layer
                if precision_mode not in layer.SUPPORTED_PRECISION_MODE:
                    # Check if there is a repleacement for the layer
                    for replacement in self.REPLACEMENT_MAP[type(layer)]:
                        dummy_replacement: BaseHailoLayer = replacement("")
                        if (
                            precision_mode in dummy_replacement.SUPPORTED_PRECISION_MODE
                            and dummy_replacement.is_supported_by_hw(self.optimization_target, precision_config)
                        ):
                            res.append((layer, replacement))
                            break
        return res

    def _switch_layer(self, layers_to_switch: List[Tuple[BaseHailoLayer, BaseHailoLayer]]) -> None:
        for layer, replacement in layers_to_switch:
            weights = layer.export_weights()
            new_layer = replacement.from_hn(layer.full_name, layer.to_hn(), self._logger)
            new_layer.import_weights(weights)
            self._model.replace_layer(new_layer, layer)

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg
