from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import (
    QuantElement,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EXP_OUT_BITS,
    EmulationType,
)
from hailo_model_optimization.acceleras.utils.opt_utils import calc_lut_table

MANTISSA_BITS = 16
EXP_BITS = 5


LUT_OUT_HIGH_BITS = 16  # where ther is high and low
LUT_OUT_LOW_BITS = 13  # where ther is  low

BITS_LUT_PER_TABLE = 8


def minus_exp_func(x):
    return tf.exp(-x)


def flip_minus_exp_func(x):
    return 1 - tf.exp(-x)


@dataclass
class ScaleMatchWeightsLossy(BaseWeightLossyElements):
    mantissa: BaseLossyElement
    clip1: BaseLossyElement


class ExpLutOp(BaseAtomicOp):
    """
    Represents Peice wise linear   activation function
    """

    weight_lossy_elements: ScaleMatchWeightsLossy
    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        logger=None,
        fully_native=True,
        bit_exact=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, bit_exact=bit_exact, **kwargs)
        self.weight_lossy_elements = ScaleMatchWeightsLossy(mantissa=IdentityElement(), clip1=IdentityElement())
        self.set_type_emulation(EmulationType.DOUBLE)

        self.func = minus_exp_func
        self.mantissas_candidate = np.array(1.0, dtype=self.FLOAT_TYPE_NP)
        self.exponent = 0

        self.create_lut_tables()

    # endregion
    def create_weight_quant_element(self, clip_bits, **kwargs):
        self._logger.debug(f"name {self.full_name} mantisaa: {MANTISSA_BITS}")

        weights_quant_elem = ScaleMatchWeightsLossy(
            mantissa=QuantElement(
                signed=False, bits=MANTISSA_BITS, wraparound=False, name=f"{self.full_name}/qe:mantissa"
            ),
            clip1=QuantElement(signed=False, bits=clip_bits, wraparound=False, name=f"{self.full_name}/qe:clip1"),
        )
        self.weight_lossy_elements = weights_quant_elem

    def export_weights(self):
        """
        export the activation params for the layers. The weight should be returned as dict.
        Returns: dict of weights

        """
        return dict()

    @property
    def mantissa_q(self):
        return self.weight_lossy_elements.mantissa(self.mantissas_candidate, training=False)

    def update_mantissa_exponent_decomposition(self):
        """
        this method scale_change and calculates the exponent_factors, and mantissa_candidates.
        Returns: exponent_factors : 2**exponents
                 mantissas_candidate : the mantissa we use to represent scale change (with no rounding)
                 exponents: the exponents we use to represent the scale change
                 all returned values are in the shape of number_of_channels
        """
        scale_change = self.scale_change
        exponents = tf.math.ceil(tf.experimental.numpy.log2(tf.math.abs(scale_change))) - MANTISSA_BITS
        mantissas_candidate = scale_change * tf.math.pow(2.0, -1.0 * exponents)
        # if the mantissa is 2** div by 2
        mantissa_overflow = mantissas_candidate >= 2**MANTISSA_BITS - 0.5
        exponents = tf.where(mantissa_overflow, exponents + 1, exponents)
        mantissas_candidate = tf.where(mantissa_overflow, mantissas_candidate / 2, mantissas_candidate)

        # if exponents > 0:
        #     exponents = 0
        return mantissas_candidate, exponents

    def create_lut_tables(self):
        ## lut scales
        self.lut_scale_in_low = self._calc_exp_out_scale_by_output_bits(16)
        self.lut_scale_in_high = self._calc_exp_out_scale_by_output_bits(8)
        self.scale_high_out = 1 / (2**LUT_OUT_HIGH_BITS - 1)
        self.scale_lut_out = 1 / (2**EXP_OUT_BITS - 1)

        # lut_tables
        lut_keys = np.arange(2**BITS_LUT_PER_TABLE)
        self.lut_values_high = self._emulate_lut_high(lut_keys, quant=True)
        self.lut_values_low = self._emulate_lut_low(lut_keys, quant=True)
        self.lut_table_high = self._build_lut(lut_keys, self.lut_values_high)
        self.lut_table_low = self._build_lut(lut_keys, self.lut_values_low)

    def get_lut_in_exp_bits(self):
        return self.weight_lossy_elements.clip1.bits

    def create_hw_params(self):
        self.scale_change = self._clac_scale_change()
        self.mantissas_candidate, self.exponent = self.update_mantissa_exponent_decomposition()
        self.create_output_encoding_candidates(0)
        self.enforce_encoding()

    def _clac_scale_change(self):
        in_lut_bits = self.get_lut_in_exp_bits()
        lut_scale_in_for_scale_change = self._calc_exp_out_scale_by_output_bits(in_lut_bits)
        return self.input_scale[0] / lut_scale_in_for_scale_change

    def _calc_exp_out_scale_by_output_bits(self, bits_in_lut):
        bits_out_softmax = 16
        smallest_value_bins = np.log(
            1 / (2 ** (bits_out_softmax + 1))
        )  # the value the smaller than that we get zero in the output

        return np.abs(smallest_value_bins) / (2**bits_in_lut - 1)

    def enforce_encoding(self, training=False):
        pass
        # if self.lut_scale_in is not None:
        #     self.output_scale = self.scale_high_out * self.scale_low_out

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    def _compute_output_shape(self, input_shape):
        return input_shape

    @property
    def quant_zero(self):
        max_val_unit = 2**BITS_LUT_PER_TABLE - 1
        smalles_val = (max_val_unit - 0) * self.lut_scale_in_low
        out_low = np.exp(-smalles_val)
        flip_out = 1 - out_low

        scale_low_out_candidate = flip_out / max_val_unit
        return 2 ** np.ceil(np.log2(1 / scale_low_out_candidate))

    @property
    def scale_low_out(self):
        return 1 / self.quant_zero

    def _build_lut(self, keys, values):
        # generic function to build a lookup table"
        init = tf.lookup.KeyValueTensorInitializer(keys, values)
        table = tf.lookup.StaticHashTable(init, default_value=-1)
        return table

    def _emulate_lut_low(self, inp, quant=True):
        return calc_lut_table(
            inp=inp,
            lut_func=flip_minus_exp_func,
            zp_in=0,
            s_in=self.lut_scale_in_low,
            zp_out=0,
            s_out=self.scale_low_out,
            bits_out=LUT_OUT_LOW_BITS,
            signed=False,
            quant=quant,
        )

    def _emulate_lut_high(self, inp, quant=True):
        return calc_lut_table(
            inp=inp,
            lut_func=minus_exp_func,
            zp_in=0,
            s_in=self.lut_scale_in_high,
            zp_out=0,
            s_out=self.scale_high_out,
            bits_out=LUT_OUT_HIGH_BITS,
            signed=False,
            quant=quant,
        )

    def _minus_exp_lut_emulation(self, data_int, bit_exact=False, is_lossless=False):
        """
        # given data_int, calculate the output of the lut tables that represent the exp(-x) function
        we will calculate the output of the high and low tables and then multiply them.
        high table : e^-{x_high}
        low table : 1 - e^{-x_low}

        WE will calc:
        high_table* 1 - high_table * low_table = e^-x_high*1 - e^-x_high * (1 - e^-x_low)
                                               = e^-x_high*1 -e^-x_high*1 + e^-x_high * e^-x_low
                                               = e^-x_high * e^-x_low
                                               = e^-{x}

        """
        bins = 2**BITS_LUT_PER_TABLE
        in_exp_bits = self.get_lut_in_exp_bits()
        if in_exp_bits == 16:
            data_high = tf.floor(data_int / bins)

            bins = tf.cast(bins, self.INT_TYPE_TF) if bit_exact else bins
            data_high = tf.cast(data_high, self.INT_TYPE_TF) if bit_exact else data_high

            data_low = data_int - data_high * bins
        else:
            data_high = data_int
            data_low = data_int - data_high

        out_high = self.lut_emulation(data_high, bit_exact=bit_exact, is_lossless=is_lossless, is_high=True)
        out_low = self.lut_emulation(data_low, bit_exact=bit_exact, is_lossless=is_lossless, is_high=False)
        if bit_exact:
            qunat_zero = tf.cast(self.quant_zero, self.INT_TYPE_TF)
            out_high = tf.cast(out_high, self.INT_TYPE_TF)
            out_low = tf.cast(out_low, self.INT_TYPE_TF)
            out_mul_q = out_high * qunat_zero - out_high * out_low
        else:
            out_mul_q = out_high * self.quant_zero - out_high * out_low
        return out_mul_q

    # region Call functions for the different emulation types
    def lut_emulation(self, inp, bit_exact=False, is_lossless=False, is_high=False):
        """
        emulation of lut table
        """
        if is_high:
            calc_func = self._emulate_lut_high
            lut_table = self.lut_table_high
        else:
            calc_func = self._emulate_lut_low
            lut_table = self.lut_table_low

        if is_lossless:
            return calc_func(inp, quant=False)

        if bit_exact:
            inp_type = inp.dtype
            inp1 = tf.cast(inp, lut_table.key_dtype)
            lut_result = lut_table.lookup(inp1)
            return tf.cast(lut_result, inp_type)
        else:
            return calc_func(inp, quant=True)

    def call_hw_sim(self, inputs, training=False, **kwargs):
        """
        Simulating APU HW - piecewise-linear approximation
        """
        inp = inputs[0]
        mantissa_q = tf.cast(self.mantissa_q, self.FLOAT_TYPE_TF)
        mul = tf.multiply(inp, mantissa_q)

        shift_exp = tf.cast(2**self.exponent, self.FLOAT_TYPE_TF)
        mul_after_shift = mul * shift_exp
        data_int = self.weight_lossy_elements.clip1(mul_after_shift)

        lut_res = self._minus_exp_lut_emulation(data_int, bit_exact=False, is_lossless=False)
        return lut_res

    def call_bit_exact(self, inputs, **kwargs):
        """
        Simulating PW HW - piecewise-linear approximation
        """
        inp = tf.cast(inputs[0], self.INT_TYPE_TF)

        mantissa_q = tf.cast(self.mantissa_q, self.INT_TYPE_TF)
        mul = tf.multiply(inp, mantissa_q)
        shift_to_hw = tf.cast(-self.exponent, self.INT_TYPE_TF)
        if shift_to_hw < 0:
            # support left shift
            mul_after_shift = tf.bitwise.left_shift(mul, -shift_to_hw)
        else:
            # support right shift
            mul_after_shift = tf.bitwise.right_shift(mul, shift_to_hw)
            # mul_after_shift = self.bankers_round_with_shift(mul, shift_to_hw, bankers_round=shift_to_hw)
        data_int = self.weight_lossy_elements.clip1(mul_after_shift)
        lut_res = self._minus_exp_lut_emulation(data_int, bit_exact=True, is_lossless=False)

        self._verify_data_dtype(inp, 16, False, "exp_in_before_scale_change")
        self._verify_data_dtype(data_int, self.weight_lossy_elements.clip1.bits, False, "exp_in_after_scale_change")
        self._verify_data_dtype(lut_res, self.output_lossy_element.bits, False, "exp_out_denuminator")
        return lut_res

    def call_native(self, inputs, **kwargs):
        return self.func(inputs[0])

    # endregion

    def export_independent_params(self):
        return {"scale_change": self.scale_change}

    def import_independent_params(self, params):
        # self.lut_scale_in = params["lut_scale_in"]
        self.scale_change = params["scale_change"]
        self.mantissas_candidate, self.exponent = self.update_mantissa_exponent_decomposition()

    def export_quant_weights(self):
        # This is under the assumption that enable_lossy was called.
        mantissa = self.weight_lossy_elements.mantissa(self.mantissas_candidate)
        exponent = self.exponent
        lut_table_low = np.array(self.lut_values_low, dtype=np.uint16)
        lut_table_high = np.array(self.lut_values_high, dtype=np.uint16)
        lut_in_exp_bits = np.array(self.get_lut_in_exp_bits(), dtype=np.uint16)
        return {
            "lut_table_high": lut_table_high,
            "lut_table_low": lut_table_low,
            "mantissa": np.round(np.float32(mantissa)),
            "exponent": np.round(np.float32(exponent)),
            "lut_in_exp_bits": lut_in_exp_bits,
        }

    def export_hw_params(self):
        params = self.export_quant_weights()
        return {
            "lut_table_high": params["lut_table_high"].astype(np.uint16),
            "lut_table_low": params["lut_table_low"].astype(np.uint16),
            "lut_in_exp_bits": params["lut_in_exp_bits"].astype(np.uint8),
            f"{self.name}_mantissa": params["mantissa"].astype(np.uint16),
            f"{self.name}_exponent": params["exponent"].astype(np.int8),
        }

    def get_output_limvals(self, output_index: int):
        if f"outputs_{output_index}" in self.stats_managers:
            lim_vals = super().get_output_limvals(output_index)
        else:
            lim_vals = np.array(self.act_func(self.get_input_limvals(0)))
        return lim_vals
