from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    MatmulCorrectionType,
)
from hailo_model_optimization.algorithms.matmul_correction.correction_blocks import (
    MMCorrectionBlock,
    MMCorrectionBlock2,
    MMCorrectionBlock3,
    MMCorrectionWeights,
)
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class MatmulCorrection(OptimizationAlgorithm):
    """
    In the case a matmul layer has dynamic weights and non-trivial incoming
    zp, we need to add a correction element to the conv layer, so we compensate
    for the zp in the matmul layer. This correction term is added as an additional
    row to the weights input. So here we append the term and change the relevant
    input size of the matmul layer in the hn file.
    """

    def __init__(self, model: HailoModel, model_config, logger_level, addition, **kwargs):
        super().__init__(model, model_config, "Matmul Correction", logger_level, **kwargs)
        self._addition = addition
        self._structural_change = False
        self._used_zp_comp_block = False
        self.mmcb = MMCorrectionBlock(model, model_config, self._logger, addition)
        self.mmcb2 = MMCorrectionBlock2(model, model_config, self._logger, addition)
        self.mmcb3 = MMCorrectionBlock3(model, model_config, self._logger, addition)
        self.mmcw = MMCorrectionWeights(model, model_config, self._logger, addition)

    def _get_valid_layer_cfg(self, lname, cfg):
        if lname in self._model.layers and cfg["correction_type"].upper() in MatmulCorrectionType._member_names_:
            return cfg
        return None

    def get_algo_config(self):
        return self._model_config.matmul_correction  # self._model_config.globals

    def finalize_global_cfg(self, algo_config):
        pass

    def should_skip_algo(self):
        return not any(isinstance(layer, HailoMatmul) for layer in self._model.layers.values())

    def _setup(self):
        pass

    def _run_int(self):
        self._matmul_correction()
        if self._structural_change:
            input_shapes = [(None,) + shape for shape in self._model.get_input_shapes()]
            self._model.compute_output_shape(input_shapes)
        if self._used_zp_comp_block:
            # call precsion config
            algo = CreateMixedPrecision(
                model=self._model,
                model_config=self._model_config,
                logger_level=self._logger_level,
                logger=self._logger,
            )
            algo.run()

    def get_correction_type(self, mm_lname: str, pred_1_lname: str):
        if mm_lname in self._model_config.matmul_correction.layers:
            c_type = self._model_config.matmul_correction.layers[mm_lname].correction_type
            self._logger.debug(f"Correction type defined for {mm_lname}. is {c_type.upper()}")
            return MatmulCorrectionType[c_type.upper()]
        else:
            # check if predecessor is conv layer
            if isinstance(self._model.layers[pred_1_lname], HailoConv):
                self._logger.debug(f"Correction type not defined for {mm_lname}. defaulting to ZP_COMP_WEIGHTS")
                return MatmulCorrectionType.ZP_COMP_WEIGHTS
            self._logger.debug(
                f"Correction type not defined for {mm_lname}. pred1 type is not conv! defaulting to ZP_COMP_BLOCK"
            )
            return MatmulCorrectionType.ZP_COMP_BLOCK

    def _matmul_correction(self):
        for lname in self._model.flow.toposort():
            layer = self._model.layers[lname]
            if isinstance(layer, HailoMatmul) and layer.transpose_matmul_input:
                matmul_preds = list(self._model.flow.predecessors_sorted(lname))
                if self._addition:
                    self._add_correction(lname, layer, matmul_preds[1])
                else:
                    conv0 = self._model.layers[matmul_preds[0]]
                    if all(values[0] >= 0 for values in conv0.get_output_limvals()):
                        self._remove_correction(layer, matmul_preds[1])

    def _add_correction(self, mm_lname: str, matmul: HailoMatmul, conv_lname):
        correction_type = self.get_correction_type(mm_lname, conv_lname)
        self._logger.debug(f"Adding correction {correction_type.value} to {mm_lname}")

        if correction_type == MatmulCorrectionType.ZP_COMP_WEIGHTS:
            self.mmcw._add_correction_weights(matmul, conv_lname)
            matmul.zp_correction_type = MatmulCorrectionType.ZP_COMP_WEIGHTS

        elif correction_type == MatmulCorrectionType.ZP_COMP_BLOCK:
            self.mmcb.add_correction_block(matmul.full_name)
            matmul.zp_correction_type = MatmulCorrectionType.ZP_COMP_BLOCK
            self._used_zp_comp_block = True

        elif correction_type == MatmulCorrectionType.ZP_COMP_BLOCK_2:
            self.mmcb2.add_correction_block(matmul.full_name)
            matmul.zp_correction_type = MatmulCorrectionType.ZP_COMP_BLOCK_2
            self._used_zp_comp_block = True

        elif correction_type == MatmulCorrectionType.ZP_COMP_BLOCK_3:
            self.mmcb3.add_correction_block(matmul.full_name)
            matmul.zp_correction_type = MatmulCorrectionType.ZP_COMP_BLOCK_3
            self._used_zp_comp_block = True

        elif correction_type == MatmulCorrectionType.ZP_COMP_NONE:
            matmul.zp_correction_type = MatmulCorrectionType.ZP_COMP_NONE
            return  # no correction needed
        elif correction_type == MatmulCorrectionType.ZP_COMP:
            # This is for legacy someelse did the correction.
            matmul.zp_correction_type = MatmulCorrectionType.ZP_COMP
            return
        else:
            raise ValueError(f"Unsupported correction type: {correction_type}")
        matmul.zp_comp_added = self._addition
        self._structural_change = True

    def _remove_correction(self, matmul: HailoMatmul, transpose_pred):
        correction_type = matmul.zp_correction_type
        self._logger.debug(f"Removing correction {correction_type} from {matmul.full_name}")
        if correction_type == MatmulCorrectionType.ZP_COMP_WEIGHTS:
            self.mmcw._remove_correction(matmul, transpose_pred)
        elif correction_type == MatmulCorrectionType.ZP_COMP_BLOCK:
            self.mmcb._remove_correction(matmul, transpose_pred)

        elif self.get_correction_type == MatmulCorrectionType.ZP_COMP_BLOCK_2:
            self.mmcb2._remove_correction(matmul, transpose_pred)

        elif self.get_correction_type == MatmulCorrectionType.ZP_COMP_BLOCK_3:
            self.mmcb3._remove_correction(matmul, transpose_pred)

        elif correction_type == MatmulCorrectionType.ZP_COMP_NONE:
            return
        else:
            raise ValueError(f"Unsupported correction type: {correction_type}")

        matmul.zp_comp_added = False
        matmul.zp_correction_type = MatmulCorrectionType.ZP_COMP_NONE
        self._structural_change = True
