from typing import Union

from hailo_model_optimization.acceleras.atomic_ops.maxpool_op import MaxPoolOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_PADDING_NEG_INF_VALUE,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import get_hn_padding


class HailoMaxPool(BaseHailoSingleAtomic):
    """Represents `maxpool` layer in the hn"""

    _hn_type = LayerType.MAXPOOL
    OP_NAME = "maxpool_op"

    def __init__(
        self,
        name: str,
        padding: Union[str, PaddingType] = "VALID",
        stride_align: Union[str, StrideAlignType] = "NW",
        pool_size=(3, 3),
        strides=(1, 1),
        logger=None,
        **kwargs,
    ):
        atomic_maxpool = MaxPoolOp(
            f"{name}/{self.OP_NAME}",
            padding=padding,
            stride_align=stride_align,
            pool_size=pool_size,
            strides=strides,
            logger=logger,
        )
        super().__init__(name=name, core_op=atomic_maxpool, logger=logger, **kwargs)

        self.encoding_const = False

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = hn_element["params"]
        padding, stride_align = get_hn_padding(params)

        layer = cls(
            name=lname,
            padding=padding,
            stride_align=stride_align,
            pool_size=params["kernel_shape"][1:3],
            strides=params["strides"][1:3],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    @property
    def is_precision_transparent(self) -> bool:
        return True

    def get_equalization_handler_type(self, predecessor_index=None):
        non_trivial_padding_const_value = self.atomic_op.padding_const_value not in [0, DEFAULT_PADDING_NEG_INF_VALUE]
        if non_trivial_padding_const_value:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)
        return EquivClassification(LayerHandlerType.transparent, is_source=False)
