import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_PADDING_NEG_INF_VALUE,
    LayerHandlerType,
    LayerSupportStatus,
)
from hailo_sdk_common.hailo_nn.exceptions import ProtobufExportError
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationTypes,
    DefuseType,
    LayerType,
    PaddingType,
    PaddingTypes,
    TemporaryPaddingType,
)
from hailo_sdk_common.hailo_nn.hn_layers.layer_common import input_to_output_height_width
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class PoolingLayer(LayerWithActivation):
    """
    handles both maxpool and avgpool layers.
    the default init for the class is maxpool.
    class-method based constructors assign pooling type based op type
    """

    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = {
        LayerType.maxpool: False,
        LayerType.avgpool: True,
    }
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.maxpool
        self._kernel_shape = None
        self._strides = None
        self._padding = None
        self._external_padding_value = None
        self._padding_const_value = 0
        self._set_kernel_to_input_shape = False
        self._dilations = None
        self._required_padding_correction = False
        self._count_include_pad = True
        self._ceil_mode = False

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        op,
        kernel_shape,
        strides,
        padding,
        padding_vals=None,
        should_set_kernel_to_input_shape=False,
        output_shapes=None,
        count_include_pad=True,
        ceil_mode=False,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes=output_shapes)
        layer.op = op
        layer.kernel_shape = kernel_shape
        layer.strides = strides
        layer.padding = padding
        if layer.padding == TemporaryPaddingType.external_undecided:
            layer.external_padding_value = padding_vals
        layer.padding_const_value = DEFAULT_PADDING_NEG_INF_VALUE if op == LayerType.maxpool else 0
        layer.count_include_pad = count_include_pad
        layer.ceil_mode = ceil_mode
        layer.toggle_set_kernel_to_input_shape(should_set_kernel_to_input_shape)
        return layer

    @property
    def is_activation_fusible(self):
        return self.op == LayerType.avgpool

    def set_input_shapes(self, input_shapes, validate=True):
        original_input_shape = self.input_shape
        super().set_input_shapes(input_shapes, validate)
        if self._set_kernel_to_input_shape or self.is_global_avg_pool(original_input_shape):
            self._kernel_shape = [1, self.input_shape[1], self.input_shape[2], 1]
            self._strides = [1, self.input_shape[1], self.input_shape[2], 1]

    @property
    def kernel_short_description(self):
        return f" ({self.kernel_shape[1]}x{self.kernel_shape[2]}/{self.strides[1]})"

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["kernel_shape"] = self._kernel_shape
        result["params"]["strides"] = self._strides
        result["params"]["padding"] = self._padding.value
        result["params"]["activation"] = self._activation.value
        if self._required_padding_correction:
            result["params"]["required_padding_correction"] = self._required_padding_correction
        result["params"]["count_include_pad"] = self._count_include_pad
        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        if self._op == LayerType.maxpool:
            node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_MAXPOOL
        elif self._op == LayerType.avgpool:
            node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_AVGPOOL
        else:
            raise ProtobufExportError(f"Unexpected op {self._op} in {self}")

        _, node.kernel_shape.height, node.kernel_shape.width, node.kernel_shape.features = self.kernel_shape
        _, node.strides.height, node.strides.width, _ = self.strides
        _, node.dilations.height, node.dilations.width, _ = self.dilations
        node.padding = pb_wrapper.PADDING_TYPE_TO_PB[self.padding]
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self._activation]
        return node

    def _calc_output_shape(self):
        _, in_h, in_w, in_f = self.input_shape
        if self._external_padding_value:
            top, bottom, left, right, front, back = self._external_padding_value
            in_h = in_h + top + bottom
            in_w = in_w + right + left
            in_f = in_f + front + back
        input_shape = [-1, in_h, in_w, in_f]

        # edge case - global average pool in case the input shape isn't known when creating the layer
        if self._set_kernel_to_input_shape:
            self._kernel_shape = [1, input_shape[1], input_shape[2], 1]
            self._strides = self._kernel_shape

        input_w, input_f = input_shape[2], input_shape[3]

        defuse_types = self._defuse_params.get("defuse_types")
        if defuse_types and DefuseType.global_avgpool_transposed_input in defuse_types:
            input_w, input_f = input_f, input_w
        input_shape = [1, input_shape[1], input_w, input_f]

        output_h, output_w = input_to_output_height_width(
            input_shape,
            self._kernel_shape[1:3],
            self._strides,
            self._padding,
        )
        if self.defuse_type is DefuseType.super_dw and "defuse_original_features" in self.defuse_params:
            output_f = self._defuse_params.get("defuse_original_features")
        elif "defuse_features" in self.defuse_params and self.defuse_type not in [
            DefuseType.none,
            DefuseType.compute_lanes,
        ]:
            output_f = self.defuse_features
        else:
            output_f = input_shape[3]

        if self.ceil_mode:
            output_h += 1 if (input_shape[1] - self.kernel_shape[1]) % self.strides[1] else 0
            output_w += 1 if (input_shape[2] - self.kernel_shape[2]) % self.strides[2] else 0

        return [-1, output_h, output_w, output_f]

    @property
    def requires_quantized_weights(self):
        if self._op not in PoolingLayer._REQUIRES_QUANTIZED_WEIGHTS:
            self._logger.warning(
                f"Pooling layer {self.name} of type {self._op.value} does not specify whether \
                quantized weights are required. Assuming False.",
            )
            return False
        return PoolingLayer._REQUIRES_QUANTIZED_WEIGHTS[self._op]

    @property
    def defuse_features(self):
        # Defuse features can't be zero. If they are, we assume that there was a problem in the hn
        if "defuse_features" in self.defuse_params and self.defuse_type is not DefuseType.none:
            return self._defuse_params.get("defuse_features")
        self._logger.warning(
            f"Layer {self.name} has defuse_features=0. Assuming invalid hn, and using kernel_features={self.kernel_shape[-1]}",
        )
        return self.kernel_shape[-1]

    @property
    def kernel_shape(self):
        return self._kernel_shape

    @kernel_shape.setter
    def kernel_shape(self, kernel_shape):
        self._kernel_shape = list(kernel_shape) if kernel_shape else kernel_shape

    @property
    def strides(self):
        return self._strides

    @strides.setter
    def strides(self, strides):
        self._strides = strides

    @property
    def stride_height(self):
        return self.strides[1]

    @property
    def stride_width(self):
        return self.strides[2]

    @stride_height.setter
    def stride_height(self, value):
        self.strides[1] = value

    @stride_width.setter
    def stride_width(self, value):
        self.strides[2] = value

    @property
    def padding(self):
        return self._padding

    @padding.setter
    def padding(self, padding):
        self._padding = padding

    @property
    def external_padding_value(self):
        return self._external_padding_value

    @external_padding_value.setter
    def external_padding_value(self, external_padding_value):
        self._external_padding_value = external_padding_value

    @property
    def padding_const_value(self):
        return self._padding_const_value

    @padding_const_value.setter
    def padding_const_value(self, padding_const_value):
        self._padding_const_value = padding_const_value

    @property
    def count_include_pad(self):
        return self._count_include_pad

    @count_include_pad.setter
    def count_include_pad(self, count_include_pad):
        self._count_include_pad = count_include_pad

    @property
    def dilations(self):
        return self._dilations

    @dilations.setter
    def dilations(self, dilations):
        self._dilations = dilations

    @property
    def required_padding_correction(self):
        return self._required_padding_correction

    @required_padding_correction.setter
    def required_padding_correction(self, required_padding_correction):
        self._required_padding_correction = required_padding_correction

    @property
    def ceil_mode(self):
        return self._ceil_mode

    @ceil_mode.setter
    def ceil_mode(self, ceil_mode):
        self._ceil_mode = ceil_mode

    def toggle_set_kernel_to_input_shape(self, should_set_kernel_to_input_shape=False):
        self._set_kernel_to_input_shape = should_set_kernel_to_input_shape

    def _is_global_pooling(self, pooling_op, input_shape):
        if not input_shape:
            input_shape = self.input_shape

        # incase there is no padding the stride value is irrelevant
        # if there is padding the stride value should be equal to the kernel size
        stride_condition = (
            [stride == kernel_shape for stride, kernel_shape in zip(self.strides[1:3], self.kernel_shape[1:3])]
            if self.op == pooling_op and self.padding != PaddingType.valid
            else [self.op == pooling_op] * 2
        )
        return (
            all(stride_condition) and self.kernel_shape[1] == input_shape[1] and self.kernel_shape[2] == input_shape[2]
        )

    def is_global_avg_pool(self, input_shape=None):
        return self._is_global_pooling(LayerType.avgpool, input_shape)

    def is_global_max_pool(self, input_shape=None):
        return self._is_global_pooling(LayerType.maxpool, input_shape)

    def is_tiled_avg_pool(self):
        return (
            self.op == LayerType.avgpool
            and self.kernel_shape[1] == self.stride_height
            and self.kernel_shape[2] == self.stride_width
        )

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer._kernel_shape = hn["params"]["kernel_shape"]
        layer._strides = hn["params"]["strides"]
        layer._padding = PaddingTypes[hn["params"]["padding"]]
        if hn["type"] == LayerType.avgpool.value:
            if "activation" in hn["params"]:
                layer._activation = ActivationTypes[hn["params"]["activation"]]
            layer._op = LayerType.avgpool
        layer.dilations = hn["params"].get("dilations", [1, 1, 1, 1])
        if "required_padding_correction" in hn["params"]:
            layer.required_padding_correction = hn["params"]["required_padding_correction"]
        elif layer.op == LayerType.avgpool and layer.padding != PaddingType.valid and not layer.count_include_pad:
            # the avg pooling layer requires padding correction
            layer.required_padding_correction = True
        else:
            layer.required_padding_correction = False
        layer.count_include_pad = hn["params"].get("count_include_pad", False)
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._kernel_shape = [1, pb.kernel_shape.height, pb.kernel_shape.width, pb.kernel_shape.features]
        layer._strides = [1, pb.strides.height, pb.strides.width, 1]
        layer._padding = pb_wrapper.PADDING_PB_TO_TYPE[pb.padding]
        if pb.type == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_AVGPOOL:
            layer._op = LayerType.avgpool
            if pb.HasField("activation"):
                layer._activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        if pb.dilations.HasField("height") and pb.dilations.HasField("width"):
            layer.dilations = [1, pb.dilations.height, pb.dilations.width, 1]
        else:
            layer.dilations = [1, 1, 1, 1]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.op = old_layer.op
        layer.kernel_shape = old_layer.kernel_shape.copy()
        layer.strides = old_layer.strides.copy()
        layer.padding = old_layer.padding
        layer.padding_const_value = old_layer.padding_const_value
        if layer.op == LayerType.avgpool:
            layer.activation = old_layer.activation
        return layer

    def move_params(self, layer):
        super().move_params(layer)
        if layer.op in [LayerType.maxpool, LayerType.avgpool]:
            self._padding_const_value = layer.padding_const_value

    @property
    def kernel_height(self):
        return self.kernel_shape[1]

    @property
    def kernel_width(self):
        return self.kernel_shape[2]

    @kernel_height.setter
    def kernel_height(self, value):
        self.kernel_shape[1] = value

    @kernel_width.setter
    def kernel_width(self, value):
        self.kernel_shape[2] = value

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def ibc_supported(self):
        if self.op == LayerType.maxpool:
            return LayerSupportStatus.unsupported
        elif self.op == LayerType.avgpool:
            if self.is_global_avg_pool() or self.is_tiled_avg_pool():
                return LayerSupportStatus.unsupported
            else:
                return LayerSupportStatus.supported
        else:
            raise KeyError(f"{self.op.name} is unknown op type for pooling layer")
