from functools import partial
from typing import Dict, List, Union

import tensorflow as tf

try:
    import nvidia.dali.fn as dali_fn
    import nvidia.dali.plugin.tf as dali_tf
    from nvidia.dali import pipeline_def, types

    DALI_INSTALLED = True
except ImportError:
    DALI_INSTALLED = False

from hailo_model_optimization.acceleras.utils.acceleras_definitions import FeaturePolicy
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasException
from hailo_model_optimization.acceleras.utils.logger import default_logger
from hailo_model_optimization.algorithms.block_by_block.cache_utils import get_layer_cache_dir


def has_gpu() -> bool:
    """
    Query for available GPU.
    """
    return len(tf.config.list_physical_devices("GPU")) > 0


class DataFeederTFRecord:
    dali_installed_command = [
        "pip",
        "install",
        "--extra-index-url",
        "https://developer.download.nvidia.com/compute/redist",
        "nvidia-dali-cuda110",
        "nvidia-dali-tf-plugin-cuda110",
    ]

    def __init__(
        self,
        device: str = "gpu",
        use_dali: Union[FeaturePolicy, bool] = True,
        device_id: int = 0,
        num_threads: int = 4,
        compression_type: str = "",
        logger=None,
    ):
        """
        This class handles the data feeding process for, and from, the block-by-block training.
            It uses tfrecords to save and load the data. The class can also use NVIDIA DALI for faster data loading.
            Note that NVIDIA DALI requires a GPU device to work properly. If a GPU is not available, the class will
            switch to the TF dataset backend on CPU and without DALI.

        Args:
            device (str, optional): 'cpu' or 'gpu'. Defaults to 'gpu'.
            use_dali (bool, optional): use Dali (if installed and there is a gpu). Defaults to True.
            device_id (int, optional): Defaults to 0.
            num_threads (int, optional): Number of CPU threads used by the pipeline (for Dali). Defaults to 4.
            compression_type (bool, optional): "GZIP", "ZLIB", or "" (no compression). Defaults to "".
            logger (Logger, optional): Defaults to None.

        """
        self.compression_type = compression_type
        self._filename_ext = ".tfrecords"
        self._logger = logger or default_logger()

        # NVIDIA DALI Parameters:
        self.device = device
        self.device_id = device_id
        self.num_threads = num_threads

        self.use_dali = use_dali
        if DALI_INSTALLED:
            self.create_dali_dataset = create_dali_dataset
        else:
            self.create_dali_dataset = partial(
                mock_create_dali_dataset,
                dali_installed_command=self.dali_installed_command,
            )

    @property
    def device(self) -> str:
        return self._device

    @device.setter
    def device(self, device: str = "gpu"):
        if "gpu" in device and not has_gpu():
            device = "/cpu:0"
            self._logger.warning(
                f"GPU was not found, switching to {device = }. Note that Dali will not work without a GPU.",
            )
        self._device = device

    @property
    def use_dali(self) -> bool:
        return self._use_dali

    @use_dali.setter
    def use_dali(self, use_dali: Union[FeaturePolicy, bool] = True):
        _use_dali = use_dali if isinstance(use_dali, bool) else use_dali == FeaturePolicy.enabled

        if _use_dali and not DALI_INSTALLED:
            self._logger.warning(
                "Trying to use DALI, but it is not installed. Set use_dali=False or install DALI using:"
                f"{' '.join(self.dali_installed_command)}",
            )
            _use_dali = False
        elif ("cpu" in self.device) and _use_dali:
            self._logger.warning(
                f"DALI must have a GPU device to work properly but {self.device = } was given. "
                "Switching to TF dataset backend without DALI.",
            )
            _use_dali = False
        elif (not has_gpu()) and _use_dali:
            self._logger.warning(
                f"DALI must have a GPU device to work properly but no GPU was detected ({has_gpu() = }). "
                "Switching to TF dataset backend without DALI.",
            )
            _use_dali = False

        self._use_dali = _use_dali

    def get_full_filename(self, cache_dir: str, ind: int = 0) -> str:
        return tf.io.gfile.join(cache_dir, f"{ind:04}_data" + self._filename_ext)

    def get_cache_list(self, layers, interlayer_results: Dict[str, str]) -> List[str]:
        cache_list = []
        for lname in layers:
            cache_list.append(interlayer_results[lname])
        return cache_list

    def read_cache_list(self, layers, interlayer_results: Dict[str, str]) -> List[str]:
        for lname in layers:
            if lname not in interlayer_results:
                raise ValueError(f"Layer {lname} is not in interlayer_results.")
        return [interlayer_results[lname] for lname in layers]

    def _get_tf_options(self, compression_type: str = ""):
        self.compression_type = compression_type if compression_type != "" else self.compression_type
        return tf.io.TFRecordOptions(compression_type=self.compression_type)

    # Start of Feature region
    @staticmethod
    def _create_float_feature(values, flatten: bool = True):
        values = tf.reshape(values, -1) if flatten else values
        return tf.train.Feature(float_list=tf.train.FloatList(value=values))

    @staticmethod
    def _bytes_feature(value):
        if isinstance(value, type(tf.constant(0))):
            value = tf.io.serialize_tensor(value).numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    @staticmethod
    def _int64_feature(values):
        if not isinstance(values, (tuple, list)):
            values = [values]
        return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

    # End of Feature region

    # Start of dataset-to-cache region
    def tfrecord_writer(self, dataset: tf.data.Dataset, output_dir: str):
        """
        Writes record_file file as tfrecords using tf.io.TFRecordWriter.
        ToDo: Split dataset into stacks if it is too large.
        """
        tf_options = self._get_tf_options()
        for ind, batch in enumerate(dataset):
            record_file = self.get_full_filename(output_dir, ind=ind)
            with tf.io.TFRecordWriter(record_file, tf_options) as writer:
                writer.write(self._tf_example(batch).SerializeToString())

    def _tf_example(self, data):
        """
        Encodes the INPUT features for saving to tfrecord file.
        """
        shape = data.shape
        feature = {
            "data": self._bytes_feature(data),
            "height": self._int64_feature(shape[0]),
            "width": self._int64_feature(shape[1]),
            "channel": self._int64_feature(shape[2]),
        }
        return tf.train.Example(features=tf.train.Features(feature=feature))

    def dataset_to_cache(
        self,
        dataset: tf.data.Dataset,
        cache_dir: str,
    ) -> Dict[str, str]:
        """
        Save dataset to cache dir.
        """
        dataset = dataset.unbatch()
        cache_dir_by_layer = dict()
        for lname in dataset.element_spec.keys():
            cache_dir_by_layer[lname] = get_layer_cache_dir(cache_dir, lname)
            output_dir = cache_dir_by_layer[lname]
            tf.io.gfile.makedirs(output_dir)
            dataset_lname = dataset.map(lambda x: x[lname], num_parallel_calls=tf.data.AUTOTUNE)
            self.tfrecord_writer(dataset_lname, output_dir)
        return cache_dir_by_layer

    # End of dataset-to-cache region

    # Start of cache-to-dataset region
    def cache_to_dataset(
        self,
        input_nodes: List[str],
        input_interlayer_results: Union[Dict[str, str], List[str]],
        output_nodes: List[str] = None,
        output_interlayer_results: Union[Dict[str, str], List[str]] = None,
        batch_size=None,
    ) -> tf.data.Dataset:
        """
        Read data from cache dir to tensorflow dataset with TFRecorderDataset.
        """
        # Load data from cache dir:
        input_dataset_pack = self._read_tfrecords_data(input_nodes, input_interlayer_results)
        if output_nodes is not None:
            output_dataset_pack = self._read_tfrecords_data(output_nodes, output_interlayer_results)

        if self.use_dali and output_nodes is not None:
            # Use NVIDIA DALI for training only.
            input_dataset_pack = self.create_dali_dataset(
                input_dataset_pack,
                batch_size=batch_size,
                device=self.device,
                device_id=self.device_id,
                num_threads=self.num_threads,
            )
            output_dataset_pack = self.create_dali_dataset(
                output_dataset_pack,
                batch_size=batch_size,
                device=self.device,
                device_id=self.device_id,
                num_threads=self.num_threads,
            )
            dataset = tf.data.Dataset.zip((input_dataset_pack, output_dataset_pack))
        else:
            # Use only tfrecords
            if output_nodes is not None:
                # dict -> list (for training only, not for inference)
                input_dataset_pack = tuple([v for v in input_dataset_pack.values()])
                output_dataset_pack = tuple([v for v in output_dataset_pack.values()])
                dataset = tf.data.Dataset.zip((input_dataset_pack, output_dataset_pack))
            else:
                dataset = tf.data.Dataset.zip(input_dataset_pack)
            dataset = dataset.batch(batch_size) if batch_size else dataset

        return dataset

    def _read_tfrecords_data(
        self,
        nodes: List[str],
        interlayer_results: Union[List[str], Dict[str, str]],
    ) -> Dict[str, tf.data.Dataset]:
        dataset = dict()
        for ind, lname in enumerate(nodes):
            folder_name = interlayer_results[ind] if isinstance(interlayer_results, list) else interlayer_results[lname]

            # Read from tfrecord files
            record_file = tf.io.gfile.glob(str(folder_name) + "/*" + self._filename_ext)
            record_file = sorted(record_file)
            ds = tf.data.TFRecordDataset(record_file, compression_type=self.compression_type)
            ds = ds.map(DataFeederTFRecord._parse_function, num_parallel_calls=tf.data.AUTOTUNE)
            # Update the dataset's element_spec
            sample = list(ds.take(1))[0]
            tensor_shape = sample.shape
            ds = ds.map(lambda x: tf.reshape(x, tensor_shape), num_parallel_calls=tf.data.AUTOTUNE)
            dataset[lname] = ds
        return dataset

    @staticmethod
    def _parse_function(example_proto):
        """
        Parse the OUTPUT features from the tfrecord file.
        """
        feature_description = {
            "height": tf.io.FixedLenFeature([], tf.int64, default_value=None),
            "width": tf.io.FixedLenFeature([], tf.int64, default_value=None),
            "channel": tf.io.FixedLenFeature([], tf.int64, default_value=None),
            "data": tf.io.FixedLenFeature([], tf.string, default_value=""),
        }
        features = tf.io.parse_example(example_proto, feature_description)
        data = tf.io.parse_tensor(features["data"], out_type=tf.float32)
        height = tf.cast(features["height"], tf.int32)
        width = tf.cast(features["width"], tf.int32)
        channel = tf.cast(features["channel"], tf.int32)
        data = tf.reshape(data, (height, width, channel))
        return data

    # End of cache-to-dataset region


def mock_create_dali_dataset(
    dataset_dict: tf.data.Dataset,
    batch_size: int = None,
    *args,
    **kwargs,
) -> None:
    """
    Mock function for create_dali_dataset. This function is a place holder and
    should not be called if Dali is not installed.
    """
    dali_installed_command = kwargs["dali_installed_command"]
    raise AccelerasException(
        "Trying to use DALI, but it is not installed. Set use_dali=False or install DALI using:"
        f"{' '.join(dali_installed_command)}",
    )


if DALI_INSTALLED:

    def create_dali_dataset(
        dataset_dict: tf.data.Dataset,
        batch_size: int = None,
        *args,
        device: str = "gpu",
        device_id: int = 0,
        num_threads: int = 4,
    ):
        """
        Notes
        (1) A GPU must be available to use with dali_tf.experimental.DALIDatasetWithInputs.
        (2) When using DALI, you must impose a dataset_size over TF's dataset (dataset.take(dataset_size))
            before sampling!

        """
        if "gpu" not in device:
            AccelerasException(f"DALI must have a GPU device to work properly but {device = } was given.")

        batch_size = batch_size if batch_size else 1

        @pipeline_def(device_id=device_id, num_threads=num_threads, batch_size=batch_size)
        def pipeline_with_inputs(device):
            encoded = []
            for k in dataset_dict:
                encoded.append(dali_fn.external_source(name=k, device=device, batch=False, dtype=types.FLOAT))
            return tuple(encoded)

        pipe = pipeline_with_inputs(device)
        num_datasets = len(dataset_dict)
        output_shapes = tuple([batch_size, *tuple(v.element_spec.shape)] for v in dataset_dict.values())
        output_dtypes = tuple([tf.float32] * num_datasets)

        with tf.device(device):
            dali_dataset = dali_tf.experimental.DALIDatasetWithInputs(
                pipeline=pipe,
                input_datasets=dataset_dict,
                batch_size=batch_size,
                output_shapes=output_shapes,
                output_dtypes=output_dtypes,
                device_id=device_id,
                num_threads=num_threads,
            )

        return dali_dataset
