from typing import Tuple, Union

import torch

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    APU_EXP_BIAS_BITS,
    APU_EXP_BIAS_BITS_D,
    APU_FINAL_SHIFT,
    APU_FINAL_SHIFT_D,
    APU_MANTISSA_BITS,
    APU_OFFSET_BITS,
    APU_OFFSET_BITS_D,
    ActivationType,
)
from hailo_model_optimization.saitama.framework.apu_modules.apu_base import APUBase
from hailo_model_optimization.saitama.framework.common.fake_quant import (
    QuantWeight,
    StaticFakeQuant,
)
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    APUPrecisionConfig,
    Encoding,
    QType,
)
from hailo_model_optimization.saitama.framework.common.utils import init_encoding, qtype_to_range

MAX_PIECES = 9


class QuantSlopes(QuantWeight):
    """
    This class defines the slope of the piecewise-linear approximation of the activation function.
    The slope is calculated based on the exponent and mantissa values.
    """

    mantissas: torch.Tensor
    exponents: torch.Tensor
    q_groups_factor: torch.Tensor

    def __init__(
        self,
        pieces_len: int,
        quant_min: int,
        quant_max: int,
        value: torch.Tensor,
        exponent_bias: int,
        num_groups_in: int,  # e.g. quantization groups
        num_groups_out: int,
        channels: int,
        is_independent_encoding: Union[Tuple[bool, bool], bool] = (True, False),
        requires_grad=False,
        **kwargs,
    ):
        super().__init__(
            quant_min,
            quant_max,
            value,
            num_groups_out,
            channels,
            axis=1,
            is_independent_encoding=is_independent_encoding,
            requires_grad=requires_grad,
        )
        self.initialize_parameters(pieces_len, num_groups_in, value.device, value.dtype)
        self.exponent_bias = exponent_bias

    def initialize_parameters(self, pieces_len, num_groups_in, device=None, dtype=None):
        tensor_config = {"device": device, "dtype": dtype}
        mantissas = torch.empty((pieces_len, num_groups_in), **tensor_config)
        exponents = torch.empty((pieces_len, num_groups_in), **tensor_config)
        q_groups_factor = torch.ones((num_groups_in,), **tensor_config)

        self.register_buffer("mantissas", mantissas)
        self.register_buffer("exponents", exponents)
        self.register_buffer("q_groups_factor", q_groups_factor)

    def get_weight(self):
        if self.frozen:
            return self.weight
        else:
            return self(self.mantissas) / 2 ** (self.exponents + self.exponent_bias)

    def get_weight_repeated(self):
        return self.get_weight().repeat_interleave(self.channels // len(self.q_groups_factor), dim=self.axis)

    def _encode(self, x):
        return super()._encode(x / self.q_groups_factor)

    def _decode(self, x):
        return super()._decode(x) * self.q_groups_factor

    def get_encode_scale(self, ndim: int = None):
        if ndim is not None:
            scale = self.view(self.scale, ndim)
        return scale

    def get_encode_zero_point(self, ndim: int = None):
        if ndim is not None:
            zero_point = self.view(self.zero_point, ndim)
        return zero_point

    def forward_encoding(self, encoding: Encoding, verify_encoding=False, **kwargs) -> Encoding:
        self.q_groups_factor = 1 / encoding.factor_by_group
        scale = encoding.scale_by_group.max()

        if verify_encoding:
            # NOTE: the entire vector here should have the same value.
            scale_by_group = encoding.scale_by_group / encoding.factor_by_group
            assert torch.allclose(scale, scale_by_group)
        output_scale = scale * self.scale
        scale_repeats = self.channels_per_group_scale
        return init_encoding(
            scale_by_group=output_scale,
            scale_repeats=scale_repeats,
            zero_point_by_group=encoding.zero_point_by_group,
            zero_point_repeats=encoding.zero_point_repeats,
            equalization_vector=encoding.equalization_vector,
        )


class QuantOffsets(QuantWeight):
    final_shift: torch.Tensor

    def __init__(
        self,
        quant_min,
        quant_max,
        value,
        num_groups,
        channels,
        *args,
        is_independent_encoding=False,
        requires_grad=False,
        **kwargs,
    ):
        super().__init__(
            quant_min,
            quant_max,
            value,
            num_groups,
            channels,
            axis=0,
            *args,
            is_independent_encoding=is_independent_encoding,
            requires_grad=requires_grad,
            **kwargs,
        )
        final_shift = torch.ones(1, dtype=value.dtype, device=value.device)
        self.register_buffer("final_shift", final_shift)

    def _encode(self, x):
        return self.apply_final_shift(super()._encode(x))

    def _decode(self, x):
        return super()._decode(self.deapply_final_shift(x))

    def apply_final_shift(self, x: torch.Tensor):
        return x.div_(2**self.final_shift)

    def deapply_final_shift(self, x: torch.Tensor):
        return x.mul_(2**self.final_shift)

    def quantize(self, x):
        x = self.deapply_final_shift(x)
        x = super().quantize(x)
        x = self.apply_final_shift(x)
        return x

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        self.scale = self.apply_final_shift(self.to_tensor(encoding.scale_by_group))

        zero_point = self.apply_final_shift(self.to_tensor(self.zero_point))
        zero_point_repeats = self.channels_per_group_zero_point

        return init_encoding(
            scale_by_group=encoding.scale_by_group,
            scale_repeats=encoding.scale_repeats,
            zero_point_by_group=zero_point,
            zero_point_repeats=zero_point_repeats,
            equalization_vector=encoding.equalization_vector,
        )


class QuantThresholds(QuantWeight):
    def get_weight_repeated(self):
        return self.get_weight().repeat_interleave(self.channels // self.num_groups_scale, dim=self.axis)

    def get_encode_scale(self, ndim: int = None):
        if ndim is not None:
            scale = self.view(self.scale, ndim)
        return scale

    def get_encode_zero_point(self, ndim: int = None):
        if ndim is not None:
            zero_point = self.view(self.zero_point, ndim)
        return zero_point

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        self.scale = encoding.scale_by_group
        return encoding


class APUActivation(APUBase):
    thresholds: torch.Tensor
    slopes: QuantSlopes
    offsets: QuantOffsets
    thresholds: QuantThresholds
    output_quantizer: StaticFakeQuant

    def __init__(
        self,
        channels: int,
        activation: ActivationType,
        precision_config: APUPrecisionConfig,
        dtype=None,
        device=None,
    ) -> None:
        super().__init__()
        self.max_pieces = MAX_PIECES
        self.activation_pieces = 1
        self.activation = activation
        self.init_precision(precision_config)

        if self.accumulator_qtype.bits == 16 and self.output_qtype.bits == 8:
            self.first_bias = APU_EXP_BIAS_BITS
            self.second_bias = APU_FINAL_SHIFT_D
            self.offsets_bits = APU_OFFSET_BITS
            self.shift_data = 0
        elif self.accumulator_qtype.bits == 32 and self.output_qtype.bits == 8:
            self.first_bias = APU_EXP_BIAS_BITS_D
            self.second_bias = APU_FINAL_SHIFT
            self.offsets_bits = APU_OFFSET_BITS_D
            self.shift_data = 0
        elif self.accumulator_qtype.bits == 32 and self.output_qtype.bits == 15:
            self.first_bias = APU_EXP_BIAS_BITS_D
            self.second_bias = APU_FINAL_SHIFT_D
            self.offsets_bits = APU_OFFSET_BITS_D
            self.shift_data = 0
        elif self.accumulator_qtype.bits == 16 and self.output_qtype.bits == 15:
            self.first_bias = APU_EXP_BIAS_BITS_D
            self.second_bias = APU_FINAL_SHIFT_D
            self.offsets_bits = APU_OFFSET_BITS_D
            self.shift_data = 8
        else:
            raise ValueError("Unsupported accumulator and output qtypes")
        self.initialize_parameters(channels, dtype, device)
        self.initialize_output_quantizer(channels, dtype, device)

    def initialize_output_quantizer(self, channels, dtype=None, device=None):
        quant_min, quant_max = qtype_to_range(self.output_qtype)
        self.output_quantizer = StaticFakeQuant(
            quant_min=quant_min,
            quant_max=quant_max,
            num_groups=channels,
            channels=channels,
            axis=1,
            is_independent_encoding=False,
            dtype=dtype,
            device=device,
        )

    def initialize_parameters(self, channels, dtype=None, device=None):
        # NOTE: assumes broader range, if we want explicit range for hailo8 - add a fix here
        mantissa_qtype = QType(bits=APU_MANTISSA_BITS + 1, signed=True)
        quant_min, quant_max = qtype_to_range(mantissa_qtype)
        weight = torch.ones((self.max_pieces, self.quantization_groups), device=device, dtype=dtype)  # slopes
        num_groups_out = 1

        self.slopes = QuantSlopes(
            self.max_pieces,
            quant_min=quant_min,
            quant_max=quant_max,
            value=weight,
            exponent_bias=self.first_bias + self.second_bias - self.shift_data,
            num_groups_in=self.quantization_groups,
            num_groups_out=num_groups_out,
            channels=channels,
        )

        offsets = torch.zeros((self.max_pieces,), dtype=dtype, device=device)
        offsets_qtype = QType(bits=self.offsets_bits, signed=True)
        quant_min, quant_max = qtype_to_range(offsets_qtype)
        self.offsets = QuantOffsets(
            quant_min=quant_min,
            quant_max=quant_max,
            value=offsets,
            num_groups=num_groups_out,
            channels=channels,
        )

        thresholds = torch.empty(
            (
                self.max_pieces - 1,
                self.quantization_groups,
            ),
            dtype=dtype,
            device=device,
        )
        quant_min, quant_max = qtype_to_range(self.accumulator_qtype)
        self.thresholds = QuantThresholds(
            quant_min=quant_min,
            quant_max=quant_max,
            value=thresholds,
            axis=1,
            num_groups=self.quantization_groups,
            channels=channels,
        )

    def init_precision(self, precision_config: APUPrecisionConfig):
        self.accumulator_qtype = precision_config.accumulator_qtype
        self.output_qtype = precision_config.output_qtype
        self.quantization_groups = precision_config.quantization_groups

    def forward(self, x, **kwargs):
        """
        simulates the APU HW - piecewise-linear approximation
        """
        result = self._forward_pwla(x)
        return self.output_quantizer(result)

    def _forward_pwla(self, x, **kwargs):
        thresholds = self.thresholds.get_weight_repeated()
        slopes = self.slopes.get_weight_repeated()
        offsets = self.offsets.get_weight()
        slopes = slopes.view(slopes.shape[0], 1, slopes.shape[1], 1, 1)
        thresholds = thresholds.view(thresholds.shape[0], 1, thresholds.shape[1], 1, 1)
        result = x * slopes[0] + offsets[0]
        for i in range(1, self.activation_pieces):
            result = torch.where(torch.ge(x, thresholds[i - 1]), x * slopes[i] + offsets[i], result)
        return result

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        encoding = self.thresholds.forward_encoding(encoding, **kwargs)
        encoding = self.slopes.forward_encoding(encoding, **kwargs)
        encoding = self.offsets.forward_encoding(encoding, **kwargs)
        encoding = self.output_quantizer.forward_encoding(encoding, **kwargs)
        return encoding

    def get_extra_state(self):
        return {
            "activation_pieces": self.activation_pieces,
        }

    def set_extra_state(self, state):
        self.activation_pieces = int(state["activation_pieces"])
