from typing import List, Tuple

import torch
import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.commom_funtions import CommonFunctions
from hailo_model_optimization.saitama.framework.common.fake_quant import QuantEqKernel
from hailo_model_optimization.saitama.framework.common.saitama_definitions import DimsInfo, Encoding, MACPrecisionConfig
from hailo_model_optimization.saitama.framework.common.utils import init_encoding, qtype_to_range
from hailo_model_optimization.saitama.framework.mac_modules.mac_base import MACBase


class MACFactorAndAdd(MACBase):
    FACTORS: Tuple[float, float]
    kernel: List[QuantEqKernel]

    def __init__(
        self,
        channels: int,
        input_repeats: Tuple[DimsInfo, DimsInfo] = None,
        precision_config: MACPrecisionConfig = None,
        device=None,
        dtype=None,
    ):
        # TODO: add repeat support
        super().__init__(channels, bias=True, precision_config=precision_config, device=device, dtype=dtype)
        self.initialize_kernels(device=device, dtype=dtype)
        self.input_repeats = input_repeats if input_repeats is not None else (DimsInfo(), DimsInfo())

    def initialize_kernels(self, device=None, dtype=None):
        self.kernel = nn.ModuleList()
        for i in range(2):
            kernel = torch.ones((self.out_channels), device=device, dtype=dtype)
            quant_min, quant_max = qtype_to_range(self.weight_qtype)
            quant_kernel = QuantEqKernel(
                quant_min=quant_min,
                quant_max=quant_max,
                value=kernel,
                num_groups=self.quantization_groups,
                channels_in=self.out_channels,
                channels_out=self.out_channels,
                cin_axis=0,
                cout_axis=0,
                is_independent_equalization=(i == 0),
            )
            self.kernel.append(quant_kernel)

    def forward_mac(self, *inputs, **kwargs):
        # NOTE: do we want an alternative implementation without the conv? (or any other multiplication)
        #   with native scale, we'll miss the equalization error in the kernel, but we'll have it in the input
        result = self.bias.get_weight().view(1, -1, 1, 1)
        for idx, inp in enumerate(inputs):
            x = CommonFunctions.apply_repeat_interleave(inp, self.input_repeats[idx])
            x = x * self.kernel[idx].get_weight().view(1, -1, 1, 1) * self.FACTORS[idx]
            result = result + x
        return result

    def forward_encoding(self, encoding_x: Encoding, encoding_y: Encoding, **kwargs) -> Encoding:
        verify_encoding = kwargs.get("verify_encoding", False)
        encoding_x = self.kernel[0].forward_encoding(encoding_x, **kwargs)
        encoding_y = self.kernel[1].forward_encoding(encoding_y, **kwargs)
        self.kernel[1].equalization_vector_out.copy_(self.kernel[0].equalization_vector_out)

        if verify_encoding:
            assert torch.allclose(encoding_x.scale_by_group, encoding_y.scale_by_group)

        encoding = init_encoding(
            scale_by_group=encoding_x.scale_by_group,
            scale_repeats=encoding_x.scale_repeats,
            zero_point_by_group=encoding_x.zero_point_by_channel + encoding_y.zero_point_by_channel,
            zero_point_repeats=1,
            equalization_vector=encoding_x.equalization_vector,
        )
        encoding = self.bias.forward_encoding(encoding, **kwargs)
        encoding = self.accumulator_quantizer.forward_encoding(encoding, **kwargs)
        return encoding


class MACEWAdd(MACFactorAndAdd):
    FACTORS = (1, 1)


class MACEWSub(MACEWAdd):
    FACTORS = (1, -1)
