from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm

ACTIVATION_CORRECTION_DICT = {
    ActivationType.INV_POS: "inverse_act_factor",
    ActivationType.INV_SQRT: "inverse_act_factor",
}

# supported layer must include neg_weights function
SUPPORTED_LAYERS = [HailoStandaloneActivation]


class ApuNegMantissaCorrection(OptimizationAlgorithm):
    """
    Fix negative slope at the APU. Relevant only for Hailo8 where the mantissa of the slope
    in the APU is represented by uint10 and therefore can't represent negative slopes.
    If the algorithm detect monotonic decreasing activation function if flip to
    monotonic increasing function and multiply all the weights by -1
    Note that the activation function must be monotonic
    """

    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        **kwargs,
    ):
        super().__init__(model, model_config, logger_level=logger_level, name="Mantissa Correction", **kwargs)

    def _setup(self):
        pass

    def should_skip_algo(self):
        return len(self.layers_to_correct) == 0

    def get_algo_config(self):
        return self._model_config.globals

    def _run_int(self):
        for layer in self.layers_to_correct:
            self.correct_layer_neg_mantissa(layer)

    @classmethod
    def correct_layer_neg_mantissa(cls, layer):
        layer.neg_weights()
        key = ACTIVATION_CORRECTION_DICT[layer.get_activation_name()]
        layer.activation_atomic_op.act_native_params[key] *= -1
        layer.activation_atomic_op.act_numeric_params[key] *= -1

    @property
    def layers_to_correct(self):
        res = []
        for layer in self._model.flow.toposort():
            acceleras_layer = self._model.layers[layer]
            if self.should_correct_layer(acceleras_layer):
                res.append(acceleras_layer)
        return res

    @staticmethod
    def should_correct_layer(layer):
        return (
            not isinstance(layer, BaseHailoNonNNCoreLayer) and layer.get_activation_name() in ACTIVATION_CORRECTION_DICT
        )

    def finalize_global_cfg(self, algo_config):
        pass
