"""
This module implements a template for layer-by-layer algorithm.
It can be used by itself to infer a model layer-by-layer
"""

import atexit
import logging
import os
import shutil
import tempfile
from abc import ABC, abstractmethod
from typing import Dict, List, Set

import numpy as np
import psutil
import tensorflow as tf
from tqdm import tqdm

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoOutputLayer
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasResourceError
from hailo_model_optimization.acceleras.utils.dataset_util import verify_dataset_size
from hailo_model_optimization.algorithms.dali_utils import tf_unpad_input
from hailo_model_optimization.algorithms.dali_utils.mock_dali_dataset import cache_list_to_dataset
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class LayerByLayer(OptimizationAlgorithm, ABC):
    """
    Template for algorithms that are inferred layer-by-layer

    A template class for algorithm that optimize the model in a layer-by-layer manor.
    This class can be used by itself to infer a model layer-by-layer

    Args:
        model: Model for the algorithm
        config_params: model configuration
        name: name of the algorithm
        dataset: data for the algorithm inference
        work_dir: directory for debug files and cache

    """

    def __init__(
        self,
        model: HailoModel,
        config_params: ModelOptimizationConfig,
        name: str,
        dataset: tf.data.Dataset,
        work_dir=None,
        logger_level=logging.DEBUG,
        **kwargs,
    ):
        super().__init__(model, config_params, name, logger_level, **kwargs)
        self._work_dir = work_dir
        atexit.register(self.cleanup_workdir)
        os.makedirs(self._work_dir, exist_ok=True)
        self._cache_dir = None
        self._labeled_unbatched_dataset = dataset
        self._dataset = None

    @abstractmethod
    def get_dataset_size(self):
        pass

    def _run_int(self):
        """Iterate over the"""
        interlayer_results = dict()
        inferred_layers = set()
        self._cache_dir = tempfile.mkdtemp(dir=self._work_dir, prefix="cache_")
        self._lbl_setup_logic()
        pbar = tqdm(total=len(self._model.layers), unit="layers", desc=self._name)
        for layer in self._model.flow.toposort():
            acceleras_layer = self._model.layers[layer]
            if isinstance(acceleras_layer, BaseHailoNonNNCoreLayer) or (
                isinstance(acceleras_layer, HailoOutputLayer)
                and isinstance(
                    self._model.layers[self._model.flow.predecessors_sorted(layer)[-1]],
                    BaseHailoNonNNCoreLayer,
                )
            ):
                continue

            pbar.set_postfix({"Layer": layer})
            curr_inputs_cache_list = self._get_layer_inputs_cache_list(
                layer,
                interlayer_results,
            )

            self._lbl_pre_layer_logic(acceleras_layer)

            try:
                curr_outputs_parent_cache = self._infer_layer(acceleras_layer, curr_inputs_cache_list)
            except tf.errors.ResourceExhaustedError:
                raise AccelerasResourceError(
                    f"GPU memory has been exhausted. Please try to use {self._name} with "
                    f"lower batch size or run on CPU.",
                )
            inferred_layers.add(layer)
            interlayer_results[layer] = curr_outputs_parent_cache

            self._lbl_post_layer_logic(
                acceleras_layer,
                curr_inputs_cache_list,
                curr_outputs_parent_cache,
                inferred_layers,
            )
            self._clean_results(interlayer_results, inferred_layers)
            pbar.update(1)
        pbar.refresh()  # flush the fine pbar state
        pbar.close()
        self._lbl_finalize_logic()
        shutil.rmtree(self._cache_dir)
        self._cache_dir = None
        return interlayer_results

    def _infer_layer(
        self,
        acceleras_layer: BaseHailoLayer,
        inputs_cache_list: List[str],
        count: int = None,
    ):
        """
        Infer the entire dataset on a single layer
        Args:
            acceleras_layer: layer to run
            inputs_cache_list: list of cache dirs for layer input
            count: limits the inference to get <count> items

        Returns
            output results of the layer as dataset

        """
        inputs_dataset, unpadded_shape = self._get_layer_inputs_dataset(
            acceleras_layer.full_name,
            inputs_cache_list,
            count,
        )

        spec = inputs_dataset.element_spec[0]
        mock_input = tf.keras.Input(spec.shape[1:])
        inputs = tf_unpad_input(mock_input, unpadded_shape)

        mock_output = acceleras_layer(inputs)
        model = tf.keras.Model(inputs=mock_input, outputs=mock_output)

        cache_dir = self._get_cache_dir(acceleras_layer.full_name)

        cfg = self.get_algo_config()
        batch_size = cfg.batch_size
        cache_compression = cfg.cache_compression
        for batch_index, data_item in enumerate(inputs_dataset):
            result = model.predict_on_batch(data_item)
            if acceleras_layer.num_outputs == 1:
                result = [result]
            self._save_batch(acceleras_layer, result, cache_dir, batch_index, batch_size, cache_compression)
        return cache_dir

    def get_algo_config(self):
        """Return the current algorithm configuration"""
        return self._model_config

    def _lbl_setup_logic(self):
        """Logic before the layers iteration starts"""

    def _lbl_pre_layer_logic(
        self,
        acceleras_layer: BaseHailoLayer,
    ):
        """
        Logic before each layer inference

        Args:
            acceleras_layer: the layer that will be inferred

        """

    def _lbl_post_layer_logic(
        self,
        acceleras_layer: BaseHailoLayer,
        curr_inputs_cache_list: List[str],
        curr_outputs_parent_cache: str,
        inferred_layers: Set[str],
    ):
        """
        Logic after each layer inference

        Args:
            acceleras_layer: the layer that was inferred
            curr_inputs_cache_list: list with the inputs cache dirs of the inferred layer
            curr_outputs_parent_cache: parent cache dir for the inferred output layer
                                        (countains sub dirs for each output)
            inferred_layers: set with layers names that has already been inferred

        """

    def _lbl_finalize_logic(self):
        """Logic after the layers iteration ends"""

    @staticmethod
    def get_cache_basename(index):
        return f"data_{index}"

    def _get_layer_inputs_cache_list(
        self,
        layer: str,
        interlayer_results: Dict[str, str],
    ) -> List[str]:
        """
        Get the inputs cache list of a given layer

        Args:
            layer: layer name
            interlayer_results: results of previously inferred layers

        Returns:
            list with cache paths of the layer's inputs

        """
        inputs_cache_list = []
        preds = self._model.flow.predecessors_sorted(layer)
        for pred in preds:
            out_ind = self._model.flow.get_edge_output_index(pred, layer)
            pred_layer = self._model.layers[pred]
            out_ind = pred_layer.resolve_output_index(out_ind)
            basename = self.get_cache_basename(out_ind)
            inp_cache = os.path.join(interlayer_results[pred], basename)
            inputs_cache_list.append(inp_cache)
        return tuple(inputs_cache_list)

    def _get_layer_inputs_dataset(
        self,
        layer: str,
        inputs_cache_list: List[str],
        count: int,
    ):
        """
        Get layer's inputs dataset from name and cache-list of the inputs

        Args:
            layer: layer name (used for input layers)
            inputs_cache_list: list of cache paths of the inputs data
            count: how many data sample to take from the dataset

        """
        cfg = self.get_algo_config()
        dataset_size = self.get_dataset_size()
        count = dataset_size if count is None else count
        if layer in self._model.flow.input_nodes:
            dataset = self._get_model_inputs_dataset(layer)
            unpadded_shape = np.array(
                [spec.shape.as_list() for spec in dataset.element_spec],
            )
        elif len(inputs_cache_list) == 0:  # special case for const layer
            # Similar behavior to HailoModel's call, just take the first input for batch reference
            fake_input_layer = self._model.flow.input_nodes[0]
            dataset = self._get_model_inputs_dataset(fake_input_layer)
            unpadded_shape = np.array(
                [spec.shape.as_list() for spec in dataset.element_spec],
            )
        else:
            # TODO: consider switching to DALI
            # (will require change in _infer_layer to limit iterations)
            dataset, unpadded_shape = cache_list_to_dataset(inputs_cache_list, count)
            dataset = tf.data.Dataset.zip((dataset,))
        batch_size = cfg.batch_size
        return dataset.take(count).batch(batch_size), unpadded_shape

    def _get_model_inputs_dataset(self, layer: str):
        """
        Gets the data for model's inputs layers

        Args:
            layer: input layer name

        """
        # TODO: we could optimize it to use DALI as well
        single_input_dt = self._dataset.map(lambda x: x[layer])
        single_input_dt = single_input_dt.map(lambda x: tf.expand_dims(x, axis=0))
        datasets = (single_input_dt,)
        return tf.data.Dataset.zip(datasets)

    def _setup(self):
        """Prepare dataset - remove labels and take relevant amount"""
        dataset_size = self.get_dataset_size()
        dataset = self._labeled_unbatched_dataset.map(prep_label_less)
        dataset = dataset.map(self._model.inputs_as_dict)
        dataset = dataset.take(dataset_size)
        if self._model.preproc_cb is not None:
            dataset = dataset.batch(1).map(self._model.preproc_cb).unbatch()
        self._dataset = dataset
        verify_dataset_size(
            self._dataset,
            dataset_size,
            warning_if_larger=self.warning_if_larger_dataset,
            logger=self._logger,
        )
        self.check_storage_usage()
        self._logger.info(f"Using dataset with {dataset_size} entries for {self._name}")

    @property
    @abstractmethod
    def warning_if_larger_dataset(self) -> bool:
        pass

    def check_storage_usage(self):
        """
        Check if the current work directory has sufficient storage for running the current algorithm
        """
        compression = self.get_algo_config().cache_compression
        max_disk_usage = self.get_max_disk_usage()
        free_disk = psutil.disk_usage(self._work_dir).free
        be_verb = "might" if compression else "will"
        self._logger.info(
            f"The algorithm {self._name} {be_verb} use up to {max_disk_usage / (10 ** 9):.02f} GB of storage space",
        )
        if max_disk_usage > free_disk:
            free_disk = free_disk / (2**30)
            be_verb = "might be" if compression else "is"
            msg = (
                f"The storage space {be_verb} insufficient. Free disk space {free_disk:0.02f} GB."
                f"Consider reducing the dataset size if you encounter space issues."
            )
            if self.get_algo_config().cache_compression:
                self._logger.warning(msg)
            else:
                self._logger.error(msg)
                raise OSError(msg)

    def get_max_disk_usage(self, factor=2) -> float:
        """
        Calculate the maximal disk usage of the current algorithm

        Args:
            factor: multiplier of stored data (usually for native & quant data)

        Returns: required storage in bytes

        """
        inferred_layers = set()
        active_layers = set()
        dataset_size = self.get_dataset_size()
        max_disk_usage = 0
        for layer in self._model.flow.toposort():
            inferred_layers.add(layer)

            tot_su = self._get_layers_output_storage_usage(active_layers) * factor
            tot_su += self._get_current_input_storage_usage(layer)
            max_disk_usage = max(max_disk_usage, tot_su * dataset_size)

            layers_with_result = set(active_layers)
            active_layers.add(layer)

            tot_su = self._get_layers_output_storage_usage(active_layers) * factor
            max_disk_usage = max(max_disk_usage, tot_su * dataset_size)

            for lname in self._iter_redundant_layers(layers_with_result, inferred_layers):
                active_layers.remove(lname)
        return max_disk_usage

    def _get_layers_output_storage_usage(self, active_layers):
        """
        Calculate the storage need to store all the output tensors of the given layers

        Args:
            active_layers: layers for which the output storage should be calculated for

        Returns: required storage in bytes for 1 data sample

        """
        tot_su = 0
        for layer in active_layers:
            for out_sh in self._model.layers[layer].output_shapes:
                s1 = np.prod(out_sh[1:])
                tot_su += s1
        return tot_su * 4

    def _get_current_input_storage_usage(self, lname):
        """
        Calculate the storage need to store all the input tensors of the given layer

        Args:
            lname: layer for which the output storage should be calculated for

        Returns: required storage in bytes for 1 data sample

        """
        tot_su = 0
        for inps in self._model.layers[lname].input_shapes:
            s1 = np.prod(inps[1:])
            tot_su += s1
        return tot_su * 4

    # Functions related to caching the dataset start here ############

    @classmethod
    def _save_batch(cls, acceleras_layer, batch, cache_dir, batch_index, batch_size, compress):
        for out_ind in range(acceleras_layer.num_outputs):
            batch_of_output = batch[out_ind]
            basename = cls.get_cache_basename(out_ind)
            output_dir = os.path.join(cache_dir, basename)
            os.makedirs(output_dir, exist_ok=True)
            for i in range(batch_of_output.shape[0]):
                abs_index = batch_size * batch_index + i
                fname = os.path.join(output_dir, f"{abs_index}.npz")
                if compress:
                    np.savez_compressed(fname, arr=batch_of_output[i])
                else:
                    np.savez(fname, arr=batch_of_output[i])

    def _get_cache_dir(self, layer):
        """
        Get cache dir from layer name
        Args:
            layer: layer name

        Returns
            cache path

        """
        layer_no_slash = layer.replace("/", "___")
        cache_dir = tempfile.mkdtemp(dir=self._cache_dir, prefix=f"{layer_no_slash}_")
        return cache_dir

    def _iter_redundant_layers(
        self,
        active_layers: Set[str],
        inferred_layers: Set[str],
    ):
        for lname in active_layers:
            successors = set(self._model.flow.successors(lname))
            if successors.issubset(inferred_layers):
                yield lname

    def _clean_results(
        self,
        interlayer_results: Dict[str, str],
        inferred_layers: Set[str],
    ):
        """
        Delete the interlayer results (both from memory and cache) after they are no longer needed
        Args:
            interlayer_results: results of layers that have been previously inferred
            inferred_layers: set with all the layers that have been inferred

        """
        active_layers = set(interlayer_results.keys())
        for lname in self._iter_redundant_layers(active_layers, inferred_layers):
            self._delete_cache(interlayer_results[lname])
            del interlayer_results[lname]

    @staticmethod
    def _delete_cache(dirname: str):
        """
        Delete the cache files of a single layer
        Args:
            filename: the directory of the given file will be removed
        """
        shutil.rmtree(dirname, ignore_errors=True)

    def cleanup_workdir(self):
        shutil.rmtree(self._work_dir, ignore_errors=True)

    # Functions related to caching the dataset end here ############


def prep_label_less(images, image_info):
    """Map function for dataset which removes the image info"""
    # TODO: The train label is off by one, it needs to be fixed
    return images
