import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops._misc_internals import hailo_reciprocal
from hailo_model_optimization.acceleras.utils.opt_utils import verify_data_dtype

## lut class
mantissa_BITS = 16
DEBUG = False
TF_TYPE = tf.float64
TYPE_INT = tf.int64


class InverseFuncLUTCalculator(object):
    """
    this class is used to calculate the inverse function using a lookup table
    it is users to calculate either
        1.inverse
        2.inverse square root
    """

    def __init__(self, sqrt=True):
        self.function = self._get_lut_function(sqrt)
        self.lut_scale_out = self._calc_output_scale(sqrt)

        self.lut_keys, self.lut_values, self.lut_deriv_values = self._get_lut(8, 16)
        self.lut_table = self._build_lut(self.lut_keys, self.lut_values)
        self.lut_deriv_table = self._build_lut(self.lut_keys, self.lut_deriv_values)

        self.exp_even = sqrt
        self.debug = DEBUG

    def _calc_output_scale(self, sqrt, bits=16):
        # calculate the output scale for lut
        if sqrt:
            max_val = tf.cast(np.array(0.25), TF_TYPE)
        else:
            max_val = tf.cast(np.array(0.5), TF_TYPE)
        scale_out = self.function(max_val) / (2**bits - 1)

        # note that is it always 2/(2**16-1)

        return scale_out

    def _get_lut_function(self, sqrt=True):
        def inv(x):
            eps = 1e-36
            return hailo_reciprocal(x, epsilon=eps, tf_type=x.dtype)

        def inv_qrt(x):
            eps = 1e-36
            return hailo_reciprocal(tf.math.sqrt(tf.maximum(x, eps)), tf_type=x.dtype)

        if sqrt:
            return inv_qrt
        else:
            return inv

    def _build_lut(self, keys, values):
        # generic function to build a lookup table"
        init = tf.lookup.KeyValueTensorInitializer(keys, values)
        table = tf.lookup.StaticHashTable(init, default_value=-1)
        return table

    # emulation function
    def lut_emulation(self, inp, bit_exact=False, is_lossless=False):
        """
        Calculate the lut_function using a lookup table
        """
        if is_lossless:
            return self.calc_lut(inp, self.function, self.lut_scale_out, quant=False)

        if bit_exact:
            inp_type = inp.dtype
            inp1 = tf.cast(inp, self.lut_table.key_dtype)
            lut_result = self.lut_table.lookup(inp1)
            return tf.cast(lut_result, inp_type)
        else:
            return self.calc_lut(inp, self.function, self.lut_scale_out, quant=True)

    def lut_deriv_emulation(self, inp, bit_exact=False, is_lossless=False):
        """
        Calculate the inverse deriv using a lookup table
        """
        if is_lossless:
            return self.calc_lut_deriv(inp, self.function, self.lut_scale_out, quant=False)

        if bit_exact:
            inp_type = inp.dtype
            inp1 = tf.cast(inp, self.lut_deriv_table.key_dtype)
            lut_result = self.lut_deriv_table.lookup(inp1)
            return tf.cast(lut_result, inp_type)
        else:
            return self.calc_lut_deriv(inp, self.function, self.lut_scale_out, quant=True)

    # emulation function end

    def fp_split(self, x, float_type):
        """
        find mantissa and exponent
        """
        x = tf.cast(x, float_type)
        index = tf.floor(tf.experimental.numpy.log2(x)) + 1
        if self.exp_even:
            index = index + index % 2
        index = tf.cast(index, float_type)
        mantissa = x / (2**index)
        return mantissa, index

    @staticmethod
    def calc_lut(inp, lut_func, s_out, quant=True):
        bits_in = 8.0
        bits_out = 16
        out_max_val = (2**bits_out) - 1

        inp_native = tf.cast(inp / 2**bits_in, TF_TYPE)

        out_native = lut_func(inp_native)
        out_quant = out_native / s_out
        if quant:
            return tf.clip_by_value(tf.round(out_quant), 0, out_max_val)
        else:
            return out_quant

    def calc_lut_deriv(self, inp, lut_func, s_out, quant=True):
        inp_lut = self.calc_lut(inp, self.function, self.lut_scale_out, quant=quant)
        inp_lut_minus_1 = self.calc_lut(inp - 1, self.function, self.lut_scale_out, quant=quant)
        if quant:
            return (inp_lut - inp_lut_minus_1) // (2**8)
        else:
            return (inp_lut - inp_lut_minus_1) / (2**8)

    def _get_lut(self, in_bits, out_bits):
        keys = np.arange(2**in_bits)  # start from 1 ( to avoid zero problems)
        keys[0] = 1
        val = self.calc_lut(keys, self.function, self.lut_scale_out, quant=True)
        out_max_val = (2**out_bits) - 1

        values = tf.cast(tf.clip_by_value(val, 0, out_max_val), TYPE_INT)
        values_deriv = np.concatenate((np.diff(values), np.array([0]))) // (2**8)
        keys_to_send = np.arange(2**in_bits)
        return keys_to_send, values, values_deriv

    def _split_mantissa(self, mantissa_quant_16):
        ## calculate the mantissa msb and lsb
        bits = 8
        MSB_mantissa_quant = mantissa_quant_16 // (2**bits)
        LSB_mantissa_quant = mantissa_quant_16 % (2**bits)
        return MSB_mantissa_quant, LSB_mantissa_quant

    def lut_vals_helper(self, inputs, float_type, int_type, bit_exact, is_lossless=False):
        """
        Helper function to calculate the lookup table (LUT) values for the inverse function.

        Args:
            inputs (tf.Tensor): The input tensor for which the LUT values are to be calculated.
            float_type (tf.DType): The floating-point data type to be used.
            int_type (tf.DType): The integer data type to be used.
            bit_exact (bool, optional): Flag indicating whether to perform bit-exact calculations. Defaults to False.
            is_lossless (bool): Flag indicating whether the calculation should be lossless.


        Returns:
            tuple: A tuple containing:
                - lut_mantissa (tf.Tensor): The calculated LUT mantissa values.
                - real_exp (tf.Tensor): The calculated real exponent values.
        """

        # get mantissa and exponent
        mantissa_native, exponent = self.fp_split(inputs, float_type=float_type)
        # calc f(x_0) + f'(x_0) * (x - x_0)
        mantissa_quant_16 = tf.math.floor(mantissa_native * (2**16))

        MSB_mantissa_quant, LSB_mantissa_quant = self._split_mantissa(mantissa_quant_16)  # get x_0

        LSB_mantissa_quant = tf.cast(LSB_mantissa_quant, TF_TYPE)
        MSB_mantissa_quant = tf.cast(MSB_mantissa_quant, TF_TYPE)

        lut_MSB_mantissa_quant = self.lut_emulation(
            MSB_mantissa_quant, bit_exact=bit_exact, is_lossless=is_lossless
        )  # get f(x_0)
        derivative = self.lut_deriv_emulation(
            MSB_mantissa_quant, bit_exact=bit_exact, is_lossless=is_lossless
        )  # get f'(x_0)

        # if bit_exact:
        #     print(f"1(HAILO) mantissa_quant_16: {mantissa_quant_16[...,0]}")
        #     print(f"2(HAILO) MSB_mantissa_quant: {MSB_mantissa_quant[...,0]}")
        #     print(f"3(HAILO) LSB_mantissa_quant: {LSB_mantissa_quant[...,0]}")
        #     print(f"4(HAILO) lut_MSB_mantissa_quant: {lut_MSB_mantissa_quant[...,0]}")
        #     print(f"5(HAILO) derivative: {derivative[...,0]}")
        LSB_mantissa_quant = tf.cast(LSB_mantissa_quant, TF_TYPE)

        diff = derivative * LSB_mantissa_quant  # get f'(x_0) * (x - x_0)
        lut_mantissa = lut_MSB_mantissa_quant + diff

        if bit_exact:
            exponent = tf.cast(exponent, int_type)
            real_exp = exponent / 2 if self.exp_even else exponent
            real_exp = tf.cast(real_exp, int_type)
            lut_mantissa = tf.cast(lut_mantissa, int_type)
            self._verify_data_dtype(exponent, 6, False, "exponet")
            self._verify_data_dtype(MSB_mantissa_quant, 8, False, "MSB_mantissa_quant")
            self._verify_data_dtype(LSB_mantissa_quant, 8, False, "LSB_mantissa_quant")
            self._verify_data_dtype(lut_MSB_mantissa_quant, 16, False, "lut_MSB_mantissa_quant")
            self._verify_data_dtype(-diff, 16, False, "diff")

            self._verify_data_dtype(lut_mantissa, 16, False, "lut_mantissa")

        else:
            real_exp = exponent / 2.0 if self.exp_even else exponent

        return lut_mantissa, real_exp

    def get_hw_lut(self):
        # create the lut table for the hardware
        reg_lut = self.lut_values
        deriv_lut = -self.lut_deriv_values
        output = []
        for i, (inv, deriv) in enumerate(zip(reg_lut, deriv_lut)):
            inv = int(inv)
            deriv = int(deriv)
            number = (deriv << 16) | inv
            output.append(np.uint32(number))
        return np.array(output, dtype=np.uint32)

    # region debuging
    def _verify_data_dtype(self, data_1, bit_width, signed, name):
        if self.debug:
            verify_data_dtype(data_1, bit_width, signed, name)
