import torch

from hailo_model_optimization.saitama.framework.common.fake_quant import (
    QuantKernel,
)
from hailo_model_optimization.saitama.framework.common.utils import qtype_to_range
from hailo_model_optimization.saitama.framework.mac_modules.mac_base import MACBase


class MACReduceSum(MACBase):
    def __init__(self, out_channels, groups, axes, bias=True, precision_config=None, device=None, dtype=None):
        super().__init__(out_channels, bias, precision_config, device, dtype)
        factory_kwargs = {"device": device, "dtype": dtype}
        self.axes = axes
        self.groups = groups
        self.initialize_kernel(**factory_kwargs)

    def initialize_kernel(self, dtype=None, device=None):
        quant_min, quant_max = qtype_to_range(self.weight_qtype)
        kernel_value = torch.ones(1, dtype=dtype, device=device)
        self.kernel = QuantKernel(
            quant_min=quant_min,
            quant_max=quant_max,
            value=torch.tensor(kernel_value, dtype=dtype, device=device),
            num_groups=1,
            channels=self.out_channels,
            cout_axis=0,
            cin_axis=0,
            is_independent_encoding=False,
            requires_grad=False,
        )

    def forward_mac(self, x, **kwargs):
        x = x * self.kernel.get_weight()
        for axis in self.axes:
            groups = self.groups[axis - 1]
            x = x.unflatten(axis, (groups, x.shape[axis] // groups))
            x = torch.sum(x, axis=axis + 1, keepdim=False)
        return x + self.bias.get_weight().view(1, -1, 1, 1)
