import numpy as np
import torch
import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.fake_quant import (
    QuantKernelScalar,
    QuantPadValue,
)
from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding
from hailo_model_optimization.saitama.framework.common.utils import (
    PaddingUtils,
    _pair,
    parse_explicit_padding,
    qtype_to_range,
)
from hailo_model_optimization.saitama.framework.mac_modules.mac_base import MACBase


class MACAvgpool2d(MACBase):
    padding_value: QuantPadValue

    def __init__(
        self,
        out_channels,
        kernel_size,
        stride,
        padding,
        count_include_pad=True,
        spatial_shape=None,
        bias=True,
        precision_config=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(out_channels, bias=bias, precision_config=precision_config, **factory_kwargs)
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride) if (stride is not None) else self.kernel_size
        self.count_include_pad = count_include_pad

        self._is_encoded = False
        self.force_encoded_padding = False
        self.padding_mode = "zeros"
        qmin, qmax = qtype_to_range(self.input_qtype)
        self.padding_value = QuantPadValue(qmin, qmax, 0.0, self.out_channels, **factory_kwargs)

        if isinstance(padding, str):
            padding, padding_mode, explicit_padding = parse_explicit_padding(
                padding, kernel_size, stride, (1, 1), spatial_shape
            )
            self._reversed_padding_repeated_twice = explicit_padding
            pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = explicit_padding
            if pad_beg_h == pad_end_h and pad_beg_w == pad_end_w:
                self.padding = (pad_beg_h, pad_beg_w)
            else:
                self.padding = padding
                self.padding_mode = padding_mode
            self.has_padding = any(p != 0 for p in self._reversed_padding_repeated_twice)
        valid_padding_modes = {"constant", "zeros"}
        if self.padding_mode not in valid_padding_modes:
            raise ValueError(
                f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'"
            )

        self.initialize_kernel(**factory_kwargs)
        self.register_load_state_dict_post_hook(self.update_padding_mode)

    def initialize_kernel(self, dtype=None, device=None):
        quant_min, quant_max = qtype_to_range(self.weight_qtype)
        spatial_size = np.prod(self.kernel_size)
        kernel_value = float(1 / spatial_size)  # python conversion to avoid np dtype
        self.kernel = QuantKernelScalar(
            quant_min=quant_min,
            quant_max=quant_max,
            value=torch.tensor([kernel_value], dtype=dtype, device=device),
            num_groups=1,
            channels=self.out_channels,
            cin_axis=0,
            cout_axis=0,
            spatial_size=int(spatial_size),
        )

    def set_encoded(self, is_encoded: bool):
        self._is_encoded = is_encoded
        encoded_pad = self.padding_value.get_encoded_weight()
        if isinstance(encoded_pad, float):
            non_zero_pad = encoded_pad != 0
        else:
            non_zero_pad = any(encoded_pad != 0)
        self.force_encoded_padding = is_encoded and self.has_padding and non_zero_pad

    def forward_mac(self, x, **kwargs):
        padding = self.padding
        if (self.padding_mode == "constant") or self.force_encoded_padding:
            pad_value = self.padding_value.get_weight()
            x = PaddingUtils.manual_padding(x, pad_value, self._reversed_padding_repeated_twice)
            padding = 0

        kernel = self.kernel.get_weight()

        pooling_res = nn.functional.avg_pool2d(
            x * kernel,
            self.kernel_size,
            stride=self.stride,
            padding=padding,
            ceil_mode=False,
            count_include_pad=self.count_include_pad,
            divisor_override=1,
        )

        bias = self.bias.get_weight() if self.bias is not None else None
        return pooling_res + bias.view(1, -1, 1, 1)

    def update_padding_mode(self, *args, **kwargs):
        if self.padding_value.weight != 0 and self.padding:
            self.padding_mode = "constant"

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        self.padding_value.forward_encoding(encoding, **kwargs)
        encoding = self.kernel.forward_encoding(encoding, **kwargs)
        encoding = self.bias.forward_encoding(encoding, **kwargs)
        encoding = self.accumulator_quantizer.forward_encoding(encoding, **kwargs)
        return encoding
