from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.utils.flow_state_utils import AtomicOpState


@dataclass
class SquareWeightsLossy(BaseWeightLossyElements):
    shift_mu: BaseLossyElement


class LootSquareOp(BaseAtomicOp):
    """
    This class emulates the reduce mean operation
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.weight_lossy_elements = SquareWeightsLossy(
            shift_mu=IdentityElement(name=f"{self.full_name}/ie:shift_mu"),
        )
        self.is_lossles = True

        self._lut_square_scale_in = 1
        self._lut_square_scale_out = 1
        self._lut_square_zp_in = 0
        self._square_shift = 0

    def export_quant_weights(self):
        return {}

    def export_independent_params(self):
        return {
            "lut_square_scale_in": np.float32(self._lut_square_scale_in),
            "lut_square_scale_out": np.float32(self._lut_square_scale_out),
            "lut_square_zp_in": np.float32(self._lut_square_zp_in),
            "square_shift": np.float32(self._square_shift),
        }

    def import_independent_params(self, params):
        self._lut_square_scale_in = np.float32(params["lut_square_scale_in"])
        self._lut_square_scale_out = np.float32(params["lut_square_scale_out"])
        self._lut_square_zp_in = np.float32(params["lut_square_zp_in"])
        self._square_shift = np.float32(params["square_shift"])

        self._square_lut_keys, self._square_lut_values = self._get_square_lut(8, 16)
        self._square_lut = self._build_lut(self._square_lut_keys, self._square_lut_values)

    def create_hw_params(self, scale_1, scale_2, zp, expected_max_accumulato_mu, **kwargs):
        self._square_shift = 2
        self._lut_square_scale_in = scale_1
        self._lut_square_zp_in = zp
        self._lut_square_scale_out = scale_2

        self._square_lut_keys, self._square_lut_values = self._get_square_lut(8, 16)
        self._square_lut = self._build_lut(self._square_lut_keys, self._square_lut_values)
        self.is_lossles = False
        self.enforce_encoding()

    def disable_lossy(self, **kwargs):
        super().disable_lossy(**kwargs)
        self.is_lossles = True

    def enable_lossy(self, **kwargs):
        super().enable_lossy(**kwargs)
        self.is_lossles = False

    @staticmethod
    def calc_lut(inp, lut_func, s_in, zp_in, s_out, quant=True):
        inp_native = (inp - zp_in) * s_in
        out_native = lut_func(inp_native)
        out_quant = out_native / s_out
        if quant:
            return np.round(out_quant)
        else:
            return out_quant

    def square_lut(self, inp):
        if self.is_lossles:
            return self.calc_lut(
                inp,
                lambda x: x**2,
                self._lut_square_scale_in,
                self._lut_square_zp_in,
                self._lut_square_scale_out,
                quant=False,
            )
        inp_type = inp.dtype
        lut_result = self._square_lut.lookup(tf.cast(inp, self._square_lut.key_dtype))
        return tf.cast(lut_result, inp_type)

    def _build_lut(self, keys, values):
        init = tf.lookup.KeyValueTensorInitializer(keys, values)
        table = tf.lookup.StaticHashTable(init, default_value=-1)
        return table

    def _get_square_lut(self, in_bits, out_bits):
        keys = np.arange(2**in_bits)
        out_max_val = (2**out_bits) - 1
        values = self.calc_lut(
            keys,
            lambda x: x**2,
            self._lut_square_scale_in,
            self._lut_square_zp_in,
            self._lut_square_scale_out,
            quant=True,
        )
        values = np.clip(values, 0, out_max_val).astype(np.int64)
        return keys, values

    def _compute_square(self, mu):
        mu_after_shift = mu / (2.0**self._square_shift)
        mu_after_shift = self.weight_lossy_elements.shift_mu(mu_after_shift)
        mu2 = self.square_lut(mu_after_shift)
        return mu2

    def call_native(self, inputs, **kwargs):
        return tf.math.square(inputs[0])

    def call_hw_sim(self, inputs, **kwargs):
        return self._compute_square(inputs[0])

    def create_weight_quant_element(self, **kwargs):
        self.weight_lossy_elements = SquareWeightsLossy(
            shift_mu=QuantElement(signed=False, bits=8, wraparound=False, name=f"{self.full_name}/qe:shift_mu"),
        )

    def enforce_encoding(self, *args, **kwargs):
        pass

    def import_flow_state(self, atomic_state: AtomicOpState):
        super().import_flow_state(atomic_state)
        self.is_lossles = atomic_state.aops_dict_kwgs["is_lossless"]

    def export_flow_state(self) -> AtomicOpState:
        aops_state = super().export_flow_state()
        aops_state.aops_dict_kwgs["is_lossless"] = self.is_lossles
        return aops_state
