import logging
from enum import Enum

import numpy as np
from pydantic.v1 import BaseModel, Field

from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerNegExponentConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import NegExponentConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    LayerFeaturePolicy,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import NegativeSlopeExponentNonFixable
from hailo_model_optimization.algorithms.algorithm_base import AlgoResults
from hailo_model_optimization.algorithms.neg_exponent_fixer.layer_splitter import LayerSplitter
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class Solutions(Enum):
    LAYER_SPLIT = "LAYER_SPLIT"
    SCALE_FACTOR = "SCALE_FACTOR"
    OFFSET_REMOVAL = "OFFSET_REMOVAL"


DEAD_LAYER_TH = 1e-3


class LayerFix(BaseModel):
    layer_name: str
    shift_need: int = Field(0, ge=0, description="The shift needed to fix the negative exponent")
    shift_reached: int = Field(0, ge=0, description="The shift reached by the layer")
    needs_fix: bool = Field(False, description="If the layer needs to be fixed")
    solution: Solutions = Field(Solutions.SCALE_FACTOR, description="The solution used to fix the layer")
    scale_shift: int = Field(0, ge=0, description="The shift needed to fix the negative exponent")
    exp_shift: int = Field(0, ge=0, description="The shift needed to fix the negative exponent")
    offset_shift: int = Field(0, ge=0, description="The shift needed to fix the negative exponent")


class NegativeExponentResults(AlgoResults):
    layer_fix: LayerFix = Field(description="The layer fix result")


class NegExponentFixer(OptimizationAlgorithm):
    """
    This algorithm search fix the Issue of negative exponent in the activation layer,
    by either increasing the scale factor or splitting the layer.
    """

    SPLIT_PRECISIONS = {PrecisionMode.a8_w8_a16, PrecisionMode.a8_w4_a16, PrecisionMode.a16_w16_a16}
    _config: NegExponentConfig
    place_holder: LayerFix

    def __init__(
        self,
        model: HailoModel,
        model_config: ModelOptimizationConfig,
        lname: str,
        logger_level=logging.INFO,
        **kwargs,
    ):
        super().__init__(model, model_config, "Negative Exponent Fix", logger_level=logger_level, **kwargs)
        self._results = NegativeExponentResults(name=self.__class__.__name__, layer_fix=LayerFix(layer_name=lname))
        self._config = self.get_algo_config()
        self._l_name = lname
        self.layer = model.layers[lname]

    #################### Algorithm Flow Control ####################
    # region Flow control

    def _run_int(self):
        l_fix = LayerFix(layer_name=self.layer.full_name, shift_need=0, shift_reached=0)
        l_fix = self.check_needs(l_fix)
        if l_fix.needs_fix:
            l_fix = self.fix_output(l_fix)
        self._results.layer_fix = l_fix

    def check_needs(self, l_fix: LayerFix) -> LayerFix:
        layer = self._model.layers[l_fix.layer_name]
        if layer.activation_atomic_op:
            act_op = layer.activation_atomic_op
            l_fix.shift_reached = l_fix.shift_need
            assigned_exp = act_op.get_assigned_exponent()
            shift_fix_offset = act_op.get_offset_needed_shift()

            shift_fix_exp = np.max(-assigned_exp)
            shift_fix = np.max([shift_fix_offset, shift_fix_exp])
            if shift_fix > 0:
                l_fix.needs_fix = True
                l_fix.shift_need = shift_fix
                l_fix.exp_shift = np.max([0, shift_fix_exp])
                l_fix.offset_shift = shift_fix_offset
            else:
                l_fix.needs_fix = False

        return l_fix

    def fix_output(self, l_fix: LayerFix) -> LayerFix:
        """Choose a Solution based on Configuration and Hw support"""
        layer = self._model.layers[l_fix.layer_name]
        layer_config = self._config.layers[l_fix.layer_name]

        # Split Conditions
        # We paid one bit on the offset when we do the split
        if l_fix.offset_shift > 0 and self._check_feature(layer_config.auto_remove_offset):
            l_fix = self._remove_offset_solution(l_fix)
            if not l_fix.needs_fix:
                return l_fix

        virtual_gain = l_fix.exp_shift - (l_fix.offset_shift + 1)
        if (
            min(l_fix.shift_need, virtual_gain) >= layer_config.split_threshold
            and (set(layer.SUPPORTED_PRECISION_MODE) & self.SPLIT_PRECISIONS)
            and layer_config.rank > 0
            and self._check_special_cases(l_fix)
        ):
            l_fix = self._split_layer_solution(l_fix)

        else:
            l_fix = self._increase_scale_factor(l_fix)
        self._log_negative_exponent_shift(l_fix)
        return l_fix

    # endregion
    #################### Solutions ####################
    # region Solutions

    def _increase_scale_factor(self, l_fix: LayerFix) -> LayerFix:
        l_fix.solution = Solutions.SCALE_FACTOR
        layer = self._model.layers[l_fix.layer_name]
        layer.update_negative_slope_exponent_shift(l_fix.shift_need)
        l_fix.scale_shift += l_fix.shift_need
        l_fix = self.check_needs(l_fix)
        return l_fix

    def _split_layer_solution(self, l_fix: LayerFix) -> LayerFix:
        l_fix.solution = Solutions.LAYER_SPLIT
        layer = self._model.layers[l_fix.layer_name]
        cfg = self._config.layers[l_fix.layer_name]
        splitter = LayerSplitter(self._model, self._model_config, self._logger)

        a_layer, b_layer = splitter.split_layer(layer, auto_clip=self._check_feature(cfg.auto_clip), rank=cfg.rank)
        c_l_fix = self.check_needs(LayerFix(layer_name=a_layer))
        b_l_fix = self.check_needs(LayerFix(layer_name=b_layer))

        if c_l_fix.needs_fix:
            if layer.activation_atomic_op.assertion_negative_slope():
                # Required fix shift is more than the output bits (will zero out the results)
                raise NegativeSlopeExponentNonFixable(
                    output_bits=15, fix_shift=l_fix.scale_shift, lname=layer.full_name
                )

        if b_l_fix.needs_fix:
            b_l_fix = self._increase_scale_factor(b_l_fix)
        return b_l_fix

    def _remove_offset_solution(self, l_fix: LayerFix) -> LayerFix:
        l_fix.solution = Solutions.OFFSET_REMOVAL
        layer = self._model.layers[l_fix.layer_name]
        act_op = layer.activation_atomic_op
        act_op.remove_offsets = True
        act_op.create_hw_params(act_op.input_scale)
        act_op.enforce_encoding()
        l_fix.offset_shift = 0
        l_fix = self.check_needs(l_fix)

        return l_fix

    # endregion
    #################### Helper Functions ####################
    # region Helper Functions

    def _check_feature(self, feature: LayerFeaturePolicy):
        return feature in [LayerFeaturePolicy.allowed, LayerFeaturePolicy.enabled]

    def _check_special_cases(self, l_fix: LayerFix) -> bool:
        """Flow control for special cases that are not supported to split mix precision
        Function returns false if a layer is not supported for split mix precision.
        Leaving this open so in the future we have other Special Cases
        """
        # if force shift is given we cant split layer
        force_shift = self._model_config.translation_config.layers[l_fix.layer_name].force_shift
        if force_shift is not None:
            return False

        if self._check_same_layer_solution(l_fix):
            return False

        return True

    def _check_same_layer_solution(self, l_fix: LayerFix) -> bool:
        """
        This checks if the layer created will be the same as the original Layer
        """
        layer = self._model.layers[l_fix.layer_name]
        precision = self._model_config.precision_config.layers[l_fix.layer_name].precision_mode

        if (
            layer.activation_atomic_op
            and layer.activation_atomic_op.act_name == ActivationType.LINEAR
            and precision in self.SPLIT_PRECISIONS
        ):
            # The split will not do anything.
            return True

        return False

    # endregion
    #################### Algorithm Base Methods ####################
    # region Abstract Methods

    def _log_negative_exponent_shift(self, l_fix: LayerFix):
        layer = self._model.layers[l_fix.layer_name]
        output_bits = layer.get_output_lossy_elements()[0].bits
        limvals_range = layer.get_output_limvals()[0][1] - layer.get_output_limvals()[0][0]
        log_msg = f"Reducing output bits of {l_fix.layer_name} by {l_fix.scale_shift} bits"
        if l_fix.scale_shift <= output_bits / 2:
            # Required fix shift is less than half of the bits
            self._logger.verbose(log_msg)
        elif l_fix.scale_shift < output_bits:
            # Required fix shift is more than half of the bits
            self._logger.warning(f"{log_msg} (More than half)")
        elif layer.activation_atomic_op.assertion_negative_slope() and limvals_range > DEAD_LAYER_TH:
            # Required fix shift is more than the output bits (will zero out the results)
            raise NegativeSlopeExponentNonFixable(
                output_bits=output_bits,
                fix_shift=l_fix.scale_shift,
                lname=layer.full_name,
            )

    def _setup(self):
        super()._setup()

    def should_skip_algo(self) -> bool:
        """
        Here we decide whether to skip the algorithm base on the algorithm configuration
        """
        return False

    def get_algo_config(self):
        return self._model_config.negative_exponent

    def finalize_global_cfg(self, algo_config):
        """
        Finalize the algorithm's config. (values that are not layer specific)
        Can include values verification, fetching data from the other algo's config, etc...
        """
        if self.layer.full_name not in algo_config.layers:
            algo_config.layers[self.layer.full_name] = LayerNegExponentConfig()
        super().finalize_global_cfg(algo_config)

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg
