import glob
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np
import tensorflow as tf
from tensorflow import TensorSpec
from tensorflow.python.eager.context import graph_mode
from tensorflow.python.framework.errors_impl import OutOfRangeError

from hailo_model_optimization.acceleras.utils.acceleras_definitions import CalibrationDataType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AcclerasDataError, DatasetException
from hailo_model_optimization.acceleras.utils.logger import default_logger


@dataclass
class DatasetContianer:
    data: Any = None
    data_type: CalibrationDataType = CalibrationDataType.auto


def get_dataset_length(dataset, threshold=None):
    """
    Iterates over the dataset and returns the length of the dataset

    Warning: This function is slow and inefficient. avoid using it if not needed

    Args:
        dataset: dataset to iter
        threshold: threshold to consider when iterating over
                    (length will always be less than or equal to threshold)

    """
    if threshold is not None:
        dataset = dataset.take(threshold)
    cardinality = dataset.cardinality()
    if (cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy():
        return dataset.reduce(0, lambda x, _: x + 1).numpy()
    else:
        return cardinality.numpy()


def verify_dataset_size(dataset, expected_size, warning_if_larger=False, logger=None):
    if warning_if_larger and logger is None:
        raise ValueError("Logger is required if 'warning_if_larger' is enabled")
    dataset_length = get_dataset_length(dataset, threshold=expected_size + int(warning_if_larger))
    if dataset_length < expected_size:
        err = f"expected dataset of size {expected_size} but got {dataset_length}"
        raise AcclerasDataError(f"Insufficient dataset, {err}")
    if dataset_length > expected_size:
        logger.warning(
            "Dataset is larger than expected size. Increasing the algorithm dataset size might improve the results",
        )


def _get_dataset_item_from_array(data, index):
    return data[index]


def _get_dataset_item_from_dict(data, index):
    return {k: v[index] for k, v in data.items()}


def get_dataset_batch_size(dataset):
    return dataset._batch_size.numpy()


def get_data_count_from_data_sample(data_batch):
    if isinstance(data_batch, (np.lib.npyio.NpzFile, dict)):
        arbitrary_item = next(iter(data_batch.values()))
    else:
        arbitrary_item = data_batch
    return arbitrary_item.shape[0]


def _add_batch_dim(np_data):
    if isinstance(np_data, (np.lib.npyio.NpzFile, dict)):
        new_data = {k: np.expand_dims(v, 0) for k, v in np_data.items()}
    else:
        new_data = np.expand_dims(np_data, 0)
    return new_data


def _get_signature_and_fetch_callback(np_data, force_dtype=None, batch=True):
    def force_dtype_map(dtype):
        return force_dtype or dtype

    if not batch:
        np_data = _add_batch_dim(np_data)

    if isinstance(np_data, (np.lib.npyio.NpzFile, dict)):
        signature = {k: TensorSpec(v.shape[1:], force_dtype_map(v.dtype)) for k, v in np_data.items()}
        fetch_callback = _get_dataset_item_from_dict
    else:
        signature = TensorSpec(np_data.shape[1:], force_dtype_map(np_data.dtype))
        fetch_callback = _get_dataset_item_from_array
    return signature, fetch_callback


def get_dataset_from_np(preprocessed_data, image_info_mock=True):
    """
    Builds tf.Iterator from preprocessed_data
    Args:
        preprocessed_data (`numpy.ndarray` or dict of `numpy.ndarray` ): a single numpy arrays
            preprocessed images with shape (calib_size,h, w, c).
            Or a dict in which each value is a single nummpy array and the keys are the input names.

    Returns
        tf.Iterator initialized with the preprocessed_data

    """
    data_signature, fetch_single_data = _get_signature_and_fetch_callback(preprocessed_data, np.float32)
    batch_size = get_data_count_from_data_sample(preprocessed_data)

    def generator():
        for i in range(batch_size):
            yield fetch_single_data(preprocessed_data, i)

    dataset = tf.data.Dataset.from_generator(generator, output_signature=data_signature)
    if image_info_mock:
        dataset = dataset.map(lambda x: (x, dict()))
    return dataset


def get_dataset_from_npy_file(npy_file, image_info_mock=True):
    np_data = np.load(npy_file, mmap_mode="r")
    data_signature, fetch_single_data = _get_signature_and_fetch_callback(np_data, np.float32)
    batch_size = get_data_count_from_data_sample(np_data)

    def generator_file():
        np_data = np.load(npy_file, mmap_mode="r")
        for i in range(batch_size):
            yield fetch_single_data(np_data, i)

    dataset = tf.data.Dataset.from_generator(generator_file, output_signature=data_signature)
    if image_info_mock:
        dataset = dataset.map(lambda x: (x, dict()))
    return dataset


def get_dataset_from_iterator(calib_data_callback):
    sess = tf.compat.v1.Session(graph=tf.Graph())

    with sess.as_default(), sess.graph.as_default(), graph_mode():
        sample_iterator = calib_data_callback()
        [sample_data, sample_image_info] = sample_iterator.get_next()
        sess.run([sample_iterator.initializer])
        sample_data, sample_image_info = sess.run([sample_data, sample_image_info])

    sess = tf.compat.v1.Session(graph=tf.Graph())
    data_signature, fetch_single_data = _get_signature_and_fetch_callback(sample_data, np.float32)
    image_info_signature, fetch_single_image_info = _get_signature_and_fetch_callback(sample_image_info)
    batch_size = get_data_count_from_data_sample(sample_data)

    def generator():
        with sess.as_default(), sess.graph.as_default(), graph_mode():
            iterator = calib_data_callback()
            [data_tensor, image_info_tensor] = iterator.get_next()
            sess.run([iterator.initializer])
            while True:
                try:
                    data, image_info = sess.run([data_tensor, image_info_tensor])
                    image_info = dict() if image_info is None else image_info
                except OutOfRangeError:
                    break
                for i in range(batch_size):
                    if i >= get_data_count_from_data_sample(data):
                        break
                    yield fetch_single_data(data, i), fetch_single_image_info(image_info, i)

    dataset = tf.data.Dataset.from_generator(generator, output_signature=(data_signature, image_info_signature))
    return dataset


def np_array_to_datafeed_callback(np_array, calib_num_batch=None, batch_size=None):
    if isinstance(np_array, dict):
        num_of_images = len(list(np_array.values())[0])
        if not np.all([num_of_images == len(np_array_data) for np_array_data in np_array.values()]):
            raise DatasetException("The number of images in the calibration set is not the same for all the inputs")
    else:
        num_of_images = len(np_array)

    if batch_size is not None and calib_num_batch is not None:
        if batch_size * calib_num_batch > num_of_images:
            raise DatasetException(
                "The number of images in the calibration set is smaller than batch_size * calib_num_batch.",
            )
    elif batch_size:
        if batch_size > num_of_images:
            raise DatasetException("The number of images in the calibration set is smaller than batch_size.")
        calib_num_batch = num_of_images // batch_size
    elif calib_num_batch:
        if calib_num_batch > num_of_images:
            raise DatasetException("The number of images in the calibration set is smaller than calib_num_batch.")
        batch_size = num_of_images // calib_num_batch
    else:
        batch_size = min(8, num_of_images)
        calib_num_batch = num_of_images // batch_size

    dataset = get_dataset_from_np(np_array)

    return dataset, calib_num_batch, batch_size


def npy_file_to_dataset(path):
    npy_view = np.load(path, mmap_mode="r")
    return get_dataset_from_np(npy_view)


def npy_dir_to_dataset(path_name):
    try:
        path_name = bytes(path_name)
    except TypeError:
        path_name = bytes(path_name, "utf-8")
    glob_path = os.path.join(path_name, b"*.np[yz]")
    sample_file = next(iter(glob.iglob(glob_path)))
    sample_np = np.load(sample_file, mmap_mode="r")
    signature, _ = _get_signature_and_fetch_callback(sample_np, np.float32, batch=False)

    def dir_generator(glob_path, npz=False):
        for image_path in glob.iglob(glob_path):
            image = np.load(image_path, mmap_mode="r+")
            # Dataset is expected to yield (image, image_info) whereas image_info is a dict with preprocessing info.
            # Since the images are already preprocessed, image_info doesn't exist and isn't needed.
            if npz:
                image = dict(image)
            yield image, dict()

    # TODO: replace the output_types with output_signatures, read 1 entry as a data sample
    dataset = tf.data.Dataset.from_generator(
        dir_generator,
        args=[glob_path, isinstance(signature, dict)],
        output_signature=(signature, {}),
    )
    return dataset


def data_to_dataset(data, data_type: CalibrationDataType, logger=None):
    image_count = None
    logger = default_logger() if logger is None else logger
    data_type = CalibrationDataType.auto if data_type is None else data_type

    data_type = CalibrationDataType(data_type)
    if data_type == CalibrationDataType.auto:
        if isinstance(data, (np.ndarray, dict)):
            data_type = CalibrationDataType.np_array
        elif isinstance(data, tf.data.Dataset):
            data_type = CalibrationDataType.dataset
        elif isinstance(data, (bytes, str, Path)) and os.path.isfile(data):
            data_type = CalibrationDataType.npy_file
        elif isinstance(data, (bytes, str, Path)) and os.path.isdir(data):
            data_type = CalibrationDataType.npy_dir
        elif callable(data):
            data_type = CalibrationDataType.callable
        else:
            raise ValueError("Couldn't detect CalibrationDataType")
        logger.verbose(f"Using data type {data_type.value} for calibration")

    if data_type == CalibrationDataType.np_array:
        dataset = get_dataset_from_np(data)
        image_count = get_data_count_from_data_sample(data)
    elif data_type == CalibrationDataType.dataset:
        dataset = data
        # TODO: validate dataset output signature
    elif data_type == CalibrationDataType.npy_file:
        npy_view = np.load(data, mmap_mode="r")
        dataset = get_dataset_from_np(npy_view)
        image_count = get_data_count_from_data_sample(npy_view)
    elif data_type == CalibrationDataType.npy_dir:
        dataset = npy_dir_to_dataset(data)
    elif data_type == CalibrationDataType.callable:
        dataset = data()
        if not isinstance(dataset, tf.data.Dataset):
            raise ValueError("Callback type returned non-dataset value")
        # TODO: validate dataset output signature
    else:
        raise ValueError(f"Unexpected CalibrationDataType {data_type.value}")
    if image_count is not None:
        dataset = dataset.apply(tf.data.experimental.assert_cardinality(image_count))
    return dataset, image_count


def rebuild_dataset_v2(dataset):
    """
    Iterate over DatasetV2 and return DatasetV1 object
    """

    def generator():
        for i, (data, image_info) in enumerate(dataset):
            yield data, image_info

    return tf.compat.v1.data.Dataset.from_generator(generator, output_signature=dataset.element_spec)
