from typing import Tuple

import torch

from hailo_model_optimization.acceleras.utils.acceleras_definitions import FeatureMultiplierType
from hailo_model_optimization.saitama.framework.common.commom_funtions import CommonFunctions
from hailo_model_optimization.saitama.framework.common.saitama_definitions import DimsInfo, MACPrecisionConfig
from hailo_model_optimization.saitama.framework.mac_modules.mac_base import MACBase


class MACEWMult(MACBase):
    zero_point: torch.Tensor

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        reduce_sum_groups: int,
        input_repeats: Tuple[DimsInfo, DimsInfo] = None,
        precision_config: MACPrecisionConfig = None,
        device=None,
        dtype=None,
    ):
        super().__init__(out_channels, precision_config=precision_config, device=device, dtype=dtype)
        self.input_repeats = input_repeats if input_repeats is not None else (DimsInfo(), DimsInfo())
        self.reduce_sum_groups = reduce_sum_groups
        self.register_buffer("zero_point", torch.zeros((2, in_channels), device=device, dtype=dtype))
        self._is_encoded = False

    def forward_mac(self, x, y, **kwargs):
        x = CommonFunctions.apply_repeat_interleave(x, self.input_repeats[0])
        y = CommonFunctions.apply_repeat_interleave(y, self.input_repeats[1])
        post_mult = torch.mul(x, y)
        if self._is_encoded:
            post_mult = (
                post_mult
                - torch.mul(x, self.zero_point[1].view(1, -1, 1, 1))
                - torch.mul(y, self.zero_point[0].view(1, -1, 1, 1))
            )
            post_mult = post_mult / 2**self.mac_shift
        post_mult = post_mult.unflatten(1, (self.reduce_sum_groups, post_mult.shape[1] // self.reduce_sum_groups))
        post_mult = torch.sum(post_mult, axis=2, keepdim=False)
        out = post_mult + self.bias.get_weight().view(1, -1, 1, 1)

        return out

    def set_encoded(self, is_encoded: bool):
        self._is_encoded = is_encoded


class MACFeatureMultiplier(MACEWMult):
    def __init__(
        self,
        feature_multiplier_type,
        reduce_sum_groups,
        in_channels,
        out_channels,
        input_repeats=None,
        bias=True,
        precision_config=None,
        device=None,
        dtype=None,
    ):
        super().__init__(
            in_channels, out_channels, reduce_sum_groups, input_repeats, bias, precision_config, device, dtype
        )
        self.reduce_sum_groups = reduce_sum_groups
        self.feature_multiplier_type = feature_multiplier_type

    def forward_mac(self, x, **kwargs):
        if self.feature_multiplier_type == FeatureMultiplierType.square:
            return MACEWMult.forward_mac(self, x, x, **kwargs)
        raise NotImplementedError(f"Feature multiplier type {self.feature_multiplier_type} is not supported")
