from dataclasses import dataclass

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import InputSpec

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import BaseQuantElement, MACDataQuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import SHIFT_CALCULATE_BUFFER
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImportParamConfigMismatch,
    AccelerasPrematureQuantOperation,
)
from hailo_model_optimization.acceleras.utils.opt_utils import calculate_shifts


@dataclass
class ELWAWeightsLossy(BaseWeightLossyElements):
    factor: BaseLossyElement


class ReduceSumOp(BaseAtomicOp):
    """
    This class emulates the reduce sum operation
    """

    weight_lossy_elements: ELWAWeightsLossy

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, groups=1, reduce_axes=[3], height_groups=1, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.weight_lossy_elements = ELWAWeightsLossy(
            factor=IdentityElement(name=f"{self.full_name}/ie:element_wise_add")
        )
        self._groups = groups  # groups is the number of groups in the last channel
        self._reduce_axes = reduce_axes
        self._height_groups = height_groups
        self.input_spec = InputSpec(ndim=4)  # TODO should be list??
        self.kernel = np.array(1, dtype=np.float32)
        self.kernel_scale = np.array(1, dtype=np.float32)
        self.kernel_zero_point = 0
        self.pre_acc_shift = 0

    def create_weight_quant_element(self, factor_bits=8, signed=True):
        self.weight_lossy_elements = ELWAWeightsLossy(
            factor=MACDataQuantElement(bits=factor_bits, signed=signed, name=f"{self.full_name}/qe:factor"),
        )

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """
        selected_shape = [self.input_shape[_] for _ in self._reduce_axes]
        self._reduce_size = tf.math.reduce_prod(selected_shape)
        self._reduce_size = self._reduce_size / self._groups
        self._reduce_size = self._reduce_size / self._height_groups
        self._reduce_size = tf.cast(self._reduce_size, self.FLOAT_TYPE_TF)

        if self.input_scale_is_scalar():
            output_scale_scalar = self.input_scale * self.kernel_scale * 2**self.pre_acc_shift
            self.output_scale = tf.repeat(output_scale_scalar, self._groups)
        else:
            groups_scale = tf.reduce_mean(tf.reshape(self.input_scale, (self._groups, -1)), axis=1)
            pre_shift_out_scale = groups_scale * self.kernel_scale
            self.output_scale = pre_shift_out_scale * tf.cast(2**self.pre_acc_shift, pre_shift_out_scale.dtype)
        if len(tf.convert_to_tensor(self.input_zero_point).shape) == 0:
            shift_val = tf.cast(
                tf.pow(tf.constant(2.0, dtype=tf.float32), tf.cast(-self.pre_acc_shift, dtype=tf.float32)),
                self.kernel_q.dtype,
            )
            self.output_zero_point = self.input_zero_points[0] * self.kernel_q * self._reduce_size * shift_val
        else:
            groups_zp = tf.reshape(self.input_zero_point, (self._groups, -1))[:, 0]
            self.output_zero_point = groups_zp * self.kernel_q * self._reduce_size * 2.0 ** (-self.pre_acc_shift)

    def _get_max_kernel_value(self, max_output_per_channel, pre_acc_shift, utilize_wraparound=False, hw_shifts=None):
        """
        We want to get the maximal value for the kernel.
        1. s_inp- inp_scale
        2. s_kernel - kernel_scale
        3  y - output_native of op

        y_q = y/s_acc
        s_acc =s_inp*s_kernel*2**shift

        we know:
        y_q <= self.output_lossy_element.max_value ==>>
        y/(s_inp*s_kernel*2**shift)<= self.output_lossy_element.max_value ==>>

        1/s_kernel <= self.output_lossy_element.max_value / (y/s_inp) ==>>
        np.abs(self.kernel) *1/kernel_scale <=  self.output_lossy_element.max_value / (y/s_x)/np.abs(self.kernel)

        kernel_q <=  self.output_lossy_element.max_value / (y/s_x)/np.abs(self.kernel)

        """
        input_scale_max = np.max(self.input_scales[0])  # assume single scale
        max_acc_value_from_stats = (np.max(max_output_per_channel) / input_scale_max) / np.abs(self.kernel)
        # clip max_acc_value_from_stats
        selected_shape = [self.input_shape[_] for _ in self._reduce_axes]
        reduce_size = (np.prod(selected_shape) / self._groups) / self._height_groups
        max_acc_theoretical_value = (
            np.max(
                np.maximum(
                    np.abs(self.input_lossy_element.min_value - self.input_zero_point + 1),
                    np.abs(self.input_lossy_element.max_value - self.input_zero_point),
                )
            )
            * reduce_size
        )
        if max_acc_value_from_stats > max_acc_theoretical_value:
            expected_max_acc = max_acc_theoretical_value
            shift_buffer = 0
        else:
            expected_max_acc = max_acc_value_from_stats
            shift_buffer = SHIFT_CALCULATE_BUFFER

        max_acc_value = (
            self.output_lossy_element.bins_count if utilize_wraparound else self.output_lossy_element.max_value
        )

        max_kernel_shifts = np.log2(max_acc_value / expected_max_acc) - shift_buffer + pre_acc_shift

        available_values = [2**s for s in range(self.weight_lossy_elements.factor.bits - 1) if s <= max_kernel_shifts]

        if len(available_values) == 0:
            # there are no shift available, we must take the kernel to be 1
            max_kernel_q = 1
            accumulator_size = self.output_lossy_element.bits  # get accumulator
            pre_acc_shift, _ = calculate_shifts(
                expected_max_acc, accumulator_size, 0, hw_shifts=hw_shifts, utilize_wraparound=utilize_wraparound
            )
        else:
            max_kernel_q = np.max(available_values)

        return max_kernel_q, pre_acc_shift

    def create_hw_params(self, max_output_per_channel, utilize_wraparound=False, hw_shifts=None, **kwargs):
        # 16 bit quantization doesn't support activation shift and in that case it's set to zero.
        weight_bits = self.weight_lossy_elements.factor.bits
        if weight_bits == 15 or weight_bits == 16:
            pre_acc_shift = tf.convert_to_tensor(0)
        else:
            pre_acc_shift = tf.convert_to_tensor(1)

        if hw_shifts is not None:
            pre_acc_shift = hw_shifts[0]

        max_kernel_q, pre_acc_shift = self._get_max_kernel_value(
            max_output_per_channel, pre_acc_shift, utilize_wraparound=utilize_wraparound, hw_shifts=hw_shifts
        )
        self.pre_acc_shift = pre_acc_shift

        kernel_q_candidate = np.sign(self.kernel) * max_kernel_q
        self.kernel_scale = self.kernel / kernel_q_candidate
        self.kernel_zero_point = 0
        self.enforce_encoding()

    def export_independent_params(self):
        return {
            "kernel_scale": np.array(self.kernel_scale, np.float32),
            "kernel_zero_point": np.array(self.kernel_zero_point, np.float32),
            "mac_shift": np.array(self.pre_acc_shift, np.float32),
            "shift_delta": np.array(0, np.float32),
            "weight_bits": np.array(self.weight_lossy_elements.factor.bits, np.float32),
            "reduce_size": np.array(self._reduce_size, np.float32),
        }

    def import_independent_params(self, params):
        if not isinstance(self.weight_lossy_elements.factor, BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_independent_params", self.full_name)
        kernel_bits = self.weight_lossy_elements.factor.bits
        imported_kernel_bits = params["weight_bits"]
        if kernel_bits != imported_kernel_bits:
            raise AccelerasImportParamConfigMismatch("factor_bits", kernel_bits, imported_kernel_bits, self.full_name)
        self.pre_acc_shift = params["mac_shift"]
        self.shift_delta = params["shift_delta"]
        self.kernel_scale = params["kernel_scale"]
        self.kernel_zero_point = params["kernel_zero_point"]
        self._reduce_size = params["reduce_size"]

    def export_quant_weights(self):  # TODO This need to be derecated or change name
        shape = np.ones(4).astype(int)
        if 1 in self._reduce_axes:
            shape[0] = self.input_shape[1] // self._height_groups  # height reduce
        if 2 in self._reduce_axes:
            shape[1] = self.input_shape[2]  # width reduce
        reduce_features = 3 in self._reduce_axes
        shape[2] = self.input_shape[3] // self._groups if reduce_features else 1
        shape[3] = self._groups if reduce_features else self.input_shape[3]
        shape = list(shape)
        kernel_q = np.tile(self.kernel_q, shape)
        return {
            "quant_kernel": np.reshape(kernel_q, shape),
        }

    def export_weights(self):
        return {"kernel": np.array(self.kernel)}

    def export_hw_params(self):
        w_type = np.int8 if self.weight_lossy_elements.factor.bits <= 8 else np.int16
        kernel = self.export_quant_weights()["quant_kernel"].astype(w_type)
        return {
            "kernel": kernel,
            "zp_kernel": np.array(self.kernel_zero_point, np.int32),
            "output_stage/mult_shift": np.array(self.pre_acc_shift, np.uint8),
        }

    def import_weights(self, layer_params):
        kernel = layer_params.get("kernel", None)
        if kernel is not None:
            self.kernel = kernel

    def _build(self, input_shape):
        self._input_group_size = int(np.prod([input_shape[axes] for axes in self._reduce_axes]) // self._groups)
        self._input_height_group_size = int(input_shape[1] // self._height_groups)

    def _compute_output_shape(self, input_shapes):
        shape = []
        if isinstance(self._reduce_axes, list) or isinstance(self._reduce_axes, tuple):
            reduce_axes = self._reduce_axes
        else:
            reduce_axes = [self._reduce_axes]
        num_axes = len(input_shapes)
        for axis, val in enumerate(input_shapes):
            if axis in reduce_axes or axis - num_axes in reduce_axes:
                shape.append(self._groups)
            else:
                shape.append(val)

        if shape[1] == 1:
            shape[1] = self._height_groups

        return shape

    @property
    def kernel_q(self):
        kernel = tf.cast(self.kernel / self.kernel_scale, tf.float32)
        return self.weight_lossy_elements.factor(kernel)

    @property
    def groups(self):
        return self._groups

    @property
    def height_groups(self):
        return self._height_groups

    @property
    def reduce_axes(self):
        return self._reduce_axes

    def call_hw_sim(self, inputs, **kwargs):
        """
        the call_hw_sim of reduce_sum will also work for reduce mean with the flowing:
        call_hw_sim:
            1.reduce_sum - factor will be self.kernel_q which is the qunat value of 1 (or -1)
            2.reduce_mean - factor will be self.kernel_q which is the qunat value of 1/n (or -1/n)
        """
        shift_val = tf.cast(
            tf.pow(tf.constant(2.0, dtype=tf.float32), tf.cast(-self.pre_acc_shift, dtype=tf.float32)), inputs[0].dtype
        )
        mult_res = [inputs[0] * shift_val]
        mult_res *= self.kernel_q

        if self.bit_exact:
            mult_res = self.output_lossy_element(mult_res)

        return self.call_native(mult_res, factor=1, **kwargs)

    def call_native(self, inputs, factor=None, **kwargs):
        """
        the native call of reduce_sum will also work for reduce mean with the flowing:
        native_call:
            1.reduce_sum - factor will be self.kernel which is 1 (or -1)
            2.reduce_mean - factor will be self.kernel which is 1/n (or -1/n)

        """
        inp = inputs[0]
        factor = self.kernel if factor is None else factor
        if self._groups > 1:
            # Change axes to be in the range in {0, 1, 2, 4}
            reduce_axes = [a % 4 + (a % 4) // 3 for a in self._reduce_axes]
            group_input = tf.reshape(inp, [-1, inp.shape[1], inp.shape[2], self._groups, inp.shape[3] // self._groups])
            reduce_sum = (
                tf.reduce_sum(input_tensor=group_input, axis=reduce_axes, keepdims=True, name=f"{self.op_name}_group")
                * factor
            )
            return tf.reshape(
                reduce_sum,
                [-1, reduce_sum.shape[1], reduce_sum.shape[2], reduce_sum.shape[3] * reduce_sum.shape[4]],
            )
        elif self._height_groups > 1:
            self._logger.debug(
                "Reduce sum layer with large input tensors is defused along the "
                "height dimension to avoid overflow in accumulator. "
                "Output is expected to be partial sum per slice.",
            )
            concat_inputs = []
            for g in range(self._height_groups):
                g_size = self._input_height_group_size
                group_input = inp[:, g * g_size : (g + 1) * g_size, :, :]
                reduce_sum = (
                    tf.reduce_sum(
                        input_tensor=group_input,
                        axis=self._reduce_axes,
                        keepdims=True,
                        name=f"{self.op_name}_{g}",
                    )
                    * factor
                )
                concat_inputs.append(reduce_sum)
            return tf.concat(concat_inputs, axis=1, name=f"concat_{self.op_name}")
        else:
            return tf.reduce_sum(input_tensor=inp, axis=self._reduce_axes, keepdims=True) * factor

    @property
    def op_name(self):
        return "reduce_sum"

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
