import os
from typing import Tuple, Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.lossy_elements.quant_element import (
    AccumulatorQuantElement,
    APUOutputQuantElement,
    APUOutputSignedQuantElement,
    APUPreOffsetQuantElement,
    BaseDecompositionElement,
    BaseQuantElement,
    MACDataQuantElement,
    MaskQuantElement,
    QuantElement,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_ACCUMULATOR_SIZE,
    HW_SHIFTS,
    BiasMode,
    DataPath,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasNumerizationError
from hailo_model_optimization.acceleras.utils.logger import default_logger


def rep_as_uint_x_int_repeats(vector, bit_re, max_feed_repeat):
    """
    represent all members of vector as R*U*I, where R is shared across the vector.
    """
    bits = bit_re.bits
    max_int = 2 ** (bits - 1) - 1.0
    max_uint = 2**bits - 1.0
    max_int_uint_mul = max_int * max_uint

    # the only degree of freedom is the I, as R*U are both shared and fixed.
    # so, for best utilization  I=127 will represent the highest element of vector,
    # and the R*U calculated accordingly

    # in case all the element in the bias are the same, we can just use a standard factorization
    if np.all(vector == vector[0]):
        if vector[0] == 0:
            factor = 0
            repeats = 1
        else:
            repeats = np.ceil(np.abs(vector[0] / max_int_uint_mul))
            factor, _ = uint_int_factorize(np.round(vector / repeats)[0], bits=bits)

    else:
        bins = 2 ** (bits - 1) - 1.0
        u_x_r = np.ceil(np.max(np.abs(vector)) / bins)
        if u_x_r != 0:
            factor, repeats = uint_smallnum_factorize(u_x_r, bits=bits, maxsmallnum=max_feed_repeat)
        else:
            repeats = 1
            factor = 0

    # TODO support double scale decomposition future

    return factor, repeats


def int_smallnum_factorize(target, bits, maxsmallnum):
    """
    find min-error factorization of a INT number (smaller than (2^(bits-1) - 1)*maxsmallnum by abs value)
    into integers multiplication I*R where I is int and R is  maxsmallnum or smaller
    """
    return BaseDecompositionElement.int_smallnum_factorize(target, bits, maxsmallnum)


def uint_smallnum_factorize(target, bits, maxsmallnum):
    """
    find min-error factorization of a UINT number (smaller than (2^bits - 1)*maxsmallnum)
    into integers multiplication U*R where U is uint and R is  maxsmallnum or smaller
    """
    return BaseDecompositionElement.uint_smallnum_factorize(target, bits, maxsmallnum)


def a_b_factorize(target, max_a, max_b):
    """
    find min-error factorization of a number as multiplication of two uints,
     one ranged up from 0 to max_a, the other ranged from 0 up to max_b
    """
    return BaseDecompositionElement.a_b_factorize(target, max_a, max_b)


def uint_int_factorize(target, bits):
    """
    find min-error factorization of INT number (in case of INT16* - < 127*255) into U*I
    """
    return BaseDecompositionElement.uint_int_factorize(target, bits)


def calculate_shifts(
    expected_max_accumulator,
    accumulator_size,
    shift_buffer,
    force_rounded_shift_delta=False,
    hw_shifts=None,
    return_needed_shift=False,
    utilize_wraparound=False,
):
    """
    In order to  ovoid overflow we have the option to shift the accumulator by up to 4 shifts.
    calculating the desired shift will be based on the satistics to get  approximately have the expected max value
    going in to the accumulator.
    """
    if utilize_wraparound:
        accumulator_max_val = 2**accumulator_size - 1
    else:
        accumulator_max_val = 2 ** (accumulator_size - 1) - 1
    desired_pre_acc_shift = np.log2(expected_max_accumulator / accumulator_max_val) + shift_buffer
    hw_shifts = HW_SHIFTS if hw_shifts is None else hw_shifts
    hw_shifts = sorted(hw_shifts)
    if accumulator_size == DEFAULT_ACCUMULATOR_SIZE:
        selector_array = np.expand_dims(desired_pre_acc_shift, -1) - np.array(hw_shifts) <= 0
        pre_acc_shift = np.min(
            np.broadcast_to(hw_shifts, selector_array.shape),
            axis=-1,
            initial=np.max(hw_shifts),
            where=selector_array,
        )
    elif accumulator_size == 32:
        pre_acc_shift = 0
    else:
        raise AccelerasNumerizationError(f"We do not support accumulator of size {accumulator_size}")
    shift_delta = desired_pre_acc_shift - pre_acc_shift  # aka "shift delta"
    shift_delta = np.maximum(0, shift_delta)
    if force_rounded_shift_delta:
        shift_delta = np.ceil(shift_delta)
    if return_needed_shift:
        return pre_acc_shift, shift_delta, np.maximum(desired_pre_acc_shift, 0)
    else:
        return pre_acc_shift, shift_delta


def limvals_to_zp_scale(
    limvals,
    bit_reducer,
    name="matching",
    logger=None,
    force_range_to_cover_zero=True,
    activation_symmetric_range=False,
    split_precision_zp=None,
):
    """
    NOTE: nudging scale s.t. zp is round (and consistent with scale)
    """
    if logger is None:
        logger = default_logger()
    if not isinstance(bit_reducer, QuantElement):
        raise AccelerasNumerizationError(f"Th {type(bit_reducer)} must be QuantElement")
    rmin, rmax = limvals
    if rmin == rmax:
        if rmin == 0:  # Special case for both signed and unsigned
            # This behavior (for unsigned) is different from older version, to support for 16 bit weights that would be
            # quantize to unsigned.
            rmin, rmax = np.float32(-1), np.float32(1)
        elif not bit_reducer.signed:  # Unsigned
            # When unsigned, and x is positive (negative), change the range to [0, factor * x] ([factor * x, 0])
            # s.t x would be mapped to a round number in the middle.
            # 0 is included in the range (duplicate of line start at 133) to support the case where
            # force_range_to_cover_zero is False.
            factor = 2 - 1 / (2 ** (bit_reducer.bits - 1))
            rmin, rmax = min(factor * rmin, np.float32(0)), max(factor * rmax, np.float32(0))
        logger.verbose(f"for layer {name} all the same - changing to limvals {[rmin, rmax]}")

    if bit_reducer.signed:  # usually, weights.
        zp = 0
        min_factor = np.abs(bit_reducer.max_value / bit_reducer.min_value)
        limvals = np.array([rmin * min_factor, rmax])
        absmax = np.max(np.abs(limvals))
        bins_pos = 2.0 ** (bit_reducer.bits - 1) - 1
        scale = absmax / bins_pos
        limvals_nudged = np.array([-absmax / min_factor, absmax])
    else:  # Unsigned; usually, activations.
        bins = 2**bit_reducer.bits - 1
        if force_range_to_cover_zero:
            # we want to verify that Zero is inside representable range
            # TODO - for now we do sync with Legacy, we may want to rethink
            rmin = min(rmin, np.float32(0))
            rmax = max(rmax, np.float32(0))

        if activation_symmetric_range:
            abs_val = np.maximum(rmax, np.abs(rmin))
            rmax = abs_val
            rmin = -abs_val
            scale_proposal = abs_val / bins
            zp_proposal = -rmin / scale_proposal
            value_zp = 256
            zp = np.round((zp_proposal) / value_zp) * value_zp + value_zp / 2

            scale_min = -rmin / zp
            scale_max = rmax / (2 * bins + 1 - zp)
            scale = max(scale_min, scale_max)
            limvals_nudged = -zp * scale, (2 * bins + 1 - zp) * scale
        elif split_precision_zp is not None:
            # zp = value + N * 256 and zp > scale_proposal
            scale_proposal = (rmax - rmin) / bins
            zp_proposal = -rmin / scale_proposal

            N = np.ceil((zp_proposal - split_precision_zp) / 256)
            zp = split_precision_zp + N * 256

            scale = rmax / (bins - zp) if zp < bins else -rmin / zp
            limvals_nudged = -zp * scale, (bins - zp) * scale
        else:
            scale_proposal = (rmax - rmin) / bins
            zp_proposal = -rmin / scale_proposal
            if os.environ.get(
                "USE_ROUND_ZP_FOR_FORCE_RANGE", 0
            ):  # special case for force_range to match precomputed values
                zp = min(np.round(zp_proposal), bins)
            else:  # regular case
                zp = min(np.ceil(zp_proposal), bins)
            scale = rmax / (bins - zp) if zp < bins else -rmin / zp
            # TODO consider saving the effectively used limvals" somewhere
            #   (in SDK analog it's written back to limvals)
            limvals_nudged = -zp * scale, (bins - zp) * scale
    return np.float32(zp), np.float32(scale), np.float32(limvals_nudged)


def mmse(ker, bits=4):
    """
    Return the mmse-optimal dynamic range, balancing the clipping and quantization errors.
    Algorithm used is the "Progressive Project Quantization" (PPQ) from the "Alpha-Blend" paper:
     https://arxiv.org/pdf/1903.01061.pdf (ARM, 2019)
    """
    bins_pos = 2 ** (bits - 1) - 1
    aker = np.abs(ker)
    nmax = np.max(aker)
    if ker.size == 1:
        return nmax
    baseclip = np.percentile(aker, 99) / nmax
    if bits == 8:
        baseclip = np.percentile(aker, 99.99) / nmax
    nstep = nmax / bins_pos * baseclip
    nquant = np.clip(np.round(ker / nstep), -bins_pos, bins_pos)
    for i in range(20):
        nstep = np.sum(ker * nquant) / np.sum(nquant * nquant)
        nquant = np.clip(np.round(ker / nstep), -bins_pos, bins_pos)

    return nstep * bins_pos


def get_kernel_bits_and_sign_by_precision_mode(precision_mode: PrecisionMode, force_signed_kernel=False):
    kernel_bits_by_mode = {
        PrecisionMode.a8_w8: (8, True),
        PrecisionMode.a8_w8_a8: (8, True),
        PrecisionMode.a16_w8_a8: (8, True),
        PrecisionMode.a16_w8_a16: (8, True),
        PrecisionMode.a8_w8_a16: (8, True),
        PrecisionMode.a8_w4: (4, True),
        PrecisionMode.a8_w4_a8: (4, True),
        PrecisionMode.a8_w4_a16: (4, True),
        PrecisionMode.a16_w4_a16: (4, True),
        PrecisionMode.a16_w4_a8: (4, True),
        PrecisionMode.a16_w16: (16, True) if force_signed_kernel else (15, False),
        PrecisionMode.a16_w16_a8: (16, True) if force_signed_kernel else (15, False),
        PrecisionMode.a16_w16_a16: (16, True) if force_signed_kernel else (15, False),
    }
    return kernel_bits_by_mode[precision_mode]


def get_input_bits_by_precision_mode(precision_mode: PrecisionMode):
    pair_representation = PrecisionMode("_".join(precision_mode.value.split("_")[:2]))
    input_bits_by_mode = {
        PrecisionMode.a8_w8: 8,
        PrecisionMode.a8_w4: 8,
        PrecisionMode.a16_w16: 15,
        PrecisionMode.a16_w8: 16,
        PrecisionMode.a16_w4: 16,
    }
    return input_bits_by_mode[pair_representation]


def get_output_bits_by_precision_mode(precision_mode: PrecisionMode):
    split_representation = precision_mode.value.split("_")
    if len(split_representation) == 2:
        return None
    output_bits = int(split_representation[2].replace("a", ""))
    if output_bits == 16:
        output_bits = 15
    return output_bits


def get_accumulator_bits_by_precision_mode(precision_mode: PrecisionMode):
    weight_value = int(precision_mode.value.split("_")[1].replace("w", ""))
    accumulator_bits_by_weight = {
        4: 16,
        8: 16,
        16: 32,
    }
    return accumulator_bits_by_weight[weight_value]


def get_output_predecessor_precision_mode_by_bits(bits: int) -> PrecisionMode:
    precision_mode_by_bits = {
        8: PrecisionMode.a8_w8,
        16: PrecisionMode.a16_w16,
        15: PrecisionMode.a16_w16,
    }
    return precision_mode_by_bits[bits]


def get_decomposition_count_by_bias_mode(bias_mode: BiasMode):
    num_decomposition_by_mode = {
        BiasMode.double_scale_initialization: 0,
        BiasMode.single_scale_decomposition: 1,
        BiasMode.double_scale_decomposition: 2,
    }
    return num_decomposition_by_mode[bias_mode]


def get_quant_element_by_data_path(data_path: DataPath, bits: int):
    qelem_by_data_path = {
        DataPath.ACCUMULATOR: AccumulatorQuantElement,
        DataPath.MAC_DATA: MACDataQuantElement,
        DataPath.LAYER_OUT: APUOutputQuantElement,
        DataPath.LAYER_IN: APUOutputQuantElement,
        DataPath.LAYER_IN_WEIGHTS: APUOutputSignedQuantElement if bits == 8 else APUOutputQuantElement,
        DataPath.LAYER_OUT_WEIGHTS: APUOutputSignedQuantElement if bits == 8 else APUOutputQuantElement,
        DataPath.DATA_MULT: AccumulatorQuantElement,
        DataPath.POST_DATA_MULT: APUPreOffsetQuantElement,
        DataPath.LAYER_X_SUM: APUOutputSignedQuantElement,
        DataPath.LAYER_X2_SUM: APUOutputQuantElement,
        DataPath.LAYER_E_X_SUM: APUOutputQuantElement,
        DataPath.EXP_NUME: APUOutputQuantElement,
        DataPath.EXP_DENO: APUOutputQuantElement,
        DataPath.LAYER_IN_WEIGHTS_16: APUOutputSignedQuantElement,
        DataPath.LAYER_IN_MASK: MaskQuantElement,
        DataPath.LAYER_IN_INV: APUOutputQuantElement,
        DataPath.LAYER_SPLIT_INPUT: APUOutputQuantElement,
        DataPath.LAYER_ROOT: APUOutputQuantElement,
        DataPath.LAYER_MU: APUOutputQuantElement,
        DataPath.INTER_BLOCK_16: APUOutputQuantElement,
        DataPath.INTER_BLOCK_8: APUOutputQuantElement,
    }
    qelem = qelem_by_data_path[data_path]
    return qelem


def _get_limvals_numeric(stats_min, stats_max, scale):
    """
    per-channel min/max (@calib) of expected HW values, up to global scalar factor which we’ll find below.
    If that factor is 1 (e.g. repeated invocation), we expect to see values within 0,255 ,
    reaching those bounds for at least one channel, or hopefully many channels if coming here
    after successful equalization
    Args:
        stats_min: min stats per channel
        stats_max: max stats per channel
        scale: the scale of the layer
    Returns:

    """
    return np.min(stats_min / scale), np.max(stats_max / scale)


def update_scale(
    scale: Union[float, np.ndarray],
    limvals: Tuple[np.ndarray, np.ndarray],
    bit_reducer: BaseQuantElement,
    name: str,
    logger,
    activation_symmetric_range: bool = False,
    split_precision_zp=None,
) -> Tuple[np.ndarray, float]:
    """
    In acceleras dont change anymore the kernel weights but update the input/ output scales. For example
    in equalization instead of changing the weights of the kernel, we change the input and output scales.
    In this approach, after applying the change of the inputs and output scales the following form:
        vector_scales_candidate := scalar_scale * vector_rescale_factors
    Where:
        scalar_scale - is the scalar scale we calculated before.
        vector_rescale_factors - is the vector scales that was calculated in equalization

    Now we must update the scalar_scales (which may have changed after equalization) because the limvals changed.

    1. An important observations that teis the connection between limvals and scales:
        if limvals_1 = scalar * limvals_0, then:
        scale_1 = scalar * scale_0
        zp_1 = zp_0
        where qp_1 = (zp_1, scale_1) is the qp of limvals_1 and qp_0 = (zp_0, scale_0) qp of limvals_0)
    2.  An important observations that ties the connection between the factors and equalization -
        native_stats_before_equalization/vector_rescale_factors =  stats_after_equalization

    Now we can continue:
     -->>(from observation 2)
    limvals_after_equalization := -->>(observation 2)
               = (np.min(stats_min/vector_rescale_factors), np.max(stats_max/vector_rescale_factors))
               = (np.min(stats_min_after_equalization), np.max(stats_max_after_equalization))


    limvals_candidate_numeric := with -  vector_scales_candidate
               = (np.min(stats_min / vector_scales_candidate), np.max(stats_max / vector_scales_candidate))
                = (np.min(stats_min / (vector_rescale_factors*scalar_scale)), np.max(stats_max / (vector_rescale_factors*scalar_scale)))
               = 1/scalar_scale * (np.min(stats_min/ vector_rescale_factors), np.max(stats_max/vector_rescale_factors))
               = 1/scalar_scale * limvals_after_equalization


    limvals_candidate_numeric * scalar_scale= limvals_after_equalization
    -->> (from observation 1)
         scalar_scales_candidates_numeric * scalar_scale = scalar_scales_after_equalization

    So now we will calculate scalar_scales_candidates_numeric and we are done.


    """
    limvals_up_to_scalar = _get_limvals_numeric(limvals[0], limvals[1], scale)
    zp, scalar_scales_candidates_numeric, _ = limvals_to_zp_scale(
        limvals_up_to_scalar,
        bit_reducer,
        name,
        logger,
        activation_symmetric_range=activation_symmetric_range,
        split_precision_zp=split_precision_zp,
    )
    new_scale = scalar_scales_candidates_numeric * scale
    return np.array(new_scale, np.float32), zp


def get_scalar_vector(vector, *, name="", rtol=1e-5, atol=1e-7):
    """
    a function that checks the vector is a scalar
    Args:
        vector: vector

    Returns: the scalar_vector

    """
    if isinstance(vector, (float, int, np.float32, np.float64)):
        return vector

    if vector.shape != ():
        if not np.allclose(vector[0] * np.ones_like(vector), vector, rtol=rtol, atol=atol):
            raise AccelerasNumerizationError(
                f"{name}: the vector must be a scalar but there is a diff "
                f"{np.max(np.abs(vector - vector[0]) / vector)}"
            )
        scalar_vector = vector[0]
    else:
        scalar_vector = vector
    return scalar_vector


# debuging utils
def get_number_bits(data, signed):
    if signed:
        min_data = np.abs(np.min(data))
        max_data = np.abs(np.max(data))

        bits_min = np.ceil(np.log2(min_data)) + 1
        bits_max = np.ceil(np.log2(max_data + 1)) + 1

        return np.max([bits_min, bits_max])
    else:
        max_val = np.max(data) + 0.1
        bits = np.ceil(np.log2(max_val))
        return bits


# lut function calculator
def calc_lut_table(inp, lut_func, zp_in, s_in, zp_out, s_out, bits_out, signed=False, quant=True):
    # given the zp_in,s_in, zp_out, s_out and the lut_func, calculate the output of the lut tables
    if signed:
        bit_width_signed = bits_out - 1
        max_val = (2**bit_width_signed) - 1
        min_val = -(2**bit_width_signed)
    else:
        max_val = (2**bits_out) - 1
        min_val = 0

    data_dec = (inp - zp_in) * s_in
    out_native = lut_func(data_dec)
    out_quant = out_native / s_out + zp_out
    if quant:
        return tf.clip_by_value(tf.math.round(out_quant), min_val, max_val)
    else:
        return out_quant


def verify_data_dtype(data_1, bit_width, signed, name):
    data = np.array(data_1)
    if signed:
        bit_width_signed = bit_width - 1
        max_val = (2**bit_width_signed) - 1
        min_val = -(2**bit_width_signed)
    else:
        max_val = (2**bit_width) - 1
        min_val = 0
    current_bits = get_number_bits(data, signed)
    if not (np.all(min_val <= data) and np.all(data <= max_val)):
        msg = f"param: {name} -current {current_bits} got {name} data {[np.min(data), np.max(data)]} - TOOO BIGGG  {bit_width} signed {signed} {[min_val, max_val]}"
        raise Exception(msg)


def bankers_round_int_shift(value, shift):
    """
    Implements round(value / 2**shift)
    """
    div_plus = (value + 2 ** (shift - 1)) // 2**shift
    div_odd = div_plus % 2
    half_mod = tf.cast(2 ** (shift - 1) == value % 2**shift, value.dtype)
    return div_plus - div_odd * half_mod


def reduce_dataset(data, reduced_batch_size):
    """
    Get the desired BS out of the data_io tensor.

    Args:
        data (list, np.array): data to cut from
        reduced_batch_size (int): number of BS to extract from each entry.

    """
    if isinstance(data, list):
        data_reduced = []
        num_entries = len(data)
        for i in range(num_entries):
            data_reduced.append(data[i][:reduced_batch_size])
    elif isinstance(data, dict):
        data_reduced = dict()
        for key, data_value in data.items():
            data_reduced.update({key: data_value[:reduced_batch_size]})
    else:
        data_reduced = data[:reduced_batch_size]
    return data_reduced
