from abc import ABC, abstractmethod
from typing import Tuple, Union

import torch
import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding
from hailo_model_optimization.saitama.framework.common.utils import _pair, init_encoding


class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # NOTE: in place logic to avoid memory allocation with large tensors
        return x.round()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


def ste_round(x):
    if torch.is_grad_enabled():
        return STEFunction.apply(x)
    else:
        return x.round_()


class BaseQuant(nn.Module, ABC):
    _quant_enabled: bool
    _is_encoded: bool

    @abstractmethod
    def set_encoded(self, value: bool): ...

    @abstractmethod
    def forward_encoding(self, *enc: Encoding, **kwargs) -> Encoding: ...

    @abstractmethod
    def enable_quantization(self, value: bool): ...

    def is_encoded(self):
        return self._is_encoded


class StaticFakeQuant(BaseQuant):
    """
    A PyTorch module for static fake quantization using external encoding. This module utilizes
    torch.fake_quant functions and supports online encoding in the future. It expects data in
    the native scale and returns it in the native scale after quantization.
    """

    scale: torch.Tensor
    zero_point: torch.Tensor

    def __init__(
        self,
        quant_min: int,
        quant_max: int,
        channels: int,
        num_groups: Union[int, Tuple[int, int]],
        axis: int,
        *args,
        quant_enabled: bool = True,
        is_independent_encoding: Union[Tuple[bool, bool], bool] = False,
        dtype=None,
        device=None,
        **kwargs,
    ):
        """
        Initializes the FakeQuant class.
        Args:
            quant_min (int): Minimum quantization value.
            quant_max (int): Maximum quantization value.
            channels (int): Number of channels.
            num_groups (Union[int, Tuple[int, int]]): Number of groups for scale and zero point.
            axis (int): Axis for quantization.
            *args: Additional arguments.
            quant_enabled (bool, optional): Flag to enable quantization. Defaults to True.
            is_independent_encoding (Union[Tuple[bool, bool], bool], optional): Flag to indicate independent encoding for scale and zero point.
                    Defaults to False. Independent encoding will be stored as parameters, otherwise as buffers.
            dtype (optional): Data type for the encoding. Defaults to None.
            device (optional): Device for the encoding. Defaults to None.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(*args, **kwargs)
        self.axis = axis
        self.quant_min = quant_min
        self.quant_max = quant_max
        self.channels = channels
        self.num_groups_scale, self.num_groups_zero_point = _pair(num_groups)
        self.channels_per_group_scale = self.channels // self.num_groups_scale
        self.channels_per_group_zero_point = self.channels // self.num_groups_zero_point
        scale_as_parameter, zp_as_parameter = _pair(is_independent_encoding)
        self._initialize_encoding(scale_as_parameter, zp_as_parameter, dtype=dtype, device=device)
        self._quant_enabled = quant_enabled
        self._is_encoded = False

    def set_encoded(self, is_encoded: bool):
        self._is_encoded = is_encoded

    def _initialize_encoding(self, scale_as_parameter, zp_as_parameter, dtype=None, device=None):
        scale = torch.ones(self.num_groups_scale, dtype=dtype, device=device)
        if scale_as_parameter:
            self.register_parameter("scale", nn.Parameter(scale))
        else:
            self.register_buffer("scale", scale)

        zero_point = torch.zeros(self.num_groups_zero_point, dtype=dtype, device=device)
        if zp_as_parameter:
            self.register_parameter("zero_point", nn.Parameter(zero_point))
        else:
            self.register_buffer("zero_point", zero_point)

    def forward(self, x: torch.Tensor):
        x = self.to_tensor(x)
        if not self._quant_enabled:
            return x
        elif self._is_encoded:
            return self.quantize(x)
        else:
            x = self._encode(x)
            x = self.quantize(x)
            x = self._decode(x)
            return x

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        self.scale = encoding.scale_by_channel
        self.zero_point = encoding.zero_point_by_channel
        return encoding

    def extra_repr(self):
        s1 = "quant_enabled={_quant_enabled}, encoded={_is_encoded}"
        if self._quant_enabled:
            s1 += ", quant_min={quant_min}, quant_max={quant_max}, channels={channels}"
            if self.num_groups_scale == self.num_groups_zero_point:
                s1 += ", num_groups={num_groups_scale}"
            else:
                s1 += ", num_groups_scale={num_groups_scale}, num_groups_zero_point={num_groups_zero_point}"
            s1 += ", axis={axis}"

        return s1.format(**self.__dict__)

    def _encode(self, x: torch.Tensor):
        scale = self.get_encode_scale(x.ndim)
        zero_point = self.get_encode_zero_point(x.ndim)
        return x.div_(scale).add_(zero_point)

    def get_encode_scale(self, ndim: int = None):
        if self.num_groups_scale == 1:
            scale = self.scale
        elif self.channels_per_group_scale == 1:
            scale = self.scale
        else:
            scale = self.scale.repeat_interleave(self.channels_per_group_scale)
        if ndim is not None:
            scale = self.view(scale, ndim)
        return scale

    def get_encode_zero_point(self, ndim: int = None):
        if self.num_groups_zero_point == 1:
            zero_point = self.zero_point
        elif self.channels_per_group_zero_point == 1:
            zero_point = self.zero_point
        else:
            zero_point = self.scale.repeat_interleave(self.channels_per_group_zero_point)
        if ndim is not None:
            zero_point = self.view(zero_point, ndim)
        return zero_point

    def _decode(self, x: torch.Tensor):
        scale = self.get_encode_scale(x.ndim)
        zero_point = self.get_encode_zero_point(x.ndim)
        return x.sub_(zero_point).mul_(scale)

    def view(self, data: torch.Tensor, ndim: int, axis=None):
        if axis is None:
            axis = self.axis
        shape = [1] * ndim
        shape[axis] = -1
        return data.view(shape)

    def quantize(self, x: torch.Tensor):
        x = ste_round(x)
        x = x.clamp_(self.quant_min, self.quant_max)
        return x

    def enable_quantization(self, quant_enabled: bool):
        self._quant_enabled = quant_enabled

    @staticmethod
    def to_tensor(x):
        return x + 0


class AccumulatorFakeQuant(StaticFakeQuant):
    """
    AccumulatorFakeQuant is a subclass of StaticFakeQuant that performs fake quantization
    with an wraparound instead of clamp.
    """

    def forward(self, x):
        x = self.to_tensor(x)
        if not self._quant_enabled:
            return x
        elif self._is_encoded:
            offset = self.quant_min
            mod_val = self.quant_max - self.quant_min + 1
        else:
            offset = self.view(self.quant_min * self.scale, x.ndim)
            mod_val = self.view((self.quant_max - self.quant_min + 1) * self.scale, x.ndim)
        x = (x - offset) % (mod_val) + offset
        return x


class QuantWeight(StaticFakeQuant):
    """
    A class that wraps a weight tensor and allows for static or online quantization.
    This class subclasses StaticFakeQuant and provides functionality to handle
    quantized weights either in a static or online manner.
    """

    weight: torch.Tensor

    def __init__(
        self,
        quant_min: int,
        quant_max: int,
        value: torch.Tensor,
        num_groups: int,
        channels: int,
        axis: int,
        *args,
        is_independent_encoding: Union[Tuple[bool, bool], bool] = False,
        is_independent_weight: bool = True,
        requires_grad: bool = True,
        **kwargs,
    ):
        super().__init__(
            quant_min,
            quant_max,
            channels,
            num_groups,
            axis,
            *args,
            is_independent_encoding=is_independent_encoding,
            dtype=value.dtype,
            device=value.device,
            **kwargs,
        )
        self.initilize_weigth(value, requires_grad, is_independent_weight)
        self.frozen = False

    def initilize_weigth(self, value, requires_grad, is_independent_weight):
        if is_independent_weight:
            self.weight = nn.Parameter(value, requires_grad=requires_grad)
        else:
            self.register_buffer("weight", value)

    def set_weight(self, weight):
        with torch.no_grad():
            self.weight.copy_(weight)

    def get_weight(self):
        if self.frozen:
            return self.weight
        else:
            return self(self.weight)

    def set_encoded(self, is_encoded: bool):
        change_encode = self._is_encoded != is_encoded
        if self.frozen and change_encode:
            x = self.weight
            if is_encoded:
                self.set_weight(self._encode(x))
            else:
                self.set_weight(self._decode(x))
        super().set_encoded(is_encoded)

    def freeze(self):
        """
        Freezes the weights of the model by setting the weights to a static state and
        disabling gradient updates for the weights.
        This method performs the following steps:
        1. Sets the weights to a static state using the `set_weight` method.
        2. Disables gradient updates for the weights by setting `requires_grad` to False.
        """
        if not self.frozen:
            self.set_weight(self.get_weight())
            if isinstance(self.weight, torch.Tensor):
                self.weight.requires_grad = False
            self.scale.requires_grad = False
            self.zero_point.requires_grad = False
            self.frozen = True
        else:
            raise ValueError("Weights are already frozen")

    def enable_quantization(self, quant_enabled):
        quant_changed = self._quant_enabled != quant_enabled
        if self.frozen and quant_changed:
            if self._quant_enabled:
                x = self.to_tensor(self.weight)
                if not self._is_encoded:
                    x = self._encode(x)
                x = self.quantize(x)
                if not self._is_encoded:
                    x = self._decode(x)
                self.set_weight(x)
            else:
                raise ValueError("Can't disable quantization after freeze")
        super().enable_quantization(quant_enabled)

    def forward(self, x: torch.Tensor):
        x = self.to_tensor(x)
        if not self._quant_enabled and not self._is_encoded:
            return x
        else:
            x = self._encode(x)
            if self._quant_enabled:
                x = self.quantize(x)
            if not self._is_encoded:
                x = self._decode(x)
            return x

    def forward_encoding(self, *enc: Encoding, **kwargs) -> Encoding:
        raise NotImplementedError(f"forward_encoding is not implemented for {self.__class__.__name__}")

    def extra_repr(self):
        s1 = super().extra_repr()
        return f"frozen={self.frozen}, {s1}"

    def get_encoded_weight(self):
        x = self.get_weight()
        if not self._is_encoded:
            x = self.to_tensor(x)
            x = self._encode(x)
        return x


class QuantPadValue(QuantWeight):
    def __init__(
        self,
        quant_min: int,
        quant_max: int,
        value: float,
        channels,
        *args,
        dtype=None,
        device=None,
        **kwargs,
    ):
        pad_value = torch.tensor([value], dtype=dtype, device=device)
        super().__init__(
            quant_min,
            quant_max,
            pad_value,
            channels,
            channels,
            0,
            *args,
            **kwargs,
        )

    def forward_encoding(self, encoding: Encoding, **kwargs):
        self.scale = encoding.scale_by_channel
        self.zero_point = encoding.zero_point_by_channel

    def to_tensor(self, x):
        return x + torch.zeros_like(self.scale)

    def set_weight(self, weight):
        if isinstance(weight, torch.Tensor):
            assert torch.allclose(weight, weight[0])
        del self.weight
        self.weight = weight[0].item()

    def set_encoded(self, is_encoded: bool):
        change_encode = self._is_encoded != is_encoded
        if self.frozen and change_encode:
            x = self.to_tensor(self.weight)
            if is_encoded:
                self.set_weight(self._encode(x))
            else:
                self.set_weight(self._decode(x))
        StaticFakeQuant.set_encoded(self, is_encoded)


class QuantBias(QuantWeight):
    # NOTE: should not apply clipping to bias
    # wraparound should be applied in the MAC unit

    def quantize(self, x: torch.Tensor):
        return ste_round(x)

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        self.scale.copy_(encoding.scale_by_channel)
        self.zero_point.copy_(-encoding.zero_point_by_channel)
        return init_encoding(
            scale_by_group=encoding.scale_by_group,
            scale_repeats=encoding.scale_repeats,
            zero_point_by_group=torch.zeros_like(encoding.scale_by_group),
            zero_point_repeats=encoding.scale_repeats,
            equalization_vector=encoding.equalization_vector,
        )


class QuantBiasDecomposed(QuantWeight):
    factor: nn.Parameter
    feed_repeat: nn.Parameter
    mac_shift: torch.Tensor

    def __init__(
        self,
        quant_min,
        quant_max,
        value,
        num_groups,
        channels,
        axis,
        *args,
        is_independent_encoding=False,
        requires_grad=True,
        **kwargs,
    ):
        super().__init__(
            quant_min,
            quant_max,
            value,
            num_groups,
            channels,
            axis,
            *args,
            is_independent_encoding=is_independent_encoding,
            requires_grad=requires_grad,
            **kwargs,
        )
        self.factor = nn.Parameter(
            torch.ones([1], dtype=value.dtype, device=value.device),
            requires_grad=requires_grad,
        )
        self.feed_repeat = nn.Parameter(
            torch.ones([1], dtype=value.dtype, device=value.device),
            requires_grad=requires_grad,
        )
        mac_shift = torch.zeros(1, dtype=torch.int32, device=value.device)
        self.register_buffer("mac_shift", mac_shift)

    def quantize(self, x):
        factor = torch.maximum(self.factor, torch.ones_like(self.factor))
        repeat = torch.maximum(self.feed_repeat, torch.ones_like(self.feed_repeat))

        shift_factor = 2**self.mac_shift
        x = x * shift_factor / (repeat * factor)  # bias INT8 input
        x = ste_round(x)
        x = x * self.factor * self.feed_repeat / shift_factor
        return x

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        self.scale.copy_(encoding.scale_by_channel)
        self.zero_point.copy_(-encoding.zero_point_by_channel)
        return init_encoding(
            scale_by_group=encoding.scale_by_group,
            scale_repeats=encoding.scale_repeats,
            zero_point_by_group=torch.zeros_like(encoding.scale_by_group),
            zero_point_repeats=encoding.scale_repeats,
            equalization_vector=encoding.equalization_vector,
        )


class QuantKernel(QuantWeight):
    mac_shift: torch.Tensor

    # NOTE: this class tries to handle the matrix scale provided by acceleras.
    #   In the future we'll need to export the equalization vectors and kernel proper scale, if we want the scales to be trainable
    # NOTE: torch.addcmul can be useful to multiply the kernel by equalization vectors
    def __init__(
        self,
        quant_min,
        quant_max,
        value,
        num_groups,
        channels,
        cout_axis,
        cin_axis,
        *args,
        is_independent_encoding=False,
        requires_grad=True,
        **kwargs,
    ):
        super().__init__(
            quant_min,
            quant_max,
            value,
            num_groups,
            channels,
            cout_axis,
            *args,
            is_independent_encoding=is_independent_encoding,
            requires_grad=requires_grad,
            **kwargs,
        )
        self.cin_axis = cin_axis
        mac_shift = torch.zeros(1, dtype=value.dtype, device=value.device)
        self.register_buffer("mac_shift", mac_shift)

    def _encode(self, x):
        zp = self.get_encode_zero_point(x.ndim)
        return self.apply_mac_shift(super()._encode(x) - zp)

    def _decode(self, x):
        zp = self.get_encode_zero_point(x.ndim)
        return super()._decode(self.deapply_mac_shift(x) + zp)

    def apply_mac_shift(self, x: torch.Tensor):
        return x.div_(2**self.mac_shift)

    def deapply_mac_shift(self, x: torch.Tensor):
        return x.mul_(2**self.mac_shift)

    def quantize(self, x):
        zp = self.get_encode_zero_point(x.ndim)
        x = self.deapply_mac_shift(x) + zp
        x = super().quantize(x)
        x = self.apply_mac_shift(x - zp)
        return x

    def calculate_bias_residue(self, kernel_encoded, zero_point_in, kernel_in_axis):
        axes = "ij"
        bias_axis = axes[kernel_in_axis]
        kernel_axes = axes[: kernel_encoded.ndim]

        bias_residue = torch.einsum(f"{kernel_axes}...,{bias_axis}->i", kernel_encoded, zero_point_in)
        return bias_residue


class QuantKernelScalar(QuantKernel):
    def __init__(
        self,
        quant_min,
        quant_max,
        value,
        num_groups,
        channels,
        cout_axis,
        cin_axis,
        spatial_size,
        *args,
        is_independent_encoding=(True, False),
        requires_grad=False,
        **kwargs,
    ):
        super().__init__(
            quant_min,
            quant_max,
            value,
            num_groups,
            channels,
            cout_axis,
            cin_axis,
            *args,
            is_independent_encoding=is_independent_encoding,
            requires_grad=requires_grad,
            **kwargs,
        )
        self.spatial_size = spatial_size

    def forward_encoding(self, encoding: Encoding, verify_encoding=False, **kwargs) -> Encoding:
        """_summary_

        Args:
            encoding (_type_): input encoding to the module

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: [accumulator scale, bias residue]
        """
        kernel = self.get_encoded_weight()
        zero_point = encoding.zero_point_by_channel
        bias_residue = self.calculate_bias_residue(kernel, zero_point, self.cin_axis) * self.spatial_size

        input_scale = encoding.scale_by_channel / encoding.equalization_vector
        if verify_encoding:
            assert torch.allclose(input_scale, input_scale[0])
        accumulator_scale_by_group = input_scale[0] * self.scale * 2**self.mac_shift
        scale_repeats = self.channels_per_group_scale
        zero_point_repeats = 1
        output_encoding = init_encoding(
            scale_by_group=accumulator_scale_by_group,
            scale_repeats=scale_repeats,
            zero_point_by_group=bias_residue,
            zero_point_repeats=zero_point_repeats,
            equalization_vector=encoding.equalization_vector,
        )
        return output_encoding


class QuantEqWeight(QuantWeight):
    """Weight that can be equalized, but doesn't receive input data
    The only use case for this atm is ConstInput
    """

    equalization_vector_out: nn.Parameter

    def __init__(
        self,
        quant_min,
        quant_max,
        value,
        num_groups,
        channels,
        axis,
        *args,
        is_independent_encoding=(True, False),
        requires_grad=True,
        **kwargs,
    ):
        super().__init__(
            quant_min,
            quant_max,
            value,
            num_groups,
            channels,
            axis,
            *args,
            is_independent_encoding=is_independent_encoding,
            requires_grad=requires_grad,
            **kwargs,
        )
        equalization_vector_out = nn.Parameter(
            torch.ones(channels, dtype=value.dtype, device=value.device), requires_grad=False
        )
        self.register_parameter("equalization_vector_out", equalization_vector_out)

    def equalize(self, x: torch.Tensor):
        eq_out_vec = self.view(self.equalization_vector_out, x.ndim, axis=self.axis)
        # Note div in place to avoid memory allocation with large kernels
        x = x.div_(eq_out_vec)
        return x

    def dequalize(self, x: torch.Tensor):
        eq_out_vec = self.view(self.equalization_vector_out, x.ndim, axis=self.axis)
        # Note mul in place to avoid memory allocation with large kernels
        x = x.mul_(eq_out_vec)
        return x

    def _encode(self, x):
        return super()._encode(self.equalize(x))

    def _decode(self, x):
        return self.dequalize(super()._decode(x))

    def forward_encoding(self, **kwargs) -> Encoding:
        return init_encoding(
            scale_by_group=self.scale,
            scale_repeats=self.channels_per_group_scale,
            zero_point_by_group=self.zero_point,
            zero_point_repeats=self.channels_per_group_zero_point,
            equalization_vector=self.equalization_vector_out,
        )


class QuantEqKernel(QuantKernel):
    equalization_vector_in: torch.Tensor
    equalization_vector_out: nn.Parameter
    mac_shift: torch.Tensor

    # NOTE: this class tries to handle the matrix scale provided by acceleras.
    #   In the future we'll need to export the equalization vectors and kernel proper scale, if we want the scales to be trainable
    # NOTE: torch.addcmul can be useful to multiply the kernel by equalization vectors
    def __init__(
        self,
        quant_min,
        quant_max,
        value,
        num_groups,
        channels_in,
        channels_out,
        cout_axis,
        cin_axis,
        *args,
        is_independent_encoding=(True, False),
        is_independent_equalization=True,
        requires_grad=True,
        **kwargs,
    ):
        super().__init__(
            quant_min,
            quant_max,
            value,
            num_groups,
            channels_out,
            *args,
            cout_axis=cout_axis,
            cin_axis=cin_axis,
            is_independent_encoding=is_independent_encoding,
            requires_grad=requires_grad,
            **kwargs,
        )
        equalization_vector_in = torch.ones(channels_in, dtype=value.dtype, device=value.device)
        self.register_buffer("equalization_vector_in", equalization_vector_in)
        equalization_vector = torch.ones(channels_out, dtype=value.dtype, device=value.device)
        if is_independent_equalization:
            equalization_vector_out = nn.Parameter(equalization_vector, requires_grad=False)
            self.register_parameter("equalization_vector_out", equalization_vector_out)
        else:
            self.register_buffer("equalization_vector_out", equalization_vector)

    def equalize(self, x: torch.Tensor):
        eq_in_vec = self.view(self.equalization_vector_in, x.ndim, axis=self.cin_axis)
        eq_out_vec = self.view(self.equalization_vector_out, x.ndim, axis=self.axis)
        # Note div in place to avoid memory allocation with large kernels
        x = x.div_(eq_in_vec).div_(eq_out_vec)
        return x

    def dequalize(self, x: torch.Tensor):
        eq_in_vec = self.view(self.equalization_vector_in, x.ndim, axis=self.cin_axis)
        eq_out_vec = self.view(self.equalization_vector_out, x.ndim, axis=self.axis)
        # Note mul in place to avoid memory allocation with large kernels
        x = x.mul_(eq_in_vec).mul_(eq_out_vec)
        return x

    def _encode(self, x):
        return super()._encode(self.equalize(x))

    def _decode(self, x):
        return self.dequalize(super()._decode(x))

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        self.equalization_vector_in.copy_(torch.reciprocal(encoding.scale_by_channel))
        kernel = self.get_encoded_weight()
        zero_point = encoding.zero_point_by_channel
        bias_residue = self.calculate_bias_residue(kernel, zero_point, self.cin_axis)

        accumulator_scale_by_group = self.scale * 2**self.mac_shift
        scale_repeats = self.channels_per_group_scale
        zero_point_repeats = 1

        return init_encoding(
            scale_by_group=accumulator_scale_by_group,
            scale_repeats=scale_repeats,
            zero_point_by_group=bias_residue,
            zero_point_repeats=zero_point_repeats,
            equalization_vector=self.equalization_vector_out,
        )
