from functools import partial

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType, BiasMode
from hailo_model_optimization.saitama.framework.apu_modules.apu_activation import MAX_PIECES
from hailo_model_optimization.saitama.translators.translator_utils import KeyHandler, _wrap_handlers


class CommonMappings:
    @staticmethod
    def get_bias_mapping(
        base,
        *,
        op_name="bias_add_op",
        scale="output_scale:0:0",
        zero_point="input_zero_point:0:0",
        factor="bias_factor_0:0",
        feed_repeat="bias_feed_repeat:0",
        mac_shift="mac_shift:0",
        weight="bias:0",
        shape_key="bias:0",
        tags=frozenset(),
    ):
        common_tags = _wrap_handlers(tags)
        decompose_tags = set([BiasMode.single_scale_decomposition])
        decompose_tags.update(common_tags)
        scale = f"{op_name}/{scale}"
        zero_point = f"{op_name}/{zero_point}"
        factor = f"{op_name}/{factor}"
        feed_repeat = f"{op_name}/{feed_repeat}"
        mac_shift = f"{op_name}/{mac_shift}"

        if weight:
            weight_handler = KeyHandler(f"{base}.weight", weight, tags=common_tags)
        else:
            weight_handler = KeyHandler(
                f"{base}.weight", (shape_key), lambda x, **kw: np.zeros_like(x), tags=common_tags
            )

        mapping = [
            # Base bias
            KeyHandler(f"{base}.scale", (scale, shape_key), lambda x, y, **kw: x + np.zeros_like(y), tags=common_tags),
            KeyHandler(f"{base}.zero_point", zero_point, lambda x, **kw: -x, tags=common_tags),  # TODO: verify this
            # Single scale decomposition
            KeyHandler(f"{base}.factor", factor, tags=decompose_tags),
            KeyHandler(f"{base}.feed_repeat", feed_repeat, tags=decompose_tags),
            KeyHandler(f"{base}.mac_shift", mac_shift, tags=decompose_tags),
            weight_handler,
        ]

        return mapping

    @staticmethod
    def get_kernel_encoding_mapping(
        base,
        *,
        op_name="conv_op",
        qgroups="act_op/quantization_groups_size:0",
        eq_vec="layer_params/eq_vec_out:0",
        in_idx=0,
        out_idx=0,
        mac_shift=None,
        kernel_zp=None,
        output_scale=None,
        input_scale=None,
        is_channelwise=False,
    ):
        if mac_shift is None:
            mac_shift = f"{op_name}/mac_shift:0"
        if kernel_zp is None:
            kernel_zp = f"{op_name}/kernel_zero_point:0"
        if output_scale is None:
            output_scale = f"{op_name}/output_scale:{out_idx}:0"
        if input_scale is None:
            input_scale = f"{op_name}/input_scale:{in_idx}:0"
        return [
            KeyHandler(
                f"{base}.scale",
                (output_scale, eq_vec, qgroups, mac_shift),
                reconstruct_kernel_scale,
            ),
            KeyHandler(f"{base}.zero_point", kernel_zp),
            KeyHandler(f"{base}.mac_shift", mac_shift),
            KeyHandler(
                f"{base}.equalization_vector_in",
                (input_scale, output_scale),
                partial(reconstruct_equalization_vector_in, is_channelwise=is_channelwise),
            ),
            KeyHandler(f"{base}.equalization_vector_out", eq_vec),
        ]

    @staticmethod
    def get_accumulator_mapping(base, *, output_scale, output_zero_point):
        return [
            KeyHandler(f"{base}.accumulator_quantizer.scale", output_scale),
            KeyHandler(f"{base}.accumulator_quantizer.zero_point", output_zero_point),
        ]

    @staticmethod
    def get_activation_mapping(
        base,
        *,
        offsets="act_op/offsets:0",
        thresholds="act_op/thresholds:0",
        mantissas="act_op/slopes_mantissas:0",
        output_factors="act_op/output_factors:0",
        exponents="act_op/slopes_exponents:0",
        accumulator_scale="act_op/input_scale:0:0",
        output_shift="act_op/output_shift_bias:0",
        qgroups_split="act_op/quantization_groups_size:0",
        equalization_vector="layer_params/eq_vec_out:0",
        output_op="output_op",
    ):
        output_scale = f"{output_op}/output_scale:0:0"
        output_zero_point = f"{output_op}/output_zero_point:0:0"
        if base != "":
            base = f"{base}."
        mapping = [
            KeyHandler(f"{base}slopes.mantissas", (mantissas, output_factors), reconstruct_original_mantissa),
            KeyHandler(f"{base}slopes.exponents", exponents, lambda x, **kw: pad_activation_pieces(x.transpose(1, 0))),
            KeyHandler(f"{base}slopes.scale", output_factors, reconstruct_rescale_factor),
            KeyHandler(f"{base}slopes.q_groups_factor", output_factors, get_group_quant_scales),
            KeyHandler(f"{base}offsets.weight", offsets, pad_activation_pieces),
            KeyHandler(
                f"{base}offsets.scale",
                (accumulator_scale, equalization_vector, qgroups_split, output_factors, output_shift),
                reconstruct_offsets_scale,
            ),
            KeyHandler(
                f"{base}offsets.zero_point",
                (output_zero_point, output_shift),
                reconstruct_offsets_zero_point,
            ),
            KeyHandler(f"{base}offsets.final_shift", output_shift),
            KeyHandler(f"{base}thresholds.weight", (thresholds, qgroups_split), reconstruct_thresholds_weight),
            KeyHandler(
                f"{base}thresholds.scale",
                (accumulator_scale, equalization_vector, qgroups_split),
                reconstruct_accumulator_scale,
            ),
            KeyHandler(f"{base}output_quantizer.scale", output_scale),
            KeyHandler(f"{base}output_quantizer.zero_point", output_zero_point),
            KeyHandler(f"{base}activation_pieces", offsets, lambda x, **kw: len(x)),
        ]
        return mapping

    def get_native_activation_mapping(base):
        if base != "":
            base = f"{base}."
        return [
            KeyHandler(f"{base}activation_delta_bias", "activation_delta_bias:0", tags=ActivationType.BIASED_DELTA),
            KeyHandler(f"{base}activation_greater_values", "activation_greater_values:0", tags=ActivationType.GREATER),
            KeyHandler(f"{base}leaky_alpha", "leaky_alpha:0", tags=ActivationType.LEAKY),
            KeyHandler(f"{base}activation_threshold", "activation_threshold:0", tags=ActivationType.THRESHOLD),
            KeyHandler(f"{base}prelu_slope", "prelu_slope:0", tags=ActivationType.PRELU),
            KeyHandler(f"{base}swish_beta", "swish_beta:0", tags=ActivationType.SWISH),
            KeyHandler(f"{base}activation_less_values", "activation_less_values:0", tags=ActivationType.LESS),
            KeyHandler(f"{base}hardsigmoid_alpha", "hardsigmoid_alpha:0", tags=ActivationType.HARDSIGMOID),
            KeyHandler(f"{base}hardsigmoid_beta", "hardsigmoid_beta:0", tags=ActivationType.HARDSIGMOID),
            KeyHandler(f"{base}clip_min", "clip_min:0", tags=ActivationType.CLIP),
            KeyHandler(f"{base}clip_max", "clip_max:0", tags=ActivationType.CLIP),
            KeyHandler(f"{base}pow_exponent", "pow_exponent:0", tags=ActivationType.POW),
            KeyHandler(
                f"{base}inverse_act_factor",
                "inverse_act_factor:0",
                tags=ActivationType.INV_POS,
                default_factory=lambda: 1.0,
            ),
            KeyHandler(
                f"{base}inverse_act_factor",
                "inverse_act_factor:0",
                tags=ActivationType.INV_SQRT,
                default_factory=lambda: 1.0,
            ),
        ]


def reconstruct_kernel_scale(accumulator_scale, equalization_vector, qgroups_size, mac_shift, **kwargs):
    # TODO: divide by min / max value of input scale?
    accumulator_by_group = reconstruct_accumulator_scale(accumulator_scale, equalization_vector, qgroups_size, **kwargs)
    return accumulator_by_group / 2**mac_shift


def reconstruct_equalization_vector_in(input_scale, output_scale, is_channelwise, **kwargs):
    if is_channelwise:
        return np.repeat(1 / (input_scale), len(output_scale) // len(input_scale))
    return 1 / input_scale


def reconstruct_offsets_scale(
    accumulator_scale,
    equalization_vector,
    qgroups_size,
    output_factors,
    output_shift,
    **kwargs,
):
    accumulator_by_group = reconstruct_accumulator_scale(accumulator_scale, equalization_vector, qgroups_size, **kwargs)
    offset_scale = accumulator_by_group / output_factors / 2**output_shift
    assert np.allclose(offset_scale, offset_scale[0])
    return offset_scale[0]


def reconstruct_original_mantissa(mantissa, output_factors, **kwargs):
    mantissa = (mantissa / output_factors.reshape(-1, 1)).transpose(1, 0)
    return pad_activation_pieces(mantissa, **kwargs)


def reconstruct_offsets_zero_point(output_zero_point, output_shift, **kwargs):
    return output_zero_point * 2**output_shift


def reconstruct_accumulator_scale(accumulator_scale, equalization_vector, qgroups_size, **kwargs):
    accumulator_scale = accumulator_scale / equalization_vector
    accumulator_scale_groups = np.zeros(len(qgroups_size))
    pos = 0
    for idx, gsize in enumerate(qgroups_size):
        group_scale = accumulator_scale[pos : pos + gsize]
        accumulator_scale_groups[idx] = np.min(group_scale)
        pos += gsize
    return accumulator_scale_groups


def reconstruct_thresholds_weight(thresholds, qgroups_size, **kwargs):
    thresholds = np.repeat(np.reshape(thresholds, (1, -1)), len(qgroups_size), axis=0).transpose(1, 0)
    return pad_activation_pieces(thresholds, value=np.inf, num_pieces=MAX_PIECES - 1, **kwargs)


def get_group_quant_scales(output_factors, **kwargs):
    rescale_factor = reconstruct_rescale_factor(output_factors)
    return (1 / output_factors) / rescale_factor


def reconstruct_rescale_factor(output_factors, **kwargs):
    return np.min(1 / output_factors)


def pad_activation_pieces(pieces, value=0, num_pieces=MAX_PIECES, **kwargs):
    padding = pieces.ndim * [(0, 0)]
    padding[0] = (0, num_pieces - len(pieces))
    return np.pad(pieces, padding, "constant", constant_values=value)


def reorder_kernel(kernel, **kwargs):
    if kernel.ndim == 2:
        kernel = kernel.reshape(1, 1, *kernel.shape)
    if kernel.ndim == 4:
        return kernel.transpose(3, 2, 0, 1)
    else:
        raise ValueError(f"Unsupported kernel shape: {kernel.shape}")


def handle_const_input_scale(scale, data, **kwargs):
    return scale[..., : data.shape[-1]].min()


def handle_const_input_equalization(scale, data, **kwargs):
    group_scale = handle_const_input_scale(scale, data, **kwargs)
    return 1 / (scale[..., : data.shape[-1]] / group_scale)
