import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import PrecisionMode, PreFTClippingMethod
from hailo_sdk_client.numeric_translator.params_sorter import ParamsSorter
from hailo_sdk_common.logger.logger import default_logger


def mmse(ker, bits=4):
    """
    Return the mmse-optimal dynamic range, balancing the clipping and quantization errors.
    Algorithm used is the "Progressive Project Quantization" (PPQ) from the "Alpha-Blend" paper:
     https://arxiv.org/pdf/1903.01061.pdf (ARM, 2019)z
    """
    bins_pos = 2 ** (bits - 1) - 1
    aker = np.abs(ker)
    nmax = np.max(aker)
    if ker.size == 1:
        return nmax
    baseclip = np.percentile(aker, 99) / nmax
    nstep = nmax / bins_pos * baseclip
    nquant = np.clip(np.round(ker / nstep), -bins_pos, bins_pos)
    for _i in range(20):
        nstep = np.sum(ker * nquant) / np.sum(nquant * nquant)
        nquant = np.clip(np.round(ker / nstep), -bins_pos, bins_pos)

    return nstep * bins_pos


def mmse4b_max8b(kernel, bits=4):
    if bits == 4:
        return mmse(kernel, bits)
    else:
        return np.max(np.abs(kernel))


def _clip_kernel_pre_ft(
    kernel,
    lname,
    bits,
    verbose=True,
    clip_method=PreFTClippingMethod.MMSE_IF4B,
    clip_factor=None,
    clip_percentile=None,
    clip_percentile_8b=None,
):
    """
    Finds optimal clipping value for a (normally 4-bit) kernel, or a slice of kernel
    """
    kernel_abs_f = np.abs(kernel).flatten()
    naive_max = np.max(kernel_abs_f)
    if clip_method == PreFTClippingMethod.SET_FACTOR:
        clip_value = naive_max * clip_factor if bits == 4 else naive_max
    elif clip_method == PreFTClippingMethod.SET_PERCENTILE:
        if bits == 4:
            clip_value = np.percentile(kernel_abs_f, clip_percentile)
        elif clip_percentile_8b is not None:
            clip_value = np.percentile(kernel_abs_f, clip_percentile_8b)
        else:
            clip_value = naive_max
    elif clip_method == PreFTClippingMethod.MMSE:
        clip_value = mmse(kernel_abs_f, bits=bits)
    elif clip_method == PreFTClippingMethod.MMSE_IF4B:
        clip_value = mmse4b_max8b(kernel_abs_f, bits=bits)
    else:
        raise NotImplementedError(f"Clipping method {clip_method.value} not yet implemented")

    if verbose:
        final_perc = 100 * np.sum(kernel_abs_f <= clip_value) / len(kernel_abs_f)
        final_factor = clip_value / naive_max
        rep_str = (
            "Clipped kernel (or slice) of {} from naive range of {:.2f} to {:.3f} "
            + "(equivalent to {:.2f}% perc', {:.2f} factor), as chosen by {} method"
        ).format(lname, naive_max, clip_value, final_perc, final_factor, str(clip_method))
        if bits == 4:
            rep_str += f"..{100 * np.mean(np.abs(kernel) < clip_value / 14):.2f}% zeros"
        default_logger().debug(rep_str)
    return clip_value


def clip_kernel_pre_ft(kernel, lname, bits, num_groups=1, verbose=True, **kwargs):
    """
    Finds optimal channelwise clipping values VECTOR for a (normally 4-bit) kernel,
    guaranteed to satisfy the GROUPWISE CONSTRAINT (equal within group)
    """
    if num_groups == 1:
        return _clip_kernel_pre_ft(kernel, lname, bits, verbose=verbose, **kwargs)

    if num_groups is None:
        num_groups = kernel.shape[-1]  # full channelwise..
        verbose = False  # would be spammy to print for each channel..

    clip_val_vec = np.ones(kernel.shape[-1])

    group_size = kernel.shape[-1] // num_groups
    assert group_size * num_groups == kernel.shape[-1]

    for group_1st_ch in range(0, kernel.shape[-1], group_size):
        kernel_slice = kernel[..., group_1st_ch : (group_1st_ch + group_size)]
        clip_val = _clip_kernel_pre_ft(kernel_slice, lname, bits, verbose=verbose, **kwargs)
        clip_val_vec[group_1st_ch : (group_1st_ch + group_size)] = clip_val

    return clip_val_vec


class ClipAwareParamsSorter(ParamsSorter):
    """
    Overrides original order decision method with the smarter heuristic:
    (A) We let all producers' kernels "vote" in the decision.
        They are concatenated in egalitarian fashion
            (afted getting properly normalized by their respective optimal range),
    (B) The resultant synthetic kernel is channel-sorted by the optimal range of each channel slice.
        This determines the order.
    (C) Bitwidth awareness is added in a simple greedy fashion - if some of the producers are 4-bit,
        they take precedence and only them are used for the concat.
        Also, the "optimal range" for all algo stages is calculated depending on bitwidth.
    """

    def __init__(self, hailo_nn, opt_range_func=mmse4b_max8b, hn_layers_4b_weights=None):
        super().__init__(hailo_nn)
        self.opt_range_func = opt_range_func

        self.hn_layers_4b_weights = hn_layers_4b_weights or [
            layer.name
            for layer in hailo_nn.stable_toposort()
            if layer.precision_config.precision_mode.reduce() == PrecisionMode.a8_w4
        ]

    def _get_feature_sort_order(self, flow):
        """
        Overrides original order decision method with the smarter heuristic (see class docstring)
        """
        producers = [flow.source] + flow.ew_bouncers  # TODO multiple sources

        prod_kernel_slices_normed = []
        bits_l = []
        for producer in producers:
            source_layer, producer_layer, source_indices, layer_indices = producer.get_as_tuple()
            producer_key = producer_layer.name
            producer_kernel = self._sorted_params[producer_key + "/kernel:0"]
            bits = 4 if producer_layer.name in self.hn_layers_4b_weights else 8
            bits_l.append(bits)

            kernel_slice = np.copy(producer_kernel[..., np.array(layer_indices)])
            kernel_slice /= self.opt_range_func(kernel_slice.flatten(), bits=bits)  # normalize
            kernel_slice = np.reshape(kernel_slice, (-1, kernel_slice.shape[-1]))
            prod_kernel_slices_normed.append(kernel_slice)

        if 4 in bits_l:  # (!) if any 4-bit layers, use them exclusively for the heuristic
            prod_kernel_slices_normed = [k for k, b in zip(prod_kernel_slices_normed, bits_l) if b == 4]
            bits = 4
        else:
            bits = 8

        prod_kernel_slices_normed_concat = np.concatenate(prod_kernel_slices_normed)

        _tmp_clip_val = np.array(
            [
                self.opt_range_func(prod_kernel_slices_normed_concat[:, ch].flatten(), bits=bits)
                for ch in range(prod_kernel_slices_normed_concat.shape[-1])
            ],
        )
        return np.argsort(_tmp_clip_val)
