import os
import shutil
import tempfile
from typing import Dict, List, Set

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.hailo_layers.base_acceleras_layer import BaseAccelerasLayer
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.algorithms.dali_utils.mock_dali_dataset import cache_list_to_dataset


def save_batch(
    acceleras_layer: BaseAccelerasLayer,
    batch,
    cache_dir: str,
    result_index: int,
    compress: bool,
) -> None:
    """
    Save single batch for all outputs of a single layer to the specified cache dir
    """
    for out_ind in range(acceleras_layer.num_outputs):
        batch_of_output = batch[out_ind]
        output_dir = _get_output_cache_path(cache_dir, out_ind)
        os.makedirs(output_dir, exist_ok=True)
        for i in range(batch_of_output.shape[0]):
            abs_index = result_index + i
            fname = os.path.join(output_dir, f"{abs_index}.npz")
            if compress:
                np.savez_compressed(fname, arr=batch_of_output[i])
            else:
                np.savez(fname, arr=batch_of_output[i])


def _get_output_cache_path(
    cache_dir: str,
    index: int,
) -> str:
    """
    Get folder name for output index
    """
    return os.path.join(cache_dir, f"data_{index}")


def get_layer_cache_dir(
    cache_base_dir: str,
    lname: str,
) -> str:
    """
    Get cache dir from layer name
    Args:
        layer: layer name

    Returns
        cache path

    """
    layer_no_slash = lname.replace("/", "___")
    cache_dir = tempfile.mkdtemp(dir=cache_base_dir, prefix=f"{layer_no_slash}_")
    return cache_dir


def dataset_from_cache(
    input_layers: List[str],
    cache_by_layer: Dict[str, str],
    count: int,
):
    """
    Create dataset for given input layers from given cache dictionary
    """
    cache_list = get_cache_list(input_layers, cache_by_layer)
    return cache_list_to_dataset(cache_list, count)


def get_cache_list(input_layers: List[str], cache_by_layer: Dict[str, str]):
    """
    Get the cache directories of the given input layers
    # TODO: add support for multiple outputs
    """
    cache_list = []
    for layer in input_layers:
        cache_list.append(_get_output_cache_path(cache_by_layer[layer], 0))
    return cache_list


def dataset_to_cache(
    dataset: tf.data.Dataset,
    model_layers: Dict[str, BaseAccelerasLayer],
    cache_dir: str,
    batch_size: int,
    compress: bool,
) -> Dict[str, str]:
    """
    Save given dataset (with a dictionary element spec) to a cache directory
    Returns:
        dictionary, layer name as keys, cache directory as values.
    """
    dataset = dataset.batch(batch_size)
    cache_dir_by_layer = dict()
    for layer_name in dataset.element_spec.keys():
        cache_dir_by_layer[layer_name] = get_layer_cache_dir(cache_dir, layer_name)
    for batch_index, data_item in enumerate(dataset):
        for lname, batch in data_item.items():
            save_batch(
                model_layers[lname],
                [batch],
                cache_dir_by_layer[lname],
                batch_index * batch_size,
                compress,
            )
    return cache_dir_by_layer


def clean_cache(
    blocks: Dict[str, ModelFlow],
    cache_by_layer: Dict[str, str],
    force_cached: List[str],
):
    """
    Delete the cached results once no longer needed
    """
    redundant_layers = _find_redundant_layers(blocks, cache_by_layer.keys())
    redundant_layers = redundant_layers - set(force_cached)
    for lname in redundant_layers:
        cache_dir = cache_by_layer.pop(lname)
        _delete_cache(cache_dir)


def _find_redundant_layers(blocks: List[ModelFlow], current_cached: Set[str]):
    """
    Find which cached layers are no longer required (based on the remaining blocks)
    # TODO: find redundent output
    """
    required_layers = set()
    for block in blocks:
        required_layers.update(set(block.input_nodes))

    redundant_layers = current_cached - required_layers
    return redundant_layers


def _delete_cache(dirname: str):
    """
    Delete the cache files of a single layer
    Args:
        filename: the directory of the given file will be removed
    """
    shutil.rmtree(dirname, ignore_errors=True)


def get_max_cache_size(model: HailoModel, blocks: List[ModelFlow], dali_cache=False) -> float:
    """
    Calculate the maximal disk usage of the current algorithm

    Args:
        factor: multiplier of stored data (usually for native & quant data)

    Returns: required storage in bytes

    """
    input_layers = model.flow.input_nodes
    size_by_layer = {}
    abs_peak = 0
    factor = 2
    for inp_layer in input_layers:
        size_by_layer[inp_layer] = np.prod(model.layers[inp_layer].output_shape[1:])
    while blocks:
        block = blocks.pop(0)
        block_outputs_size_by_layer = _find_block_outputs_size(model, block)
        block_inputs_size_by_layer = _find_block_inputs_size(model, block)
        current_size = _get_size(size_by_layer)
        current_outputs_size = _get_size(block_outputs_size_by_layer)
        current_inputs_size = _get_size(block_inputs_size_by_layer)
        unused_layers = _find_redundant_layers(blocks, size_by_layer.keys())
        unused_layers = {unused: size_by_layer[unused] for unused in unused_layers}
        freed_size = _get_size(unused_layers)
        for i in range(factor):
            existing_cache_size = current_size * factor
            new_cache_size = current_outputs_size * (i + 1)
            freed_cache_size = freed_size * (factor - i)
            sub_peak = existing_cache_size + new_cache_size - freed_cache_size
            if sub_peak > abs_peak:
                abs_peak = sub_peak
            if i == 0 and dali_cache:
                dali_cache_size = current_inputs_size + current_outputs_size
                sub_peak += dali_cache_size
                if sub_peak > abs_peak:
                    abs_peak = sub_peak
        for unused in unused_layers:
            size_by_layer.pop(unused)
        size_by_layer.update(block_outputs_size_by_layer)
    return abs_peak


def _get_size(cached: Dict[str, int]):
    """
    Get the size (in pixels) of the given cache
    """
    return sum(cached.values())


def _find_block_outputs_size(model: HailoModel, block: ModelFlow) -> Dict[str, int]:
    """
    Get the outputs size (in pixel) of the given block
    """
    new_layers_size = {}
    for output_node in block.output_nodes:
        real_output = block.predecessors_sorted(output_node)[0]
        out_ind = block.get_edge_output_index(real_output, output_node)
        out_ind = model.layers[real_output].resolve_output_index(out_ind)
        output_shape = model.layers[real_output].output_shapes[out_ind]
        new_layers_size[output_node] = np.prod(output_shape[1:])
    return new_layers_size


def _find_block_inputs_size(model: HailoModel, block: ModelFlow) -> Dict[str, int]:
    """
    Get the outputs size (in pixel) of the given block
    """
    input_layers_size = {}
    for input_node in block.input_nodes:
        real_input = block.successors_sorted(input_node)[0]
        inp_ind = block.get_edge_input_index(input_node, real_input)
        input_shape = model.layers[real_input].input_shapes[inp_ind]
        input_layers_size[input_node] = np.prod(input_shape[1:])
    return input_layers_size
