from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.utils.flow_state_utils import AtomicOpState

MANTISA_BITS = 6
MU_SHIFT = 2
REDUCE_SUM_BITS = 10


@dataclass
class ReduceSumWeightsLossy(BaseWeightLossyElements):
    clip1: BaseLossyElement
    clip2: BaseLossyElement


class ReduceMeanNormOp(BaseAtomicOp):
    """
    This class emulates the reduce mean operation
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, reduce_axes=3, logger=None, fully_native=None, square=False, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._square = square
        self._reduce_axes = reduce_axes

        self.is_lossles = True
        self._mlt_cfg = None
        self._shift_cfg = 0
        self._mult_shift = 0

        self.weight_lossy_elements = ReduceSumWeightsLossy(
            clip1=IdentityElement(name=f"{self.full_name}/ie:clip1"),
            clip2=IdentityElement(name=f"{self.full_name}/ie:clip2"),
        )

    def disable_lossy(self, **kwargs):
        super().disable_lossy(**kwargs)
        self.is_lossles = True

    def enable_lossy(self, **kwargs):
        super().enable_lossy(**kwargs)
        self.is_lossles = False

    def _build(self, input_shape):
        if self._mlt_cfg is None:
            self._mlt_cfg = 1.0 / input_shape[-1]

    def get_div_cfg(self, factor):
        bits = MANTISA_BITS
        mlt_cfg, exp = np.frexp(factor)  # mlt_cfg*2** exp
        shift_cfg = min(63, bits - exp)
        mlt_cfg = mlt_cfg * 2**bits
        return mlt_cfg, shift_cfg

    @property
    def mlt_cfg(self):
        if self.is_lossles:
            return self._mlt_cfg
        return int(np.round(self._mlt_cfg))

    @property
    def shift_cfg(self):
        return self._shift_cfg - self._mult_shift

    def create_hw_params(self, s_in=1, s_out=1, shift=0, **kwargs):
        self._mult_shift = shift
        ratio = s_in / s_out
        f_out = self.input_shape[-1]
        factor = (1.0 / f_out) * ratio
        self.f_out = f_out
        self.s_in = s_in
        self.s_out = s_out
        self.factor = factor
        mlt1, shift1 = self.get_div_cfg(factor)
        self._mlt_cfg = mlt1  # x_2
        self._shift_cfg = shift1
        self.enforce_encoding()
        self.is_lossles = False

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        if self._square:
            inp = tf.math.square(inp)
        return tf.reduce_mean(input_tensor=inp, axis=self._reduce_axes, keepdims=True)

    def call_hw_sim(self, inputs, **kwargs):
        inp = inputs[0]
        if self._square:
            # remove zero_point
            inp = inp - self.input_zero_point
            inp = tf.math.square(inp)

        sum_x = tf.reduce_sum(input_tensor=inp, axis=-1, keepdims=True)

        clipped_sum_x = self.weight_lossy_elements.clip1(sum_x)

        mean_x_tag = tf.multiply(clipped_sum_x, self.mlt_cfg)
        mean_x_clipped = self.weight_lossy_elements.clip2(mean_x_tag / (2**self.shift_cfg))
        return mean_x_clipped

    def create_weight_quant_element(self, bit_clip1, bit_clip2):
        signed = False
        self.weight_lossy_elements = ReduceSumWeightsLossy(
            clip1=QuantElement(signed=signed, bits=bit_clip1, wraparound=False, name=f"{self.full_name}/qe:clip1"),
            clip2=QuantElement(signed=signed, bits=bit_clip2, wraparound=False, name=f"{self.full_name}/qe:clip2"),
        )

    def enforce_encoding(self, *args, **kwargs):
        if not self._square:
            self.output_scale = self.input_scales[0] / (2**self._mult_shift)
            self.output_zero_point = self.input_zero_points[0] * (2**self._mult_shift)

    def export_independent_params(self):
        return {
            "mult_shift": np.float32(self._mult_shift),
            "mlt_cfg": np.float32(self._mlt_cfg),
            "shift_cfg": np.float32(self._shift_cfg),
        }

    def import_independent_params(self, params):
        self._mult_shift = params["mult_shift"]
        self._mlt_cfg = params["mlt_cfg"]
        self._shift_cfg = params["shift_cfg"]

    def import_flow_state(self, atomic_state: AtomicOpState):
        super().import_flow_state(atomic_state)
        self.is_lossles = atomic_state.aops_dict_kwgs["is_lossless"]

    def export_flow_state(self) -> AtomicOpState:
        aops_state = super().export_flow_state()
        aops_state.aops_dict_kwgs["is_lossless"] = self.is_lossles
        return aops_state
