from hailo_model_optimization.acceleras.hailo_layers.hailo_argmax import HailoArgMax
from hailo_model_optimization.acceleras.hailo_layers.hailo_avgpool_v2 import HailoAvgPool
from hailo_model_optimization.acceleras.hailo_layers.hailo_batch_norm import HailoBatchNorm
from hailo_model_optimization.acceleras.hailo_layers.hailo_batch_norm_add import HailoBatchNormAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_bbox_decoder import HailoBBoxDecoder
from hailo_model_optimization.acceleras.hailo_layers.hailo_concat import HailoConcat
from hailo_model_optimization.acceleras.hailo_layers.hailo_const import HailoConst
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv3d import HailoConv3D
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_a16_pre_act_sum import HailoConvA16PreActSum
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_a16_quant_weight_group import (
    HailoConvA16W8QuantWeightGroup,
)
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_add import HailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_decompose import HailoConvDecompose
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_decompose_pluto import HailoConvDecomposePluto
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_quant_weight_group import (
    HailoConvQuantWeightGroup,
)
from hailo_model_optimization.acceleras.hailo_layers.hailo_crosscorrelation_dw import HailoCrossCorrelationDW
from hailo_model_optimization.acceleras.hailo_layers.hailo_deconv import HailoDeconv
from hailo_model_optimization.acceleras.hailo_layers.hailo_dense import HailoDense
from hailo_model_optimization.acceleras.hailo_layers.hailo_depth_to_space import HailoDepthToSpace
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise_add import HailoDepthwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult import HailoElementwiseMult
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult_on_mac import HailoElementwiseMultOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_sub import HailoElementwiseSub
from hailo_model_optimization.acceleras.hailo_layers.hailo_external_pad import HailoExternalPad
from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_multiplier import HailoFeatureMultiplier
from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_multiplier_on_mac import HailoFeatureMultiplierOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_shuffle import HailoFeatureShuffle
from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_splitter import HailoFeatureSplitter
from hailo_model_optimization.acceleras.hailo_layers.hailo_format_conversion import HailoFormatConversion
from hailo_model_optimization.acceleras.hailo_layers.hailo_fused_bbox_decoder import HailoFusedBboxDecoder
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import (
    HailoCacheInputLayer,
    HailoCacheOutputLayer,
    HailoInputLayer,
    HailoOutputLayer,
)
from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_norm import HailoLayerNorm
from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_normalization import HailoLayerNormalization
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.hailo_layers.hailo_maxpool import HailoMaxPool
from hailo_model_optimization.acceleras.hailo_layers.hailo_nms import HailoNMS
from hailo_model_optimization.acceleras.hailo_layers.hailo_non_nn_core_output_layer import HailoNonNNCoreOutputLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_normalization import HailoNormalization
from hailo_model_optimization.acceleras.hailo_layers.hailo_normalization_add import HailoNormalizationAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_portal import HailoPortal
from hailo_model_optimization.acceleras.hailo_layers.hailo_postprocess import HailoPostprocess
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split import (
    HailoPrecisionSplit,
    HailoPrecisionSplitPixels,
)
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split_signed import HailoPrecisionSplitSigned
from hailo_model_optimization.acceleras.hailo_layers.hailo_proposal_generator import HailoProposalGenerator
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_max import HailoReduceMax
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_mean import HailoReduceMean
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum import HailoReduceSum
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum_a16_pre_act_sum import HailoReduceSumA16PreActSum
from hailo_model_optimization.acceleras.hailo_layers.hailo_resize_bilinear_mac import HailoResizeBilinearMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_resize_bilinear_ppu import HailoResizeBilinearPpu
from hailo_model_optimization.acceleras.hailo_layers.hailo_resize_nearest_neighbor import HailoResizeNearestNeighbor
from hailo_model_optimization.acceleras.hailo_layers.hailo_row_splitter import HailoRowSplitter
from hailo_model_optimization.acceleras.hailo_layers.hailo_shortcut import HailoShortcut
from hailo_model_optimization.acceleras.hailo_layers.hailo_slice import HailoSlice
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax import HailoSoftmax
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mars import HailoSoftmaxMars
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mask import HailoSoftmaxMask
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mask_on_mac import HailoSoftmaxMaskOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_space_to_depth import HailoSpaceToDepth
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.hailo_layers.hailo_width_splitter import HailoWidthSplitter
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EWMultType,
    IOType,
    LayerType,
    OptimizationTarget,
    PostprocessTarget,
    PrecisionMode,
    PrecisionSplitMode,
    ResizeBilinearPixelsMode,
    ResizeMethod,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasValueError,
)

_LAYER_BY_TYPE = {
    layer._hn_type: layer
    for layer in [
        HailoStandaloneActivation,
        HailoReduceSum,
        HailoConcat,
        HailoDense,
        HailoDeconv,
        HailoSoftmax,
        HailoMaxPool,
        HailoAvgPool,
        HailoDepthToSpace,
        HailoSpaceToDepth,
        HailoSlice,
        HailoBBoxDecoder,
        HailoFusedBboxDecoder,
        HailoProposalGenerator,
        HailoNMS,
        HailoFeatureSplitter,
        HailoRowSplitter,
        HailoWidthSplitter,
        HailoElementwiseAdd,
        HailoFeatureShuffle,
        HailoInputLayer,
        HailoFormatConversion,
        HailoMatmul,
        HailoReduceMax,
        HailoArgMax,
        HailoConst,
        HailoElementwiseSub,
        HailoShortcut,
        HailoPostprocess,
        HailoExternalPad,
        HailoReduceMean,
        HailoPortal,
        HailoPrecisionSplitSigned,
        HailoLayerNormalization,
    ]
}

_LAYER_BY_TYPE.update(
    {
        # (ew_mult_type)
        LayerType.FEATURE_MULTIPLIER: {
            EWMultType.on_apu: HailoFeatureMultiplier,
            EWMultType.on_mac: HailoFeatureMultiplierOnMac,
        },
        LayerType.SOFTMAX: {
            OptimizationTarget.SAGE: HailoSoftmax,
            OptimizationTarget.PLUTO: HailoSoftmax,
            OptimizationTarget.MARS: HailoSoftmaxMars,
            OptimizationTarget.MERCURY: HailoSoftmax,
        },
        # (is_softmax_mask, ew_mult_type)
        LayerType.ELEMENTWISE_MULT: {
            (False, EWMultType.on_apu): HailoElementwiseMult,
            (True, EWMultType.on_apu): HailoSoftmaxMask,
            (False, EWMultType.on_mac): HailoElementwiseMultOnMac,
            (True, EWMultType.on_mac): HailoSoftmaxMaskOnMac,
        },
        # (elwa, spatial)
        LayerType.NORMALIZATION: {True: HailoNormalizationAdd, False: HailoNormalization},
        # (optimization_target)
        LayerType.LAYER_NORM: {
            OptimizationTarget.MARS: HailoLayerNormalization,
            OptimizationTarget.SAGE: HailoLayerNormalization,
            OptimizationTarget.PLUTO: HailoLayerNorm,
            OptimizationTarget.MERCURY: HailoLayerNormalization,
        },
        LayerType.BATCH_NORM: {True: HailoBatchNormAdd, False: HailoBatchNorm},
        # (is_3d, decomp, elwa)
        LayerType.CONV: {
            (True, False, False): HailoConv3D,
            (False, True, False): {
                OptimizationTarget.SAGE: HailoConvDecompose,
                OptimizationTarget.PLUTO: HailoConvDecomposePluto,
                OptimizationTarget.MERCURY: HailoConvDecompose,
            },
            (False, False, True): HailoConvAdd,
            (False, False, False): HailoConv,
        },
        # (elwa, dynamic_weights)
        LayerType.DW: {
            (True, False): HailoDepthwiseAdd,
            (False, True): HailoCrossCorrelationDW,
            (False, False): HailoDepthwise,
        },
        # io_type
        LayerType.INPUT_LAYER: {
            IOType.STANDARD: HailoInputLayer,
            IOType.CACHE: HailoCacheInputLayer,
        },
        # (engine, io_type)
        LayerType.OUTPUT_LAYER: {
            (PostprocessTarget.NN_CORE, IOType.STANDARD): HailoOutputLayer,
            (PostprocessTarget.NN_CORE, IOType.CACHE): HailoCacheOutputLayer,
            (PostprocessTarget.CPU, IOType.STANDARD): HailoNonNNCoreOutputLayer,
        },
        # (method, bilinear pixels mode)
        LayerType.RESIZE: {
            (ResizeMethod.NEAREST_NEIGHBOR, None): HailoResizeNearestNeighbor,
            (ResizeMethod.BILINEAR, ResizeBilinearPixelsMode.HALF_PIXELS): HailoResizeBilinearMac,
            (ResizeMethod.BILINEAR, ResizeBilinearPixelsMode.DISABLED): HailoResizeBilinearMac,
            (ResizeMethod.BILINEAR, ResizeBilinearPixelsMode.ALIGN_CORNERS): HailoResizeBilinearPpu,
        },
        LayerType.PRECISION_SPLITTER: {
            PrecisionSplitMode.NORMAL: HailoPrecisionSplit,
            PrecisionSplitMode.PIXELS: HailoPrecisionSplitPixels,
        },
    },
)


def walk_dict(dict_to_walk):
    if isinstance(dict_to_walk, dict):
        for sub_value in dict_to_walk.values():
            yield from walk_dict(sub_value)
    else:
        yield dict_to_walk


def iterate_layer_classes():
    for value in _LAYER_BY_TYPE.values():
        yield from walk_dict(value)


def get_layer_type_from_hn_element(hn_element, optimization_target=None):
    if optimization_target is None:
        raise AccelerasValueError("optimization_target cannot be None")
    try:
        ltype = LayerType(hn_element["type"])
    except AccelerasValueError:
        raise AccelerasImplementationError(f"Layer type {hn_element['type']} is not yet supported..")

    params = hn_element.get("params", dict())
    quantization_params = hn_element.get("quantization_params", dict())
    precision_mode = (
        PrecisionMode(quantization_params.get("precision_mode")) if "precision_mode" in quantization_params else None
    )
    elwa = params.get("elementwise_add", False)
    if ltype not in _LAYER_BY_TYPE:
        raise ValueError(f"{ltype} doesn't exist in op factories")
    layer_class = _LAYER_BY_TYPE[ltype]
    if ltype == LayerType.SOFTMAX:
        layer_class = layer_class[optimization_target]
    elif ltype == LayerType.LAYER_NORM:
        has_groups = params.get("groups", 1) > 1
        # groups layer norm are only supported on MAC for now
        optimization_target = OptimizationTarget.SAGE if has_groups else optimization_target
        layer_class = layer_class[optimization_target]
    elif ltype in [LayerType.NORMALIZATION, LayerType.BATCH_NORM]:
        layer_class = layer_class[elwa]
    elif ltype == LayerType.REDUCE_SUM:
        if optimization_target in [OptimizationTarget.PLUTO, OptimizationTarget.MARS]:
            if precision_mode and precision_mode.input_bits() == 16 and precision_mode.weight_bits() != 16:
                layer_class = HailoReduceSumA16PreActSum
    elif ltype == LayerType.CONV:
        layer_class = get_conv_layer_class(
            params, quantization_params, layer_class, elwa, optimization_target, precision_mode
        )
    elif ltype == LayerType.DW:
        dynamic_weights = params.get("dynamic_weights", False)
        layer_class = layer_class[(elwa, dynamic_weights)]
    elif ltype == LayerType.INPUT_LAYER:
        io_type = IOType(hn_element.get("io_type", IOType.STANDARD.value))
        layer_class = layer_class[io_type]
    elif ltype == LayerType.OUTPUT_LAYER:
        io_type = IOType(hn_element.get("io_type", IOType.STANDARD.value))
        engine = PostprocessTarget(hn_element.get("engine", PostprocessTarget.NN_CORE.value))
        layer_class = layer_class[(engine, io_type)]
    elif ltype == LayerType.RESIZE:
        method = ResizeMethod(hn_element["params"]["method"])
        resize_bilinear_pixels_mode = hn_element["params"].get(
            "resize_bilinear_pixels_mode",
            ResizeBilinearPixelsMode.ALIGN_CORNERS,
        )
        resize_bilinear_pixels_mode = ResizeBilinearPixelsMode(resize_bilinear_pixels_mode)
        if method == ResizeMethod.NEAREST_NEIGHBOR:
            resize_bilinear_pixels_mode = None
        layer_class = layer_class[(method, resize_bilinear_pixels_mode)]
    elif ltype == LayerType.ELEMENTWISE_MULT:
        is_softmax_mask = params.get("is_softmax_mask", False)
        ew_mult_type = EWMultType(params.get("ew_mult_type", EWMultType.on_apu))
        layer_class = layer_class[(is_softmax_mask, ew_mult_type)]
    elif ltype == LayerType.FEATURE_MULTIPLIER:
        ew_mult_type = EWMultType(params.get("ew_mult_type", EWMultType.on_apu))
        layer_class = layer_class[ew_mult_type]
    elif ltype == LayerType.PRECISION_SPLITTER:
        precision_split_mode = PrecisionSplitMode(params.get("precision_split_mode", PrecisionSplitMode.NORMAL))
        layer_class = layer_class[precision_split_mode]
    return layer_class


def get_conv_layer_class(params, quantization_params, layer_class, elwa, optimization_target, precision_mode):
    """Get the appropriate layer class for CONV layer type."""
    is_quantization_weight_groups = quantization_params.get("quantization_weight_groups", -1) > 1
    is_3d_op = params.get("disparity", params.get("input_disparity", 1)) > 1
    decompose_weights = params.get("decompose_weights", False)
    layer_class = layer_class[(is_3d_op, decompose_weights, elwa)]
    if is_quantization_weight_groups:
        return get_conv_quant_weight_group_class(optimization_target, precision_mode)
    if should_use_a16_pre_act_sum(optimization_target, precision_mode):
        return HailoConvA16PreActSum
    if decompose_weights:
        return layer_class[optimization_target]
    return layer_class


def get_conv_quant_weight_group_class(optimization_target, precision_mode):
    """Get the appropriate layer class for HailoConvQuantWeightGroup layer type."""
    if optimization_target in [OptimizationTarget.PLUTO]:
        raise AccelerasImplementationError("Quantization weight groups is not supported for PLUTO")
    if precision_mode and precision_mode.input_bits() == 16:
        return HailoConvA16W8QuantWeightGroup
    return HailoConvQuantWeightGroup


def should_use_a16_pre_act_sum(optimization_target, precision_mode):
    """Check if A16 Pre Act Sum should be used."""
    return (
        optimization_target in [OptimizationTarget.PLUTO, OptimizationTarget.MARS]
        and precision_mode
        and precision_mode.input_bits() == 16
        and precision_mode.weight_bits() != 16
    )


def gen_acceleras_layers_from_hn(lname, hn_layer, optimization_target, logger=None):
    """
    Interpret a single HN layer (usually corresponding to a single LCU kernel),
        as a collection of (but mostly a single..) acceleras layer(s)

    Args:
        lname (str): layer name
        hn_layer (dict):  piece of HN describing the layer
        optimization_target (OptimizationTarget): target for optimization

    Raises:
        AccelerasImplementationError

    Returns:
        dict: new layers to create, indexed by name.
              NOTE: already splitting here the HN layer to few acceleras layers,
              in some cases, e.g. FeatureSplitter which turns into few Slicers)
              If indeed more than one, should be OrderedDict with same order as hn_layer['output_shapes']

    """
    op_class = get_layer_type_from_hn_element(hn_layer, optimization_target=optimization_target)
    returned_dict = op_class.from_hn(lname, hn_layer, logger=logger)
    if not isinstance(returned_dict, dict):
        returned_dict = {lname: returned_dict}
    return returned_dict


def load_precision_config(layer_name, new_acceleras_layers, layer_data, optimization_target):
    dummy_cfg = LayerPrecisionConfig()
    if layer_data.get("quantization_params", None) is not None:
        if layer_data["quantization_params"].get("precision_mode", None) is not None:
            quantization_params = layer_data["quantization_params"]
            cfg_dict = {key: quantization_params[key] for key in quantization_params.keys() if key in dummy_cfg.keys()}
            # need to ask how to get the defualt value of the config per layer
            precision_cfg = LayerPrecisionConfig(**cfg_dict)
            precision_cfg.fill_default_config(new_acceleras_layers[layer_name])
            new_acceleras_layers[layer_name].import_precision_config(precision_cfg, optimization_target)
    return new_acceleras_layers
