import torch
import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.fake_quant import QuantEqKernel
from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding
from hailo_model_optimization.saitama.framework.common.utils import qtype_to_range
from hailo_model_optimization.saitama.framework.mac_modules.mac_base import MACBase


class MACDense(MACBase):
    def __init__(self, in_channels, out_channels, bias=True, precision_config=None, device=None, dtype=None):
        super().__init__(out_channels, bias, precision_config, device, dtype)
        self.in_channels = in_channels

        self.initialize_kernel(device=device, dtype=dtype)

    def initialize_kernel(self, device=None, dtype=None):
        kernel = torch.empty(self.out_channels, self.in_channels, device=device, dtype=dtype)
        quant_min, quant_max = qtype_to_range(self.weight_qtype)
        self.kernel = QuantEqKernel(
            quant_min=quant_min,
            quant_max=quant_max,
            value=kernel,
            num_groups=self.quantization_groups,
            channels_in=self.in_channels,
            channels_out=self.out_channels,
            cin_axis=1,
            cout_axis=0,
        )

    def forward_mac(self, x: torch.Tensor, **kwargs):
        # Flatten x's spatial dimensions
        indim = x.dim()
        x = x.reshape(x.shape[0], -1)
        bias = self.bias.get_weight() if self.bias is not None else None
        result = nn.functional.linear(x, self.kernel.get_weight(), bias=bias)
        spatial = [1] * (indim - result.dim())
        return result.reshape(*result.shape, *spatial)

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        encoding = self.kernel.forward_encoding(encoding, **kwargs)
        encoding = self.bias.forward_encoding(encoding, **kwargs)
        encoding = self.accumulator_quantizer.forward_encoding(encoding, **kwargs)
        return encoding
