from typing import Optional, Tuple

import torch
import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.fake_quant import QuantEqKernel, QuantPadValue
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    Encoding,
    MACPrecisionConfig,
    size_2_int,
)
from hailo_model_optimization.saitama.framework.common.utils import (
    PaddingUtils,
    _pair,
    _reverse_repeat_tuple,
    parse_explicit_padding,
    qtype_to_range,
)
from hailo_model_optimization.saitama.framework.mac_modules.mac_base import MACBase


class MACConv2d(MACBase):
    padding_value: QuantPadValue

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: size_2_int,
        stride: size_2_int = 1,
        padding: size_2_int = 0,
        dilation: size_2_int = 1,
        groups: int = 1,
        bias: bool = True,  # Do we allow bias=False, or is it required to compansate for the zp multiplication?
        padding_mode: str = "zeros",
        spatial_shape: Optional[Tuple[int, int]] = None,  # Spatial shape will be used for padding calculation
        precision_config: MACPrecisionConfig = None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(out_channels, bias=bias, precision_config=precision_config, **factory_kwargs)

        if groups <= 0:
            raise ValueError("groups must be a positive integer")
        if in_channels % groups != 0:
            raise ValueError("in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError("out_channels must be divisible by groups")
        valid_padding_strings = {"same", "valid", "same_tensorflow"}
        if isinstance(padding, str) and padding not in valid_padding_strings:
            raise ValueError(f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}")

        # TODO: add padding mode based on hailo behavior
        self.in_channels = in_channels

        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = padding if isinstance(padding, str) else _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups

        self.padding_mode = padding_mode
        self.expected_spatial_shape = spatial_shape
        # `_reversed_padding_repeated_twice` is the padding to be passed to
        # `F.pad` if needed (e.g., for non-zero padding types that are
        # implemented as two ops: padding + conv). `F.pad` accepts paddings in
        # reverse order than the dimension.
        if isinstance(self.padding, str):
            padding, padding_mode, explicit_padding = parse_explicit_padding(
                padding, kernel_size, stride, dilation, spatial_shape
            )
            self.padding = padding
            self.padding_mode = padding_mode
            self._reversed_padding_repeated_twice = explicit_padding
        else:
            self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)

        self._is_encoded = False
        self.force_encoded_padding = False
        self.has_padding = any(p != 0 for p in self._reversed_padding_repeated_twice)

        valid_padding_modes = {"constant", "zeros"}
        if self.padding_mode not in valid_padding_modes:
            raise ValueError(
                f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'"
            )

        self.initialize_kernel(**factory_kwargs)
        qmin, qmax = qtype_to_range(self.input_qtype)
        self.padding_value = QuantPadValue(qmin, qmax, 0.0, self.in_channels, **factory_kwargs)

    @property
    def is_dw(self):
        return self.in_channels == self.groups

    def initialize_kernel(self, device=None, dtype=None):
        kernel = torch.empty(
            (self.out_channels, self.in_channels // self.groups, *self.kernel_size), device=device, dtype=dtype
        )
        quant_min, quant_max = qtype_to_range(self.weight_qtype)
        cin_axis = 0 if self.is_dw else 1
        self.kernel = QuantEqKernel(
            quant_min=quant_min,
            quant_max=quant_max,
            value=kernel,
            num_groups=self.quantization_groups,
            channels_in=self.in_channels,
            channels_out=self.out_channels,
            cin_axis=cin_axis,
            cout_axis=0,
        )

    def set_encoded(self, is_encoded: bool):
        self._is_encoded = is_encoded
        encoded_pad = self.padding_value.get_encoded_weight()
        if isinstance(encoded_pad, float):
            non_zero_pad = encoded_pad != 0
        else:
            non_zero_pad = any(encoded_pad != 0)
        self.force_encoded_padding = is_encoded and self.has_padding and non_zero_pad

    def forward_mac(self, inp: torch.Tensor, **kwargs):
        # NOTE: as long as we work in native scale, we don't need to apply the mac shift
        kernel = self.kernel.get_weight()
        bias = self.bias.get_weight() if self.bias is not None else None
        if self.kernel_size == (1, 1) and self.stride == (1, 1) and self.groups == 1:
            in_dim = inp.ndim
            in_shape = inp.shape
            if in_dim == 4:
                inp = inp.view(inp.size(0), inp.size(1), -1)
                inp = inp.transpose(1, 2)
            result = nn.functional.linear(inp, kernel.view(self.out_channels, -1), bias=bias)
            if in_dim == 4:
                result = result.transpose(2, 1)
                result = result.view(inp.size(0), result.size(1), *in_shape[2:])
            return result
        else:
            padding = self.padding

            if self.padding_mode != "zeros" or self.force_encoded_padding:
                # NOTE: as long as we run in native scaling, we don't need custom pad value, if we change that we'll need use the zp
                pad_value = self.padding_value.get_weight()
                inp = PaddingUtils.manual_padding(inp, pad_value, self._reversed_padding_repeated_twice)
                padding = _pair(0)

            result = nn.functional.conv2d(inp, kernel, bias, self.stride, padding, self.dilation, self.groups)
            return result

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        self.padding_value.forward_encoding(encoding, **kwargs)
        encoding = self.kernel.forward_encoding(encoding, **kwargs)
        encoding = self.bias.forward_encoding(encoding, **kwargs)
        encoding = self.accumulator_quantizer.forward_encoding(encoding, **kwargs)
        return encoding

    def extra_repr(self):
        s1 = "{in_channels}, {out_channels}, {kernel_size}"
        s1 += ", stride={stride}, padding={padding}, dilation={dilation}, groups={groups}"
        s1 += ", padding_mode={padding_mode}"
        if self.expected_spatial_shape is not None:
            s1 += ", spatial_shape={expected_spatial_shape}"
        return s1.format(**self.__dict__)
