from typing import Optional

import torch
import torch.nn as nn

from hailo_model_optimization.saitama.framework.apu_modules.apu_base import APUBase
from hailo_model_optimization.saitama.framework.common.saitama_definitions import Encoding, size_2_int
from hailo_model_optimization.saitama.framework.common.utils import _pair, parse_explicit_padding


class APUMaxpool2d(APUBase):
    def __init__(
        self,
        kernel_size: size_2_int,
        stride: Optional[size_2_int] = None,
        padding: size_2_int = 0,
        dilation: size_2_int = 1,
        spatial_shape=None,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride) if (stride is not None) else self.kernel_size
        self.padding = padding
        self.dilation = _pair(dilation)

        self.padding_mode = "zeros"
        self._padding_value = -torch.inf

        if isinstance(padding, str):
            padding, padding_mode, explicit_padding = parse_explicit_padding(
                padding, kernel_size, stride, (1, 1), spatial_shape
            )
            self._reversed_padding_repeated_twice = explicit_padding
            pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = explicit_padding
            if pad_beg_h == pad_end_h and pad_beg_w == pad_end_w:
                self.padding = (pad_beg_h, pad_beg_w)
            else:
                self.padding = padding
                self.padding_mode = padding_mode
            self.has_padding = any(p != 0 for p in self._reversed_padding_repeated_twice)
        valid_padding_modes = {"constant", "zeros"}
        if self.padding_mode not in valid_padding_modes:
            raise ValueError(
                f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'"
            )

    def forward(self, x, **kwargs):
        padding = self.padding
        if self.padding_mode != "zeros":
            x = nn.functional.pad(
                x,
                self._reversed_padding_repeated_twice,
                mode=self.padding_mode,
                value=self._padding_value,
            )
            padding = 0
        return nn.functional.max_pool2d(
            x,
            self.kernel_size,
            self.stride,
            padding,
            self.dilation,
        )

    def forward_encoding(self, encoding: Encoding, **kwargs) -> Encoding:
        return encoding
