from collections import namedtuple
from dataclasses import dataclass
from enum import Enum

from pydantic.v1 import BaseModel, Field

from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError

HW_SHIFTS = [1, 2, 3, 4]
SHIFT_OPTIONS_DEFAULT = [1, 2, 3, 4]
SHIFT_OPTIONS_4BIT = [0, 1, 2, 3]
HW_SHIFTS_PLUTO = [0, 1, 2, 3, 4]

# Before UINT8 x INT8 MAC multiplier, the 4b weights are "placed" into the 8b field.
#  this is done with some (HW hardcoded) shift[-left]; motivated by the non-zero pre-acc shift[-right]
WEIGHTS_PLACEMENT_SHIFT = 1

# The input data to ew_mult is 2 int9 arguments. after multiplication, we get int17.
# The APU multiplier support int16 input, so we have to drop 1 bit
EW_MULT_MULTIPLIER_SHIFT = 2
HAILO15_EW_MULT_MULTIPLIER_SHIFT = 1

ZP_FEED_REPEAT = 4  # TODO will be changed when gets in properly

# Number of bits for the activation's (APU) lossy banker shift rounding, as performed over HW.
POST_SHIFT_1_ROUNDING = 6
POST_SHIFT_2_ROUNDING = 3

OUTPUT_ENCODING_CHICKEN_BIT = False

# represents -inf of padding constant value in pooling/external pad layers
DEFAULT_PADDING_NEG_INF_VALUE = -(2**31)


# mars softmax definitions
EXP_OUT_BITS = 29
LUT_IN_BITS = 16
EXP_NUME_BITS = 16

ZP_LOW_SPLIT_PRECISION_PIXEL = 127.0


class AccelerasEnum(Enum):
    @classmethod
    def _missing_(cls, value):
        raise AccelerasValueError(f"{value} is not a valid {cls.__name__}")


# Note: Enums decorated with `@auto_generate_schema` are used in HN schemas!
# To allow for HN backwards compatibility:
# * Do not remove keys from the enums, rather append new keys.
# * Do not change the string values.
def auto_generate_schema(enumeration):
    """
    Enums decorated with this decorator will be added to the autogenerated
    hn_definitions.schema.auto.json file.
    """
    enumeration.auto_generate_schema = True
    return enumeration


class CaseInsensitiveEnum(AccelerasEnum):
    @classmethod
    def _missing_(cls, value):
        if isinstance(value, str):
            if value.upper() in cls._value2member_map_:
                return cls._value2member_map_[value.upper()]
            elif value.lower() in cls._value2member_map_:
                return cls._value2member_map_[value.lower()]
        return super()._missing_(value)


class DataPath(Enum):
    LAYER_IN = "layer_in"
    LAYER_IN_WEIGHTS = "layer_in_weights"
    ACCUMULATOR = "accumulator"
    DATA_MULT = "data_mult"
    POST_DATA_MULT = "post_data_mult"
    MAC_DATA = "mac_data"
    LAYER_OUT = "layer_out"
    LAYER_OUT_WEIGHTS = "layer_out_weights"
    LAYER_X_SUM = "layer_x_sum"
    LAYER_X2_SUM = "layer_x2_sum"
    LAYER_E_X_SUM = "layer_e_x_sum"

    LAYER_IN_WEIGHTS_16 = "layer_in_weights_16"
    LAYER_IN_MASK = "layer_in_mask"

    EXP_DENO = "exp_deno"
    EXP_NUME = "exp_nume"

    LAYER_SPLIT_INPUT = "layer_split_input"
    LAYER_IN_INV = "layer_in_inv"
    LAYER_MU = "layer_mu"
    LAYER_ROOT = "layer_root"

    INTER_BLOCK_16 = "intermediate_block_connection_16_bits"
    INTER_BLOCK_8 = "intermediate_block_connection_8_bits"


# hn definitions #####################################################################################
# TODO: copied values from hn, make sure nothing is missing + test


class LayerType(AccelerasEnum):
    """Enum-like class for layers types"""

    CONV = "conv"
    DW = "dw"
    DENSE = "dense"
    DECONV = "deconv"
    NORMALIZATION = "normalization"
    BATCH_NORM = "batch_norm"
    ACTIVATION = "activation"
    SOFTMAX = "softmax"
    MAXPOOL = "maxpool"
    AVGPOOL = "avgpool"
    RESIZE = "resize"
    DEPTH_TO_SPACE = "depth_to_space"
    SPACE_TO_DEPTH = "space_to_depth"
    CONCAT = "concat"
    SLICE = "slice"
    BBOX_DECODER = "bbox_decoder"
    PROPOSAL_GENERATOR = "proposal_generator"
    NMS = "nms"
    FEATURE_SPLITTER = "feature_splitter"
    ELEMENTWISE_ADD = "ew_add"
    ELEMENTWISE_MULT = "ew_mult"
    ELEMENTWISE_SUB = "ew_sub"
    ELEMENTWISE_MAX = "ew_max"
    FEATURE_SHUFFLE = "feature_shuffle"
    INPUT_LAYER = "input_layer"
    OUTPUT_LAYER = "output_layer"
    FORMAT_CONVERSION = "format_conversion"
    MATMUL = "matmul"
    REDUCE_SUM = "reduce_sum"
    REDUCE_MAX = "reduce_max"
    ARGMAX = "argmax"
    CONST_INPUT = "const_input"
    SHORTCUT = "shortcut"
    REDUCE_MEAN = "reduce_mean"
    FEATURE_MULTIPLIER = "feature_multiplier"
    POSTPROCESS = "postprocess"
    EXTERNAL_PAD = "external_pad"
    ROW_SPLITTER = "row_splitter"
    WIDTH_SPLITTER = "width_splitter"
    CONV_DECOMPOSE = "conv_decompose"
    PORTAL = "portal"
    LAYER_NORM = "layer_normalization"
    FUSED_BBOX_DECODER = "fused_bbox_decoder"
    PRECISION_SPLITTER = "precision_splitter"
    PRECISION_SPLITTER_SIGNED = "precision_splitter_signed"


class ActivationType(AccelerasEnum):
    RELU = "relu"
    RELU6 = "relu6"
    SIGMOID = "sigmoid"
    LINEAR = "linear"
    LEAKY = "leaky"
    ELU = "elu"
    EXP = "exp"
    TANH = "tanh"
    PRELU = "prelu"
    SOFTPLUS = "softplus"
    SILU = "silu"
    GELU = "gelu"
    MISH = "mish"
    INV_POS = "inv_pos"
    MINUS_INV_POS = "minus_inv_pos"
    THRESHOLD = "threshold"
    BIASED_DELTA = "biased_delta"
    HARDSWISH = "hardswish"
    SWISH = "swish"
    RELU1 = "relu1"
    SQRT = "sqrt"
    LESS = "less"
    LOG = "log"
    HARDSIGMOID = "hardsigmoid"
    CLIP = "clip"
    INV_SQRT = "inv_sqrt"
    SOFTSIGN = "softsign"
    DELTA = "delta"
    GREATER = "greater"
    POW = "pow"
    HDR_COMPRESSION = "hdr_compression"
    RELU_POSITIVE_SQUARE = "relu_positive_square"
    PWL = "pwl"
    EXP_DECOMPOSE = "exp_decompose"
    SHIFT = "shift"


BoundedActivation = {
    ActivationType.LINEAR.value: False,
    ActivationType.RELU.value: False,
    ActivationType.RELU6.value: True,
    ActivationType.RELU1.value: True,
    ActivationType.LEAKY.value: False,
    ActivationType.ELU.value: False,
    ActivationType.SIGMOID.value: True,
    ActivationType.EXP.value: False,
    ActivationType.TANH.value: True,
    ActivationType.THRESHOLD.value: False,
    ActivationType.BIASED_DELTA.value: True,
    ActivationType.PRELU.value: False,
    ActivationType.SOFTPLUS.value: False,
    ActivationType.SILU.value: False,
    ActivationType.GELU.value: False,
    ActivationType.MISH.value: False,
    ActivationType.INV_POS.value: False,
    ActivationType.HARDSWISH.value: False,
    ActivationType.HARDSIGMOID.value: True,
    ActivationType.CLIP.value: True,
    ActivationType.INV_SQRT.value: False,
    ActivationType.SOFTSIGN.value: True,
    ActivationType.DELTA.value: True,
    ActivationType.MINUS_INV_POS.value: False,
    ActivationType.SWISH.value: False,
    ActivationType.SQRT.value: False,
    ActivationType.LESS.value: True,
    ActivationType.LOG.value: False,
    ActivationType.GREATER.value: True,
    ActivationType.HDR_COMPRESSION.value: False,
    ActivationType.RELU_POSITIVE_SQUARE.value: False,
    ActivationType.PWL.value: False,
    ActivationType.EXP_DECOMPOSE.value: False,
    ActivationType.SHIFT.value: False,
}


class NativeName(AccelerasEnum):
    EW_SUB = "ew_sub"


class OpStates(Enum):
    """Enum-like class for the different states a Op or Layer can be"""

    FP = "floating_point"
    CALIBRATED = "calibrated"
    QUANTIZED = "quantized"


class PaddingType(AccelerasEnum):
    VALID = "VALID"
    SAME = "SAME"
    DECONV = "DECONV"
    SAME_TENSORFLOW = "SAME_TENSORFLOW"


class ResizeBilinearPixelsMode(AccelerasEnum):
    """
    Enum-like class for resize bilinear pixel coordinates transformation mode.
    Currently supporting basic PyTorch and TF variants.
    """

    DISABLED = "disabled"
    ALIGN_CORNERS = "align_corners"
    HALF_PIXELS = "half_pixels"


class ResizeMethod(AccelerasEnum):
    """Enum-like class for methods supported by 'resize' LayerType"""

    BILINEAR = "bilinear"
    NEAREST_NEIGHBOR = "nearest_neighbor"


class DepthToSpaceType(Enum):
    """Enum-like class for Types of depth_to_space by 'DepthToSpace' LayerType"""

    dcr = "dcr"
    crd = "crd"


class SpaceToDepthType(AccelerasEnum):
    """Enum-like class acceleras for Types of space_to_depth by 'SpaceToDepth' LayerType"""

    CLASSIC_DCR = "classic_dcr"
    CLASSIC_CRD = "classic_crd"
    FOCUS = "focus"
    SERIAL = "serial"


@auto_generate_schema
class PrecisionSplitMode(AccelerasEnum):
    """TODO: fill docstring"""

    NORMAL = "normal"
    PIXELS = "pixels"


class NpzExportMode(AccelerasEnum):
    """Enum-like class acceleras for Types of npz export mode"""

    WEIGHTS = 0
    QNPZ = 1
    ACCELERAS = 2


@auto_generate_schema
class ConcatAxis(AccelerasEnum):
    """Enum-like class acceleras for the axis to concatenate"""

    features = "features"
    spatial_h = "spatial_h"
    spatial_w = "spatial_w"


DEFAULT_CONCAT_AXIS = ConcatAxis.features
CONCAT_DIM_TO_AXIS = {1: ConcatAxis.spatial_h, 2: ConcatAxis.spatial_w, 3: ConcatAxis.features}
CONCAT_AXIS_TO_DIM = {ConcatAxis.spatial_h: 1, ConcatAxis.spatial_w: 2, ConcatAxis.features: 3}


@auto_generate_schema
class FormatConversionType(str, Enum):
    """Enum-like class for type conversion types"""

    mipi_bayer_rggb_to_hailo_rgb = "mipi_bayer_rggb_to_hailo_rgb"
    mipi_bayer_bggr_to_hailo_rgb = "mipi_bayer_bggr_to_hailo_rgb"
    mipi_bayer_grbg_to_hailo_rgb = "mipi_bayer_grbg_to_hailo_rgb"
    mipi_bayer_gbrg_to_hailo_rgb = "mipi_bayer_gbrg_to_hailo_rgb"
    twelve_to_eight_bit = "twelve_to_eight_bit"
    twelve_to_sixteen_bit = "twelve_to_sixteen_bit"
    sixteen_to_twelve_bit = "sixteen_to_twelve_bit"
    features_to_width_features = "features_to_width_features"
    flat_to_frames = "flat_to_frames"
    frames_to_flat = "frames_to_flat"
    transpose_width_features = "transpose_width_features"
    transpose_matmul = "transpose_matmul"
    spatial_expand = "spatial_expand"  # backwards compatibility
    spatial_flatten = "spatial_flatten"  # backwards compatibility
    spatial_reshape = "spatial_reshape"
    tf_rgb_to_hailo_rgb = "tf_rgb_to_hailo_rgb"
    tf_rgbx_to_hailo_rgb = "tf_rgbx_to_hailo_rgb"
    mipi_rgb888_to_hailo_rgb = "mipi_rgb888_to_hailo_rgb"
    hailo_rgb_to_tf_rgb = "hailo_rgb_to_tf_rgb"
    hailo_rgb_to_ppu = "hailo_rgb_to_ppu"
    ppu_to_hailo_rgb = "ppu_to_hailo_rgb"
    hailo_rgb_to_f8cr = "hailo_rgb_to_f8cr"
    f8cr_to_hailo_rgb = "f8cr_to_hailo_rgb"
    yuy2_to_hailo_yuv = "yuy2_to_hailo_yuv"
    hxf_to_w_transposed = "hxf_to_w_transposed"
    f_to_hxw_transposed = "f_to_hxw_transposed"
    fcr_to_c8fr = "fcr_to_c8fr"
    f8cr_to_fcr = "f8cr_to_fcr"
    c8fr_to_frames = "c8fr_to_frames"
    reshape_1xw0_to_hxw = "reshape_1xw0_to_hxw"
    nv12_to_hailo_yuv = "nv12_to_hailo_yuv"
    nv21_to_hailo_yuv = "nv21_to_hailo_yuv"
    hailo_rgb_to_lcu = "hailo_rgb_to_lcu"
    i420_to_hailo_yuv = "i420_to_hailo_yuv"
    transpose_height_width = "transpose_height_width"
    general_reshape = "general_reshape"
    reshape_height_features = "reshape_height_features"
    reshape_post_ew_mult = "reshape_post_ew_mult"
    rotation = "rotation"
    split_windowed_attention = "split_windowed_attention"  # [b, h * w, c] -> [1, number_of_windows, window_size, c], splitted to several layers in fuser
    merge_windowed_attention = "merge_windowed_attention"  # [1, number_of_windows, window_size, c] -> [b, sqrt(number_of_windows*window_size), sqrt(number_of_windows*window_size), c], splitted to several layers in fuser
    groups_to_spatial_flatten = (
        "groups_to_spatial_flatten"  # [B, H, W, C * G] -> [B, H * W * G, C], splitted to several layers in fuser
    )
    spatial_flatten_to_groups = (
        "spatial_flatten_to_groups"  # [B, H * W * G, C] -> [B, H, W, C * G], splitted to several layers in fuser
    )
    partial_groups_to_spatial_flatten = "partial_groups_to_spatial_flatten"  # [B, H, W, C * G1] -> [B, H * W * G1, C * G2], splitted to several layers in fuser
    mask = "mask"  # Special format conversion to help with LLM Softmax mask generation.
    cos = "cos"  # Special format conversion to help with LLM RoPE coefficients generation.
    sin = "sin"  # Special format conversion to help with LLM RoPE coefficients generation.
    embedding = "embedding"  # Special format convertion to help with LLM inference.


@auto_generate_schema
class ColorConversionType(str, Enum):
    """Enum-like class for color conversion types"""

    yuv_to_rgb = "yuv_to_rgb"
    yuv_full_range_to_rgb = "yuv_full_range_to_rgb"
    yuv601_to_rgb = "yuv601_to_rgb"
    yuv709_to_rgb = "yuv709_to_rgb"
    bgr_to_rgb = "bgr_to_rgb"
    yuv_to_bgr = "yuv_to_bgr"
    yuv_full_range_to_bgr = "yuv_full_range_to_bgr"
    yuv601_to_bgr = "yuv601_to_bgr"
    yuv709_to_bgr = "yuv709_to_bgr"
    rgb_to_bgr = "rgb_to_bgr"


ColorConversionTypes = {x.value: x for x in ColorConversionType}


@auto_generate_schema
class HybridConversionType(str, Enum):
    """Enum-like class for hybrid conversion types, composed of color conversion and format conversion type"""

    yuy2_to_rgb = "yuy2_to_rgb"
    nv12_to_rgb = "nv12_to_rgb"
    nv21_to_rgb = "nv21_to_rgb"
    i420_to_rgb = "i420_to_rgb"


HybridConversionTypes = {x.value: x for x in HybridConversionType}


@auto_generate_schema
class EWMultType(str, Enum):
    on_apu = "on_apu"
    on_mac = "on_mac"


InputConversions = [
    ColorConversionType.yuv_to_rgb,
    ColorConversionType.yuv_full_range_to_rgb,
    ColorConversionType.yuv601_to_rgb,
    ColorConversionType.yuv709_to_rgb,
    ColorConversionType.bgr_to_rgb,
    ColorConversionType.yuv_to_bgr,
    ColorConversionType.yuv_full_range_to_bgr,
    ColorConversionType.yuv601_to_bgr,
    ColorConversionType.yuv709_to_bgr,
    ColorConversionType.rgb_to_bgr,
    FormatConversionType.tf_rgb_to_hailo_rgb,
    FormatConversionType.tf_rgbx_to_hailo_rgb,
    FormatConversionType.yuy2_to_hailo_yuv,
    FormatConversionType.nv12_to_hailo_yuv,
    FormatConversionType.nv21_to_hailo_yuv,
    FormatConversionType.i420_to_hailo_yuv,
    HybridConversionType.yuy2_to_rgb,
    HybridConversionType.nv12_to_rgb,
    HybridConversionType.nv21_to_rgb,
    HybridConversionType.i420_to_rgb,
]

InputConversionTypes = {x.value: x for x in InputConversions}

EmulationSupportedConversions = [
    ColorConversionType.yuv_to_rgb,
    ColorConversionType.yuv_full_range_to_rgb,
    ColorConversionType.yuv601_to_rgb,
    ColorConversionType.yuv709_to_rgb,
    ColorConversionType.bgr_to_rgb,
    ColorConversionType.yuv_to_bgr,
    ColorConversionType.yuv_full_range_to_bgr,
    ColorConversionType.yuv601_to_bgr,
    ColorConversionType.yuv709_to_bgr,
    ColorConversionType.rgb_to_bgr,
    FormatConversionType.yuy2_to_hailo_yuv,
    FormatConversionType.tf_rgbx_to_hailo_rgb,
    FormatConversionType.nv12_to_hailo_yuv,
    HybridConversionType.yuy2_to_rgb,
    HybridConversionType.nv12_to_rgb,
    FormatConversionType.nv21_to_hailo_yuv,
    FormatConversionType.i420_to_hailo_yuv,
]


DEFAULT_RANK2_SLICE = [0, 1, 1]
DEFAULT_X_POINTS_MAX_VALUE = 1e8
DEFAULT_ACCUMULATOR_SIZE = 16
SHIFT_CALCULATE_BUFFER = 0.297
APU_MANTISSA_BITS = 10
APU_EXP_BITS = 4

DUMMY_EXPONENT = 10
MAX_NUM_REPEATS = 32
MAX_NUM_REPEATS_ELTWISE = 4
DEFAULT_NULL_CHANNELS_CUTOFF_FACTOR = 1e-4
MAX_ACTIVATION_VALUE = 255
RECOMMENDED_DATASET_SIZE = 1024
RECOMMENDED_CALIBSET_SIZE_FOR_BN_CHECKER = 16
MINIMUM_PARAMS_FOR_COMPRESSION = 20e6
DEFAULT_ZERO_STATIC_CHANNELS_EPSILON = 1e-7
# Calc the maximum throughput for enabling 16bit output:
#   (1) PCIe Gen 3 single lane (1GB/s)
#   (2) utilization=0.8
#   (3) fps=300
MAXIMUM_THROUGHPUT_FOR_16BIT_OUTPUT = 1e9 * 0.8 / 300
MAXIMUM_DESCRIPTOR_SIZE_IN_HAILORT = 2e6
# list taken from precision_mode section of the hailo dfc user guide
SUPPORTED_LAYERS_IN_A16_W16 = {
    LayerType.CONV.value,
    LayerType.DENSE.value,
    LayerType.DW.value,
    LayerType.ELEMENTWISE_ADD.value,
}

# for 16-->>8  mode
APU_EXP_BIAS_BITS = 7
APU_CLIP_BITS_1 = 13
APU_OFFSET_BITS = 13
APU_CLIP_BITS_2 = 11

# for all the rest modes:
# 16-->>16
# 32-->>16
# 32-->> 8
APU_EXP_BIAS_BITS_D = 11
APU_CLIP_BITS_1_D = 20
APU_OFFSET_BITS_D = 19
APU_CLIP_BITS_2_D = 18

# for 32-->>8 mode
APU_FINAL_SHIFT = 10

# for all the rest modes:
# 16-->>16
# 32-->>16
# 16-->>8
APU_FINAL_SHIFT_D = 3

ACTIVATION_CLIP_BITS_HAILO_LAYER_NORM = 30  # int

RECOMMENDED_COMP_LEVEL = 4
RECOMMENDED_OPTM_LEVEL = 4


class LossType(CaseInsensitiveEnum):
    """Loss types."""

    CROSS_ENTROPY = "ce"
    L2 = "l2"
    L2REL = "l2rel"
    COSINE = "cosine"
    L2REL_CHW = "l2rel_channelwise_weighted"

    @classmethod
    def internals(cls):
        return {cls.L2REL_CHW}


class StatsState(AccelerasEnum):
    BLANK = "stats_blank"
    RUNNING = "stats_collecting"
    COMPLETE = "stats_complete"


class EmulationType(AccelerasEnum):
    REGULAR = "regular"
    DOUBLE = "double"


class StrideAlignType(AccelerasEnum):
    NW = "NW"
    SE = "SE"


# General Algorithm definitions ###################################################################


class EncodingMatchType(AccelerasEnum):
    DETECTION_HEAD_MATCH = "detection_head_match"
    SCALE_MATCH = "scale_match"
    ZERO_POINT_MATCH = "zero_point_match"
    NO_MATCH = "no_match"


class IOVectorPolicy(AccelerasEnum):
    SCALAR_INPUT_SCALE = "scalar_input_scale"
    SCALAR_INPUT_ZP = "scalar_input_zp"


class ModelOptimizationCommand(Enum):
    pre_quantization_optimization = "pre_quantization_optimization"
    compression_params = "compression_params"
    quantization_param = "quantization_param"
    model_optimization_config = "model_optimization_config"
    post_quantization_optimization = "post_quantization_optimization"
    model_optimization_flavor = "model_optimization_flavor"


class LayerSupportStatus(Enum):
    supported = "supported"
    unsupported = "unsupported"
    unexpected = "unexpected"


class QuantizationAlgorithms(Enum):
    equalization = "equalization"
    params_sorter = "params_sorter"
    dead_channels_removal = "dead_channels_removal"
    ibc = "ibc"
    quarot = "quarot"


class AccelerasPolicy(Enum):
    enabled = "enabled"
    allowed = "allowed"
    disabled = "disabled"


# Equiv set definitions ###########################################################################


class LayerHandlerType(Enum):
    transparent = "transparent"
    consumer = "consumer"
    multi_source = "multi_source"
    cc_aggregator = "cc_aggregator"
    unsupported = "unsupported"
    featurewise = "featurewise"
    ew_bouncer = "ew_bouncer"
    activation = "activation"
    skip = "skip"
    unexpected = "unexpected"
    output = "output"
    undefined = "undefined"
    matmul = "matmul"
    matmul_transpose = "matmul_transpose"


class LayerEquivType(Enum):
    consumer = "consumer"
    producer = "producer"


# Quantization params definitions #################################################################


class QuantizationParamField(Enum):
    max_elementwise_feed_repeat = "max_elementwise_feed_repeat"
    null_channels_cutoff_factor = "null_channels_cutoff_factor"
    max_bias_feed_repeat = "max_bias_feed_repeat"
    activation_fit = "activation_fit"
    quantization_groups = "quantization_groups"
    precision_mode = "precision_mode"
    bias_mode = "bias_mode"
    ignore_hw_limitation_assertion = "ignore_hw_limitation_assertion"


class QuantizationDeprecatedParam(str, Enum):
    use_16bit_bias = "use_16bit_bias"
    use_4bit_weights = "use_4bit_weights"
    exponential_mode_4bit_weights = "exponential_mode_4bit_weights"


@auto_generate_schema
class ActivationFitPolicy(Enum):
    allowed = "allowed"
    enabled = "enabled"
    disabled = "disabled"


ActivationFitPolicys = {x.value: x for x in ActivationFitPolicy}


@auto_generate_schema
class IgnoreHwLimitationAssertionPolicy(Enum):
    allowed = "allowed"
    enabled = "enabled"
    disabled = "disabled"


@auto_generate_schema
class OutputMinMaxStrategy(Enum):
    sigmoid = "sigmoid"
    softmax = "softmax"
    default = "default"


OutputMinMaxStrategies = {x.value: x for x in OutputMinMaxStrategy}


@auto_generate_schema
class PrecisionMode(Enum):
    """Enum-like class for selecting the precision mode"""

    a8_w8 = "a8_w8"
    a8_w4 = "a8_w4"
    a16_w16 = "a16_w16"
    a16_w8 = "a16_w8"
    a16_w4 = "a16_w4"
    a8_w4_exp = "a8_w4_exp"
    a16_w16_non_zero = "a16_w16_non_zero"
    native = "native"
    a8_w8_a8 = "a8_w8_a8"
    a8_w8_a16 = "a8_w8_a16"
    a8_w4_a8 = "a8_w4_a8"
    a8_w4_a16 = "a8_w4_a16"
    a16_w16_a8 = "a16_w16_a8"
    a16_w16_a16 = "a16_w16_a16"
    a16_w8_a8 = "a16_w8_a8"
    a16_w8_a16 = "a16_w8_a16"
    a16_w4_a16 = "a16_w4_a16"
    a16_w4_a8 = "a16_w4_a8"

    def reduce(self) -> "PrecisionMode":
        if self.name.startswith("a8_w4"):
            res = PrecisionMode.a8_w4
        elif self.name.startswith("a8_w8"):
            res = PrecisionMode.a8_w8
        elif self.name.startswith("a16_w16"):
            res = PrecisionMode.a16_w16
        elif self.name.startswith("a16_w8"):
            res = PrecisionMode.a16_w8
        elif self.name.startswith("a16_w4"):
            res = PrecisionMode.a16_w4
        else:
            res = self
        return res

    def input_precision_mode(self) -> "PrecisionMode":
        input_precision = {
            PrecisionMode.a8_w8: PrecisionMode.a8_w8_a8,
            PrecisionMode.a8_w4: PrecisionMode.a8_w8_a8,
            PrecisionMode.a16_w16: PrecisionMode.a16_w16_a16,
            PrecisionMode.a8_w4_exp: PrecisionMode.a8_w8_a8,
            PrecisionMode.a16_w16_non_zero: PrecisionMode.a16_w16_a16,
            PrecisionMode.a8_w8_a8: PrecisionMode.a8_w8_a8,
            PrecisionMode.a8_w8_a16: PrecisionMode.a8_w8_a8,
            PrecisionMode.a8_w4_a8: PrecisionMode.a8_w8_a8,
            PrecisionMode.a8_w4_a16: PrecisionMode.a8_w8_a8,
            PrecisionMode.a16_w16_a8: PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w16_a16: PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w8_a8: PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w8_a16: PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w4_a8: PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w4_a16: PrecisionMode.a16_w16_a16,
        }

        return input_precision[self]

    def output_precision_mode(self) -> "PrecisionMode":
        output_precision = {
            PrecisionMode.a8_w8: PrecisionMode.a8_w8_a8,
            PrecisionMode.a8_w4: PrecisionMode.a8_w8_a8,
            PrecisionMode.a16_w16: PrecisionMode.a16_w16_a16,
            PrecisionMode.a8_w4_exp: PrecisionMode.a8_w8_a8,
            PrecisionMode.a16_w16_non_zero: PrecisionMode.a16_w16_a16,
            PrecisionMode.a8_w8_a8: PrecisionMode.a8_w8_a8,
            PrecisionMode.a8_w8_a16: PrecisionMode.a16_w16_a16,
            PrecisionMode.a8_w4_a8: PrecisionMode.a8_w8_a8,
            PrecisionMode.a8_w4_a16: PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w16_a8: PrecisionMode.a8_w8_a8,
            PrecisionMode.a16_w16_a16: PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w8_a8: PrecisionMode.a8_w8_a8,
            PrecisionMode.a16_w8_a16: PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w4_a8: PrecisionMode.a8_w8_a8,
            PrecisionMode.a16_w4_a16: PrecisionMode.a16_w16_a16,
        }

        return output_precision[self]

    def has_output_bits(self) -> bool:
        precision = self.value.split("_")
        return len(precision) == 3 and precision[2].startswith("a")

    def output_bits(self) -> int:
        if self.value.endswith("a16"):
            res = 16
        elif self.value.endswith("a8"):
            res = 8
        else:
            raise ValueError("PrecisionMode doesn't specify output_bits, operation not premitted")
        return res

    def input_bits(self) -> int:
        if self.value.startswith("a16"):
            res = 16
        elif self.value.startswith("a8"):
            res = 8
        else:
            raise ValueError("PrecisionMode doesn't specify input_bits, operation not premitted")
        return res

    def weight_bits(self) -> int:
        weights = self.value.split("_")[1]
        try:
            res = int(weights.lstrip("w"))
        except ValueError:
            raise ValueError("PrecisionMode doesn't specify weight_bits, operation not premitted")
        return res


@dataclass
class SplittedPrecisionMode:
    """Splitted precision mode is a wrapper for PrecisionMode that doesn't enforce supported precision. Used when iterative changes of the precision is required"""

    input: int
    weights: int
    output: int

    @classmethod
    def from_precision_mode(cls, precision_mode: PrecisionMode):
        try:
            precision = precision_mode.value.split("_")
            if len(precision) != 3:
                raise ValueError("PrecisionMode doesn't specify all 3 bits, operation not premitted")
            in_bits = int(precision[0].lstrip("a"))
            weiight_bits = int(precision[1].lstrip("w"))
            out_bits = int(precision[2].lstrip("a"))
            return cls(in_bits, weiight_bits, out_bits)
        except ValueError:
            raise ValueError("PrecisionMode doesn't specify all 3 bits, operation not premitted")

    def to_precision_mode(self) -> PrecisionMode:
        try:
            return PrecisionMode(f"a{self.input}_w{self.weights}_a{self.output}")
        except ValueError:
            raise ValueError("PrecisionMode not supported")


PrecisionModes = {x.value: x for x in PrecisionMode}
ExplicitPrecisionModes = [
    PrecisionMode.a8_w8_a8,
    PrecisionMode.a8_w8_a16,
    PrecisionMode.a8_w4_a8,
    PrecisionMode.a8_w4_a16,
    PrecisionMode.a16_w16_a8,
    PrecisionMode.a16_w16_a16,
    PrecisionMode.a16_w8_a8,
    PrecisionMode.a16_w8_a16,
    PrecisionMode.a16_w4_a8,
    PrecisionMode.a16_w4_a16,
]


def get_input_output_precision_mode(precision_mode):
    if precision_mode.value.startswith("a8_"):
        input_precision_mode = PrecisionMode.a8_w8_a8
    else:
        input_precision_mode = PrecisionMode.a16_w16_a16

    if precision_mode in [
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w4,
        PrecisionMode.a8_w4_exp,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w4_a8,
        PrecisionMode.a16_w16_a8,
        PrecisionMode.a16_w8_a8,
        PrecisionMode.a16_w4_a8,
    ]:
        output_precision_mode = PrecisionMode.a8_w8_a8
    else:
        output_precision_mode = PrecisionMode.a16_w16_a16

    return input_precision_mode, output_precision_mode


@auto_generate_schema
class BiasMode(Enum):
    """Enum-like class for the selecting the bias width"""

    single_scale_decomposition = "single_scale_decomposition"
    double_scale_initialization = "double_scale_initialization"
    double_scale_decomposition = "double_scale_decomposition"


BiasModes = {x.value: x for x in BiasMode}


@auto_generate_schema
class Use16bitBiasPolicy(Enum):
    """Enum-like class for the different modes for 16bit bias representation"""

    disabled = "disabled"
    enabled = "enabled"
    decomposition = "decomposition"


Use16bitBiasPolicies = {x.value: x for x in Use16bitBiasPolicy}
Use16bitBiasPolicyToBiasMode = dict(zip(Use16bitBiasPolicy, BiasMode))
LAYERS_KEY = "layers"

# Full-precision optimization definitions #########################################################


EquivClassification = namedtuple("EquivClassification", ["handler_type", "is_source"])


class PreQuantizationFeature(Enum):
    dead_channels_removal = "dead_channels_removal"
    zero_static_channels = "zero_static_channels"
    se_optimization = "se_optimization"
    equalization = "equalization"
    dead_layers_removal = "dead_layers_removal"
    weights_clipping = "weights_clipping"
    activation_clipping = "activation_clipping"
    ew_add_fusing = "ew_add_fusing"
    layer_decomposition = "layer_decomposition"
    smart_softmax_stats = "smart_softmax_stats"
    defuse = "defuse"
    resolution_reduction = "resolution_reduction"
    global_avgpool_reduction = "global_avgpool_reduction"
    add_shortcut_layer = "add_shortcut_layer"
    layer_norm_decomposition = "layer_norm_decomposition"
    matmul_correction = "matmul_correction"
    matmul_equalization = "matmul_equalization"
    matmul_decomposition = "matmul_decomposition"
    switch_concat_with_add = "switch_concat_with_add"
    split_ew_mult_by_bit_significance = "split_ew_mult_by_bit_significance"
    use_prequantized_weights = "use_prequantized_weights"
    conv_decomposition = "conv_decomposition"
    split_fused_activation = "split_fused_activation"
    quarot = "quarot"
    conv_a16_w4 = "conv_a16_w4"


class ResolutionReductionStage(Enum):
    apply = "apply"
    revert = "revert"


class ResolutionReductionInterpolationMode(Enum):
    disabled = "disabled"
    bilinear = "bilinear"


class LayerNormMode(Enum):
    nn_core = "nn_core"
    ppu = "ppu"
    auto = "auto"


class LayerNormDecompositionMode(str, Enum):
    split_precision = "split_precision"
    uniform_precision = "uniform_precision"
    auto = "auto"


class WeightsClippingMode(Enum):
    """Enum-like class for weights clipping mode in quantization"""

    disabled = "disabled"
    manual = "manual"
    percentile = "percentile"
    mmse = "mmse"  # minimum mean square error
    mmse_if4b = "mmse_if4b"


class CalcKernelMode(Enum):
    """Enum-like class for weights mode of kernel forrce scale quantization"""

    limvals = "limvals"
    kernel_vals = "kernel_vals"


@auto_generate_schema
class EqualizationPolicy(Enum):
    allowed = "allowed"
    enabled = "enabled"
    disabled = "disabled"


@auto_generate_schema
class EqualizationTargetPolicy(Enum):
    default = "default"
    activation = "activation"
    weights = "weights"


class EqualizationMode(Enum):
    kernel_equalization = "kernel_equalization"
    min_based = "min_based"
    noise_based = "noise_based"


class DeadLayersRemovalPolicy(Enum):
    allowed = "allowed"
    enabled = "enabled"
    disabled = "disabled"


class DeadChannelsRemovalPolicy(Enum):
    enabled = "enabled"
    disabled = "disabled"


class ActivationClippingMode(Enum):
    """Enum-like class for activation clipping mode in quantization"""

    disabled = "disabled"
    manual = "manual"
    percentile = "percentile"


class SEOptimizationMethod(Enum):
    tse = "tse"


class TiledSqueezeAndExciteMode(Enum):
    sequential = "sequential"
    custom = "custom"
    disabled = "disabled"


# Quantized optimization definitions ##############################################################
class PostQuantizationFeature(Enum):
    bias_correction = "bias_correction"
    train_encoding = "train_encoding"
    finetune = "finetune"
    adaround = "adaround"
    block_round_training = "block_round_training"
    mix_precision_search = "mix_precision_search"


class MOConfigCommand(Enum):
    compression_params = "compression_params"
    negative_exponent = "negative_exponent"
    global_config = "globals"
    calibration = "calibration"
    checker_cfg = "checker_cfg"
    precision_config = "precision_config"


class TogglePolicy(Enum):
    enabled = "enabled"
    disabled = "disabled"


class ThreeWayPolicy(Enum):
    allowed = "allowed"
    enabled = "enabled"
    disabled = "disabled"


# Adaround
class FeaturePolicy(Enum):
    enabled = "enabled"
    disabled = "disabled"


class LayerFeaturePolicy(Enum):
    allowed = "allowed"
    enabled = "enabled"
    disabled = "disabled"


class AdaRoundMode(Enum):
    train_4bit = "train_4bit"
    train_all = "train_all"


class CompressionTypes(Enum):
    none = "none"
    gzip = "gzip"
    zlib = "zlib"


class InfusibleEWAddType(Enum):
    """
    EW Add layer type in case it is not fusible.
    """

    conv = "conv"
    ew_add = "ew_add"


# Bias correction
@auto_generate_schema
class BiasCorrectionPolicy(Enum):
    allowed = "allowed"
    enabled = "enabled"
    disabled = "disabled"


# Fine Tune
class FinetunePolicy(Enum):
    enabled = "enabled"
    disabled = "disabled"


class PreFTClippingMethod(Enum):
    """Pre fine tune clipping methods."""

    SET_FACTOR = "factor"
    SET_PERCENTILE = "percentile"
    MMSE = "mmse"  # minimum mean square error
    MMSE_IF4B = "4b_mmse_8b_noclip"


DEFAULT_CLIP_METHOD = PreFTClippingMethod.MMSE_IF4B  #: Default method for pre-QFT clipping
DEFAULT_CLIP_PERCENTILE = 99.7  #: Default clipping percentile for 4-bit kernels
DEFAULT_CLIP_FACTOR = 0.25  #: Default clipping factor for 4-bit kernels
DEFAULT_EPOCHS = 4  #: Default epochs for fine tune training
DEFAULT_LEARNING_RATE = 2e-4  #: Default learning rate for fine tune training
DEFAULT_DATASET_SIZE = 1024  #: Default dataset size for fine tune training
DEFAULT_BATCH_SIZE = 8


class ScheduleType(Enum):
    """Types of scheduling, used mainly for learning rate."""

    COSINE_RESTARTS = "cosine_restarts"
    EXPONENTIAL = "exponential"
    CONSTANT = "constant"


class Optimizer(Enum):
    """Fine tune algorithm optimizers."""

    adam = "adam"
    sgd = "sgd"
    momentum = "momentum"
    rmsprop = "rmsprop"


class WarmupStrategy(Enum):
    CONSTANT = "constant"
    GRADUAL = "gradual"


class MetaArchType(Enum):
    yolo = "yolo"


class QFTWriterMode(Enum):
    disabled = "disabled"
    basic = "basic"  # loss & learning rate
    advanced = "advanced"  # basic + gradients l2norm
    expert = "expert"  # advanced + actual values of weights


###################################
##### MixPrecisionSearch #########
##################################


class SensitivitySearch(Enum):
    LINEAR = "linear"
    PARETO = "pareto"


class MultiOutputMetric(Enum):
    HARMONY = "harmony"


class ComprecisionMetric(Enum):
    MACS = "macs"
    BOPS = "bops"
    WEIGHS = "weighs"


# Compression definitions #########################################################################
class ParamsCompressionPolicy(Enum):
    """Enum-like class for the different modes for compressing params in the CAT"""

    allowed = "allowed"
    enabled = "enabled"
    disabled = "disabled"


@auto_generate_schema
class FeatureMultiplierType(Enum):
    """Enum-like class for feature multiplier working mode."""

    user_specified = "user_specified"
    square = "square"
    yolov5 = "yolov5"


class PostprocessTarget(str, Enum):
    NN_CORE = "nn_core"
    CPU = "cpu"
    AUTO = "auto"


class IOType(Enum):
    """Enum-like class for IO types"""

    STANDARD = "standard"
    CACHE = "cache"


class CacheOpMode(Enum):
    """Enum-like class for cache operation modes"""

    READ = "read"
    WRITE = "write"


class NMSOnCpuMetaArchitectures(Enum):
    """Network meta architectures to which on-cpu post-processing can be added."""

    YOLOV5 = "yolov5"
    YOLOX = "yolox"
    SSD = "ssd"
    YOLOV5_SEG = "yolov5_seg"
    YOLOV8 = "yolov8"
    DAMOYOLO = "damoyolo"


class PostprocessType(Enum):
    """Type of post-process layers which runs on cpu"""

    NMS = "nms"
    IOU = "iou"
    LOGITS = "logits"
    RESIZE = "resize"
    BBOX_DECODER = "bbox_decoder"


class ResizePostprocessLayerParams(Enum):
    RESIZE_SHAPE = "resize_shape"
    RESIZE_METHOD = "resize_method"
    PIXELS_MODE = "resize_bilinear_pixels_mode"


class NMSProperties(Enum):
    """Enum-like class for NMS post-process command arguments"""

    CONFIG_PATH = "config_path"
    META_ARCH = "meta_arch"
    ENGINE = "engine"
    ENFORCE_IOU_THRESHOLD = "enforce_iou_threshold"
    BBOX_DECODING_ONLY = "bbox_decoding_only"
    DFL_ON_NN_CORE = "dfl_on_nn_core"
    OUTPUT_ORIGINAL_NAME = "output_original_name"

    # The members below can be also configured via the json
    SCORES_TH = "nms_scores_th"
    IOU_TH = "nms_iou_th"
    IMAGE_DIMS = "image_dims"
    CLASSES = "classes"
    MAX_TOTAL_OUTPUT_PROPOSALS = "max_total_output_proposals"
    MAX_PROPOSALS_PER_CLASS = "max_proposals_per_class"
    MASK_THRESHOLD = "mask_threshold"
    REGRESSION_LENGTH = "regression_length"


# Proposals properties are not arguments of the nms postprocess command.
NMS_ARGUMENTS_ORDER = [
    x
    for x in NMSProperties
    if x not in [NMSProperties.MAX_TOTAL_OUTPUT_PROPOSALS, NMSProperties.MAX_PROPOSALS_PER_CLASS]
]

# two coordinates and a score
DEFAULT_BOX_AND_OBJ_PXLS = 5


class BBoxDecodersInfo(Enum):
    # general bbox info
    NAME = "name"
    CLS_LAYER = "cls_layer"
    # centernet bbox info
    REG_LAYER_H = "reg_layer_h"
    REG_LAYER_W = "reg_layer_w"
    # SSD and YOLOV5 bbox info
    H = "h"
    W = "w"
    # SSD bbox info
    REG_LAYER = "reg_layer"
    # YOLOV5 bbox info
    STRIDE = "stride"
    ENCODED_LAYER = "encoded_layer"
    # YOLOX bbox additional info
    OBJ_LAYER = "objectness_layer"
    # YOLOv8 combined layer of regression and class prediction
    COMBINED_LAYER = "combined_layer"


class ProtoInfo(Enum):
    # general bbox info
    NUMBER = "number"
    STRIDE = "stride"
    PROTO_LAYER = "proto_layer"


class IOUPostprocessLayerParams(Enum):
    IOU_TH = "iou_th"
    MAX_PROPOSALS_PER_CLASS = "max_proposals_per_class"
    CLASSES = "classes"
    NMS_SCORES_TH = "nms_scores_th"


class LogitsPostprocessLayerParams(Enum):
    TYPE = "logits_type"
    AXIS = "logits_axis"


class CalibrationDataType(Enum):
    """Types of data used for calibration during quantization."""

    #: ``numpy.ndarray`` or dict of ``numpy.ndarray``
    np_array = "np_array"

    #: Tensorflow 2.x batched dataset object
    dataset = "dataset"

    npy_file = "npy_file"

    npy_dir = "npy_dir"

    callable = "callable"

    #: Auto detect calibration data type
    auto = "auto"


class GPUAvailabilityMode(Enum):
    NOT_IN_USE = "not_in_use"
    IN_USE = "in_use"
    NOT_AVAILABLE = "not_available"


class GPUInfo(BaseModel):
    gpu_availability: GPUAvailabilityMode = Field(GPUAvailabilityMode.NOT_AVAILABLE)
    num_gpus: int = Field(0)


class DistributionStrategy(Enum):
    SINGLE = "single"
    DATA_P = "data_parallelization"
    AUTO = "auto"
    MODEL_P = "model_parallelization"


class OptimizationTarget(str, Enum):
    SAGE = "sage"
    MERCURY = "mercury"
    MARS = "mars"
    PLUTO = "pluto"
    EMULATION = "emulation"


class PreQuantizationDefuseType(CaseInsensitiveEnum):
    INPUT_FEATURES = "input_features"
    MHA = "mha"


DEFAULT_OPTIMIZATION_TARGET = OptimizationTarget.SAGE


class MatmulCorrectionType(AccelerasEnum):
    ZP_COMP_WEIGHTS = "zp_comp_weights"
    ZP_COMP_BLOCK = "zp_comp_block"
    ZP_COMP_BLOCK_2 = "zp_comp_block_2"
    ZP_COMP_BLOCK_3 = "zp_comp_block_3"
    ZP_COMP = "zp_comp"
    ZP_COMP_NONE = "zp_comp_none"


class SplitFusedActivationPolicy(Enum):
    enabled = "enabled"
    disabled = "disabled"
    allowed = "allowed"


class OrthoGenType(Enum):
    HADAMARD = "HADAMARD"
    RANDOM = "RANDOM"
    PARTIAL_RANDOM = "PARTIAL_RANDOM"


class FlowState(Enum):
    FULLY_NATIVE = "FULLY_NATIVE"
    NUMERIC = "NUMERIC"
    BIT_EXACT = "BIT_EXACT"


class SoftmaxBiasOptimizationAlgorithm(Enum):
    MSE = "MSE"
    ZERO_MEAN = "ZERO_MEAN"
    AC_DC = "AC_DC"


class AdapterType(AccelerasEnum):
    BASE = "base"
    LORA = "lora"


class TrackerStage(AccelerasEnum):
    FP_OPTIMIZE = "fp_optimize"
    QUANTIZE = "quantize"


class TrackerType(AccelerasEnum):
    FOLD_NORMALIZATION = "fold_normalization"
    MATRIX_MULTIPLICATION = "matrix_multiplication"
    GATHER = "gather"
    SPLIT = "split"
