from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.atomic_ops.inverse_lut_calc import InverseFuncLUTCalculator
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EmulationType,
)


@dataclass
class SoftmaxWeightsLossy(BaseWeightLossyElements):
    clip_before_apu: BaseLossyElement


class FinalizeSoftmaxOp(BaseAtomicOp):
    """
    This op gets the numerator and the denominator of the softmax and calculates the softmax
    the op is calculates as follows:
        inv(denominator) = calculate the lut(denominator_mantissa) >> (denominator_exponent + shift_to_plus)
        results = numerator*  inv(denominator)
    """

    num_inputs = 2
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

        self.f_out = None
        self.is_lossless = True
        self._shift_to_plus = 0
        self.lut_inv = InverseFuncLUTCalculator(sqrt=False)

        self.weight_lossy_elements = SoftmaxWeightsLossy(
            clip_before_apu=IdentityElement(),
        )

        self.set_type_emulation(EmulationType.DOUBLE)

    def _compute_output_shape(self, input_shape):
        return input_shape[0]

    def disable_lossy(self, **kwargs):
        super().disable_lossy(**kwargs)
        self.is_lossless = True

    def enable_lossy(self, **kwargs):
        super().enable_lossy(**kwargs)
        self.is_lossless = False

    @property
    def bit_exact_supported(self) -> bool:
        return True

    def _calc_shift_to_plus(self):
        output_scale = self.scale_e_x * self.scale_out_inv
        max_value = np.maximum(
            np.max(np.abs(self.get_output_stats(0).max)),
            np.max(np.abs(self.get_output_stats(0).min)),
        )
        current_max = self.output_lossy_element.max_value * output_scale
        return np.floor(np.log2(current_max / max_value) / 2 - 1) * 2

    def create_hw_params(self):
        self._shift_to_plus = self._calc_shift_to_plus()
        self.enforce_encoding()
        self.is_lossless = False

    def softmax_logic(self, inputs):
        """
        calculate the softmax by :
        numinator * inv(denominator)
        """
        numerator = inputs[0]
        sum_e_x = inputs[1]

        inv_mantissa, inv_exponent = self._get_mantissa_exp_vals(sum_e_x, bit_exact=False)
        shift_to_plus = tf.cast(self._shift_to_plus, self.FLOAT_TYPE_TF)
        inv_exponent = inv_exponent - shift_to_plus
        softmax_value = self._compute_softmax(numerator, inv_mantissa, inv_exponent)
        softmax_value = self.weight_lossy_elements.clip_before_apu(softmax_value)

        return softmax_value

    def _compute_softmax_bit_exact(self, numerator, inv_mantissa, inv_exponent):
        self._verify_data_dtype(numerator, 16, False, "numerator")
        self._verify_data_dtype(inv_mantissa, 16, False, "inv_mantissa")
        self._verify_data_dtype(inv_exponent, 5, True, "inv_exponent")

        mul = tf.multiply(numerator, inv_mantissa)
        self._verify_data_dtype(mul, 33, False, "mul")

        inv_exponent = tf.cast(inv_exponent, self.FLOAT_TYPE_TF)
        mul = tf.cast(mul, self.FLOAT_TYPE_TF)

        softmax_result = tf.floor(mul / (2.0**inv_exponent))
        softmax_result = tf.cast(softmax_result, self.INT_TYPE_TF)

        self._verify_data_dtype(softmax_result, 34, False, "mul_after_exponent")
        return softmax_result

    def _compute_softmax(self, numerator, inv_mantissa, inv_exponent):
        numerator = tf.cast(numerator, self.FLOAT_TYPE_TF)
        inv_mantissa = tf.cast(inv_mantissa, self.FLOAT_TYPE_TF)

        mul = tf.multiply(numerator, inv_mantissa)
        return tf.round(mul / (2**inv_exponent))

    @property
    def lut_inv_scale_out(self):
        return self.lut_inv.lut_scale_out

    def _get_mantissa_exp_vals(self, sum_e_x, bit_exact=False):
        return self.lut_inv.lut_vals_helper(
            sum_e_x,
            float_type=self.FLOAT_TYPE_TF,
            int_type=self.INT_TYPE_TF,
            bit_exact=bit_exact,
            is_lossless=self.is_lossless,
        )

    def export_quant_weights(self):
        return {}

    def export_weights(self):
        return {}

    def call_native(self, inputs, **kwargs):
        inputs_e_x = inputs[0]
        sum_e_x = inputs[1]
        return tf.truediv(inputs_e_x, sum_e_x)

    def call_hw_sim(self, inputs, **kwargs):
        if self.is_lossless:
            inputs_native = self._decode_inputs(inputs)
            result = self.call_native(inputs_native, **kwargs)
            result = result if isinstance(result, list) else [result]
            res = self._encode_outputs(result)
            return res
        return self.softmax_logic(inputs)

    def call_bit_exact(self, inputs, **kwargs):
        numerator = inputs[0]
        sum_e_x = inputs[1]
        self._verify_data_dtype(sum_e_x, 56, False, "sum_e_x_sum_e_xsum_e_x")

        inv_mantissa, inv_exponent = self._get_mantissa_exp_vals(sum_e_x, bit_exact=True)
        shift_to_plus = tf.cast(self._shift_to_plus, self.INT_TYPE_TF)
        inv_exponent = inv_exponent - shift_to_plus

        softmax_value = self._compute_softmax_bit_exact(numerator, inv_mantissa, inv_exponent)
        softmax_value = self.hw_simulation_by_lossy_element(softmax_value, self.weight_lossy_elements.clip_before_apu)
        if self.debug_mode:
            self.inv_mantissa = inv_mantissa
            self.inv_exponent = inv_exponent

        return softmax_value

    def is_differentiable(self) -> bool:
        return False

    def create_weight_quant_element(self, bits):
        self.weight_lossy_elements = SoftmaxWeightsLossy(
            clip_before_apu=QuantElement(
                signed=True,
                bits=bits,
                wraparound=False,
                name=f"{self.full_name}/qe:clip_before_apu",
            ),
        )

    def export_independent_params(self):
        return {
            "shift_to_plus": np.float32(self._shift_to_plus),
        }

    def export_hw_params(self):
        lut_table = self.lut_inv.get_hw_lut()
        return {
            "lut_table": lut_table,
            "shift_to_plus": np.int8(self._shift_to_plus),
        }

    def import_independent_params(self, params):
        self._shift_to_plus = params["shift_to_plus"]

    @property
    def scale_e_x(self):
        scale = self.input_scales[0]
        if np.array(scale).shape != ():
            scale = scale[0]
        return tf.cast(scale, self.FLOAT_TYPE_TF)

    @property
    def scale_e_x_sum(self):
        scale = self.input_scales[1]
        if np.array(scale).shape != ():
            scale = scale[0]
        return tf.cast(scale, self.FLOAT_TYPE_TF)

    @property
    def scale_out_inv(self):
        scale_to_calc = self.lut_inv.function(self.scale_e_x_sum)
        return self.lut_inv_scale_out * scale_to_calc

    def enforce_encoding(self, *args, **kwargs):
        output_channels = self.output_shape[-1]
        self.output_scale = np.repeat(self.scale_e_x * self.scale_out_inv / (2**self._shift_to_plus), output_channels)
        self.output_zero_point = np.array(0, self.FLOAT_TYPE_NP)
