from typing import Union

import numpy as np
import tensorflow as tf
from tensorflow.python.keras.utils import conv_utils

from hailo_model_optimization.acceleras.atomic_ops._misc_internals import get_tf_same_padding
from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_PADDING_NEG_INF_VALUE,
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.padding_utils import diy_pad


class MaxPoolOp(BaseAtomicOp):
    """
    MaxPool core operation

    Args:
        padding : Witch padding are supported
                  VALID:  Either the `string`
                  SAME: or `"VALID"` indicating the type of

        stride_align :
        pool_size: Size of the pool block
        strides: stride of the sliding window for each dimension of the input tensor

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        padding: Union[str, PaddingType] = "VALID",
        stride_align: Union[str, StrideAlignType] = "NW",
        pool_size=(3, 3),
        strides=(1, 1),
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.padding = PaddingType(padding)
        self.stride_align = StrideAlignType(stride_align)  # For our DIY padding
        self.pool_size = pool_size
        self.strides = strides
        self.padding_const_value = DEFAULT_PADDING_NEG_INF_VALUE

    @property
    def padding_const_value_q(self):
        if self.padding_const_value == DEFAULT_PADDING_NEG_INF_VALUE:
            quant_value = self.padding_const_value if self.input_lossy_elements[0].is_lossless else 0
            return tf.cast(quant_value, tf.float32)

        input_scale = self.input_scale if self.input_scale.shape == () else self.input_scale[0]
        zp = tf.reduce_mean(self.input_zero_points[0])
        quantized_val = self.padding_const_value / input_scale + zp
        return self.input_lossy_element(quantized_val)

    def call_hw_sim(self, inputs, **kwargs):
        return self._call_maxpool(inputs[0], self.padding_const_value_q)

    def create_weight_quant_element(self, **kwargs):
        pass

    def call_native(self, inputs, **kwargs):
        return self._call_maxpool(inputs[0], self.padding_const_value)

    def _call_maxpool(self, inputs, pad_value):
        unpadded_input = inputs
        if self.padding == PaddingType.VALID:
            padded_input = unpadded_input
        elif self.padding == PaddingType.SAME:  # DIY padding
            pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = get_tf_same_padding(
                [1, 1],
                *unpadded_input.shape[1:3],
                *self.pool_size,
                *self.strides,
            )
            padded_input = diy_pad(
                unpadded_input,
                pad_value,
                self.stride_align,
                pad_beg_h,
                pad_end_h,
                pad_beg_w,
                pad_end_w,
            )
        else:
            raise AccelerasImplementationError(f"Padding type {self.padding.value} is not supported")

        return tf.nn.max_pool(padded_input, ksize=self.pool_size, strides=self.strides, padding="VALID")

    def _compute_output_shape(self, input_shape):
        padding = self.padding.value.lower()
        rows = conv_utils.conv_output_length(input_shape[1], self.pool_size[0], padding, self.strides[0])
        cols = conv_utils.conv_output_length(input_shape[2], self.pool_size[1], padding, self.strides[1])
        return [input_shape[0], rows, cols, input_shape[3]]

    def export_hw_params(self):
        return {"padding_const_value": self.padding_const_value_q.numpy().astype(np.uint16)}

    def export_quant_weights(self):
        return {"padding_const_value": self.padding_const_value_q.numpy()}

    def export_weights(self):
        return {"padding_const_value": self.padding_const_value}

    def import_weights(self, layer_params: LayerParams):
        param_dict = dict(layer_params)
        padding_const_value = param_dict.get("padding_const_value", self.padding_const_value)
        self.padding_const_value = np.float32(padding_const_value)

    def enforce_encoding(self):
        """
        Non-arithmetic layer, encoding should be the same across...
        NOTE: input has the wider range, so it makes sense to propagate forward.
        """
        self.output_scale = self.input_scales[0]
        self.output_zero_point = self.input_zero_points[0]

    def define_encodings(self, flow):
        super().define_encodings(flow)
        if self.padding != PaddingType.VALID:
            flow.get_encoding(f"{self.full_name}/input_zero_point:0").scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        enc.identity(f"{self.full_name}/output_scale:0", f"{self.full_name}/input_scale:0")
        enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
