from typing import Tuple

import einops
import torch
from torch import Tensor

from hailo_model_optimization.saitama.framework.common.commom_funtions import Reshaping
from hailo_model_optimization.saitama.framework.common.fake_quant import (
    QuantWeight,
)
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    DimsInfo,
    Encoding,
    MACPrecisionConfig,
)
from hailo_model_optimization.saitama.framework.common.utils import init_encoding, qtype_to_range
from hailo_model_optimization.saitama.framework.mac_modules.mac_base import MACBase


class QuantZpVal(QuantWeight):
    def forward_encoding(self, encoding: Encoding, **kwargs):
        self.scale = encoding.scale_by_channel
        self.weight = encoding.scale_by_channel * encoding.zero_point_by_channel
        return encoding


class MACMatmul(MACBase):
    feed_repeat: Tensor
    zp: QuantWeight

    def __init__(
        self,
        chanels_in_0: int,
        channels_out: int,
        groups: int = 1,
        transpose_input: bool = False,
        zp_comp_rank: int = 0,
        bias: bool = False,
        window: DimsInfo = None,
        input_tiles: Tuple[DimsInfo, DimsInfo] = None,
        precision_config: MACPrecisionConfig = None,
        device=None,
        dtype=None,
    ):
        d_d = {"device": device, "dtype": dtype}

        super().__init__(channels_out, bias=bias, precision_config=precision_config, **d_d)
        input_tiles = input_tiles if input_tiles is not None else (DimsInfo(1, 1, 1), DimsInfo(1, 1, 1))
        self._is_encoded = False
        self.groups = groups
        self.transpose_input = transpose_input
        self.zp_comp_rank = zp_comp_rank
        self.input_tiles = input_tiles
        self.window = window if window is not None else DimsInfo(1, 1, 1)
        self.inizialize_zp(chanels_in_0, **d_d)
        self.register_buffer("feed_repeat", torch.zeros(self.zp_comp_rank, **d_d))

    def inizialize_zp(self, channels, **d_d):
        zp = torch.zeros(channels, **d_d)
        quant_min, quant_max = qtype_to_range(self.input_qtype)
        self.zp = QuantZpVal(
            quant_min=quant_min,
            quant_max=quant_max,
            value=zp,
            channels=channels,
            num_groups=channels,
            axis=0,
            is_independent_weight=False,
        )

    def forward_mac(self, inp_0: Tensor, inp_1: Tensor, **kwargs) -> Tensor:
        # NOTE: as long as we work in native scale, we don't need to apply the mac shift

        inp_0 = Reshaping.reshape_matmul_input(inp_0, self.groups, self.window, False)
        inp_1 = Reshaping.reshape_matmul_input(inp_1, self.groups, self.window, self.transpose_input)
        inp_0, inp_1 = Reshaping.tile_inputs(inp_0, self.input_tiles[0], inp_1, self.input_tiles[1])

        inp_1_w, w_sum_decompose = self._split_inp(inp_1)

        zp_comp = self._build_zp_comp(w_sum_decompose) if self.zp_comp_rank > 0 else 0

        if self._is_encoded:
            matmul_results = inp_0 @ inp_1_w + zp_comp
            matmul_results = matmul_results / 2**self.mac_shift
        else:
            zp_decode = self.zp.get_weight()
            zp_decode = einops.rearrange(zp_decode, "(g c) -> g c", g=self.groups)
            # After groups this can be cleaner and faster
            clean_zp_comp = einops.einsum(zp_decode, inp_1_w, " g c, b g h c w-> b g h w")
            clean_zp_comp = einops.rearrange(clean_zp_comp, "... c ->  ... 1 c")
            matmul_results = (inp_0 @ inp_1_w) + clean_zp_comp + zp_comp

        matmul_out = Reshaping.reshape_matmul_output(matmul_results, self.groups, self.window)
        return matmul_out

    def forward_encoding(self, inp_0_encoding: Encoding, inp_1_encoding: Encoding, **kwargs) -> Encoding:
        inp_0_encoding = self.zp.forward_encoding(inp_0_encoding, **kwargs)

        scale_0 = einops.rearrange(inp_0_encoding.scale_by_channel, "(g c) -> g 1 c", g=self.groups)[..., 0]
        scale_1 = einops.rearrange(inp_1_encoding.scale_by_channel, "(g c) -> g 1 c", g=self.groups)[..., 0]
        out_scale = scale_0 * scale_1 * 2**self.mac_shift
        scale = einops.repeat(out_scale, "g c -> g (c rc)", rc=self.out_channels // self.groups).flatten()
        zero_point = torch.zeros_like(scale)
        enconding = init_encoding(scale, 1, zero_point, 1)
        enconding = self.accumulator_quantizer.forward_encoding(enconding, **kwargs)
        return enconding

    def _split_inp(self, inp: Tensor) -> Tuple[Tensor, Tensor]:
        splits = [inp.shape[3] - self.zp_comp_rank, self.zp_comp_rank]
        x, zp_comp = torch.split(inp, splits, dim=3)
        return x, zp_comp

    def _build_zp_comp(self, sum_w_decompose: torch.Tensor) -> torch.Tensor:
        feed_repeat = self.feed_repeat.view(1, 1, 1, -1, 1)
        sum_w = torch.sum(sum_w_decompose * feed_repeat, dim=3, keepdim=True)
        zp = self.zp.get_weight()
        zp = einops.rearrange(zp, "(g c) -> 1 g 1 1 c 1", g=self.groups)[..., 0, :]
        zp_comp = zp * sum_w
        return zp_comp

    def set_encoded(self, is_encoded: bool):
        self._is_encoded = is_encoded
