from enum import Enum
from functools import partial

import numpy as np
import pandas as pd
import tensorflow as tf


class AnalysisTarget(Enum):
    logits = "logits"
    activations = "activations"


class AnalysisMode(Enum):
    simple = "simple"
    advanced = "advanced"


class AnalysisType(Enum):
    layer_by_layer = "layer_by_layer"
    full_quant_net = "full_quant_net"


class ModelMode(Enum):
    native = "native"
    numeric = "numeric"


def save_results(sampled_tensors, noise_results, work_dir, logger):
    # Save noise results as csv files
    _noises_to_csv(noise_results, work_dir)

    for analyze_mode in sampled_tensors.keys():
        name = f"sampled_tensors_{analyze_mode}.npz"
        _save_data_to_npz(sampled_tensors[analyze_mode], work_dir, name=name, logger=logger)


def _save_data_to_npz(data, work_dir, name, logger):
    data_path = f"{work_dir}/{name}"
    hailo_np_savez(data_path, logger, **data)


def _noises_to_csv(results, work_dir):
    for analysis_type in results.keys():
        for analaysis_target in results[analysis_type].keys():
            if (
                analaysis_target == AnalysisTarget.logits.full_name
                and analysis_type == AnalysisType.layer_by_layer.full_name
            ):
                for output in results[analysis_type][analaysis_target]:
                    snr = results[analysis_type][analaysis_target][output].get()
                    data = {"layer": list(snr.keys()), "snr": [v[0] for v in snr.values()]}
                    noise_df = pd.DataFrame(data=data)
                    output_formated = output.replace("/", ".")
                    file_path = f"{work_dir}/{analaysis_target}_{analysis_type}_{output_formated}_noise_results.csv"
                    noise_df.to_csv(file_path, index=False)
            else:
                snr = results[analysis_type][analaysis_target].get()
                data = {"layer": list(snr.keys()), "snr": [v[0] for v in snr.values()]}
                noise_df = pd.DataFrame(data=data)
                file_path = f"{work_dir}/{analaysis_target}_{analysis_type}_noise_results.csv"
                noise_df.to_csv(file_path, index=False)


def _get_iterator(np_array, batch_size):
    """
    Builds tf.Iterator from numpy array

    Returns
        tf.Iterator initialized with the numpy array

    """

    def generator():
        for image in np_array:
            # Dataset is expected to yield (image, image_info) whereas image_info is a dict with preprocessing info.
            # Since the images are already preprocessed, image_info doesn't exist and isn't needed.
            yield image, None

    dataset = tf.data.Dataset.from_generator(generator, output_types=(tf.float32, tf.float32)).batch(batch_size)
    return dataset


def init_dataset(data_npy, batch_size=8):
    data_feed_cb = partial(_get_iterator, data_npy, batch_size)
    return data_feed_cb


def get_weights_noise(model):
    layer_names = [layer_name for layer_name in model.flow.toposort()]

    l2_dist_weights = dict()

    for lname in layer_names:
        layer = model.layers[lname]
        if not hasattr(layer, "conv_op"):
            continue  # Layer dosent have a kernel

        kernel_native = layer.conv_op.kernel.numpy()

        layer.enable_lossy()  # enable loosy for reciving the quantized kernel
        kernel_numeric = layer.conv_op.final_numeric_kernel.numpy()
        kernel_scales = layer.conv_op.kernel_scale.numpy()
        kernel_scales = kernel_scales / 2**layer.conv_op.weight_placement_shift
        rescaled_kernel = kernel_numeric * kernel_scales
        diff_weights = kernel_native - rescaled_kernel

        l2_dist_weights[lname] = 10 * np.log10(np.mean(diff_weights**2) / np.mean(kernel_native**2))

    return l2_dist_weights


def hailo_np_savez(file, logger, *args, **kwds):
    # TODO: SDK-10099
    if is_containing_none(args) or is_containing_none(list(kwds.values())):
        logger.warning("np.ndarray contains Nones, so Numpy saves it using pickle")
    np.savez(file, *args, **kwds)


def is_containing_none(item):
    """
    This method check if there is a None item within a sequence

    Args:
        item: Any object to test for None

    Return:
        True if any item contains None, otherwise False.

    """
    if isinstance(item, np.ndarray):
        if item.dtype == object:
            return is_containing_none(item.tolist())
        else:
            return False

    if isinstance(item, dict):
        # TODO: SDK-10099
        return is_containing_none(list(item.values()))

    if isinstance(item, (set, list, tuple)):
        for i in item:
            inner_check = is_containing_none(i)
            if inner_check:
                return True
        return False
    return item is None
