import torch

from hailo_model_optimization.saitama.framework.common.fake_quant import QuantEqKernel
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    Encoding,
    MACPrecisionConfig,
)
from hailo_model_optimization.saitama.framework.common.utils import (
    qtype_to_range,
)
from hailo_model_optimization.saitama.framework.mac_modules.mac_base import MACBase


class MACNorm(MACBase):
    def __init__(
        self,
        channels: int,
        axis=1,
        bias: bool = True,  # Do we allow bias=False, or is it required to compansate for the zp multiplication?
        precision_config: MACPrecisionConfig = None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(channels, bias=bias, precision_config=precision_config, **factory_kwargs)
        self.axis = axis
        self.initialize_kernel(**factory_kwargs)

    def initialize_kernel(self, device=None, dtype=None):
        kernel = torch.empty((self.out_channels), device=device, dtype=dtype)
        quant_min, quant_max = qtype_to_range(self.weight_qtype)
        self.kernel = QuantEqKernel(
            quant_min=quant_min,
            quant_max=quant_max,
            value=kernel,
            num_groups=self.quantization_groups,
            channels_in=self.out_channels,
            channels_out=self.out_channels,
            cin_axis=0,
            cout_axis=0,
        )

    def forward_mac(self, inp: torch.Tensor, **kwargs):
        # NOTE: as long as we work in native scale, we don't need to apply the mac shift
        kernel = self.kernel.get_weight()
        bias = self.bias.get_weight()
        in_dim = inp.ndim
        shape = [1] * in_dim
        shape[self.axis] = -1
        return inp * kernel.view(shape) + bias.view(shape)

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        encoding = self.kernel.forward_encoding(encoding, **kwargs)
        encoding = self.bias.forward_encoding(encoding, **kwargs)
        encoding = self.accumulator_quantizer.forward_encoding(encoding, **kwargs)
        return encoding

    def extra_repr(self):
        s1 = "{out_channels}, axis={axis}"
        return s1.format(**self.__dict__)
