import os
from typing import List, Tuple

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.utils.logger import default_logger
from hailo_model_optimization.algorithms.dali_utils.dataset_util import (
    generator_output_signature,
    pad_data_generator,
    padding_info,
)

_first_usage = True


def mock_dali_train_dataset(
    x_cache_list: List[str],
    y_cache_list: List[str],
    batch_size: int,
    count: int,
    base_dir: str,
    shuffle: bool = False,
    seed: int = None,
) -> Tuple[tf.data.Dataset, tuple, tuple]:
    global _first_usage
    if _first_usage:
        command = [
            "pip",
            "install",
            "--extra-index-url",
            "https://developer.download.nvidia.com/compute/redist",
            "nvidia-dali-cuda110",
            "nvidia-dali-tf-plugin-cuda110",
        ]

        default_logger().warning(
            f"DALI is not installed, using tensorflow dataset for layer by layer train. "
            f"Using DALI will improve train time significantly. To install it use: {' '.join(command)}",
        )
        if shuffle:
            default_logger().warning(
                "Dataset isn't shuffled without DALI. "
                "To remove this warning add the following model script command: "
                "`post_quantization_optimization(adaround, shuffle=False)`",
            )
        _first_usage = False

    x_datast, x_unpadded_shape = cache_list_to_dataset(x_cache_list, count)
    y_datast, y_unpadded_shape = cache_list_to_dataset(y_cache_list, count)
    train_dataset = tf.data.Dataset.zip((x_datast, y_datast)).batch(batch_size)
    optimied_train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE).cache()
    return optimied_train_dataset.repeat(), x_unpadded_shape, y_unpadded_shape


def cache_list_to_dataset(cache_list, count):
    sample_data = [np.load(os.path.join(cache_dir, "0.npz"))["arr"] for cache_dir in cache_list]
    stacked_shape, padded_shape, unpadded_shape = padding_info(sample_data, stack_axis=0)
    output_signature = generator_output_signature(sample_data, stacked_shape)
    args = (cache_list, count, padded_shape)
    dataset = tf.data.Dataset.from_generator(
        pad_data_generator,
        output_signature=output_signature,
        args=args,
    )
    return dataset, unpadded_shape
