from enum import Enum

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasDecompositionError,
    AccelerasNumerizationError,
)
from hailo_model_optimization.acceleras.utils.flow_state_utils import (
    LossyState,
)


class QuantTrainMode(Enum):
    STE = "ste"
    NOISE = "noise"
    NOISED_STE = "noised_ste"
    NATIVE = "native"
    DISABLED = "disabled"


@tf.custom_gradient
def fake_quant_int(x, min_val, max_val):
    """
    We use our home-bred basic STE (aka fake-quant) implementation, because:
        A. there are some rare edge cases where tf.fake_quantization is just wrong w.r.t round(),
           e.g. try fqint16(1.499)==2.
        B. also, it will open more possibilities later on...
    NOTE: we assume input is already at the right scale, so no rescaling needed!
          (leveraging AcceLeras separation of numerization and bit-reduction
           to keep this part bulletproof by triviality..)
    """
    in_range = tf.logical_and(tf.less(min_val, x), tf.less(x, max_val))

    def grad(dy):
        return tf.where(in_range, dy, 0.0), 0, 0  # clipped ramp
        # return tf.where(in_range, dy, dy), 0, 0  # TMP - infinite ramp

    return tf.clip_by_value(tf.round(x), min_val, max_val), grad


def do_wraparound(inp, signed, bits):
    dtype = inp.dtype
    if signed:
        inp = tf.cast(inp, tf.float64)
        inp = inp + 2 ** (bits - 1)
    # NOTE - floormod used to run on the CPU, but it might've been optimized?
    out = tf.math.floormod(inp, 2**bits)
    if signed:
        out = out - 2 ** (bits - 1)
    return tf.cast(out, dtype)


def get_int_dtype(bits, signed):
    dtype_dict = {
        # (bits, signed): dtype
        (8, True): tf.int8,
        (8, False): tf.uint8,
        (16, True): tf.int16,
        (16, False): tf.uint16,
        (32, True): tf.int32,
        (32, False): tf.uint32,
    }
    if (bits, signed) in dtype_dict:
        dtype = dtype_dict[(bits, signed)]
    elif bits < 64:
        dtype = tf.int64
    else:
        raise ValueError(f"Unsupported bit-width: {bits}")
    return dtype


def _integer_wraparound(inp, signed, bits):
    inp_dtype = inp.dtype
    int_dtype = get_int_dtype(bits, signed)
    # Explicit round for banker's rounding
    round_inp = tf.round(inp)
    if bits != int_dtype.size * 8:
        # Bitmask operations are not applied on CPU, so we use the default logic instead
        out = do_wraparound(round_inp, signed, bits)
    else:
        # large int cast and then small int cast to avoid undefined behavior
        # float to int cast out is undefined behavior if the float is outside the range of the int
        intermidiate_dtype = tf.int64 if bits >= 32 else tf.int32
        out = tf.cast(tf.cast(round_inp, intermidiate_dtype), int_dtype)
        out = tf.cast(out, inp_dtype)
    return out


class BaseQuantElement(BaseLossyElement):
    def __init__(self, **kwargs):
        super().__init__(kwargs.get("name"))
        self._train_mode = QuantTrainMode.STE
        self._wraparound_loss = False

    def enable_wraparound_loss(self):
        self._wraparound_loss = True

    def disable_wraparound_loss(self):
        self._wraparound_loss = False

    @property
    def train_mode(self):
        return self._train_mode

    @train_mode.setter
    def train_mode(self, value):
        self._train_mode = QuantTrainMode(value)

    def quant_call(self, inp, training, bits=None, signed=None, wraparound=None, symmetric=None):
        bits = self.bits if bits is None else bits
        signed = self.signed if signed is None else signed
        wraparound = self.wraparound if wraparound is None else wraparound
        symmetric = self.symmetric if symmetric is None else symmetric
        train_mode = self._train_mode if training else QuantTrainMode.DISABLED
        if wraparound:
            if train_mode != QuantTrainMode.DISABLED and self._wraparound_loss:
                self._add_wraparound_loss(inp, bits, signed)
                quant_value = self.quant(inp, bits, signed, symmetric=False, train_mode=train_mode)
            else:
                quant_value = self.quant_wraparound(inp, bits, signed, train_mode=train_mode)
        else:
            quant_value = self.quant(inp, bits, signed, symmetric=symmetric, train_mode=train_mode)
        return quant_value

    def _add_wraparound_loss(self, inp, bits, signed):
        maxval = self.get_max_value(bits, signed)
        minval = self.get_min_value(bits, signed)
        top_overflow_penalty = tf.maximum(10 * inp / maxval - 9, 0)
        bot_overflow_penalty = tf.maximum(10 * (maxval + minval - inp) / maxval - 9, 0)
        term = top_overflow_penalty + bot_overflow_penalty
        loss = tf.reduce_mean(tf.pow(term, 2))
        self.add_loss(loss)

    @classmethod
    def quant(cls, inp, bits, signed, symmetric, train_mode):
        dtype = inp.dtype
        maxval = tf.cast(cls.get_max_value(bits, signed), dtype)
        minval = tf.cast(cls.get_min_value(bits, signed, symmetric), dtype)
        if train_mode == QuantTrainMode.NOISE:
            noise = tf.random.uniform(tf.shape(inp), -0.5, 0.5, dtype=dtype)
            out = tf.clip_by_value(inp + noise, minval, maxval, name="noised_quant")
        elif train_mode == QuantTrainMode.NOISED_STE:
            noise = tf.random.uniform(tf.shape(inp), -0.5, 0.5, dtype=dtype)
            out = fake_quant_int(inp, minval, maxval)
        elif train_mode == QuantTrainMode.STE:
            out = fake_quant_int(inp, minval, maxval)
        elif train_mode == QuantTrainMode.NATIVE:
            out = tf.clip_by_value(inp, minval, maxval, name="clip")
        else:
            out = tf.clip_by_value(tf.round(inp), minval, maxval, name="quant")
        return out

    @classmethod
    def quant_wraparound(cls, inp, bits, signed, train_mode):
        # Wrapaound is not implemented with modulo op because it is not optimized in the GPU
        if train_mode == QuantTrainMode.DISABLED:
            return _integer_wraparound(inp, signed, bits)
        else:
            wrapped_inp = do_wraparound(inp, signed, bits)
            return cls.quant(wrapped_inp, bits, signed, symmetric=False, train_mode=train_mode)

        return cls.quant(wrapped_inp, bits, signed, symmetric=False, train_mode=train_mode)

    @classmethod
    def get_max_value(cls, bits, signed):
        if signed:
            bits = bits - 1
        return 2**bits - 1

    @classmethod
    def get_min_value(cls, bits, signed, symmetric=True):
        if signed:
            minval = -cls.get_max_value(bits, signed)
            if not symmetric:
                minval -= 1
        else:
            minval = 0
        return minval

    # TODO - check where this is used
    @classmethod
    def float_quant(cls, inp, mantissa_bits, exponent_bits):
        return inp  # TODO!

    def export_flow_state(self) -> LossyState:
        flow_state = super().export_flow_state()
        flow_state.lossy_dict_kwgs = {
            "wraparound_loss": self._wraparound_loss,
            "quant_train_mode": self._train_mode.value,
        }
        return flow_state

    def import_flow_state(self, lossy_state: LossyState) -> None:
        super().import_flow_state(lossy_state)
        self._wraparound_loss = lossy_state.lossy_dict_kwgs["wraparound_loss"]
        self._train_mode = QuantTrainMode(lossy_state.lossy_dict_kwgs["quant_train_mode"])


class BaseDecompositionElement(BaseQuantElement):
    @classmethod
    def uint_smallnum_factorize(cls, target, bits, maxsmallnum):
        """
        find min-error factorization of a UINT number (smaller than (2^bits - 1)*maxsmallnum)
        into integers multiplication U*R where U is uint and R is  maxsmallnum or smaller
        """
        max_a = 2**bits - 1
        max_b = maxsmallnum
        return cls.a_b_factorize(target, max_a, max_b)

    @classmethod
    def int_smallnum_factorize(cls, target, bits, maxsmallnum):
        """
        find min-error factorization of a INT number (smaller than (2^(bits-1) - 1)*maxsmallnum by abs value)
        into integers multiplication I*R where I is int and R is  maxsmallnum or smaller
        """
        int_fac, smallnum = cls.a_b_factorize(np.abs(target), 2 ** (bits - 1) - 1, maxsmallnum)
        return int_fac * np.sign(target), smallnum

    @staticmethod
    def a_b_factorize(target, max_a, max_b):
        """
        find min-error factorization of a number as multiplication of two uints,
        one ranged up from 0 to max_a, the other ranged from 0 up to max_b
        """
        if target < 0:
            raise RuntimeError(f"Unexpected value in unsigned decomposition logic, target value is negative: {target}")
        if target > (max_a * max_b) + min(max_a, max_b) / 2:
            raise AccelerasDecompositionError(target, max_a, max_b)
        target = np.float32(target)
        b_fac = np.arange(max(np.floor(target / max_a), 1), max_b + 1)
        a_fac = np.minimum(np.round(target / b_fac), max_a)
        error = np.abs(target - a_fac * b_fac)
        bind = np.argmin(error)
        return a_fac[bind], b_fac[bind]

    @staticmethod
    def uint_int_factorize(target, bits):
        """
        find min-error factorization of INT number (in case of INT16* - < 127*255) into U*I
        """
        if bits == 15:  # in case of 16 bit, hardware implements INT16 using negate
            max_int = 2**bits - 1.0
        else:
            max_int = 2 ** (bits - 1) - 1.0
        max_uint = 2**bits - 1.0
        max_int_uint_mul = max_int * max_uint
        if abs(target) > max_int_uint_mul:
            raise AccelerasNumerizationError(f"The number {target} can't be represented, to big {max_int_uint_mul}")
        target = np.float32(target)
        target_sign = np.sign(target) if target != 0 else 1
        int_fac = np.arange(max(np.ceil(np.abs(target) / max_uint), 1), max_int + 1) * target_sign
        uint_fac = np.round(np.abs(target / int_fac))
        error = np.abs(target - uint_fac * int_fac)
        bind = np.argmin(error)
        return uint_fac[bind], int_fac[bind]


class BiasQuantElement(BaseDecompositionElement):
    """
    This element transoforms the bias from full-precision to its quant value in the accumulator
    The decomposed factor and num repeats is calculated offline, during the quantization
    The INT value of the bias is calculated online during inference, based on the given input.

    Args:
        accumulator_bits: data bit-width in the accumulator
        uint_bits: bit-width in the MAC data inputs, used when num_decomposition > 0
        int_bits: bit-width in the MAC weight inputs

        num_decomposition: describes the bias_mode, 0 - initialization, 1 - single_decomp, 2 double-decomp
        max_feed_repeat: max repeat value during the decomposition

    """

    def __init__(
        self,
        accumulator_bits,
        uint_bits,
        int_bits,
        num_decomposition,
        max_feed_repeat,
        kernel_bits,
        signed,
        symmetric=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self._accumulator_bits = accumulator_bits
        self._pre_acc_shift = 0
        self._uint_bits = uint_bits
        self._int_bits = int_bits
        self._num_decomposition = num_decomposition
        self._max_feed_repeat = max_feed_repeat
        self._factors: int = None
        self._repeats: int = None
        self._kernel_bits = kernel_bits
        self._signed = signed
        if symmetric is None:
            self._symmetric = True if (num_decomposition > 0 and kernel_bits >= 15) else False
        else:
            self._symmetric = symmetric

    def _is_eq(self, other):
        return (
            (other.accumulator_bits == self._accumulator_bits)
            and (other.uint_bits == self._uint_bits)
            and (other.int_bits == self._int_bits)
            and (other.num_decomposition == self._num_decomposition)
            and (other.max_feed_repeat == self._max_feed_repeat)
        )

    @property
    def accumulator_bits(self):
        return self._accumulator_bits

    @property
    def pre_acc_shift(self):
        return self._pre_acc_shift

    @pre_acc_shift.setter
    def pre_acc_shift(self, value):
        self._pre_acc_shift = value

    @property
    def uint_bits(self):
        return self._uint_bits

    @property
    def int_bits(self):
        return self._int_bits

    @property
    def num_decomposition(self):
        return self._num_decomposition

    @property
    def max_feed_repeat(self):
        return self._max_feed_repeat

    @property
    def factors(self):
        return sorted(self._factors, reverse=True) if self._factors is not None else None

    @factors.setter
    def factors(self, value):
        self._factors = value

    @property
    def repeats(self):
        return self._repeats

    @repeats.setter
    def repeats(self, value):
        self._repeats = value

    @property
    def symmetric(self):
        return self._symmetric

    def lossy_call(self, inp, training=False):
        vector_in_accumulator = self.accumulator_quant(inp, training)
        if self.num_decomposition == 0:
            return vector_in_accumulator

        int_values = self.get_decomposed_int(vector_in_accumulator, training=training)
        reconstructed_value = 0
        # TODO: make sure it works properly
        for i in range(self.num_decomposition):
            mul_res = self.factors[i] * int_values[i]
            reconstructed_i = self.accumulator_quant(
                mul_res / tf.cast(2**self.pre_acc_shift, mul_res.dtype),
                training,
            )
            reconstructed_value += reconstructed_i * self.repeats
        bias_value = self.accumulator_quant(reconstructed_value, training)
        return bias_value

    def accumulator_quant(self, vector, training):
        return self.quant_call(vector, training, bits=self.accumulator_bits, signed=True, wraparound=True)

    def get_decomposed_int(self, vector_in_accumulator, training):
        """
        The bias is decomposed to 3 elements - INT, UINT, repeats.
        This function assumes the UINT and repeats have already been computed, and returns the expected INT value

        Args:
            vector_in_accumulator: desired bias value as it should be represented in the accumulator

        """
        if (self.factors is None) or (self.repeats is None):
            raise RuntimeError("Can't run decomposed bias before decomposition")
        value_pre_shift = vector_in_accumulator * tf.cast(2**self.pre_acc_shift, vector_in_accumulator.dtype)
        if self.repeats == 0:
            desired_post_mult = value_pre_shift
        else:
            desired_post_mult = value_pre_shift / self.repeats
        residue = desired_post_mult
        int_values = []
        for index in range(self.num_decomposition):
            factor = self.factors[index]
            decomposed_int = residue / max(factor, 1)
            decomposed_int = self.quant_call(
                decomposed_int,
                training,
                bits=self.int_bits,
                signed=True,
                wraparound=False,
            )

            int_values.append(decomposed_int)
            residue -= decomposed_int * factor
        return int_values

    def export_as_quant(self, inp):
        vector_in_accumulator = self.accumulator_quant(inp, training=False)
        if self.num_decomposition == 0:
            return [vector_in_accumulator]

        int_values = self.get_decomposed_int(vector_in_accumulator, training=False)
        return int_values

    def decompose(self, inp):
        decomposition_handlers = {
            0: self.initialization,
            1: self.single_decomposition,
            2: self.double_decomposition,
        }
        decomposer = decomposition_handlers[self.num_decomposition]
        factors, repeats = decomposer(inp)
        self.factors = factors
        self.repeats = repeats

    def initialization(self, vector):
        return [1], 1

    def double_decomposition(self, vector):
        vector *= 2**self.pre_acc_shift
        max_int = self.get_max_value(self.int_bits, True)
        max_uint = self.get_max_value(self.uint_bits, False)
        factors = np.array([max_uint, 1])
        repeats = np.ceil(np.max(np.abs(vector)) / ((factors[0] + factors[1]) * max_int))
        return factors, repeats

    def single_decomposition(self, vector):
        vector *= tf.cast(2**self.pre_acc_shift, vector.dtype)
        max_int = self.get_max_value(self.int_bits, True)
        max_uint = self.get_max_value(self.uint_bits, False)
        max_int_uint_mul = max_int * max_uint
        if np.all(vector == vector[0]):
            if vector[0] == 0:
                factor = 0
                repeats = 1
            else:
                repeats = np.ceil(np.abs(vector[0] / max_int_uint_mul))
                factor, _ = self.uint_int_factorize(np.round(vector / repeats)[0], bits=self.uint_bits)
        else:
            bins = max_int
            u_x_r = np.ceil(np.max(np.abs(vector)) / bins)
            if u_x_r != 0:
                factor, repeats = self.uint_smallnum_factorize(
                    u_x_r,
                    bits=self.uint_bits,
                    maxsmallnum=self.max_feed_repeat,
                )
            else:
                repeats = 1
                factor = 0
        return np.array([factor]), repeats


class QuantElement(BaseQuantElement):
    """
    Encapsulating an "atom" of information loss as a unary op modifying a single tensor,
    aka "lossy gate" across an edge of the computational graph.

    It's called this way to encompass general "bit loss" - Pruning, scalal/vector quantizations, etc.
    Removes "least significant bits" (in the wide sense), and does only that. No [re]scaling here!
    The tensor is assumed to arrive appropriately pre-codnitioned (e.g. scaled to INT8 range)

    Notes
        - Currently implementing scalar quantization in forward pass,
            supporting STE-style backpropagation (aka "FakeQuant")
        - Seems straightforward to generalize to vector quantization
            (finding appropriate cluster and replacing by its centroid..)
        - Pruning can be added on top as a multiplication by a boolean mask,
            though management of mask itself will be external
    """

    def __init__(self, signed=False, bits=8, wraparound=False, symmetric=True, **kwargs):
        super().__init__(**kwargs)
        self.signed = signed
        self.bits = bits
        self.wraparound = wraparound
        self.symmetric = symmetric

    def lossy_call(self, inp, training=False):
        """
        this is the callabll part of the object-
        """
        return self.quant_call(inp, training)

    @property
    def max_value(self):
        return self.get_max_value(self.bits, self.signed)

    @property
    def min_value(self):
        return self.get_min_value(self.bits, self.signed, self.symmetric)

    @property
    def bins_count(self):
        return self.max_value - self.min_value

    @FutureWarning
    def gen_pruning_op(self, boolean_mask):
        """
        TODO ..just a rough sketch now..
        """

        def prune(inp):
            return inp * boolean_mask

        return prune

    def _is_eq(self, other):
        return (self.signed == other.signed) and (self.bits == other.bits) and (self.wraparound == other.wraparound)

    def align_to_range(self, inp, training=False):
        # TODO: consider moving this logic to the quant call itself. (STE won't have custom grad for clip)
        minval = self.get_min_value(self.bits, self.signed)
        maxval = self.get_max_value(self.bits, self.signed)
        if self.wraparound and (
            not self._wraparound_loss or not training or self._train_mode == QuantTrainMode.DISABLED
        ):
            out = do_wraparound(inp, self.signed, self.bits)
        else:
            out = tf.clip_by_value(inp, minval, maxval)
        return out


class AccumulatorQuantElement(QuantElement):
    """QuantElement that simulates the behavior of the Accumulator (e.g. signed with wraparound)"""

    def __init__(self, bits, **kwargs):
        super().__init__(signed=True, bits=bits, wraparound=True, **kwargs)


class APUOutputQuantElement(QuantElement):
    """QuantElement that simulates the standard output of the APU (e.g. unsigned without wraparound)"""

    def __init__(self, bits, **kwargs):
        super().__init__(signed=False, bits=bits, wraparound=False, **kwargs)


class APUOutputSignedQuantElement(QuantElement):
    """QuantElement that simulates the standard output of the APU (e.g. unsigned without wraparound)"""

    def __init__(self, bits, **kwargs):
        super().__init__(signed=True, bits=bits, wraparound=False, symmetric=False, **kwargs)


class MACDataQuantElement(QuantElement):
    """QuantElement that simulates the Data line of the MAC unit (e.g.  without wraparound)"""

    def __init__(self, bits, signed=True, **kwargs):
        super().__init__(signed=signed, bits=bits, wraparound=False, **kwargs)


class APUPreOffsetQuantElement(QuantElement):
    """QuantElement that simulates the offset line of the APU unit (e.g. signed without wraparound)"""

    def __init__(self, bits, **kwargs):
        super().__init__(signed=True, bits=bits, wraparound=False, **kwargs)


class AdaRoundQuantElement(QuantElement):
    def __init__(self, signed=False, bits=4, **kwargs):
        super().__init__(**kwargs)
        self.signed = signed
        self.bits = bits
        self.gamma = -0.1
        self.zeta = 1.1
        self.var = None

    def lossy_call(self, inp, training=False):
        if self.signed:
            clip_max = (2 ** (self.bits - 1)) - 1
            clip_min = -clip_max
        else:
            clip_max = (2**self.bits) - 1
            clip_min = 0
        clip_min = np.float32(clip_min)
        clip_max = np.float32(clip_max)
        if not training:
            inp = tf.round(inp)
        return tf.clip_by_value(inp, clip_min, clip_max)

    def set_var(self, var):
        self.var = var

    def get_initial_var_value(self, inp):
        # Initialization is based on the inverse of the forward function
        round_value = inp - np.floor(inp)
        x1 = (self.zeta - self.gamma) / (round_value - self.gamma)
        initial_value = -np.log(x1 - 1)
        return initial_value

    def get_fractional_value(self):
        return tf.clip_by_value(tf.sigmoid(self.var) * (self.zeta - self.gamma) + self.gamma, 0, 1)

    def __call__(self, inp, **kwargs):
        floor_val = tf.floor(inp)
        round_val = self.get_fractional_value()
        return super().__call__(floor_val + round_val, **kwargs)


class MaskQuantElement(QuantElement):
    def __init__(self, bits, **kwargs):
        super().__init__(signed=False, bits=bits, wraparound=False, **kwargs)

    def lossy_call(self, inp, training=False):
        clip_max = (2**self.bits) - 1
        min_vals = tf.where(inp == -np.inf, True, False)
        max_vals = tf.where(inp == np.inf, True, False)
        clip_by_value = tf.clip_by_value(tf.math.round(inp), 0, clip_max)

        clip_by_value = tf.where(min_vals, -np.inf, clip_by_value)  # set -inf to -inf
        clip_by_value = tf.where(max_vals, np.inf, clip_by_value)  # set inf to inf
        return clip_by_value


class RoundUpQuantElement(QuantElement):
    def lossy_call(self, inp, training=False):
        """
        switching tf.round to tf.ceil in the bit reducer
        """
        return super().lossy_call(inp + 0.5, training=training)


class RoundDownQuantElement(QuantElement):
    def lossy_call(self, inp, training=False):
        """
        switching tf.round to tf.floor in the bit reducer
        """
        return super().lossy_call(inp - 0.5, training=training)


class BiasedDeltaQuantElement(QuantElement):
    """In case of a biased delta activation, quant the pieces to [0, 1]"""

    def lossy_call(self, inp, training=False):
        """
        switching tf.round to tf.ceil in the bit reducer
        """
        return tf.math.ceil(inp)


class RoundQuantElement(QuantElement):
    """do only round with no clip"""

    def lossy_call(self, inp):
        """
        switching bit reducer to round with no clip and wrap around
        """
        return tf.math.round(inp)
