from collections import OrderedDict
from typing import Dict, Union

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_concat import HailoConcat
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_quant_weight_group import HailoConvQuantWeightGroup
from hailo_model_optimization.acceleras.hailo_layers.hailo_crosscorrelation_dw import HailoCrossCorrelationDW
from hailo_model_optimization.acceleras.hailo_layers.hailo_dense import HailoDense
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoInputLayer, HailoOutputLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.hailo_layers.hailo_maxpool import HailoMaxPool
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax import HailoSoftmax


def extract_stats_sdk_format(acceleras_model, layers_to_clip=None):
    """

    Sdk interface collect stats
    Args:
        acceleras_model:  acceleras_model:  assumed to have all the stats already collected
        layers_to_clip : if not None a list of layers to do clipping on
    Returns:

    """
    if layers_to_clip:
        return extract_stats_histogram(acceleras_model, layers_to_clip)
    else:
        return extract_basic_stats(acceleras_model)


def extract_basic_stats(acceleras_model):
    """

    Sdk interface collect stats
    Args:
        acceleras_model:  acceleras_model:  assumed to have all the stats already collected
    Returns:

    """
    results_dict = OrderedDict()
    for layer_name, acceleras_layer in acceleras_model.layers.items():
        if isinstance(
            acceleras_layer,
            (HailoSoftmax, HailoMaxPool, HailoConcat, HailoInputLayer, HailoOutputLayer, BaseHailoNonNNCoreLayer),
        ):
            # TODO - add dict of layers to not do statistics on
            continue
        stats = get_all_stats(acceleras_layer)

        for key_stats, value in stats.items():
            full_key_stats = f"{layer_name}/{key_stats}:0"
            results_dict[full_key_stats] = np.float32(value)
    return results_dict


def extract_stats_histogram(acceleras_model, clip_layers):
    """
    Sdk interface histogram stats
    Args:
        acceleras_model:  acceleras_model:  assumed to have all the stats already collected
        clip_layers :list of layers to do clipping on
    Returns:

    """
    results_dict = OrderedDict()

    for layer_name in clip_layers:
        acceleras_layer = acceleras_model.layers[layer_name]
        histogram = acceleras_layer.get_output_stats()[0].histogram
        results_dict[f"{layer_name}/activation_hist:0"] = histogram
    return results_dict


def _reshape_dense_input(acceleras_layer, input_energy):
    """
    when dense layer comes after conv layer we need to reshape the inputs to
    the way they came from the conv:
    input_energy of acceleras layer of shape: (256) -->> (2,2,64)
    """
    # TODO we now mimic the sdk -
    #  that for dense saves the energy - (2,2,64) instead of (64). We may change it in the future
    if isinstance(acceleras_layer, HailoDense):
        # TODO - this is a hack for mimic the sdk = when the layer before is
        #  (1,1,64) we dont want to do reshape and we want to keep it (64). We may change it in the future
        shape = acceleras_layer._build_input_shape[1:]
        if len(shape) > 1 and (not (shape[0] == 1 and shape[1] == 1)):
            input_energy = np.reshape(input_energy, shape)
    return input_energy


def _reshape_dense_kernel_for_equalization(acceleras_layer, kernel):
    """
    when dense layer comes after conv layer we need to reshape the kernel for equalization the whay
    the way it "came" from the conv:
    so if the kernel is now (1,1,256)  we will -->> (2,2,64)
    """
    shape = np.prod(acceleras_layer.input_shape[1:-1])
    if shape > 1:
        kernel = np.reshape(kernel, [1, shape, -1, kernel.shape[-1]])
    return kernel


def _reshape_ew_add_kernel_for_equalization(kernel, index):
    """
    ew_add layers kernel is of two vectors. When it is a consumer we will ant the shape to be:
    (2, 256)  we will -->> (1,1,256,1)
    """
    return np.reshape(kernel[index], [1, 1, -1, 1])


def _reshape_kernel_for_equalization(acceleras_layer, kernel, index):
    if isinstance(acceleras_layer, HailoDense):
        kernel = _reshape_dense_kernel_for_equalization(acceleras_layer, kernel)

    elif isinstance(acceleras_layer, HailoElementwiseAdd):
        kernel = _reshape_ew_add_kernel_for_equalization(kernel, index)

    return kernel


def get_all_stats(acceleras_layer) -> Dict[str, Union[float, np.array]]:
    """
    Args:
        acceleras_layer: (Union [CompositeOp, AtomicOp]):

    Returns: dict in the format of the sdk

    """
    results_dict = dict()

    input_stats = acceleras_layer.get_input_stats()
    pre_act_stats = acceleras_layer.get_preact_stats()[0]
    output_stats = acceleras_layer.get_output_stats()[0]

    # pre act stats
    results_dict["stats_min_pre_act_features"] = pre_act_stats.min
    results_dict["stats_max_pre_act_features"] = pre_act_stats.max
    results_dict["stats_min_pre_act"] = np.min(pre_act_stats.min)
    results_dict["stats_max_pre_act"] = np.max(pre_act_stats.max)
    results_dict["stats_pre_energy_features"] = pre_act_stats.energy

    # input stats
    if len(input_stats) >= 1:
        results_dict["stats_min_inp"] = np.min(input_stats[0].min)
        results_dict["stats_max_inp"] = np.max(input_stats[0].max)
        results_dict["stats_min_inp_features"] = input_stats[0].min
        results_dict["stats_max_inp_features"] = input_stats[0].max

        inputs_energy = _reshape_dense_input(acceleras_layer, input_stats[0].energy)
        results_dict["stats_input_energy_features"] = inputs_energy

    if len(input_stats) > 1:
        results_dict["stats_min_out_features_elwa"] = input_stats[1].min
        results_dict["stats_max_out_features_elwa"] = input_stats[1].max
        results_dict["stats_min_elwa"] = np.min(input_stats[1].min)
        results_dict["stats_max_elwa"] = np.max(input_stats[1].max)

    # output stats
    results_dict["stats_min_out"] = np.min(output_stats.min)
    results_dict["stats_max_out"] = np.max(output_stats.max)
    results_dict["stats_min_out_features"] = output_stats.min
    results_dict["stats_max_out_features"] = output_stats.max
    results_dict["stats_output_energy_features"] = output_stats.energy
    results_dict["stats_non_zero_percent_features"] = output_stats.non_zero_percent  # int he sdk we assume its an int.

    if isinstance(acceleras_layer, (HailoMatmul, HailoCrossCorrelationDW)):
        results_dict["stats_min_weights_in"] = np.min(input_stats[0].min)
        results_dict["stats_max_weights_in"] = np.max(input_stats[0].max)

    return results_dict


def get_input_layer_equalization_stats(acceleras_layer, is_consumer, index) -> Dict[str, Union[float, np.array]]:
    if not (isinstance(acceleras_layer, HailoInputLayer)) or is_consumer or index != 0:
        raise ValueError("get_input_layer_equalization_stats is only for input layer that is not consumer and index 0")

    stats_layer = dict()

    output_stats = acceleras_layer.get_output_stats()[0]
    stats_layer["kernel"] = np.ones((1, 1, 1, acceleras_layer.output_shape[-1])) / acceleras_layer.output_scale
    stats_layer["kernel_before"] = np.ones((1, 1, 1, acceleras_layer.output_shape[-1])) / acceleras_layer.output_scale

    stats_layer["full_layer_name"] = acceleras_layer.full_name
    stats_layer["pre_activation_min"] = output_stats.min
    stats_layer["pre_activation_max"] = output_stats.max

    stats_layer["post_activation_max"] = output_stats.max
    stats_layer["post_activation_min"] = output_stats.min
    stats_layer["input_energy"] = output_stats.energy
    stats_layer["output_energy"] = output_stats.energy

    stats_layer["non_zero_percent"] = output_stats.non_zero_percent
    axis_to_masl = (0, 1, 2)
    stats_layer["axes_to_max"] = [i in axis_to_masl for i in range(4)]
    stats_layer["number_bits"] = 8
    return stats_layer


def get_equalization_stats(acceleras_layer, is_consumer, index) -> Dict[str, Union[float, np.array]]:
    """
    Args:
        acceleras_layer: (Union [CompositeOp, AtomicOp]):
        is_consumer: whether the layer is consumer
    Returns: dict in the format of the sdk

    """
    if isinstance(acceleras_layer, HailoInputLayer):
        return get_input_layer_equalization_stats(acceleras_layer, is_consumer, index)
    stats_layer = dict()

    input_stats = acceleras_layer.get_input_stats()
    pre_act_stats = acceleras_layer.get_preact_stats()[0]
    output_stats = acceleras_layer.get_output_stats()[0]
    stats_layer["kernel"], stats_layer["kernel_before"] = get_kernel_for_equalization(
        acceleras_layer,
        is_consumer,
        index,
    )

    stats_layer["full_layer_name"] = acceleras_layer.full_name
    stats_layer["pre_activation_min"] = pre_act_stats.min
    stats_layer["pre_activation_max"] = pre_act_stats.max

    stats_layer["post_activation_max"] = output_stats.max
    stats_layer["post_activation_min"] = output_stats.min
    inputs_energy = _reshape_dense_input(acceleras_layer, input_stats[0].energy)
    stats_layer["input_energy"] = inputs_energy
    stats_layer["output_energy"] = output_stats.energy

    stats_layer["non_zero_percent"] = output_stats.non_zero_percent
    # if the layer is consumer we need to change the axis.
    if isinstance(acceleras_layer, HailoElementwiseAdd) and not is_consumer:
        axis_to_masl = (0, 1, 2)
    else:
        axis_to_masl = (0, 1, 3) if is_consumer else acceleras_layer.conv_op.axes2reduce
    stats_layer["axes_to_max"] = [i in axis_to_masl for i in range(4)]
    stats_layer["number_bits"] = acceleras_layer._get_kernel_bits()
    return stats_layer


def get_kernel_for_equalization(acceleras_layer, is_consumer, index):
    kernel_scales = acceleras_layer.get_kernel_scale_matrix_component().numpy()
    kernel = acceleras_layer.get_kernel().numpy() / kernel_scales
    if is_consumer:
        kernel = _reshape_kernel_for_equalization(acceleras_layer, kernel, index)
    elif isinstance(acceleras_layer, HailoElementwiseAdd):
        kernel = np.reshape(kernel, [1, 1, 2, -1])
    conv_groups = acceleras_layer.groups
    if isinstance(acceleras_layer, HailoConvQuantWeightGroup):
        kernel = acceleras_layer.revert_kernel_shape(kernel)
    if conv_groups == 1:
        # In the case of no groups
        return kernel, kernel

    input_indices, output_indices = acceleras_layer.conv_op.group_kernel_indices()
    in_channels = input_indices[-1]
    out_channels = output_indices[-1]
    # dense_kernel: this is a sort of an expanded kernel in which each group is copied to
    #   a patch along its original input and output indices.
    dense_kernel = np.zeros((kernel.shape[0], kernel.shape[1], in_channels, out_channels))
    for g in range(conv_groups):
        kernel_group = kernel[
            :,
            :,
            : (input_indices[g + 1] - input_indices[g]),
            output_indices[g] : output_indices[g + 1],
        ]
        dense_kernel[:, :, input_indices[g] : input_indices[g + 1], output_indices[g] : output_indices[g + 1]] = (
            kernel_group
        )

    return dense_kernel, kernel
