import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasNumerizationError
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams

CUT_OFF = 10


class SoftmaxOp(BaseAtomicOp):
    """
    Emulate softmax operation exp(x_i)/sum_j(exp(x_j))

    Examples
        Examples of use
        >>> op = SoftmaxOp()

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, axis=-1, groups=1, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._generate_inv_lut()
        self.cut_off = CUT_OFF
        self._axis = axis
        self._groups = groups
        self._additive_mask = None

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]

        if self._additive_mask is not None:
            # additive mask is used to add a constant to the input before softmax for dropping unnecessary values
            additive_mask = np.tile(self._additive_mask, inp.shape[-1] // self._additive_mask.shape[-1])
            inp = inp + additive_mask

        # performs softmax operation
        if self._groups > 1:
            softmax_inputs = []
            input_group_size = int(inp.shape[3] / self._groups)
            for g in range(self._groups):
                group_input = inp[:, :, :, g * input_group_size : (g + 1) * input_group_size]
                softmax = tf.nn.softmax(group_input, axis=self._axis)
                softmax_inputs.append(softmax)
            softmax = tf.concat(softmax_inputs, axis=-1)
        else:
            softmax = tf.nn.softmax(inp, axis=self._axis)
        return softmax

    def _compute_output_shape(self, input_shapes):
        return input_shapes

    def call_hw_sim(self, inputs, **kwargs):
        inp = inputs[0]
        max_input = tf.reduce_max(inp, axis=self._axis, keepdims=True)
        inputs_norm = inp - max_input
        exp = tf.math.exp(inputs_norm * self.input_scales[0])
        exp_sum = tf.reduce_sum(exp, keepdims=True, axis=self._axis)
        res_native = tf.math.divide(exp, exp_sum)
        return tf.math.divide(res_native, self.output_scale)

    def is_differentiable(self) -> bool:
        return False

    def enforce_encoding(self, *args, **kwargs):
        # Set the scales of the scores as score_scale and boxes as box_scale
        max_int = self.output_lossy_element.max_value
        self.output_scale = np.repeat(np.array(1 / max_int, np.float32), len(self.input_scales[0]))
        self.output_zero_point = np.array(0, np.float32)

    def _generate_inv_lut(self):
        max_16b = 2**16 - 1
        inv_lut = np.round(1.0 / np.linspace(1.0 / (255.0 * max_16b), 1.0 / max_16b, 256))
        inv_lut = np.rint(np.append(np.array(inv_lut[0]), np.array(0.5 * (np.add(inv_lut[0:-1], inv_lut[1:])))))
        self._inv_lut = inv_lut

    def create_weight_quant_element(self, **kwargs):
        self.weight_lossy_elements = BaseWeightLossyElements()

    @property
    def exp_lut(self):
        # generate e^x lut. maps [0:255] --> (2**16 - 1)*[e^-255 --> 1]
        # cut_off assign zeros in lut in all entries > cut_off (quantized)
        lut_size = 255.0
        lut_max = 2**16 - 1
        scale = self.get_scalar_vector(self.input_scales[0])
        cut_off_q = int(np.clip(np.ceil(self.cut_off / scale), 0, lut_size))
        exp_lut = np.zeros(256)
        exp_lut[:cut_off_q] = np.arange(cut_off_q)
        inds = exp_lut.nonzero()
        exp_lut[inds] = np.round(lut_max * np.exp(-scale * exp_lut[inds]))
        exp_lut[0] = lut_max
        return exp_lut

    @property
    def inv_lut(self):
        return self._inv_lut

    @staticmethod
    def get_scalar_vector(vector, eps=1e-5):
        """
        a function that checks the vector is a scalar up to eps
        Args:
            vector: vector
            eps: the threshold we allow it to differ

        Returns: the scalar_vector

        """
        if isinstance(vector, (np.float32, float)):
            return vector

        if vector.shape != ():
            if eps < np.max(np.abs(vector - vector[0]) / vector):
                raise AccelerasNumerizationError("the vector must be a scalar")
            scalar_vector = vector[0]
        else:
            scalar_vector = vector
        return scalar_vector

    def import_weights(self, layer_params: LayerParams):
        param_dict = dict(layer_params)
        self._additive_mask = param_dict.get("additive_mask")

    def export_weights(self):
        return dict()

    def export_quant_weights(self):
        softmax_lut = self._get_softmax_lut()
        return {
            "softmax_lut": np.uint32(softmax_lut),
        }

    def export_hw_params(self) -> dict:
        return self.export_quant_weights()

    def _get_softmax_lut(self):
        exp_lut = self.exp_lut
        inv_lut = self.inv_lut
        output = []
        for i, (exp, inv) in enumerate(zip(exp_lut, inv_lut)):
            # Aligned because each line mem is 48B, but each write is 32B.
            exp = int(exp.item())
            inv = int(inv.item())
            first = (inv << 16) | exp
            second = (int(i) << 8) | (inv >> 16)
            output.append(np.array(first).astype(np.uint32))
            output.append(np.array(second).astype(np.uint32))
        return np.array(output, dtype=np.uint32)
