from dataclasses import dataclass
from typing import Union

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import InputSpec
from tensorflow.python.keras.utils import conv_utils

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.encoding.encoding_layer import TensorInitializer
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import BaseQuantElement, MACDataQuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    SHIFT_CALCULATE_BUFFER,
    IgnoreHwLimitationAssertionPolicy,
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImportParamConfigMismatch,
    AccelerasInitializationError,
    AccelerasNumerizationError,
    AccelerasPrematureQuantOperation,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import calculate_shifts
from hailo_model_optimization.acceleras.utils.padding_utils import handle_padding


@dataclass
class AvgpoolWeightsLossy(BaseWeightLossyElements):
    kernel: BaseLossyElement


class AvgPoolOp(BaseAtomicOp):
    """
    avg pool - w.o. activation, elw-add (but with bias) as implemented by Hailo MAC.
    NOTE 1: The output of this op is accumulator contents, with the appropriate scale
            (taking into account Mr.Acc. shift)
    NOTE 2: The implementation is using tf.nn.avg_pool although no averaging is done on MAC, only convolution.

    """

    weight_lossy_elements: AvgpoolWeightsLossy

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        kernel_size,
        strides=(1, 1),
        padding: Union[str, PaddingType] = "VALID",
        stride_align: Union[str, StrideAlignType] = "NW",
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        """
        Args:
            kernel_size : a tuple representing the pooling windows size
            strides, padding, stride_align : arguments forwarded to padding and avg_pool tf calls

        """
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

        self.weight_lossy_elements = AvgpoolWeightsLossy(kernel=IdentityElement(name=f"{self.full_name}/ie:avg_pool"))
        self.padding = PaddingType(padding)
        if self.padding not in {PaddingType.VALID, PaddingType.SAME}:
            raise AccelerasInitializationError(f"Padding type {padding} is not supported in {type(self)}")
        self.padding_const_value = 0

        self.stride_align = StrideAlignType(stride_align)  # For our DIY padding
        self.pre_acc_shift = 0
        spatial_dims = 2
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, spatial_dims, "kernel_size")
        self.strides = conv_utils.normalize_tuple(strides, spatial_dims, "strides")  # , allow_zero=True)
        self.input_spec = InputSpec(min_ndim=spatial_dims + 2)
        self.kernel = None
        self.shift_delta = 0  # TODO
        self.kernel_shape_prod = np.prod(self.kernel_size)
        self.weight = 1.0 / self.kernel_shape_prod
        self.kernel_scale = 1
        self.kernel_zero_point = 0

    def create_hw_params(self, max_final_accumulator_by_channel, hw_shifts=None, **kwargs):
        self.quantized_kernel_candidate = 4  # preferred quantized value

        self.kernel_scale = self.weight / self.quantized_kernel_candidate
        acc_scale_before_shift = self.input_scales[0] * self.kernel_scale
        expected_max_accumulator_value = np.max(np.abs(max_final_accumulator_by_channel / acc_scale_before_shift))
        accumultor_size = self.output_lossy_element.bits  # get accumulator
        pre_acc_shift, shift_delta = calculate_shifts(
            expected_max_accumulator_value,
            accumultor_size,
            SHIFT_CALCULATE_BUFFER,
            hw_shifts=hw_shifts,
        )
        if shift_delta > 0:
            if shift_delta > 2 and self._ignore_hw_limitation_assertion != IgnoreHwLimitationAssertionPolicy.enabled:
                raise AccelerasNumerizationError(
                    f"Shift delta in {self.full_name} is larger than 2 ({shift_delta:.2f}), cannot quantize. "
                    "A possible solution is to use a pre-quantization model script command to reduce global "
                    "average-pool spatial dimensions, please refer to the user guide for more info.",
                )
            # HW can't provide a shift large enough to avoid final accumulator overflow,
            #  we need smaller numeric values by making kernel range wider
            shift_delta = np.ceil(shift_delta)
            self._logger.info(
                f"No shifts available for layer {self.full_name}, using max shift instead. delta={shift_delta:.04f}"
            )
            self.kernel_scale *= 2**shift_delta

        # no more use but will be exported to qnpz debug info (TODO)
        self.shift_delta = shift_delta
        self.pre_acc_shift = pre_acc_shift
        self.kernel_zero_point = 0
        self.enforce_encoding()

    @property
    def final_quantized_kernel(self):
        return self.get_quant_kernel(training=False)

    @staticmethod
    def safe_divide(x, y):
        if tf.is_tensor(x) or tf.is_tensor(y):
            if not tf.is_tensor(x):
                x = tf.convert_to_tensor(x, dtype=y.dtype)
            if not tf.is_tensor(y):
                y = tf.convert_to_tensor(y, dtype=x.dtype)
        return tf.divide(x, y)

    def get_quant_kernel(self, training=False):
        weight_divided_by_scale = tf.cast(self.safe_divide(self.weight, self.kernel_scale), tf.float32)
        return self.weight_lossy_elements.kernel(weight_divided_by_scale, training=training)

    @property
    def padding_const_value_q(self):
        input_scale = self.input_scale if self.input_scale.shape == () else self.input_scale[0]
        zp = tf.reduce_mean(self.input_zero_points[0])
        quantized_val = self.padding_const_value / input_scale + zp
        return self.input_lossy_element(quantized_val)

    def call_native(self, inputs, **kwargs):
        padded_input = handle_padding(
            inputs[0],
            self.padding,
            self.kernel_size,
            self.strides,
            self.padding_const_value,
            self.stride_align,
        )
        result = tf.nn.avg_pool(padded_input, ksize=self.kernel_size, strides=self.strides, padding="VALID")
        return result

    def call_hw_sim(self, inputs, training=False, **kwargs):
        padded_input = handle_padding(
            inputs[0],
            self.padding,
            self.kernel_size,
            self.strides,
            self.padding_const_value_q,
            self.stride_align,
        )
        # here we simulate the hardware. As hardware does convolution, we need to multiply by self.kernel_shape_prod
        # as it's the windows size the is the denominator in tf.nn.avg_pool. Besides that, we multiply the input by the
        # weight and shift it
        quant_kernel = self.get_quant_kernel(training=training)
        if self.bit_exact:
            final_avgpool = self.lossy_avg_pool(
                padded_input,
                ksize=self.kernel_size,
                strides=self.strides,
                kernel=quant_kernel,
                pre_acc_shift=self.pre_acc_shift,
            )
        else:
            tf_avgpool = tf.nn.avg_pool(padded_input, ksize=self.kernel_size, strides=self.strides, padding="VALID")
            # this holds up to a constant, the weight and the pool size
            final_avgpool = tf_avgpool * quant_kernel * self.kernel_shape_prod / 2**self.pre_acc_shift
        return final_avgpool

    def lossy_avg_pool(self, inputs, ksize, strides, kernel, pre_acc_shift=0):
        k_height, k_width = ksize
        s_height, s_width = strides
        ch_out = inputs.shape[-1]
        ta = tf.TensorArray(dtype=inputs.dtype, size=0, dynamic_size=True)
        for i in range(ch_out):
            patches = tf.image.extract_patches(
                inputs[:, :, :, i : i + 1],
                sizes=[1, k_height, k_width, 1],
                strides=[1, s_height, s_width, 1],
                rates=[1, 1, 1, 1],
                padding="VALID",
            )
            mul = self.output_lossy_element(patches * kernel / 2**pre_acc_shift)
            ch_res = tf.reduce_sum(mul, axis=-1)
            ta = ta.write(i, ch_res)
        return tf.transpose(ta.stack(), perm=(1, 2, 3, 0))

    def export_independent_params(self):
        return {
            "mac_shift": np.array(self.pre_acc_shift, np.float32),
            "shift_delta": np.array(self.shift_delta, np.float32),
            "kernel_scale": np.array(self.kernel_scale, np.float32),
            "kernel_zero_point": np.array(self.kernel_zero_point, np.float32),
            "weight_bits": np.array(self.weight_lossy_elements.kernel.bits, np.float32),
        }

    def import_independent_params(self, params):
        if not isinstance(self.weight_lossy_elements.kernel, BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_independent_params", self.full_name)
        kernel_bits = self.weight_lossy_elements.kernel.bits
        imported_kernel_bits = params["weight_bits"]
        if kernel_bits != imported_kernel_bits:
            raise AccelerasImportParamConfigMismatch("kernel_bits", kernel_bits, imported_kernel_bits, self.full_name)
        self.pre_acc_shift = params["mac_shift"]
        self.shift_delta = params["shift_delta"]
        self.kernel_scale = params["kernel_scale"]
        self.kernel_zero_point = params["kernel_zero_point"]

    def export_hw_params(self):
        w_type = np.int8 if self.weight_lossy_elements.kernel.bits <= 8 else np.int16
        kernel_q = self.get_quant_kernel(training=False).numpy().astype(w_type)
        kernel_q = np.ones(self.kernel_shape, dtype=w_type) * kernel_q
        return {
            "kernel": kernel_q,
            "zp_kernel": np.array(self.kernel_zero_point, np.int32),
            "padding_const_value": self.padding_const_value_q.numpy().astype(np.uint16),
            "output_stage/mult_shift": np.array(self.pre_acc_shift, np.uint8),
        }

    def export_quant_weights(self):
        kernel_q = np.ones(self.kernel_shape) * self.final_quantized_kernel.numpy()
        return {
            "quant_kernel": np.array(kernel_q, np.float32),
            "padding_const_value": self.padding_const_value_q.numpy(),
        }

    def create_weight_quant_element(self, kernel_bits, signed=True):
        self.weight_lossy_elements = AvgpoolWeightsLossy(
            kernel=MACDataQuantElement(bits=kernel_bits, signed=signed, name=f"{self.full_name}/qe:kernel"),
        )

    def _build(self, input_shape):
        """
        .. follows keras/layers/convolutional.py ..
        """
        input_channels = int(input_shape[-1])
        self.kernel_shape = self.kernel_size + (1, input_channels)

    def _compute_output_shape(self, input_shape):
        padding = self.padding.value.lower()
        rows = conv_utils.conv_output_length(input_shape[1], self.kernel_size[0], padding, self.strides[0])
        cols = conv_utils.conv_output_length(input_shape[2], self.kernel_size[1], padding, self.strides[1])
        return [input_shape[0], rows, cols, input_shape[3]]

    def export_weights(self, apply_scale_factors=False):
        kernel = np.ones(self.kernel_shape) * self.final_quantized_kernel
        if apply_scale_factors:
            # TODO implement it when we need to.
            raise NotImplementedError("apply_scale_factors not supported yet")
        return {"kernel": kernel, "padding_const_value": self.padding_const_value}

    def enforce_encoding(self, *args, training=False, **kwargs):
        self.output_scale = self.input_scales[0] * self.kernel_scale * 2**self.pre_acc_shift
        self.output_zero_point = tf.cast(
            self.input_zero_points[0]
            * self.get_quant_kernel(training=training)
            * 2 ** (-1.0 * self.pre_acc_shift)
            * self.kernel_shape_prod,
            tf.float32,
        )

    @property
    def bit_exact_supported(self) -> bool:
        return True

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.add_encoding(
            f"{self.full_name}/kernel_scale:0",
            EncodingType.Scale,
            scalar=False,
            shape=(),
            initializer=TensorInitializer(self.kernel_scale),
        )
        flow.add_encoding(
            f"{self.full_name}/mac_shift:0",
            EncodingType.Scale,
            scalar=False,
            shape=(),
            initializer=TensorInitializer(self.pre_acc_shift),
            quant=True,
            quant_min=1.0,
            quant_max=4.0,
        )

        if self.padding != PaddingType.VALID:
            flow.get_encoding(f"{self.full_name}/input_zero_point:0").scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        # 16 bit quantization doesn't support activation shift and in that case it's set to zero.
        if self.output_lossy_element.bits == 32:
            enc.identity(f"{self.full_name}/mac_shift:0", 0.0)

        # compute output_scale
        enc.mul(enc.dummy(0), f"{self.full_name}/input_scale:0", f"{self.full_name}/kernel_scale:0", inverse=True)
        enc.shift(f"{self.full_name}/output_scale:0", enc.dummy(0), f"{self.full_name}/mac_shift:0")

        # compute output_zero_point
        enc.div(enc.dummy(1), self.weight, f"{self.full_name}/kernel_scale:0")
        enc.cast(enc.dummy(2), enc.dummy(1))
        enc.lossy_element(enc.dummy("final_quantized_kernel"), enc.dummy(2), self.weight_lossy_elements.kernel)
        enc.mul(enc.dummy(3), f"{self.full_name}/input_zero_point:0", enc.dummy("final_quantized_kernel"))
        enc.shift(enc.dummy(3), enc.dummy(4), f"{self.full_name}/mac_shift:0")
        enc.mul(enc.dummy(5), enc.dummy(4), self.kernel_shape_prod)
        enc.cast(f"{self.full_name}/output_zero_point:0", enc.dummy(5))

    def define_const_constraints(self, enc):
        super().define_const_constraints(enc)
        enc.identity(f"{self.full_name}/kernel_scale:0", self.kernel_scale)
        enc.identity(f"{self.full_name}/mac_shift:0", self.pre_acc_shift)

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.kernel_scale = encodings[f"{self.full_name}/kernel_scale:0"]
        self.pre_acc_shift = encodings[f"{self.full_name}/mac_shift:0"]

    def import_weights(self, layer_params: LayerParams):
        param_dict = dict(layer_params)
        padding_const_value = param_dict.get("padding_const_value", self.padding_const_value)
        self.padding_const_value = np.float32(padding_const_value)
        if "kernel" in param_dict:
            self.kernel_shape = layer_params["kernel"].shape
