from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.utils.flow_state_utils import AtomicOpState
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams

INV_LUT_OUT_BITS = 9
INV_LUT_IN_BITS = 16  # this should be actually 14 but we will now treat it as 16


@dataclass
class InvWeightsLossy(BaseWeightLossyElements):
    clip: BaseLossyElement


class RootvarOp(BaseAtomicOp):
    """
    This class emulates the reduce mean operation
    """

    num_inputs = 2
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.weight_lossy_elements = InvWeightsLossy(
            clip=IdentityElement(name=f"{self.full_name}/ie:clip"),
        )
        self.is_lossles = True  # TODO correct this spelling in diffrent PR
        self._lut_inv_scale_in = 1
        self._lut_inv_scale_out = 1

    def export_quant_weights(self):
        return {}

    def export_independent_params(self):
        return {
            "lut_inv_scale_in": np.float32(self._lut_inv_scale_in),
            "lut_inv_scale_out": np.float32(self._lut_inv_scale_out),
        }

    def import_independent_params(self, params):
        self._lut_inv_scale_in = np.float32(params["lut_inv_scale_in"])
        self._lut_inv_scale_out = np.float32(params["lut_inv_scale_out"])

        self._inv_lut_keys, self._inv_lut_values = self._get_inv_lut(INV_LUT_IN_BITS, INV_LUT_OUT_BITS)
        self._inv_lut = self._build_lut(self._inv_lut_keys, self._inv_lut_values)

    def import_weights(self, layer_params: LayerParams, **kwargs):
        epsilon = layer_params.get("epsilon", 1e-06)
        self.epsilon = epsilon

    def export_weights(self):
        return {"epsilon": self.epsilon}

    def create_hw_params(self, scale_2, scale_3, **kwargs):
        self._lut_inv_scale_in = scale_2
        self._lut_inv_scale_out = scale_3

        self._inv_lut_keys, self._inv_lut_values = self._get_inv_lut(INV_LUT_IN_BITS, INV_LUT_OUT_BITS)
        self._inv_lut = self._build_lut(self._inv_lut_keys, self._inv_lut_values)
        self.is_lossles = False
        self.enforce_encoding()

    def disable_lossy(self, **kwargs):
        super().disable_lossy(**kwargs)
        self.is_lossles = True

    def enable_lossy(self, **kwargs):
        super().enable_lossy(**kwargs)
        self.is_lossles = False

    @staticmethod
    def calc_lut(inp, lut_func, s_in, zp_in, s_out, quant=True):
        inp_native = (inp - zp_in) * s_in
        out_native = lut_func(inp_native)
        out_quant = out_native / s_out
        if quant:
            return np.round(out_quant)
        else:
            return out_quant

    def inv_lut(self, inp):
        if self.is_lossles:
            return self.calc_lut(
                inp,
                lambda x: 1.0 / tf.math.sqrt(x + self.epsilon),
                self._lut_inv_scale_in,
                0,
                self._lut_inv_scale_out,
                quant=False,
            )
        inp_type = inp.dtype
        lut_result = self._inv_lut.lookup(tf.cast(inp, self._inv_lut.key_dtype))
        return tf.cast(lut_result, inp_type)

    def _build_lut(self, keys, values):
        init = tf.lookup.KeyValueTensorInitializer(keys, values)
        table = tf.lookup.StaticHashTable(init, default_value=-1)
        return table

    def _get_inv_lut(self, in_bits, out_bits):
        keys = np.arange(2**in_bits)
        val = self.calc_lut(
            keys,
            lambda x: 1.0 / tf.math.sqrt(x + self.epsilon),
            self._lut_inv_scale_in,
            0,
            self._lut_inv_scale_out,
            quant=True,
        )
        out_max_val = (2**out_bits) - 1
        values = np.clip(val, 0, out_max_val).astype(np.int64)
        return keys, values

    def _compute_variance(self, x2, mu2, training=False):
        var = x2 - mu2
        var = self.weight_lossy_elements.clip(var, training=training)

        result = self.inv_lut(var)
        return result

    def call_native(self, inputs, **kwargs):
        x2 = inputs[0]
        mu2 = inputs[1]
        diff = x2 - mu2
        return 1.0 / tf.math.sqrt(diff + self.epsilon)

    def call_hw_sim(self, inputs, training=False, **kwargs):
        return self._compute_variance(inputs[0], inputs[1], training=training)

    def create_weight_quant_element(self, **kwargs):
        self.weight_lossy_elements = InvWeightsLossy(
            clip=QuantElement(signed=False, bits=INV_LUT_IN_BITS, wraparound=False, name=f"{self.full_name}/qe:clip"),
        )

    def import_flow_state(self, atomic_state: AtomicOpState):
        super().import_flow_state(atomic_state)
        self.is_lossles = atomic_state.aops_dict_kwgs["is_lossless"]

    def export_flow_state(self) -> AtomicOpState:
        aops_state = super().export_flow_state()
        aops_state.aops_dict_kwgs["is_lossless"] = self.is_lossles
        return aops_state
