import os
import random
import sys
import tempfile
from dataclasses import dataclass
from typing import Iterator, List, Tuple

import numpy as np
import nvidia.dali.plugin.tf as dali_tf
import tensorflow as tf
from nvidia.dali import fn, pipeline_def

from hailo_model_optimization.algorithms.dali_utils.dataset_util import (
    generator_output_signature,
    pad_data_generator,
    padding_info,
)


@dataclass
class DALIDatasetInfo:
    """Data created during DALI data processing, required for Dataset initialization"""

    file_list: str  # txt file with individual npy files
    padded_shape: tuple  # shape of the padded DALI dataset
    dtype: tf.DType  # dtype of the DALI dataset


def write_dali_np_datasets(path: str, padded_iterator: Iterator, spec: tf.TensorSpec) -> DALIDatasetInfo:
    """
    Writes each data sample from the numpy array to an individual file

    DALI dataset supports read from npy if there's a file from each data sample,
    the files will be names `<n>.npy` and will be stored in the given path
    It creates a single txt file with all the created files

    Args:
        path: path in which the individual npy files will be stored
        data_np: numpy array that needs to be loaded as dataset layer

    Returns:
        DALIDatasetInfo for dataset initialization

    """
    file_list = os.path.join(path, "files.txt")
    with open(file_list, "w") as filep:
        for i, data_item in enumerate(padded_iterator):
            np.save(os.path.join(path, f"{i}.npy"), data_item)
            filep.write(f"{i}.npy\n")
    dataset_info = DALIDatasetInfo(
        file_list=file_list,
        padded_shape=(None, *spec.shape),
        dtype=spec.dtype,
    )
    return dataset_info


def process_tensors_cache(cache_list: List[str], count, base_dir) -> Tuple[DALIDatasetInfo, tuple]:
    """
    Processes caches list and converts them to data format that can be processed by DALI

    DALI has 2 flaws -
    1. the data readers expects a single npy for each data sample
    2. It doesn't support multiple input / outputs dataset for keras' fit function
    This function treats these flaws:
    1. It takes the files from multiple cache dirs and write an npy for each data
    2. It pads all the inputs and stacks them together (to a single tensor)
        so that DALI will be able to process multi input / output data

    Args:
        cache_list: list of cache directories (cache dir for each input)
        base_dir: directory in which the individual npy files will be written to

    Return:
        Tuple(DALIDatasetInfo, unpadded_shape of the data)

    """
    # TODO: move the padding logic inside write_dali_np_datasets (and apply it to each sample).
    #       This should allow better RAM efficiency (but a bit slower 'preprocess')
    sample_data = [np.load(os.path.join(cache_dir, "0.npz"))["arr"] for cache_dir in cache_list]
    stacked_shape, padded_shape, unpadded_shape = padding_info(sample_data, stack_axis=0)
    padded_np_iterator = pad_data_generator(cache_list, count, padded_shape)
    base_dir = tempfile.mkdtemp(dir=base_dir)
    spec = generator_output_signature(sample_data, stacked_shape)
    return write_dali_np_datasets(base_dir, padded_np_iterator, spec), unpadded_shape


@pipeline_def(num_threads=4)
def pipe_npy(file_lists: List[str], has_gpu: bool, shuffle=False, seed_id=None) -> tuple:
    """
    Creates a DALI pipeline from multiple file lists of npy files

    Args:
        file_lists: list of txt files with names of npy files

    Return:
        tuple with DALI's data node for each file_list

    """
    data_nodes = list()
    if seed_id is None:
        seed_id = random.randrange(sys.maxsize)
    for file_list in file_lists:
        data = fn.readers.numpy(device="cpu", file_list=file_list, seed=seed_id, random_shuffle=shuffle)
        if has_gpu:
            data = data.gpu()
        data_nodes.append(data)
    return tuple(data_nodes)


def dali_train_dataset(
    x_cache_list: List[str],
    y_cache_list: List[str],
    batch_size: int,
    count: int,
    base_dir: str,
    shuffle: bool = False,
    seed: int = None,
) -> Tuple[dali_tf.DALIDataset, tuple, tuple]:
    """
    Create a DALIDataset for keras' fit function from cache_lists for inputs and for labels

    The padding behavior is explained in `process_tensors_cache`,
    the dataset should be sliced with tf_unpad_input

    Args:
        x_cache_list: cache list for the fit's input
        y_cache_list: cache list for the fit's label / reference
        batch_size: batch size of the dataset
        base_dir: will be used to cache DALI related files

    Return:
        Tuple with (DALI_dataset, x's unpadded shape, y's unpadded shape)

    """
    x_dataset_info, x_unpadded_shape = process_tensors_cache(x_cache_list, count, base_dir)
    y_dataset_info, y_unpadded_shape = process_tensors_cache(y_cache_list, count, base_dir)
    has_gpu = len(tf.config.list_physical_devices("GPU")) != 0
    device_id = 0 if has_gpu else None
    pipeline = pipe_npy(
        [x_dataset_info.file_list, y_dataset_info.file_list],
        has_gpu=has_gpu,
        shuffle=shuffle,
        seed_id=seed,
        batch_size=batch_size,
        device_id=device_id,
    )
    train_dataset = dali_tf.DALIDataset(
        pipeline,
        output_dtypes=(x_dataset_info.dtype, y_dataset_info.dtype),
        output_shapes=(x_dataset_info.padded_shape, y_dataset_info.padded_shape),
        batch_size=batch_size,
        prefetch_queue_depth=1,
        device_id=device_id,
        num_threads=4,
    )
    return train_dataset, x_unpadded_shape, y_unpadded_shape
