from typing import List, Tuple, Union

import torch
import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.commom_funtions import CommonFunctions
from hailo_model_optimization.saitama.framework.common.fake_quant import QuantBias, QuantBiasDecomposed
from hailo_model_optimization.saitama.framework.common.saitama_definitions import DimsInfo, Encoding, MACPrecisionConfig
from hailo_model_optimization.saitama.framework.common.utils import init_encoding, qtype_to_range
from hailo_model_optimization.saitama.framework.mac_modules.mac_base import MACBase


class MACEWMultOnAPU(MACBase):
    bias: List[Union[QuantBias, QuantBiasDecomposed]]
    apu_shift: torch.Tensor

    def __init__(
        self,
        channels: int,
        input_repeats: Tuple[DimsInfo, DimsInfo] = None,
        precision_config: MACPrecisionConfig = None,
        device=None,
        dtype=None,
    ):
        super().__init__(channels, bias=True, precision_config=precision_config, device=device, dtype=dtype)
        self.input_repeats = input_repeats if input_repeats is not None else (DimsInfo(1, 1, 1), DimsInfo(1, 1, 1))
        apu_shift = torch.zeros(1, device=device, dtype=dtype)
        self.register_buffer("apu_shift", apu_shift)
        self._is_encoded = False

    def set_encoded(self, is_encoded: bool):
        self._is_encoded = is_encoded

    def initialize_bias(self, device=None, dtype=None):
        self.bias = nn.ModuleList()
        for _ in range(2):
            bias_shape = self.get_bias_shape()
            quant_min, quant_max = qtype_to_range(self.accumulator_qtype)
            bias = self.create_bias_weight(quant_min, quant_max, bias_shape, device=device, dtype=dtype)
            self.bias.append(bias)

    def forward_mac(self, x, y, **kwargs):
        x = CommonFunctions.apply_repeat_interleave(x, self.input_repeats[0])
        y = CommonFunctions.apply_repeat_interleave(y, self.input_repeats[1])
        # NOTE: We don't use the kernel, so we don't need to apply the mac shift
        x = x + self.bias[0].get_weight().view(1, -1, 1, 1)
        y = y + self.bias[1].get_weight().view(1, -1, 1, 1)
        result = torch.mul(x, y)
        if self._is_encoded:
            result.div_(2**self.apu_shift)
        return result

    def forward_encoding(self, encoding_x: Encoding, encoding_y: Encoding, **kwargs):
        encoding_x = self.bias[0].forward_encoding(encoding_x, **kwargs)
        encoding_y = self.bias[1].forward_encoding(encoding_y, **kwargs)
        x_ch_repeat = self.input_repeats[0].channels
        y_ch_repeat = self.input_repeats[1].channels
        x_eq_vec = encoding_x.equalization_vector.repeat_interleave(x_ch_repeat)
        y_eq_vec = encoding_y.equalization_vector.repeat_interleave(y_ch_repeat)
        x_repeats = encoding_x.scale_repeats * x_ch_repeat
        y_repeats = encoding_y.scale_repeats * y_ch_repeat
        post_mult_enc = init_encoding(
            scale_by_group=encoding_x.scale_by_group * encoding_y.scale_by_group * 2**self.apu_shift,
            scale_repeats=min(x_repeats, y_repeats),
            zero_point_by_group=torch.zeros_like(encoding_x.zero_point_by_group),
            zero_point_repeats=1,
            equalization_vector=x_eq_vec * y_eq_vec,
        )
        return self.accumulator_quantizer.forward_encoding(post_mult_enc, **kwargs)
