import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType

ACTIVATION_FITTING_PIECES = 8

QUANTIZATION_SUPPORTED_ACTIVATION_TYPES = [
    ActivationType.RELU,
    ActivationType.LINEAR,
    ActivationType.LEAKY,
    ActivationType.RELU6,
    ActivationType.SIGMOID,
    ActivationType.EXP,
    ActivationType.TANH,
    ActivationType.INV_POS,
    ActivationType.MINUS_INV_POS,
    ActivationType.BIASED_DELTA,
    ActivationType.LESS,
    ActivationType.THRESHOLD,
    ActivationType.HARDSIGMOID,
    ActivationType.SOFTPLUS,
    ActivationType.SQRT,
    ActivationType.ELU,
    ActivationType.GELU,
    ActivationType.HARDSWISH,
    ActivationType.CLIP,
    ActivationType.SILU,
    ActivationType.INV_SQRT,
    ActivationType.MISH,
    ActivationType.SWISH,
    ActivationType.LOG,
    ActivationType.SOFTSIGN,
    ActivationType.DELTA,
    ActivationType.GREATER,
    ActivationType.POW,
    ActivationType.HDR_COMPRESSION,
    ActivationType.RELU1,
    ActivationType.RELU_POSITIVE_SQUARE,
    ActivationType.PWL,
    ActivationType.EXP_DECOMPOSE,
    ActivationType.SHIFT,
]

NATIVE_ONLY_ACTIVATION_TYPES = [ActivationType.PRELU]

ACTIVATIONS_TO_FIT = [
    ActivationType.EXP,
    ActivationType.INV_POS,
    ActivationType.MINUS_INV_POS,
    ActivationType.SQRT,
    ActivationType.LOG,
    ActivationType.POW,
    ActivationType.RELU_POSITIVE_SQUARE,
]

MUST_FIT_ACTIVATIONS = [
    ActivationType.SQRT,
    ActivationType.POW,
    ActivationType.RELU_POSITIVE_SQUARE,
]

ACTIVATIONS_FITTING_SUPPORTED = [*ACTIVATIONS_TO_FIT, ActivationType.INV_SQRT, ActivationType.GELU]


def get_num_of_pieces(activation, act_native_params=None, act_numeric_params=None, consider_params=True):
    if activation not in QUANTIZATION_SUPPORTED_ACTIVATION_TYPES or activation == ActivationType.LINEAR:
        return 1
    if activation in [
        ActivationType.RELU,
        ActivationType.LEAKY,
        ActivationType.LESS,
        ActivationType.GREATER,
        ActivationType.THRESHOLD,
    ]:
        return 2
    if activation in [
        ActivationType.RELU6,
        ActivationType.RELU1,
        ActivationType.BIASED_DELTA,
        ActivationType.DELTA,
        ActivationType.HARDSIGMOID,
    ]:
        return 3
    if activation in [ActivationType.HDR_COMPRESSION]:
        return 4
    if activation in [ActivationType.ELU]:
        return 5
    if activation in [
        ActivationType.EXP,
        ActivationType.TANH,
        ActivationType.GELU,
        ActivationType.HARDSWISH,
        ActivationType.LOG,
        ActivationType.SOFTSIGN,
        ActivationType.MINUS_INV_POS,
    ] or (
        activation == ActivationType.INV_POS
        and (not consider_params or act_native_params.get("inverse_act_factor", 0) < 0)
    ):
        return 8
    if activation in [
        ActivationType.SIGMOID,
        ActivationType.SOFTPLUS,
        ActivationType.SILU,
        ActivationType.MISH,
        ActivationType.EXP_DECOMPOSE,
        ActivationType.SHIFT,
        ActivationType.SWISH,
    ] or (
        activation == ActivationType.INV_SQRT
        and (not consider_params or act_native_params.get("inverse_act_factor", 0) < 0)
    ):
        return 9
    if activation == ActivationType.PWL:
        return len(act_native_params["slopes"])

    if activation == ActivationType.CLIP:
        slopes = 3
        if consider_params:
            if act_numeric_params.get("clip_min", -1) == -np.inf:
                slopes -= 1
            if act_numeric_params.get("clip_max", 1) == np.inf:
                slopes -= 1
        return slopes

    raise ValueError(f"Failed get_len_slopes, unexpected activation {activation}")
