from hailo_model_optimization.acceleras.utils.acceleras_definitions import PrecisionMode
from hailo_sdk_common.exceptions.exceptions import SDKException
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType, ResizeMethod

DEFAULT_OUTPUT_ACTIVATION_BITS = 8
DEFAULT_X_POINTS_MAX_VALUE = 1e8
DEFAULT_ACCUMULATOR_SIZE = 16
SHIFT_INT32_BIAS = 11
DEFAULT_APU_MODE = 8
DEFAULT_EBIAS = 7  # SHIFT_INT_BIAS in HW's APU.
W_INT32_BIAS = 19
DEFAULT_BETA = 3
W_BIAS = 13
MAX_ALLOWED_NEGATIVE_SLOPE = 4
MAX_ALLOWED_OVERFLOW_OFFSETS = 4


class InterLayerPrecisionMode:
    def __init__(
        self,
        accumulator_size=DEFAULT_ACCUMULATOR_SIZE,
        apu_mode=DEFAULT_APU_MODE,
        zero_point_weights=False,
        use_4bit_weights=False,
        exponential_mode_4bit_weights=False,
    ):
        self._accumulator_size = accumulator_size
        self._apu_mode = apu_mode
        self._zero_point_weights = zero_point_weights
        self._use_4bit_weights = use_4bit_weights
        self._exponential_mode_4bit_weights = exponential_mode_4bit_weights

    def __repr__(self):
        if self._zero_point_weights:
            description = " (zero point weights)"
        elif self._use_4bit_weights:
            description = " (4-bit weights)"
        elif self._exponential_mode_4bit_weights:
            description = " (4-bit weights exponential mode)"
        else:
            description = ""
        return f"<< InterLayerPrecisionMode {self._accumulator_size}->{self._apu_mode}{description} >>"

    def is_mode(self, accumulator_size, apu_mode):
        return self._accumulator_size == accumulator_size and self._apu_mode == apu_mode

    @property
    def accumulator_size(self):
        return self._accumulator_size

    @property
    def apu_mode(self):
        return self._apu_mode

    @property
    def use_4bit_weights(self):
        return self._use_4bit_weights

    @property
    def exponential_mode_4bit_weights(self):
        return self._exponential_mode_4bit_weights

    @property
    def weight_bits(self):
        if self._use_4bit_weights and not self._exponential_mode_4bit_weights:
            return 4
        return self._accumulator_size / 2 - int(self._zero_point_weights)

    @property
    def beta(self):
        return 10 if self.is_mode(32, 8) else DEFAULT_BETA

    @property
    def max_allowed_negative_slope(self):
        # when we are allowing to do retry on negative slope, we acctualy allow some bits to be taken from activation.
        # Hence, if there is more bits in the apu, we may allow more bits to be taken from it.
        return MAX_ALLOWED_NEGATIVE_SLOPE * 2 if self._apu_mode == 16 else MAX_ALLOWED_NEGATIVE_SLOPE

    @property
    def ebias(self):
        return DEFAULT_EBIAS if self.is_mode(16, 8) else SHIFT_INT32_BIAS

    @property
    def shifter_bias_max_value(self):
        return 2 ** (W_BIAS - 1) - 1 if self.is_mode(16, 8) else 2**W_INT32_BIAS - 1

    @property
    def activation_computation_max_value(self):
        return 2 ** (self._accumulator_size - 1) - 1

    @property
    def input_activation_bits(self):
        return 15 if self._accumulator_size == 32 else 8

    @property
    def x_points_mask_max_value(self):
        return 1e16 if self._accumulator_size == 32 else DEFAULT_X_POINTS_MAX_VALUE

    @property
    def output_activation_bits(self):
        return 15 if self._apu_mode == 16 else DEFAULT_OUTPUT_ACTIVATION_BITS

    @property
    def zero_point_weights(self):
        return self._zero_point_weights

    @classmethod
    def _is_precision_mode_forwarder(cls, layer):
        if layer.op == LayerType.resize and layer._method != ResizeMethod.bilinear:
            return True
        return layer.op in [
            LayerType.reduce_max,
            LayerType.concat,
            LayerType.shortcut,
            LayerType.external_pad,
            LayerType.depth_to_space,
            LayerType.space_to_depth,
            LayerType.slice,
            LayerType.feature_shuffle,
            LayerType.feature_splitter,
            LayerType.row_splitter,
            LayerType.width_splitter,
            LayerType.format_conversion,
        ]

    @classmethod
    def _get_successor_precision(cls, layer, hailo_nn):
        successors = list(hailo_nn.successors(layer))
        successors_num = len(successors)
        curr_precision = layer.precision_config.precision_mode
        if successors_num == 0:
            return curr_precision
        successors_precisions = [curr_precision] * successors_num
        for idx in range(successors_num):
            succ = successors[idx]
            if cls._is_precision_mode_forwarder(succ):
                successors_precisions[idx] = cls._get_successor_precision(succ, hailo_nn)
            else:
                successors_precisions[idx] = succ.precision_config.precision_mode

        if len({cls._precision_mode_act_size(precision) for precision in successors_precisions}) != 1:
            raise SDKException(f"Successors of layer {layer.name} have different precision modes")
        return successors_precisions[0]

    @classmethod
    def _get_current_precision(cls, layer, hailo_nn):
        predecessors = list(hailo_nn.predecessors(layer))
        predecessors_num = len(predecessors)
        curr_precision = layer.precision_config.precision_mode
        if predecessors_num == 0 or not cls._is_precision_mode_forwarder(layer):
            return curr_precision
        predecessors_precisions = [curr_precision] * predecessors_num
        for idx in range(predecessors_num):
            pred = predecessors[idx]
            if cls._is_precision_mode_forwarder(pred):
                predecessors_precisions[idx] = cls._get_predecessors_precision(pred, hailo_nn)
            else:
                predecessors_precisions[idx] = pred.precision_config.precision_mode

        if len({cls._precision_mode_act_size(precision) for precision in predecessors_precisions}) != 1:
            raise SDKException(f"Predecessors of layer {layer.name} have different precision modes")
        return predecessors_precisions[0]

    @classmethod
    def from_hailo_nn(cls, hailo_nn, layer_name):
        layer = hailo_nn.get_layer_by_name(layer_name)
        current_precision_mode = cls._get_current_precision(layer, hailo_nn)
        successor_precision_mode = cls._get_successor_precision(layer, hailo_nn)

        return cls.from_precision_mode(
            current_precision_mode,
            successor_precision_mode,
        )

    @classmethod
    def from_precision_mode(cls, current_precision_mode, successor_precision_mode):
        apu_mode = cls._precision_mode_act_size(successor_precision_mode)
        in_bits = cls._precision_mode_act_size(current_precision_mode)

        use_4bit_weights = (
            False
            if current_precision_mode is None
            else current_precision_mode.reduce() in [PrecisionMode.a8_w4, PrecisionMode.a8_w4_exp]
        )
        zero_point_weights = (
            False
            if current_precision_mode is None
            else current_precision_mode.reduce() in [PrecisionMode.a16_w16, PrecisionMode.a16_w16_a16]
        )
        exponential_mode_4bit_weights = (
            False if current_precision_mode is None else current_precision_mode.reduce() == PrecisionMode.a8_w4_exp
        )

        return cls(
            accumulator_size=2 * in_bits,
            apu_mode=apu_mode,
            zero_point_weights=zero_point_weights,
            use_4bit_weights=use_4bit_weights,
            exponential_mode_4bit_weights=exponential_mode_4bit_weights,
        )

    @classmethod
    def _precision_mode_act_size(cls, precision_mode):
        return (
            16
            if precision_mode
            in [
                PrecisionMode.a16_w16,
                PrecisionMode.a16_w16_a16,
                PrecisionMode.a16_w16_a8,
                PrecisionMode.a16_w16_non_zero,
            ]
            else 8
        )
