from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import EmulationType
from hailo_model_optimization.acceleras.utils.flow_state_utils import AtomicOpState

MAX_SHIFT_MU = 7


@dataclass
class ReduceSumWeightsLossy(BaseWeightLossyElements):
    clip1: BaseLossyElement
    clip2: BaseLossyElement


class ReduceSumPPUOp(BaseAtomicOp):
    """
    This op is used to emulate the reduce sum operation in the PPU of layer norm/softmax
    This class emulates the reduce sum operation
    """

    # there are 3 types of reduce sum:
    # type 1:     square the input and then reduce sum - and multiply by the number of channels then shift the resutls |rms_norm=False, square=True, is_softmax=False
    # type 2:     reduce sum and then shift the resutls                                                                |rms_norm=False, square=False (is_softmax==True/False)
    # type 3:     reduce sum - rms - then it it is 0 - redundent op                                                    |rms_norm=True,  square=False, is_softmax=False

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        is_softmax=False,
        groups=1,
        reduce_axes=(3,),
        logger=None,
        fully_native=None,
        rms_norm=False,
        square=False,
        **kwargs,
    ):
        """
        Args:
            reduce_axes: The axis to reduce
            groups : The number of groups to reduce
            rms_norm: Whether the layer is RMS normalization (Root Mean Square Layer Normalization) -
                f(x) = x / sqrt(mean(x^2))

        """
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        if groups > 1 and not is_softmax:
            raise ValueError("We dont support groups >1 for layer norm")

        if rms_norm and is_softmax:
            raise ValueError("We dont support rms_norm in softmax layer")

        if square and is_softmax:
            raise ValueError("We dont support square in softmax layer")

        if groups > 1 and rms_norm:
            raise ValueError("groups>1 and rms_norm cannot be used together")

        self._reduce_axes = reduce_axes
        self._square = square
        self._groups = groups
        self._rms_norm = rms_norm

        self.is_lossless = True
        self.f_out = 1
        self.is_softmax = is_softmax
        if square or self.is_softmax:
            self.set_type_emulation(EmulationType.DOUBLE)

        self.weight_lossy_elements = ReduceSumWeightsLossy(
            clip1=IdentityElement(name=f"{self.full_name}/ie:cip1"),
            clip2=IdentityElement(name=f"{self.full_name}/ie:clip2"),
        )
        self._shift_cfg = 0.0

    @property
    def rms_norm(self) -> bool:
        return self._rms_norm

    @property
    def bit_exact_supported(self) -> bool:
        return True

    @property
    def shift_cfg(self):
        return self._shift_cfg

    def _build(self, input_shape):
        self.f_out = input_shape[-1]

    def _compute_output_shape(self, input_shapes):
        return input_shapes

    def disable_lossy(self, **kwargs):
        super().disable_lossy(**kwargs)
        self.is_lossless = True

    def enable_lossy(self, **kwargs):
        super().enable_lossy(**kwargs)
        self.is_lossless = False

    def create_hw_params(self, force_shift=None):
        """
        get hw params for the layer
            - force_shift: set the shift value - if None then claculate the shift value

        """

        if force_shift is not None:
            shift_cfg = force_shift
        else:
            out_stats = self.get_output_stats(0)
            max_final_out_by_channel = np.maximum(np.abs(out_stats.min), np.abs(out_stats.max), dtype=np.float32)
            current_max_from_stats_per_channel = max_final_out_by_channel / self.input_scale

            current_max_from_stats = np.max(current_max_from_stats_per_channel)
            max_value_in_hw = self.output_lossy_element.max_value
            shift_cfg = -1 * np.floor(np.log2(max_value_in_hw / current_max_from_stats))
            shift_cfg = max(shift_cfg, 0)
            # the maximal value from stats
            if np.max(shift_cfg) > MAX_SHIFT_MU:
                shift_delta = shift_cfg - MAX_SHIFT_MU
                shift_cfg = MAX_SHIFT_MU
                # TODO add mabey a fail in the future and solve it in ac different way
                self._logger.warning(f"Shift value is too high, reducing it by {shift_delta} to {MAX_SHIFT_MU}")
        self.is_lossless = False
        self._shift_cfg = shift_cfg
        self.enforce_encoding()

    def _reduce_sum_call_native(self, inp, should_repeat=True, **kwargs):
        """
        the native reduce_sum
        """
        if self._groups > 1:
            # Change axes to be in the range in {0, 1, 2, 4}
            reduce_axes = [a % 4 + (a % 4) // 3 for a in self._reduce_axes]
            group_input = tf.reshape(inp, [-1, inp.shape[1], inp.shape[2], self._groups, inp.shape[3] // self._groups])
            reduce_sum = tf.reduce_sum(
                input_tensor=group_input, axis=reduce_axes, keepdims=True, name=f"{self.name}_group"
            )
            val = tf.reshape(
                reduce_sum, [-1, reduce_sum.shape[1], reduce_sum.shape[2], reduce_sum.shape[3] * reduce_sum.shape[4]]
            )
        else:
            val = tf.reduce_sum(input_tensor=inp, axis=self._reduce_axes, keepdims=True)
        if should_repeat:
            to_repeat = tf.shape(inp)[-1] // self._groups
            val = tf.repeat(val, to_repeat, axis=self._reduce_axes[0])
        return val

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]

        # type 3
        if self.rms_norm:
            inp = tf.zeros_like(inp)
            # just returens zeros
            return inp

        # type 1
        if self._square:
            # mult inp**2 by f_out
            f_out = tf.cast(self.f_out, self.FLOAT_TYPE_TF)
            inp = tf.math.square(inp) * f_out

        resu = self._reduce_sum_call_native(inp)
        return resu

    def reduce_sum_logic(self, inputs):
        inp = inputs[0]
        ## type 3
        if self.rms_norm:
            inp = tf.zeros_like(inp)
            return inp
        ## type 1
        if self._square:
            # mult inp**2 by f_out
            inputs_square = tf.math.square(inp)
            inp = inputs_square * self.f_out
        ## type 1-2
        sum_x = self._reduce_sum_call_native(inp)

        shift_cfg = 2**self.shift_cfg
        x_after_shift = sum_x / tf.cast(shift_cfg, self.FLOAT_TYPE_TF)
        clipped_sum_x_after_shift = self.weight_lossy_elements.clip2(x_after_shift)
        return clipped_sum_x_after_shift

    def call_bit_exact(self, inputs, **kwargs):
        inp = inputs[0]
        # type 3
        if self.rms_norm:
            inp = tf.zeros_like(inp)
            return inp

        if self.is_softmax:
            self._verify_data_dtype(inp, 32, False, f"inputs_{self.name}")
        else:
            self._verify_data_dtype(inp, 16, True, f"inputs_{self.name}")
        ## type 2
        if self._square:
            # mult inp**2 by f_out
            inputs_square = tf.math.square(inp)
            self._verify_data_dtype(inputs_square, 32, False, f"inputs_square_{self.name}")
            self._verify_data_dtype(self.f_out, 12, False, f"f_out_{self.name}")

            f_out = tf.cast(self.f_out, self.INT_TYPE_TF)
            inp = inputs_square * f_out

            self._verify_data_dtype(inp, 44, False, f"inputs_squre_{self.name}")

        ## type 1-2
        sum_x = self._reduce_sum_call_native(inp, should_repeat=False)
        self._verify_data_dtype(
            sum_x, self.weight_lossy_elements.clip1.bits, self.weight_lossy_elements.clip1.signed, "sum_x"
        )  # (uint56 - square, int28)

        x_after_shift = self.signed_shift_bankers_rounding(sum_x, self.shift_cfg)

        if self._square:
            self._verify_data_dtype(x_after_shift, 56, False, "x2_after_shift_square")
        else:
            # clip only if not square
            x_after_shift = self.hw_simulation_by_lossy_element(x_after_shift, self.weight_lossy_elements.clip2)
        self._verify_data_dtype(
            x_after_shift,
            self.weight_lossy_elements.clip2.bits,
            self.weight_lossy_elements.clip2.signed,
            "clipped_sum_x_after_shift",
        )  # (uint56 - square, int20)
        x_after_shift = tf.repeat(x_after_shift, tf.shape(inp)[-1] // self._groups, axis=-1)  # handle groups

        return x_after_shift

    def call_hw_sim(self, inputs, **kwargs):
        if self.is_lossless:
            inputs_native = self._decode_inputs(inputs)
            result = self.call_native(inputs_native, **kwargs)

            result = result if isinstance(result, list) else [result]
            res = self._encode_outputs(result)
            return res
        else:
            return self.reduce_sum_logic(inputs)

    def create_weight_quant_element(self, bit_clip1, bit_clip2):
        signed = not (self._square or self.is_softmax)
        self.weight_lossy_elements = ReduceSumWeightsLossy(
            clip1=QuantElement(signed=signed, bits=bit_clip1, wraparound=False, name=f"{self.full_name}/qe:clip1"),
            clip2=QuantElement(signed=signed, bits=bit_clip2, wraparound=False, name=f"{self.full_name}/qe:clip2"),
        )

    def enforce_encoding(self, *args, **kwargs):
        input_scale = self.input_scale
        output_channels = self.output_shape[-1]

        if np.array(input_scale).shape != ():
            input_scale = np.array(input_scale).reshape((self._groups, -1))[:, 0]
            to_repeat = output_channels / self._groups
        else:
            to_repeat = output_channels

        if self._square:
            input_scale = input_scale**2

        output_scale = input_scale * (2**self.shift_cfg)
        self.output_scale = np.repeat(output_scale, to_repeat)

    def export_independent_params(self):
        return {
            "shift_cfg": np.float32(self.shift_cfg),
        }

    def import_independent_params(self, params):
        self._shift_cfg = params["shift_cfg"]

    def export_hw_params(self):
        return {
            f"shift_{self.name}": np.array(self.shift_cfg, np.uint8),
        }

    def import_flow_state(self, atomic_state: AtomicOpState):
        super().import_flow_state(atomic_state)
        self.is_lossles = atomic_state.aops_dict_kwgs["is_lossless"]

    def export_flow_state(self) -> AtomicOpState:
        aops_state = super().export_flow_state()
        aops_state.aops_dict_kwgs["is_lossless"] = self.is_lossless
        return aops_state
